193 lines
7.7 KiB
Python
193 lines
7.7 KiB
Python
import json
|
||
import logging
|
||
import os
|
||
import sqlite3
|
||
from typing import Any, Dict, Optional, List
|
||
from langchain.chat_models import init_chat_model
|
||
# from deepagents import create_deep_agent
|
||
from langchain.agents import create_agent
|
||
from langchain.agents.middleware import SummarizationMiddleware
|
||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||
from langchain_core.callbacks import BaseCallbackHandler
|
||
from langgraph.checkpoint.memory import MemorySaver
|
||
from utils.fastapi_utils import detect_provider
|
||
|
||
from .guideline_middleware import GuidelineMiddleware
|
||
from .tool_output_length_middleware import ToolOutputLengthMiddleware
|
||
from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH
|
||
from utils.agent_config import AgentConfig
|
||
|
||
|
||
class LoggingCallbackHandler(BaseCallbackHandler):
|
||
"""自定义的 CallbackHandler,使用项目的 logger 来记录日志"""
|
||
|
||
def __init__(self, logger_name: str = 'app'):
|
||
self.logger = logging.getLogger(logger_name)
|
||
|
||
def on_llm_end(self, response, **kwargs: Any) -> None:
|
||
"""当 LLM 结束时调用"""
|
||
self.logger.info("✅ LLM End - Output:")
|
||
|
||
# 打印生成的文本
|
||
if hasattr(response, 'generations') and response.generations:
|
||
for gen_idx, generation_list in enumerate(response.generations):
|
||
for msg_idx, generation in enumerate(generation_list):
|
||
if hasattr(generation, 'text'):
|
||
output_list = generation.text.split("\n")
|
||
for i, output in enumerate(output_list):
|
||
if output.strip():
|
||
self.logger.info(f"{output}")
|
||
elif hasattr(generation, 'message'):
|
||
output_list = generation.message.split("\n")
|
||
for i, output in enumerate(output_list):
|
||
if output.strip():
|
||
self.logger.info(f"{output}")
|
||
|
||
def on_llm_error(
|
||
self, error: Exception, **kwargs: Any
|
||
) -> None:
|
||
"""当 LLM 出错时调用"""
|
||
self.logger.error(f"❌ LLM Error: {error}")
|
||
|
||
def on_tool_start(
|
||
self, serialized: Optional[Dict[str, Any]], input_str: str, **kwargs: Any
|
||
) -> None:
|
||
"""当工具开始调用时调用"""
|
||
if serialized is None:
|
||
tool_name = 'unknown_tool'
|
||
else:
|
||
tool_name = serialized.get('name', 'unknown_tool')
|
||
self.logger.info(f"🔧 Tool Start - {tool_name} with input: {str(input_str)[:100]}")
|
||
|
||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||
"""当工具调用结束时调用"""
|
||
self.logger.info(f"✅ Tool End Output: {output}")
|
||
|
||
def on_tool_error(
|
||
self, error: Exception, **kwargs: Any
|
||
) -> None:
|
||
"""当工具调用出错时调用"""
|
||
self.logger.error(f"❌ Tool Error: {error}")
|
||
|
||
def on_agent_action(self, action, **kwargs: Any) -> None:
|
||
"""当 Agent 执行动作时调用"""
|
||
self.logger.info(f"🎯 Agent Action: {action.log}")
|
||
|
||
|
||
# Utility functions
|
||
def read_system_prompt():
|
||
"""读取通用的无状态系统prompt"""
|
||
with open("./prompt/system_prompt_default.md", "r", encoding="utf-8") as f:
|
||
return f.read().strip()
|
||
|
||
|
||
def read_mcp_settings():
|
||
"""读取MCP工具配置"""
|
||
with open("./mcp/mcp_settings.json", "r") as f:
|
||
mcp_settings_json = json.load(f)
|
||
return mcp_settings_json
|
||
|
||
|
||
async def get_tools_from_mcp(mcp):
|
||
"""从MCP配置中提取工具"""
|
||
# 防御式处理:确保 mcp 是列表且长度大于 0,且包含 mcpServers
|
||
if not isinstance(mcp, list) or len(mcp) == 0 or "mcpServers" not in mcp[0]:
|
||
return []
|
||
|
||
# 修改 mcp[0]["mcpServers"] 列表,把 type 字段改成 transport
|
||
# 如果没有 transport,则根据是否存在 url 默认 transport 为 http 或 stdio
|
||
for cfg in mcp[0]["mcpServers"].values():
|
||
if "type" in cfg:
|
||
cfg.pop("type")
|
||
if "transport" not in cfg:
|
||
cfg["transport"] = "http" if "url" in cfg else "stdio"
|
||
|
||
# 确保 mcp[0]["mcpServers"] 是字典类型
|
||
if not isinstance(mcp[0]["mcpServers"], dict):
|
||
return []
|
||
|
||
try:
|
||
mcp_client = MultiServerMCPClient(mcp[0]["mcpServers"])
|
||
mcp_tools = await mcp_client.get_tools()
|
||
return mcp_tools
|
||
except Exception:
|
||
# 发生异常时返回空列表,避免上层调用报错
|
||
return []
|
||
|
||
|
||
async def init_agent(config: AgentConfig):
|
||
"""
|
||
初始化 Agent,支持持久化内存和对话摘要
|
||
|
||
Args:
|
||
config: AgentConfig 对象,包含所有初始化参数
|
||
mcp: MCP配置(如果为None则使用配置中的mcp_settings)
|
||
"""
|
||
# 如果没有提供mcp,使用config中的mcp_settings
|
||
mcp = config.mcp_settings if config.mcp_settings else read_mcp_settings()
|
||
system = config.system_prompt if config.system_prompt else read_system_prompt()
|
||
mcp_tools = await get_tools_from_mcp(mcp)
|
||
|
||
# 检测或使用指定的提供商
|
||
model_provider,base_url = detect_provider(config.model_name, config.model_server)
|
||
|
||
# 构建模型参数
|
||
model_kwargs = {
|
||
"model": config.model_name,
|
||
"model_provider": model_provider,
|
||
"temperature": 0.8,
|
||
"base_url": base_url,
|
||
"api_key": config.api_key
|
||
}
|
||
if config.generate_cfg:
|
||
model_kwargs.update(config.generate_cfg)
|
||
llm_instance = init_chat_model(**model_kwargs)
|
||
|
||
# 创建自定义的日志处理器
|
||
logging_handler = LoggingCallbackHandler()
|
||
|
||
# 构建中间件列表
|
||
middleware = []
|
||
|
||
# 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware
|
||
if config.enable_thinking:
|
||
middleware.append(GuidelineMiddleware(config.bot_id, llm_instance, system, config.robot_type, config.language, config.user_identifier))
|
||
|
||
# 添加工具输出长度控制中间件
|
||
tool_output_middleware = ToolOutputLengthMiddleware(
|
||
max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH,
|
||
truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart',
|
||
tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None, # 可配置特定工具
|
||
exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [], # 排除的工具
|
||
preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True,
|
||
preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True
|
||
)
|
||
middleware.append(tool_output_middleware)
|
||
|
||
# 初始化 checkpointer 和中间件
|
||
checkpointer = None
|
||
|
||
if config.session_id:
|
||
checkpointer = MemorySaver()
|
||
summarization_middleware = SummarizationMiddleware(
|
||
model=llm_instance,
|
||
max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS,
|
||
messages_to_keep=20, # 摘要后保留最近 20 条消息
|
||
summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。"
|
||
)
|
||
middleware.append(summarization_middleware)
|
||
|
||
agent = create_agent(
|
||
model=llm_instance,
|
||
system_prompt=system,
|
||
tools=mcp_tools,
|
||
middleware=middleware,
|
||
checkpointer=checkpointer # 传入 checkpointer 以启用持久化
|
||
)
|
||
|
||
# 将 handler 和 checkpointer 存储在 agent 的属性中,方便在调用时使用
|
||
agent.logging_handler = logging_handler
|
||
agent.checkpointer = checkpointer
|
||
agent.bot_id = config.bot_id
|
||
agent.session_id = config.session_id
|
||
return agent |