qwen_agent/agent/deep_assistant.py
2025-12-24 11:05:10 +08:00

201 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
import time
import copy
from typing import Any, Dict
from langchain.chat_models import init_chat_model
# from deepagents import create_deep_agent
from deepagents_cli.agent import create_cli_agent
from langchain.agents import create_agent
from langchain.agents.middleware import SummarizationMiddleware
from langchain_mcp_adapters.client import MultiServerMCPClient
from utils.fastapi_utils import detect_provider
from .guideline_middleware import GuidelineMiddleware
from .tool_output_length_middleware import ToolOutputLengthMiddleware
from .tool_use_cleanup_middleware import ToolUseCleanupMiddleware
from utils.settings import SUMMARIZATION_MAX_TOKENS, SUMMARIZATION_MESSAGES_TO_KEEP, TOOL_OUTPUT_MAX_LENGTH, MCP_HTTP_TIMEOUT, MCP_SSE_READ_TIMEOUT
from agent.agent_config import AgentConfig
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
from agent.agent_memory_cache import get_memory_cache_manager
from .checkpoint_utils import prepare_checkpoint_message
import aiosqlite
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
import os
# 全局 MemorySaver 实例
# from langgraph.checkpoint.memory import MemorySaver
# _global_checkpointer = MemorySaver()
logger = logging.getLogger('app')
# Utility functions
def read_system_prompt():
"""读取通用的无状态系统prompt"""
with open("./prompt/system_prompt_default.md", "r", encoding="utf-8") as f:
return f.read().strip()
def read_mcp_settings():
"""读取MCP工具配置"""
with open("./mcp/mcp_settings.json", "r") as f:
mcp_settings_json = json.load(f)
return mcp_settings_json
async def get_tools_from_mcp(mcp):
"""从MCP配置中提取工具带缓存"""
start_time = time.time()
# 防御式处理:确保 mcp 是列表且长度大于 0且包含 mcpServers
if not isinstance(mcp, list) or len(mcp) == 0 or "mcpServers" not in mcp[0]:
logger.info(f"get_tools_from_mcp: invalid mcp config, elapsed: {time.time() - start_time:.3f}s")
return []
# 尝试从缓存获取
cache_manager = get_memory_cache_manager()
cached_tools = cache_manager.get_mcp_tools(mcp)
if cached_tools is not None:
logger.info(f"get_tools_from_mcp: cached {len(cached_tools)} tools, elapsed: {time.time() - start_time:.3f}s")
return cached_tools
# 深拷贝 mcp 配置,避免修改原始配置(影响缓存键)
mcp_copy = copy.deepcopy(mcp)
# 修改 mcp_copy[0]["mcpServers"] 列表,把 type 字段改成 transport
# 如果没有 transport则根据是否存在 url 默认 transport 为 http 或 stdio
for cfg in mcp_copy[0]["mcpServers"].values():
if "type" in cfg:
cfg.pop("type")
if "transport" not in cfg:
cfg["transport"] = "http" if "url" in cfg else "stdio"
# 为 HTTP/ SSE 传输的 MCP 服务器添加超时配置
# 如果配置中未设置超时,使用全局默认值
if cfg.get("transport") in ("http", "sse"):
if "timeout" not in cfg:
cfg["timeout"] = MCP_HTTP_TIMEOUT
if "sse_read_timeout" not in cfg:
cfg["sse_read_timeout"] = MCP_SSE_READ_TIMEOUT
# 确保 mcp_copy[0]["mcpServers"] 是字典类型
if not isinstance(mcp_copy[0]["mcpServers"], dict):
logger.info(f"get_tools_from_mcp: mcpServers is not dict, elapsed: {time.time() - start_time:.3f}s")
return []
try:
mcp_client = MultiServerMCPClient(mcp_copy[0]["mcpServers"])
mcp_tools = await mcp_client.get_tools()
# 缓存结果
cache_manager.set_mcp_tools(mcp, mcp_tools)
logger.info(f"get_tools_from_mcp: loaded {len(mcp_tools)} tools, elapsed: {time.time() - start_time:.3f}s")
return mcp_tools
except Exception as e:
# 发生异常时返回空列表,避免上层调用报错
logger.info(f"get_tools_from_mcp: error {e}, elapsed: {time.time() - start_time:.3f}s")
return []
async def init_agent(config: AgentConfig):
"""
初始化 Agent支持持久化内存和对话摘要
注意:不再缓存 agent只缓存 mcp_tools
返回 (agent, checkpointer) 元组,调用后需要归还 checkpointer
Args:
config: AgentConfig 对象,包含所有初始化参数
Returns:
(agent, checkpointer) 元组
"""
# 加载配置
final_system_prompt = await load_system_prompt_async(
config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier
)
final_mcp_settings = await load_mcp_settings_async(
config.project_dir, config.mcp_settings, config.bot_id, config.robot_type
)
# 如果没有提供mcp使用config中的mcp_settings
mcp_settings = final_mcp_settings if final_mcp_settings else read_mcp_settings()
system_prompt = final_system_prompt if final_system_prompt else read_system_prompt()
config.system_prompt = mcp_settings
config.mcp_settings = system_prompt
# 获取 mcp_tools缓存逻辑已内置到 get_tools_from_mcp 中)
mcp_tools = await get_tools_from_mcp(mcp_settings)
# 检测或使用指定的提供商
model_provider, base_url = detect_provider(config.model_name, config.model_server)
# 构建模型参数
model_kwargs = {
"model": config.model_name,
"model_provider": model_provider,
"temperature": 0.8,
"base_url": base_url,
"api_key": config.api_key
}
if config.generate_cfg:
model_kwargs.update(config.generate_cfg)
llm_instance = init_chat_model(**model_kwargs)
# 创建新的 agent不再缓存
logger.info(f"Creating new agent for session: {getattr(config, 'session_id', 'no-session')}")
checkpointer = None
create_start = time.time()
if config.robot_type == "deep_agent":
# 使用 DeepAgentX 创建 agent
agent, composite_backend = create_cli_agent(
model=llm_instance,
assistant_id=config.bot_id,
tools=mcp_tools,
auto_approve=True,
)
else:
# 构建中间件列表
middleware = []
# 首先添加 ToolUseCleanupMiddleware 来清理孤立的 tool_use
middleware.append(ToolUseCleanupMiddleware())
# 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware
if config.enable_thinking:
middleware.append(GuidelineMiddleware(llm_instance, config, system_prompt))
# 添加工具输出长度控制中间件
tool_output_middleware = ToolOutputLengthMiddleware(
max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH,
truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart',
tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None,
exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [],
preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True,
preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True
)
middleware.append(tool_output_middleware)
# 从连接池获取 checkpointer
if config.session_id:
from .checkpoint_manager import get_checkpointer_manager
manager = get_checkpointer_manager()
checkpointer = await manager.acquire_for_agent(config.bot_id, config.session_id)
await prepare_checkpoint_message(config, checkpointer)
summarization_middleware = SummarizationMiddleware(
model=llm_instance,
max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS,
messages_to_keep=SUMMARIZATION_MESSAGES_TO_KEEP,
summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。"
)
middleware.append(summarization_middleware)
agent = create_agent(
model=llm_instance,
system_prompt=system_prompt,
tools=mcp_tools,
middleware=middleware,
checkpointer=checkpointer
)
logger.info(f"create {config.robot_type} elapsed: {time.time() - create_start:.3f}s")
return agent, checkpointer