qwen_agent/agent/deep_assistant.py
2025-12-15 23:54:32 +08:00

209 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
import os
import sqlite3
from typing import Any, Dict, Optional
from langchain.chat_models import init_chat_model
# from deepagents import create_deep_agent
from langchain.agents import create_agent
from langchain.agents.middleware import SummarizationMiddleware
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_core.callbacks import BaseCallbackHandler
from langgraph.checkpoint.memory import MemorySaver
from utils.fastapi_utils import detect_provider
from .guideline_middleware import GuidelineMiddleware
from .tool_output_length_middleware import ToolOutputLengthMiddleware
from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH
class LoggingCallbackHandler(BaseCallbackHandler):
"""自定义的 CallbackHandler使用项目的 logger 来记录日志"""
def __init__(self, logger_name: str = 'app'):
self.logger = logging.getLogger(logger_name)
# def on_llm_start(
# self, serialized: Optional[Dict[str, Any]], prompts: list[str], **kwargs: Any
# ) -> None:
# """当 LLM 开始时调用"""
# self.logger.info("🤖 LLM Start - Input Messages:")
# if prompts:
# for i, prompt in enumerate(prompts):
# self.logger.info(f" Message {i+1}:\n{prompt}")
# else:
# self.logger.info(" No prompts")
def on_llm_end(self, response, **kwargs: Any) -> None:
"""当 LLM 结束时调用"""
self.logger.info("✅ LLM End - Output:")
# 打印生成的文本
if hasattr(response, 'generations') and response.generations:
for gen_idx, generation_list in enumerate(response.generations):
for msg_idx, generation in enumerate(generation_list):
if hasattr(generation, 'text'):
output_list = generation.text.split("\n")
for i, output in enumerate(output_list):
if output.strip():
self.logger.info(f"{output}")
elif hasattr(generation, 'message'):
output_list = generation.message.split("\n")
for i, output in enumerate(output_list):
if output.strip():
self.logger.info(f"{output}")
def on_llm_error(
self, error: Exception, **kwargs: Any
) -> None:
"""当 LLM 出错时调用"""
self.logger.error(f"❌ LLM Error: {error}")
def on_tool_start(
self, serialized: Optional[Dict[str, Any]], input_str: str, **kwargs: Any
) -> None:
"""当工具开始调用时调用"""
if serialized is None:
tool_name = 'unknown_tool'
else:
tool_name = serialized.get('name', 'unknown_tool')
self.logger.info(f"🔧 Tool Start - {tool_name} with input: {str(input_str)[:100]}")
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""当工具调用结束时调用"""
self.logger.info(f"✅ Tool End Output: {output}")
def on_tool_error(
self, error: Exception, **kwargs: Any
) -> None:
"""当工具调用出错时调用"""
self.logger.error(f"❌ Tool Error: {error}")
def on_agent_action(self, action, **kwargs: Any) -> None:
"""当 Agent 执行动作时调用"""
self.logger.info(f"🎯 Agent Action: {action.log}")
# Utility functions
def read_system_prompt():
"""读取通用的无状态系统prompt"""
with open("./prompt/system_prompt_default.md", "r", encoding="utf-8") as f:
return f.read().strip()
def read_mcp_settings():
"""读取MCP工具配置"""
with open("./mcp/mcp_settings.json", "r") as f:
mcp_settings_json = json.load(f)
return mcp_settings_json
async def get_tools_from_mcp(mcp):
"""从MCP配置中提取工具"""
# 防御式处理:确保 mcp 是列表且长度大于 0且包含 mcpServers
if not isinstance(mcp, list) or len(mcp) == 0 or "mcpServers" not in mcp[0]:
return []
# 修改 mcp[0]["mcpServers"] 列表,把 type 字段改成 transport
# 如果没有 transport则根据是否存在 url 默认 transport 为 http 或 stdio
for cfg in mcp[0]["mcpServers"].values():
if "type" in cfg:
cfg.pop("type")
if "transport" not in cfg:
cfg["transport"] = "http" if "url" in cfg else "stdio"
# 确保 mcp[0]["mcpServers"] 是字典类型
if not isinstance(mcp[0]["mcpServers"], dict):
return []
try:
mcp_client = MultiServerMCPClient(mcp[0]["mcpServers"])
mcp_tools = await mcp_client.get_tools()
return mcp_tools
except Exception:
# 发生异常时返回空列表,避免上层调用报错
return []
async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None,
model_server=None, generate_cfg=None,
system_prompt=None, mcp=None, robot_type=None, language="jp", user_identifier=None,
session_id=None):
"""
初始化 Agent支持持久化内存和对话摘要
Args:
bot_id: Bot ID
model_name: 模型名称
api_key: API密钥
model_server: 模型服务器地址
generate_cfg: 生成配置
system_prompt: 系统提示
mcp: MCP配置
robot_type: 机器人类型
language: 语言
user_identifier: 用户标识
session_id: 会话ID如果为None则不启用持久化内存
"""
system = system_prompt if system_prompt else read_system_prompt()
mcp = mcp if mcp else read_mcp_settings()
mcp_tools = await get_tools_from_mcp(mcp)
# 检测或使用指定的提供商
model_provider,base_url = detect_provider(model_name,model_server)
# 构建模型参数
model_kwargs = {
"model": model_name,
"model_provider": model_provider,
"temperature": 0.8,
"base_url": base_url,
"api_key": api_key
}
if generate_cfg:
model_kwargs.update(generate_cfg)
llm_instance = init_chat_model(**model_kwargs)
# 创建自定义的日志处理器
logging_handler = LoggingCallbackHandler()
# 构建中间件列表
middleware = [GuidelineMiddleware(bot_id, llm_instance, system, robot_type, language, user_identifier)]
# 添加工具输出长度控制中间件
tool_output_middleware = ToolOutputLengthMiddleware(
max_length=getattr(generate_cfg, 'tool_output_max_length', None) or TOOL_OUTPUT_MAX_LENGTH,
truncation_strategy=getattr(generate_cfg, 'tool_output_truncation_strategy', 'smart'),
tool_filters=getattr(generate_cfg, 'tool_output_filters', None), # 可配置特定工具
exclude_tools=getattr(generate_cfg, 'tool_output_exclude', []), # 排除的工具
preserve_code_blocks=getattr(generate_cfg, 'preserve_code_blocks', True),
preserve_json=getattr(generate_cfg, 'preserve_json', True)
)
middleware.append(tool_output_middleware)
# 初始化 checkpointer 和中间件
checkpointer = None
if session_id:
checkpointer = MemorySaver()
summarization_middleware = SummarizationMiddleware(
model=llm_instance,
max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS,
messages_to_keep=20, # 摘要后保留最近 20 条消息
summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。"
)
middleware.append(summarization_middleware)
agent = create_agent(
model=llm_instance,
system_prompt=system,
tools=mcp_tools,
middleware=middleware,
checkpointer=checkpointer # 传入 checkpointer 以启用持久化
)
# 将 handler 和 checkpointer 存储在 agent 的属性中,方便在调用时使用
agent.logging_handler = logging_handler
agent.checkpointer = checkpointer
agent.bot_id = bot_id
agent.session_id = session_id
return agent