From e36787fb63d7261e8b4814318742ead0be1e55a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Tue, 16 Dec 2025 21:26:20 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9agent=5Fconfig?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- {utils => agent}/agent_config.py | 47 +++++++++++++++---- agent/deep_assistant.py | 78 +++----------------------------- agent/guideline_middleware.py | 14 ++++-- agent/logging_handler.py | 57 +++++++++++++++++++++++ agent/sharded_agent_manager.py | 2 +- routes/chat.py | 39 ++++------------ utils/fastapi_utils.py | 8 ++-- 7 files changed, 123 insertions(+), 122 deletions(-) rename {utils => agent}/agent_config.py (72%) create mode 100644 agent/logging_handler.py diff --git a/utils/agent_config.py b/agent/agent_config.py similarity index 72% rename from utils/agent_config.py rename to agent/agent_config.py index 05e439a..edee4e1 100644 --- a/utils/agent_config.py +++ b/agent/agent_config.py @@ -1,7 +1,7 @@ """Agent配置类,用于管理所有Agent相关的参数""" -from typing import Optional, List, Dict, Any -from dataclasses import dataclass +from typing import Optional, List, Dict, Any, TYPE_CHECKING +from dataclasses import dataclass, field import logging import json @@ -20,7 +20,7 @@ class AgentConfig: # 配置参数 system_prompt: Optional[str] = None - mcp_settings: Optional[List[Dict]] = None + mcp_settings: Optional[List[Dict]] = field(default_factory=list) robot_type: Optional[str] = "general_agent" generate_cfg: Optional[Dict] = None enable_thinking: bool = False @@ -34,6 +34,9 @@ class AgentConfig: stream: bool = False tool_response: bool = True preamble_text: Optional[str] = None + messages: Optional[List] = field(default_factory=list) + + logging_handler: Optional['LoggingCallbackHandler'] = None def to_dict(self) -> Dict[str, Any]: """转换为字典格式,用于传递给需要**kwargs的函数""" @@ -53,7 +56,8 @@ class AgentConfig: 'session_id': self.session_id, 'stream': self.stream, 'tool_response': self.tool_response, - 'preamble_text': self.preamble_text + 'preamble_text': self.preamble_text, + 'messages': self.messages, } def safe_print(self): @@ -64,8 +68,14 @@ class AgentConfig: logger.info(f"config={json.dumps(safe_dict, ensure_ascii=False)}") @classmethod - def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None): + def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None, messages: Optional[List] = None): """从v1请求创建配置""" + # 延迟导入避免循环依赖 + from .logging_handler import LoggingCallbackHandler + + if messages is None: + messages = [] + return cls( bot_id=request.bot_id, api_key=api_key, @@ -81,12 +91,20 @@ class AgentConfig: project_dir=project_dir, stream=request.stream, tool_response=request.tool_response, - generate_cfg=generate_cfg + generate_cfg=generate_cfg, + logging_handler=LoggingCallbackHandler(), + messages=messages ) @classmethod - def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None): + def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None): """从v2请求创建配置""" + # 延迟导入避免循环依赖 + from .logging_handler import LoggingCallbackHandler + + if messages is None: + messages = [] + return cls( bot_id=request.bot_id, api_key=bot_config.get("api_key"), @@ -102,5 +120,16 @@ class AgentConfig: project_dir=project_dir, stream=request.stream, tool_response=request.tool_response, - generate_cfg={} # v2接口不传递额外的generate_cfg - ) \ No newline at end of file + generate_cfg={}, # v2接口不传递额外的generate_cfg + logging_handler=LoggingCallbackHandler(), + messages=messages + ) + + def invoke_config(self): + """返回Langchain需要的配置字典""" + config = {} + if self.logging_handler: + config["callbacks"] = [self.logging_handler] + if self.session_id: + config["configurable"] = {"thread_id": self.session_id} + return config \ No newline at end of file diff --git a/agent/deep_assistant.py b/agent/deep_assistant.py index 4e1481f..5cac35e 100644 --- a/agent/deep_assistant.py +++ b/agent/deep_assistant.py @@ -8,70 +8,13 @@ from langchain.chat_models import init_chat_model from langchain.agents import create_agent from langchain.agents.middleware import SummarizationMiddleware from langchain_mcp_adapters.client import MultiServerMCPClient -from langchain_core.callbacks import BaseCallbackHandler from langgraph.checkpoint.memory import MemorySaver from utils.fastapi_utils import detect_provider from .guideline_middleware import GuidelineMiddleware from .tool_output_length_middleware import ToolOutputLengthMiddleware from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH -from utils.agent_config import AgentConfig - - -class LoggingCallbackHandler(BaseCallbackHandler): - """自定义的 CallbackHandler,使用项目的 logger 来记录日志""" - - def __init__(self, logger_name: str = 'app'): - self.logger = logging.getLogger(logger_name) - - def on_llm_end(self, response, **kwargs: Any) -> None: - """当 LLM 结束时调用""" - self.logger.info("✅ LLM End - Output:") - - # 打印生成的文本 - if hasattr(response, 'generations') and response.generations: - for gen_idx, generation_list in enumerate(response.generations): - for msg_idx, generation in enumerate(generation_list): - if hasattr(generation, 'text'): - output_list = generation.text.split("\n") - for i, output in enumerate(output_list): - if output.strip(): - self.logger.info(f"{output}") - elif hasattr(generation, 'message'): - output_list = generation.message.split("\n") - for i, output in enumerate(output_list): - if output.strip(): - self.logger.info(f"{output}") - - def on_llm_error( - self, error: Exception, **kwargs: Any - ) -> None: - """当 LLM 出错时调用""" - self.logger.error(f"❌ LLM Error: {error}") - - def on_tool_start( - self, serialized: Optional[Dict[str, Any]], input_str: str, **kwargs: Any - ) -> None: - """当工具开始调用时调用""" - if serialized is None: - tool_name = 'unknown_tool' - else: - tool_name = serialized.get('name', 'unknown_tool') - self.logger.info(f"🔧 Tool Start - {tool_name} with input: {str(input_str)[:100]}") - - def on_tool_end(self, output: str, **kwargs: Any) -> None: - """当工具调用结束时调用""" - self.logger.info(f"✅ Tool End Output: {output}") - - def on_tool_error( - self, error: Exception, **kwargs: Any - ) -> None: - """当工具调用出错时调用""" - self.logger.error(f"❌ Tool Error: {error}") - - def on_agent_action(self, action, **kwargs: Any) -> None: - """当 Agent 执行动作时调用""" - self.logger.info(f"🎯 Agent Action: {action.log}") +from agent.agent_config import AgentConfig # Utility functions @@ -124,9 +67,9 @@ async def init_agent(config: AgentConfig): mcp: MCP配置(如果为None则使用配置中的mcp_settings) """ # 如果没有提供mcp,使用config中的mcp_settings - mcp = config.mcp_settings if config.mcp_settings else read_mcp_settings() - system = config.system_prompt if config.system_prompt else read_system_prompt() - mcp_tools = await get_tools_from_mcp(mcp) + mcp_settings = config.mcp_settings if config.mcp_settings else read_mcp_settings() + system_prompt = config.system_prompt if config.system_prompt else read_system_prompt() + mcp_tools = await get_tools_from_mcp(mcp_settings) # 检测或使用指定的提供商 model_provider,base_url = detect_provider(config.model_name, config.model_server) @@ -143,15 +86,11 @@ async def init_agent(config: AgentConfig): model_kwargs.update(config.generate_cfg) llm_instance = init_chat_model(**model_kwargs) - # 创建自定义的日志处理器 - logging_handler = LoggingCallbackHandler() - # 构建中间件列表 middleware = [] - # 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware if config.enable_thinking: - middleware.append(GuidelineMiddleware(config.bot_id, llm_instance, system, config.robot_type, config.language, config.user_identifier)) + middleware.append(GuidelineMiddleware(llm_instance, config, system_prompt)) # 添加工具输出长度控制中间件 tool_output_middleware = ToolOutputLengthMiddleware( @@ -179,15 +118,10 @@ async def init_agent(config: AgentConfig): agent = create_agent( model=llm_instance, - system_prompt=system, + system_prompt=system_prompt, tools=mcp_tools, middleware=middleware, checkpointer=checkpointer # 传入 checkpointer 以启用持久化 ) - # 将 handler 和 checkpointer 存储在 agent 的属性中,方便在调用时使用 - agent.logging_handler = logging_handler - agent.checkpointer = checkpointer - agent.bot_id = config.bot_id - agent.session_id = config.session_id return agent \ No newline at end of file diff --git a/agent/guideline_middleware.py b/agent/guideline_middleware.py index 36df44c..e76edea 100644 --- a/agent/guideline_middleware.py +++ b/agent/guideline_middleware.py @@ -1,3 +1,4 @@ +from ast import Str from langchain.agents.middleware import AgentState, AgentMiddleware, ModelRequest, ModelResponse from langchain_core.messages import convert_to_openai_messages from agent.prompt_loader import load_guideline_prompt @@ -9,15 +10,18 @@ from langchain_core.messages import SystemMessage from typing import Any, Callable from langchain_core.callbacks import BaseCallbackHandler from langchain_core.outputs import LLMResult +from .agent_config import AgentConfig import logging import re logger = logging.getLogger('app') + class GuidelineMiddleware(AgentMiddleware): - def __init__(self, bot_id: str, model:BaseChatModel, prompt: str, robot_type: str, language: str, user_identifier: str): + def __init__(self, model:BaseChatModel, config:AgentConfig, prompt: str): self.model = model - self.bot_id = bot_id + self.bot_id = config.bot_id + processed_system_prompt, guidelines, tool_description, scenarios, terms_list = extract_block_from_system_prompt(prompt) self.processed_system_prompt = processed_system_prompt @@ -25,10 +29,10 @@ class GuidelineMiddleware(AgentMiddleware): self.tool_description = tool_description self.scenarios = scenarios - self.language = language - self.user_identifier = user_identifier + self.language = config.language + self.user_identifier = config.user_identifier - self.robot_type = robot_type + self.robot_type = config.robot_type self.terms_list = terms_list if self.robot_type == "general_agent": diff --git a/agent/logging_handler.py b/agent/logging_handler.py new file mode 100644 index 0000000..96eea99 --- /dev/null +++ b/agent/logging_handler.py @@ -0,0 +1,57 @@ +"""日志回调处理器模块""" + +import logging +from typing import Any, Optional, Dict +from langchain_core.callbacks import BaseCallbackHandler + + +class LoggingCallbackHandler(BaseCallbackHandler): + """自定义的 CallbackHandler,使用项目的 logger 来记录日志""" + + def __init__(self, logger_name: str = 'app'): + self.logger = logging.getLogger(logger_name) + + def on_llm_end(self, response, **kwargs: Any) -> None: + """当 LLM 结束时调用""" + self.logger.info("✅ LLM End - Output:") + + # 打印生成的文本 + if hasattr(response, 'generations') and response.generations: + for gen_idx, generation_list in enumerate(response.generations): + for msg_idx, generation in enumerate(generation_list): + if hasattr(generation, 'text'): + output_list = generation.text.split("\n") + for i, output in enumerate(output_list): + if output.strip(): + self.logger.info(f"{output}") + elif hasattr(generation, 'message'): + output_list = generation.message.split("\n") + for i, output in enumerate(output_list): + if output.strip(): + self.logger.info(f"{output}") + + def on_llm_error( + self, error: Exception, **kwargs: Any + ) -> None: + """当 LLM 出错时调用""" + self.logger.error(f"❌ LLM Error: {error}") + + def on_tool_start( + self, serialized: Optional[Dict[str, Any]], input_str: str, **kwargs: Any + ) -> None: + """当工具开始调用时调用""" + if serialized is None: + tool_name = 'unknown_tool' + else: + tool_name = serialized.get('name', 'unknown_tool') + self.logger.info(f"🔧 Tool Start - {tool_name} with input: {str(input_str)[:100]}") + + def on_tool_end(self, output: str, **kwargs: Any) -> None: + """当工具调用结束时调用""" + self.logger.info(f"✅ Tool End Output: {output}") + + def on_tool_error( + self, error: Exception, **kwargs: Any + ) -> None: + """当工具调用出错时调用""" + self.logger.error(f"❌ Tool Error: {error}") \ No newline at end of file diff --git a/agent/sharded_agent_manager.py b/agent/sharded_agent_manager.py index 09b489c..0ae424f 100644 --- a/agent/sharded_agent_manager.py +++ b/agent/sharded_agent_manager.py @@ -26,7 +26,7 @@ logger = logging.getLogger('app') from agent.deep_assistant import init_agent from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async -from utils.agent_config import AgentConfig +from agent.agent_config import AgentConfig class ShardedAgentManager: diff --git a/routes/chat.py b/routes/chat.py index 70aa40e..709622a 100644 --- a/routes/chat.py +++ b/routes/chat.py @@ -21,7 +21,7 @@ from utils.fastapi_utils import ( ) from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT -from utils.agent_config import AgentConfig +from agent.agent_config import AgentConfig router = APIRouter() @@ -98,14 +98,12 @@ def format_messages_to_chat_history(messages: list) -> str: async def enhanced_generate_stream_response( agent_manager, - messages: list, config: AgentConfig ): """增强的渐进式流式响应生成器 - 并发优化版本 Args: agent_manager: agent管理器 - messages: 消息列表 config: AgentConfig 对象,包含所有参数 """ try: @@ -116,7 +114,7 @@ async def enhanced_generate_stream_response( # Preamble 任务 async def preamble_task(): try: - preamble_result = await call_preamble_llm(messages,config) + preamble_result = await call_preamble_llm(config) # 只有当preamble_text不为空且不为""时才输出 if preamble_result and preamble_result.strip() and preamble_result != "": preamble_content = f"[PREAMBLE]\n{preamble_result}\n" @@ -147,12 +145,7 @@ async def enhanced_generate_stream_response( chunk_id = 0 message_tag = "" - stream_config = {} - if config.session_id: - stream_config["configurable"] = {"thread_id": config.session_id} - if hasattr(agent, 'logging_handler'): - stream_config["callbacks"] = [agent.logging_handler] - async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=stream_config, max_tokens=MAX_OUTPUT_TOKENS): + async for msg, metadata in agent.astream({"messages": config.messages}, stream_mode="messages", config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS): new_content = "" if isinstance(msg, AIMessageChunk): @@ -270,17 +263,14 @@ async def enhanced_generate_stream_response( async def create_agent_and_generate_response( - messages: list, config: AgentConfig ) -> Union[ChatResponse, StreamingResponse]: """创建agent并生成响应的公共逻辑 Args: - messages: 消息列表 config: AgentConfig 对象,包含所有参数 """ config.safe_print() - logger.info(f"messages={json.dumps(messages, ensure_ascii=False)}") config.preamble_text, config.system_prompt = get_preamble_text(config.language, config.system_prompt) # 如果是流式模式,使用增强的流式响应生成器 @@ -288,28 +278,17 @@ async def create_agent_and_generate_response( return StreamingResponse( enhanced_generate_stream_response( agent_manager=agent_manager, - messages=messages, config=config ), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} ) - + messages = config.messages # 使用公共函数处理所有逻辑 agent = await agent_manager.get_or_create_agent(config) - - # 准备最终的消息 - final_messages = messages.copy() - - # 非流式响应 - agent_config = {} - if config.session_id: - agent_config["configurable"] = {"thread_id": config.session_id} - if hasattr(agent, 'logging_handler'): - agent_config["callbacks"] = [agent.logging_handler] - agent_responses = await agent.ainvoke({"messages": final_messages}, config=agent_config, max_tokens=MAX_OUTPUT_TOKENS) - append_messages = agent_responses["messages"][len(final_messages):] + agent_responses = await agent.ainvoke({"messages": messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS) + append_messages = agent_responses["messages"][len(messages):] response_text = "" for msg in append_messages: if isinstance(msg,AIMessage): @@ -394,10 +373,9 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] = # 处理消息 messages = process_messages(request.messages, request.language) # 创建 AgentConfig 对象 - config = AgentConfig.from_v1_request(request, api_key, project_dir, generate_cfg) + config = AgentConfig.from_v1_request(request, api_key, project_dir, generate_cfg, messages) # 调用公共的agent创建和响应生成逻辑 return await create_agent_and_generate_response( - messages=messages, config=config ) @@ -471,10 +449,9 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st # 处理消息 messages = process_messages(request.messages, request.language) # 创建 AgentConfig 对象 - config = AgentConfig.from_v2_request(request, bot_config, project_dir) + config = AgentConfig.from_v2_request(request, bot_config, project_dir, messages) # 调用公共的agent创建和响应生成逻辑 return await create_agent_and_generate_response( - messages=messages, config=config ) diff --git a/utils/fastapi_utils.py b/utils/fastapi_utils.py index f5de94a..282ad51 100644 --- a/utils/fastapi_utils.py +++ b/utils/fastapi_utils.py @@ -11,7 +11,7 @@ import logging from langchain_core.messages import HumanMessage, AIMessage, SystemMessage from langchain.chat_models import init_chat_model from utils.settings import MASTERKEY, BACKEND_HOST -from utils.agent_config import AgentConfig +from agent.agent_config import AgentConfig USER = "user" ASSISTANT = "assistant" @@ -561,7 +561,7 @@ def get_preamble_text(language: str, system_prompt: str): return default_preamble, system_prompt # 返回默认preamble和原始system_prompt -async def call_preamble_llm(messages: list, config: AgentConfig) -> str: +async def call_preamble_llm(config: AgentConfig) -> str: """调用大语言模型处理guideline分析 Args: @@ -587,8 +587,8 @@ async def call_preamble_llm(messages: list, config: AgentConfig) -> str: model_server = config.model_server language = config.language preamble_choices_text = config.preamble_text - last_message = get_user_last_message_content(messages) - chat_history = format_messages_to_chat_history(messages) + last_message = get_user_last_message_content(config.messages) + chat_history = format_messages_to_chat_history(config.messages) # 替换模板中的占位符 system_prompt = preamble_template.replace('{preamble_choices_text}', preamble_choices_text).replace('{chat_history}', chat_history).replace('{last_message}', last_message).replace('{language}', get_language_text(language))