diff --git a/agent/deep_assistant.py b/agent/deep_assistant.py index af16fe2..b596755 100644 --- a/agent/deep_assistant.py +++ b/agent/deep_assistant.py @@ -1,13 +1,84 @@ import json +import logging +from typing import Any, Dict, Optional from langchain.chat_models import init_chat_model # from deepagents import create_deep_agent from langchain.agents import create_agent from langchain_mcp_adapters.client import MultiServerMCPClient +from langchain_core.callbacks import BaseCallbackHandler from utils.fastapi_utils import detect_provider from .guideline_middleware import GuidelineMiddleware +class LoggingCallbackHandler(BaseCallbackHandler): + """自定义的 CallbackHandler,使用项目的 logger 来记录日志""" + + def __init__(self, logger_name: str = 'app'): + self.logger = logging.getLogger(logger_name) + + # def on_llm_start( + # self, serialized: Optional[Dict[str, Any]], prompts: list[str], **kwargs: Any + # ) -> None: + # """当 LLM 开始时调用""" + # self.logger.info("🤖 LLM Start - Input Messages:") + # if prompts: + # for i, prompt in enumerate(prompts): + # self.logger.info(f" Message {i+1}:\n{prompt}") + # else: + # self.logger.info(" No prompts") + + def on_llm_end(self, response, **kwargs: Any) -> None: + """当 LLM 结束时调用""" + self.logger.info("✅ LLM End - Output:") + + # 打印生成的文本 + if hasattr(response, 'generations') and response.generations: + for gen_idx, generation_list in enumerate(response.generations): + for msg_idx, generation in enumerate(generation_list): + if hasattr(generation, 'text'): + output_list = generation.text.split("\n") + for i, output in enumerate(output_list): + if output.strip(): + self.logger.info(f"{output}") + elif hasattr(generation, 'message'): + output_list = generation.message.split("\n") + for i, output in enumerate(output_list): + if output.strip(): + self.logger.info(f"{output}") + + def on_llm_error( + self, error: Exception, **kwargs: Any + ) -> None: + """当 LLM 出错时调用""" + self.logger.error(f"❌ LLM Error: {error}") + + def on_tool_start( + self, serialized: Optional[Dict[str, Any]], input_str: str, **kwargs: Any + ) -> None: + """当工具开始调用时调用""" + if serialized is None: + tool_name = 'unknown_tool' + else: + tool_name = serialized.get('name', 'unknown_tool') + self.logger.info(f"🔧 Tool Start - {tool_name} with input: {str(input_str)[:100]}") + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """当工具调用结束时调用""" + self.logger.info(f"✅ Tool End Output: {output}") + + def on_tool_error( + self, error: Exception, **kwargs: Any + ) -> None: + """当工具调用出错时调用""" + self.logger.error(f"❌ Tool Error: {error}") + + + def on_agent_action(self, action, **kwargs: Any) -> None: + """当 Agent 执行动作时调用""" + self.logger.info(f"🎯 Agent Action: {action.log}") + + # Utility functions def read_system_prompt(): """读取通用的无状态系统prompt""" @@ -56,7 +127,7 @@ async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None, # 检测或使用指定的提供商 model_provider,base_url = detect_provider(model_name,model_server) - + # 构建模型参数 model_kwargs = { "model": model_name, @@ -69,10 +140,17 @@ async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None, model_kwargs.update(generate_cfg) llm_instance = init_chat_model(**model_kwargs) + # 创建自定义的日志处理器 + logging_handler = LoggingCallbackHandler() + agent = create_agent( model=llm_instance, system_prompt=system, tools=mcp_tools, middleware=[GuidelineMiddleware(bot_id, llm_instance, system, robot_type, language, user_identifier)] ) + + # 将 handler 存储在 agent 的属性中,方便在调用时使用 + agent.logging_handler = logging_handler + return agent diff --git a/routes/chat.py b/routes/chat.py index 1e25bca..b97bd58 100644 --- a/routes/chat.py +++ b/routes/chat.py @@ -141,7 +141,8 @@ async def enhanced_generate_stream_response( chunk_id = 0 message_tag = "" - async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages"): + config = {"callbacks": [agent.logging_handler]} if hasattr(agent, 'logging_handler') else {} + async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config): new_content = "" if isinstance(msg, AIMessageChunk): @@ -311,7 +312,8 @@ async def create_agent_and_generate_response( final_messages = messages.copy() # 非流式响应 - agent_responses = await agent.ainvoke({"messages": final_messages}) + config = {"callbacks": [agent.logging_handler]} if hasattr(agent, 'logging_handler') else {} + agent_responses = await agent.ainvoke({"messages": final_messages}, config=config) append_messages = agent_responses["messages"][len(final_messages):] response_text = "" for msg in append_messages: