qwen_agent/agent/deep_assistant.py
2025-12-18 00:38:04 +08:00

170 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
import time
from typing import Any, Dict
from langchain.chat_models import init_chat_model
# from deepagents import create_deep_agent
from langchain.agents import create_agent
from langchain.agents.middleware import SummarizationMiddleware
from langchain_mcp_adapters.client import MultiServerMCPClient
from langgraph.checkpoint.memory import MemorySaver
from utils.fastapi_utils import detect_provider
from .guideline_middleware import GuidelineMiddleware
from .tool_output_length_middleware import ToolOutputLengthMiddleware
from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH
from agent.agent_config import AgentConfig
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
from agent.agent_memory_cache import get_memory_cache_manager
from .checkpoint_utils import prepare_checkpoint_message
logger = logging.getLogger('app')
# 全局 MemorySaver 实例
_global_checkpointer = MemorySaver()
# 使用内存缓存管理器
# Utility functions
def read_system_prompt():
"""读取通用的无状态系统prompt"""
with open("./prompt/system_prompt_default.md", "r", encoding="utf-8") as f:
return f.read().strip()
def read_mcp_settings():
"""读取MCP工具配置"""
with open("./mcp/mcp_settings.json", "r") as f:
mcp_settings_json = json.load(f)
return mcp_settings_json
async def get_tools_from_mcp(mcp):
"""从MCP配置中提取工具"""
# 防御式处理:确保 mcp 是列表且长度大于 0且包含 mcpServers
if not isinstance(mcp, list) or len(mcp) == 0 or "mcpServers" not in mcp[0]:
return []
# 修改 mcp[0]["mcpServers"] 列表,把 type 字段改成 transport
# 如果没有 transport则根据是否存在 url 默认 transport 为 http 或 stdio
for cfg in mcp[0]["mcpServers"].values():
if "type" in cfg:
cfg.pop("type")
if "transport" not in cfg:
cfg["transport"] = "http" if "url" in cfg else "stdio"
# 确保 mcp[0]["mcpServers"] 是字典类型
if not isinstance(mcp[0]["mcpServers"], dict):
return []
try:
mcp_client = MultiServerMCPClient(mcp[0]["mcpServers"])
mcp_tools = await mcp_client.get_tools()
return mcp_tools
except Exception:
# 发生异常时返回空列表,避免上层调用报错
return []
async def init_agent(config: AgentConfig):
"""
初始化 Agent支持持久化内存和对话摘要
Args:
config: AgentConfig 对象,包含所有初始化参数
mcp: MCP配置如果为None则使用配置中的mcp_settings
"""
# 初始化 checkpointer 和中间件
checkpointer = None
if config.session_id:
checkpointer = _global_checkpointer
await prepare_checkpoint_message(config, checkpointer)
# 获取缓存管理器
cache_manager = get_memory_cache_manager()
# 获取唯一的缓存键
cache_key = config.get_unique_cache_id()
# 如果有缓存键,检查缓存
if cache_key:
# 尝试从缓存中获取 agent
cached_agent = cache_manager.get(cache_key)
if cached_agent is not None:
logger.info(f"Using cached agent for session: {config.session_id}, cache_key: {cache_key}")
return cached_agent
else:
logger.info(f"Cache miss for session: {config.session_id}, cache_key: {cache_key}")
# 没有缓存或缓存已过期,创建新的 agent
logger.info(f"Creating new agent for session: {getattr(config, 'session_id', 'no-session')}")
final_system_prompt = await load_system_prompt_async(
config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier
)
final_mcp_settings = await load_mcp_settings_async(
config.project_dir, config.mcp_settings, config.bot_id, config.robot_type
)
# 如果没有提供mcp使用config中的mcp_settings
mcp_settings = final_mcp_settings if final_mcp_settings else read_mcp_settings()
system_prompt = final_system_prompt if final_system_prompt else read_system_prompt()
config.system_prompt = mcp_settings
config.mcp_settings = system_prompt
mcp_tools = await get_tools_from_mcp(mcp_settings)
# 检测或使用指定的提供商
model_provider,base_url = detect_provider(config.model_name, config.model_server)
# 构建模型参数
model_kwargs = {
"model": config.model_name,
"model_provider": model_provider,
"temperature": 0.8,
"base_url": base_url,
"api_key": config.api_key
}
if config.generate_cfg:
model_kwargs.update(config.generate_cfg)
llm_instance = init_chat_model(**model_kwargs)
# 构建中间件列表
middleware = []
# 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware
if config.enable_thinking:
middleware.append(GuidelineMiddleware(llm_instance, config, system_prompt))
# 添加工具输出长度控制中间件
tool_output_middleware = ToolOutputLengthMiddleware(
max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH,
truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart',
tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None, # 可配置特定工具
exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [], # 排除的工具
preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True,
preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True
)
middleware.append(tool_output_middleware)
if checkpointer:
summarization_middleware = SummarizationMiddleware(
model=llm_instance,
max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS,
messages_to_keep=20, # 摘要后保留最近 20 条消息
summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。"
)
middleware.append(summarization_middleware)
agent = create_agent(
model=llm_instance,
system_prompt=system_prompt,
tools=mcp_tools,
middleware=middleware,
checkpointer=checkpointer # 传入 checkpointer 以启用持久化
)
# 如果有缓存键,将 agent 加入缓存
if cache_key:
# 使用 DiskCache 缓存管理器存储 agent
cache_manager.set(cache_key, agent)
logger.info(f"Cached agent for session: {config.session_id}, cache_key: {cache_key}")
return agent