From 9525c0f883f5c79424eca290c48296201ebbac0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Tue, 16 Dec 2025 16:06:47 +0800 Subject: [PATCH] add AgentConfig --- agent/deep_assistant.py | 83 +++++----- agent/sharded_agent_manager.py | 65 +++----- routes/chat.py | 270 +++++++++++++-------------------- utils/agent_config.py | 97 ++++++++++++ utils/api_models.py | 2 + utils/fastapi_utils.py | 22 ++- 6 files changed, 282 insertions(+), 257 deletions(-) create mode 100644 utils/agent_config.py diff --git a/agent/deep_assistant.py b/agent/deep_assistant.py index aa6a00c..4e1481f 100644 --- a/agent/deep_assistant.py +++ b/agent/deep_assistant.py @@ -2,7 +2,7 @@ import json import logging import os import sqlite3 -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List from langchain.chat_models import init_chat_model # from deepagents import create_deep_agent from langchain.agents import create_agent @@ -15,6 +15,8 @@ from utils.fastapi_utils import detect_provider from .guideline_middleware import GuidelineMiddleware from .tool_output_length_middleware import ToolOutputLengthMiddleware from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH +from utils.agent_config import AgentConfig + class LoggingCallbackHandler(BaseCallbackHandler): """自定义的 CallbackHandler,使用项目的 logger 来记录日志""" @@ -22,17 +24,6 @@ class LoggingCallbackHandler(BaseCallbackHandler): def __init__(self, logger_name: str = 'app'): self.logger = logging.getLogger(logger_name) - # def on_llm_start( - # self, serialized: Optional[Dict[str, Any]], prompts: list[str], **kwargs: Any - # ) -> None: - # """当 LLM 开始时调用""" - # self.logger.info("🤖 LLM Start - Input Messages:") - # if prompts: - # for i, prompt in enumerate(prompts): - # self.logger.info(f" Message {i+1}:\n{prompt}") - # else: - # self.logger.info(" No prompts") - def on_llm_end(self, response, **kwargs: Any) -> None: """当 LLM 结束时调用""" self.logger.info("✅ LLM End - Output:") @@ -78,7 +69,6 @@ class LoggingCallbackHandler(BaseCallbackHandler): """当工具调用出错时调用""" self.logger.error(f"❌ Tool Error: {error}") - def on_agent_action(self, action, **kwargs: Any) -> None: """当 Agent 执行动作时调用""" self.logger.info(f"🎯 Agent Action: {action.log}") @@ -97,6 +87,7 @@ def read_mcp_settings(): mcp_settings_json = json.load(f) return mcp_settings_json + async def get_tools_from_mcp(mcp): """从MCP配置中提取工具""" # 防御式处理:确保 mcp 是列表且长度大于 0,且包含 mcpServers @@ -123,75 +114,69 @@ async def get_tools_from_mcp(mcp): # 发生异常时返回空列表,避免上层调用报错 return [] -async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None, - model_server=None, generate_cfg=None, - system_prompt=None, mcp=None, robot_type=None, language="jp", user_identifier=None, - session_id=None): + +async def init_agent(config: AgentConfig): """ 初始化 Agent,支持持久化内存和对话摘要 Args: - bot_id: Bot ID - model_name: 模型名称 - api_key: API密钥 - model_server: 模型服务器地址 - generate_cfg: 生成配置 - system_prompt: 系统提示 - mcp: MCP配置 - robot_type: 机器人类型 - language: 语言 - user_identifier: 用户标识 - session_id: 会话ID(如果为None,则不启用持久化内存) + config: AgentConfig 对象,包含所有初始化参数 + mcp: MCP配置(如果为None则使用配置中的mcp_settings) """ - system = system_prompt if system_prompt else read_system_prompt() - mcp = mcp if mcp else read_mcp_settings() + # 如果没有提供mcp,使用config中的mcp_settings + mcp = config.mcp_settings if config.mcp_settings else read_mcp_settings() + system = config.system_prompt if config.system_prompt else read_system_prompt() mcp_tools = await get_tools_from_mcp(mcp) # 检测或使用指定的提供商 - model_provider,base_url = detect_provider(model_name,model_server) + model_provider,base_url = detect_provider(config.model_name, config.model_server) # 构建模型参数 model_kwargs = { - "model": model_name, + "model": config.model_name, "model_provider": model_provider, "temperature": 0.8, "base_url": base_url, - "api_key": api_key + "api_key": config.api_key } - if generate_cfg: - model_kwargs.update(generate_cfg) + if config.generate_cfg: + model_kwargs.update(config.generate_cfg) llm_instance = init_chat_model(**model_kwargs) # 创建自定义的日志处理器 logging_handler = LoggingCallbackHandler() # 构建中间件列表 - middleware = [GuidelineMiddleware(bot_id, llm_instance, system, robot_type, language, user_identifier)] + middleware = [] + + # 只有在 enable_thinking 为 True 时才添加 GuidelineMiddleware + if config.enable_thinking: + middleware.append(GuidelineMiddleware(config.bot_id, llm_instance, system, config.robot_type, config.language, config.user_identifier)) # 添加工具输出长度控制中间件 tool_output_middleware = ToolOutputLengthMiddleware( - max_length=getattr(generate_cfg, 'tool_output_max_length', None) or TOOL_OUTPUT_MAX_LENGTH, - truncation_strategy=getattr(generate_cfg, 'tool_output_truncation_strategy', 'smart'), - tool_filters=getattr(generate_cfg, 'tool_output_filters', None), # 可配置特定工具 - exclude_tools=getattr(generate_cfg, 'tool_output_exclude', []), # 排除的工具 - preserve_code_blocks=getattr(generate_cfg, 'preserve_code_blocks', True), - preserve_json=getattr(generate_cfg, 'preserve_json', True) + max_length=getattr(config.generate_cfg, 'tool_output_max_length', None) if config.generate_cfg else None or TOOL_OUTPUT_MAX_LENGTH, + truncation_strategy=getattr(config.generate_cfg, 'tool_output_truncation_strategy', 'smart') if config.generate_cfg else 'smart', + tool_filters=getattr(config.generate_cfg, 'tool_output_filters', None) if config.generate_cfg else None, # 可配置特定工具 + exclude_tools=getattr(config.generate_cfg, 'tool_output_exclude', []) if config.generate_cfg else [], # 排除的工具 + preserve_code_blocks=getattr(config.generate_cfg, 'preserve_code_blocks', True) if config.generate_cfg else True, + preserve_json=getattr(config.generate_cfg, 'preserve_json', True) if config.generate_cfg else True ) middleware.append(tool_output_middleware) # 初始化 checkpointer 和中间件 checkpointer = None - if session_id: + if config.session_id: checkpointer = MemorySaver() summarization_middleware = SummarizationMiddleware( - model=llm_instance, - max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS, + model=llm_instance, + max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS, messages_to_keep=20, # 摘要后保留最近 20 条消息 summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。" ) middleware.append(summarization_middleware) - + agent = create_agent( model=llm_instance, system_prompt=system, @@ -203,6 +188,6 @@ async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None, # 将 handler 和 checkpointer 存储在 agent 的属性中,方便在调用时使用 agent.logging_handler = logging_handler agent.checkpointer = checkpointer - agent.bot_id = bot_id - agent.session_id = session_id - return agent + agent.bot_id = config.bot_id + agent.session_id = config.session_id + return agent \ No newline at end of file diff --git a/agent/sharded_agent_manager.py b/agent/sharded_agent_manager.py index c71c7d7..09b489c 100644 --- a/agent/sharded_agent_manager.py +++ b/agent/sharded_agent_manager.py @@ -26,6 +26,7 @@ logger = logging.getLogger('app') from agent.deep_assistant import init_agent from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async +from utils.agent_config import AgentConfig class ShardedAgentManager: @@ -67,7 +68,8 @@ class ShardedAgentManager: def _get_cache_key(self, bot_id: str, model_name: str = None, api_key: str = None, model_server: str = None, generate_cfg: Dict = None, - system_prompt: str = None, mcp_settings: List[Dict] = None) -> str: + system_prompt: str = None, mcp_settings: List[Dict] = None, + enable_thinking: bool = False) -> str: """获取包含所有相关参数的哈希值作为缓存键""" cache_data = { 'bot_id': bot_id, @@ -76,7 +78,8 @@ class ShardedAgentManager: 'model_server': model_server or '', 'generate_cfg': json.dumps(generate_cfg or {}, sort_keys=True), 'system_prompt': system_prompt or '', - 'mcp_settings': json.dumps(mcp_settings or [], sort_keys=True) + 'mcp_settings': json.dumps(mcp_settings or [], sort_keys=True), + 'enable_thinking': enable_thinking } cache_str = json.dumps(cache_data, sort_keys=True) @@ -116,35 +119,29 @@ class ShardedAgentManager: if removed_count > 0: logger.info(f"分片已清理 {removed_count} 个过期的助手实例缓存") - async def get_or_create_agent(self, - bot_id: str, - project_dir: Optional[str], - model_name: str = "qwen3-next", - api_key: Optional[str] = None, - model_server: Optional[str] = None, - generate_cfg: Optional[Dict] = None, - language: Optional[str] = None, - system_prompt: Optional[str] = None, - mcp_settings: Optional[List[Dict]] = None, - robot_type: Optional[str] = "general_agent", - user_identifier: Optional[str] = None, - session_id: Optional[str] = None): - """获取或创建文件预加载的助手实例""" - + async def get_or_create_agent(self, config: AgentConfig): + """获取或创建文件预加载的助手实例 + + Args: + config: AgentConfig 对象,包含所有初始化参数 + """ + # 更新请求统计 with self._stats_lock: self._global_stats['total_requests'] += 1 - + # 异步加载配置文件(带缓存) final_system_prompt = await load_system_prompt_async( - project_dir, language, system_prompt, robot_type, bot_id, user_identifier + config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier ) final_mcp_settings = await load_mcp_settings_async( - project_dir, mcp_settings, bot_id, robot_type + config.project_dir, config.mcp_settings, config.bot_id, config.robot_type ) - - cache_key = self._get_cache_key(bot_id, model_name, api_key, model_server, - generate_cfg, final_system_prompt, final_mcp_settings) + config.system_prompt = final_system_prompt + config.mcp_settings = final_mcp_settings + cache_key = self._get_cache_key(config.bot_id, config.model_name, config.api_key, config.model_server, + config.generate_cfg, final_system_prompt, final_mcp_settings, + config.enable_thinking) # 获取分片 shard_index = self._get_shard_index(cache_key) @@ -160,7 +157,7 @@ class ShardedAgentManager: with self._stats_lock: self._global_stats['cache_hits'] += 1 - logger.info(f"分片复用现有的助手实例缓存: {cache_key} (bot_id: {bot_id}, shard: {shard_index})") + logger.info(f"分片复用现有的助手实例缓存: {cache_key} (bot_id: {config.bot_id}, shard: {shard_index})") return agent # 更新缓存未命中统计 @@ -188,27 +185,15 @@ class ShardedAgentManager: self._cleanup_old_agents(shard) # 创建新的助手实例 - logger.info(f"分片创建新的助手实例缓存: {cache_key}, bot_id: {bot_id}, shard: {shard_index}") + logger.info(f"分片创建新的助手实例缓存: {cache_key}, bot_id: {config.bot_id}, shard: {shard_index}") current_time = time.time() - - agent = await init_agent( - bot_id=bot_id, - model_name=model_name, - api_key=api_key, - model_server=model_server, - generate_cfg=generate_cfg, - system_prompt=final_system_prompt, - mcp=final_mcp_settings, - robot_type=robot_type, - language=language, - user_identifier=user_identifier, - session_id=session_id - ) + + agent = await init_agent(config) # 缓存实例 async with shard['lock']: shard['agents'][cache_key] = agent - shard['unique_ids'][cache_key] = bot_id + shard['unique_ids'][cache_key] = config.bot_id shard['access_times'][cache_key] = current_time shard['creation_times'][cache_key] = current_time diff --git a/routes/chat.py b/routes/chat.py index d695d39..bdfad4c 100644 --- a/routes/chat.py +++ b/routes/chat.py @@ -14,13 +14,14 @@ from utils import ( from agent.sharded_agent_manager import init_global_sharded_agent_manager from utils.api_models import ChatRequestV2 from utils.fastapi_utils import ( - process_messages, format_messages_to_chat_history, + process_messages, create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config, - call_preamble_llm, get_preamble_text, get_user_last_message_content, + call_preamble_llm, get_preamble_text, create_stream_chunk ) from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT +from utils.agent_config import AgentConfig router = APIRouter() @@ -69,89 +70,91 @@ def append_assistant_last_message(messages: list, content: str) -> bool: messages.append({"role":"assistant","content":content}) return messages + +def get_user_last_message_content(messages: list) -> str: + """获取最后一条用户消息的内容""" + if not messages: + return "" + + for msg in reversed(messages): + if msg.get('role') == 'user': + return msg.get('content', '') + + return "" + + +def format_messages_to_chat_history(messages: list) -> str: + """将消息格式化为聊天历史字符串""" + chat_history = "" + for msg in messages: + role = msg.get('role', '') + content = msg.get('content', '') + if role == 'user': + chat_history += f"用户: {content}\n" + elif role == 'assistant': + chat_history += f"助手: {content}\n" + return chat_history + + async def enhanced_generate_stream_response( agent_manager, - bot_id: str, - api_key: str, messages: list, - tool_response: bool, - model_name: str, - model_server: str, - language: str, - system_prompt: str, - mcp_settings: Optional[list], - robot_type: str, - project_dir: Optional[str], - generate_cfg: Optional[dict], - user_identifier: Optional[str], - session_id: Optional[str] = None, + config: AgentConfig ): - """增强的渐进式流式响应生成器 - 并发优化版本""" + """增强的渐进式流式响应生成器 - 并发优化版本 + + Args: + agent_manager: agent管理器 + messages: 消息列表 + config: AgentConfig 对象,包含所有参数 + """ try: - # 准备参数 - query_text = get_user_last_message_content(messages) - chat_history = format_messages_to_chat_history(messages) - preamble_text, system_prompt = get_preamble_text(language, system_prompt) - # 创建输出队列和控制事件 output_queue = asyncio.Queue() preamble_completed = asyncio.Event() - + # Preamble 任务 async def preamble_task(): try: - preamble_result = await call_preamble_llm(chat_history, query_text, preamble_text, language, model_name, api_key, model_server) + preamble_result = await call_preamble_llm(messages,config) # 只有当preamble_text不为空且不为""时才输出 if preamble_result and preamble_result.strip() and preamble_result != "": preamble_content = f"[PREAMBLE]\n{preamble_result}\n" - chunk_data = create_stream_chunk(f"chatcmpl-preamble", model_name, preamble_content) + chunk_data = create_stream_chunk(f"chatcmpl-preamble", config.model_name, preamble_content) await output_queue.put(("preamble", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n")) logger.info(f"Stream mode: Generated preamble text ({len(preamble_result)} chars)") else: logger.info("Stream mode: Skipped empty preamble text") - + # 标记 preamble 完成 preamble_completed.set() await output_queue.put(("preamble_done", None)) - + except Exception as e: logger.error(f"Error generating preamble text: {e}") # 即使出错也要标记完成,避免阻塞 preamble_completed.set() await output_queue.put(("preamble_done", None)) - + # Agent 任务(准备 + 流式处理) async def agent_task(): try: # 准备 agent - agent = await agent_manager.get_or_create_agent( - bot_id=bot_id, - project_dir=project_dir, - model_name=model_name, - api_key=api_key, - model_server=model_server, - generate_cfg=generate_cfg, - language=language, - system_prompt=system_prompt, - mcp_settings=mcp_settings, - robot_type=robot_type, - user_identifier=user_identifier, - session_id=session_id, - ) - + agent = await agent_manager.get_or_create_agent(config) + # 开始流式处理 logger.info(f"Starting agent stream response") chunk_id = 0 message_tag = "" - - config = {} - if session_id: - config["configurable"] = {"thread_id": session_id} + + stream_config = {} + if config.session_id: + stream_config["configurable"] = {"thread_id": config.session_id} if hasattr(agent, 'logging_handler'): - config["callbacks"] = [agent.logging_handler] - async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config, max_tokens=MAX_OUTPUT_TOKENS): + stream_config["callbacks"] = [agent.logging_handler] + async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=stream_config, max_tokens=MAX_OUTPUT_TOKENS): new_content = "" - + if isinstance(msg, AIMessageChunk): # 处理工具调用 if msg.tool_call_chunks: @@ -161,7 +164,7 @@ async def enhanced_generate_stream_response( new_content = f"[{message_tag}] {tool_call_chunk['name']}\n" if tool_call_chunk['args']: new_content += tool_call_chunk['args'] - + # 处理文本内容 elif msg.content: preamble_completed.set() @@ -172,9 +175,9 @@ async def enhanced_generate_stream_response( new_content = f"[{meta_message_tag}]\n" if msg.text: new_content += msg.text - + # 处理工具响应 - elif isinstance(msg, ToolMessage) and tool_response and msg.content: + elif isinstance(msg, ToolMessage) and config.tool_response and msg.content: message_tag = "TOOL_RESPONSE" new_content = f"[{message_tag}] {msg.name}\n{msg.text}\n" @@ -183,40 +186,48 @@ async def enhanced_generate_stream_response( if chunk_id == 0: logger.info(f"Agent首个Token已生成, 开始流式输出") chunk_id += 1 - chunk_data = create_stream_chunk(f"chatcmpl-{chunk_id}", model_name, new_content) + chunk_data = create_stream_chunk(f"chatcmpl-{chunk_id}", config.model_name, new_content) await output_queue.put(("agent", f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n")) - + # 发送最终chunk - final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", model_name, finish_reason="stop") + final_chunk = create_stream_chunk(f"chatcmpl-{chunk_id + 1}", config.model_name, finish_reason="stop") await output_queue.put(("agent", f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n")) await output_queue.put(("agent_done", None)) - + except Exception as e: logger.error(f"Error in agent task: {e}") await output_queue.put(("agent_done", None)) - + # 并发执行任务 - preamble_task_handle = asyncio.create_task(preamble_task()) + # 只有在 enable_thinking 为 True 时才执行 preamble 任务 + if config.enable_thinking: + preamble_task_handle = asyncio.create_task(preamble_task()) + else: + # 如果不启用 thinking,创建一个空的已完成任务 + preamble_task_handle = asyncio.create_task(asyncio.sleep(0)) + # 直接标记 preamble 完成 + preamble_completed.set() + agent_task_handle = asyncio.create_task(agent_task()) - + # 输出控制器:确保 preamble 先输出,然后是 agent stream preamble_output_done = False - + while True: try: # 设置超时避免无限等待 item_type, item_data = await asyncio.wait_for(output_queue.get(), timeout=1.0) - + if item_type == "preamble": # 立即输出 preamble 内容 if item_data: yield item_data preamble_output_done = True - + elif item_type == "preamble_done": - # Preamble 已完成,标记并继续处理 + # Preamble 已完成,标记并继续 preamble_output_done = True - + elif item_type == "agent": # Agent stream 内容,需要等待 preamble 输出完成 if preamble_output_done: @@ -227,18 +238,18 @@ async def enhanced_generate_stream_response( # 等待 preamble 完成 await preamble_completed.wait() preamble_output_done = True - + elif item_type == "agent_done": # Agent stream 完成,结束循环 break - + except asyncio.TimeoutError: # 检查是否还有任务在运行 if all(task.done() for task in [preamble_task_handle, agent_task_handle]): # 所有任务都完成了,退出循环 break continue - + # 发送结束标记 yield "data: [DONE]\n\n" logger.info(f"Enhanced stream response completed") @@ -259,77 +270,43 @@ async def enhanced_generate_stream_response( async def create_agent_and_generate_response( - bot_id: str, - api_key: str, messages: list, - stream: bool, - tool_response: bool, - model_name: str, - model_server: str, - language: str, - system_prompt: Optional[str], - mcp_settings: Optional[list], - robot_type: str, - project_dir: Optional[str] = None, - generate_cfg: Optional[dict] = None, - user_identifier: Optional[str] = None, - session_id: Optional[str] = None + config: AgentConfig ) -> Union[ChatResponse, StreamingResponse]: - """创建agent并生成响应的公共逻辑""" - if generate_cfg is None: - generate_cfg = {} + """创建agent并生成响应的公共逻辑 + + Args: + messages: 消息列表 + config: AgentConfig 对象,包含所有参数 + """ + config.preamble_text, config.system_prompt = get_preamble_text(config.language, config.system_prompt) # 如果是流式模式,使用增强的流式响应生成器 - if stream: + if config.stream: return StreamingResponse( enhanced_generate_stream_response( agent_manager=agent_manager, - bot_id=bot_id, - api_key=api_key, messages=messages, - tool_response=tool_response, - model_name=model_name, - model_server=model_server, - language=language, - system_prompt=system_prompt or "", - mcp_settings=mcp_settings, - robot_type=robot_type, - project_dir=project_dir, - generate_cfg=generate_cfg, - user_identifier=user_identifier, - session_id=session_id + config=config ), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} ) - _, system_prompt = get_preamble_text(language, system_prompt) + # 使用公共函数处理所有逻辑 - agent = await agent_manager.get_or_create_agent( - bot_id=bot_id, - project_dir=project_dir, - model_name=model_name, - api_key=api_key, - model_server=model_server, - generate_cfg=generate_cfg, - language=language, - system_prompt=system_prompt, - mcp_settings=mcp_settings, - robot_type=robot_type, - user_identifier=user_identifier, - session_id=session_id, - ) + agent = await agent_manager.get_or_create_agent(config) # 准备最终的消息 final_messages = messages.copy() # 非流式响应 - config = {} - if session_id: - config["configurable"] = {"thread_id": session_id} + agent_config = {} + if config.session_id: + agent_config["configurable"] = {"thread_id": config.session_id} if hasattr(agent, 'logging_handler'): - config["callbacks"] = [agent.logging_handler] - agent_responses = await agent.ainvoke({"messages": final_messages}, config=config, max_tokens=MAX_OUTPUT_TOKENS) + agent_config["callbacks"] = [agent.logging_handler] + agent_responses = await agent.ainvoke({"messages": final_messages}, config=agent_config, max_tokens=MAX_OUTPUT_TOKENS) append_messages = agent_responses["messages"][len(final_messages):] response_text = "" for msg in append_messages: @@ -340,7 +317,7 @@ async def create_agent_and_generate_response( response_text += f"[{meta_message_tag}]\n"+output_text+ "\n" if len(msg.tool_calls)>0: response_text += "".join([f"[TOOL_CALL] {tool['name']}\n{json.dumps(tool["args"]) if isinstance(tool["args"],dict) else tool["args"]}\n" for tool in msg.tool_calls]) - elif isinstance(msg,ToolMessage) and tool_response: + elif isinstance(msg,ToolMessage) and config.tool_response: response_text += f"[TOOL_RESPONSE] {msg.name}\n{msg.text}\n" if len(response_text) > 0: @@ -368,11 +345,11 @@ async def create_agent_and_generate_response( async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)): """ Chat completions API similar to OpenAI, supports both streaming and non-streaming - + Args: request: ChatRequest containing messages, model, optional dataset_ids list, required bot_id, system_prompt, mcp_settings, and files authorization: Authorization header containing API key (Bearer ) - + Returns: Union[ChatResponse, StreamingResponse]: Chat completion response or stream @@ -400,7 +377,7 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] = try: # v1接口:从Authorization header中提取API key作为模型API密钥 api_key = extract_api_key_from_auth(authorization) - + # 获取bot_id(必需参数) bot_id = request.bot_id if not bot_id: @@ -408,31 +385,18 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] = # 创建项目目录(如果有dataset_ids且不是agent类型) project_dir = create_project_directory(request.dataset_ids, bot_id, request.robot_type) - + # 收集额外参数作为 generate_cfg - exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id'} + exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id', 'enable_thinking'} generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields} - # 处理消息 messages = process_messages(request.messages, request.language) - + # 创建 AgentConfig 对象 + config = AgentConfig.from_v1_request(request, api_key, project_dir, generate_cfg) # 调用公共的agent创建和响应生成逻辑 return await create_agent_and_generate_response( - bot_id=bot_id, - api_key=api_key, messages=messages, - stream=request.stream, - tool_response=request.tool_response, - model_name=request.model, - model_server=request.model_server, - language=request.language, - system_prompt=request.system_prompt, - mcp_settings=request.mcp_settings, - robot_type=request.robot_type, - project_dir=project_dir, - generate_cfg=generate_cfg, - user_identifier=request.user_identifier, - session_id=request.session_id + config=config ) except Exception as e: @@ -496,38 +460,20 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st # 从后端API获取机器人配置(使用v2的鉴权方式) bot_config = await fetch_bot_config(bot_id) - - # v2接口:API密钥优先从后端配置获取,其次才从Authorization header获取 - # 注意:这里的Authorization header已经用于鉴权,不再作为API key使用 - api_key = bot_config.get("api_key") - # 创建项目目录(从后端配置获取dataset_ids) project_dir = create_project_directory( bot_config.get("dataset_ids", []), bot_id, bot_config.get("robot_type", "general_agent") ) - # 处理消息 messages = process_messages(request.messages, request.language) - + # 创建 AgentConfig 对象 + config = AgentConfig.from_v2_request(request, bot_config, project_dir) # 调用公共的agent创建和响应生成逻辑 return await create_agent_and_generate_response( - bot_id=bot_id, - api_key=api_key, messages=messages, - stream=request.stream, - tool_response=request.tool_response, - model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"), - model_server=bot_config.get("model_server", ""), - language=request.language or bot_config.get("language", "zh"), - system_prompt=bot_config.get("system_prompt"), - mcp_settings=bot_config.get("mcp_settings", []), - robot_type=bot_config.get("robot_type", "general_agent"), - project_dir=project_dir, - generate_cfg={}, # v2接口不传递额外的generate_cfg - user_identifier=request.user_identifier, - session_id=request.session_id + config=config ) except HTTPException: @@ -537,4 +483,4 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st error_details = traceback.format_exc() logger.error(f"Error in chat_completions_v2: {str(e)}") logger.error(f"Full traceback: {error_details}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") \ No newline at end of file diff --git a/utils/agent_config.py b/utils/agent_config.py new file mode 100644 index 0000000..d20a11a --- /dev/null +++ b/utils/agent_config.py @@ -0,0 +1,97 @@ +"""Agent配置类,用于管理所有Agent相关的参数""" + +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field +from dataclasses import dataclass + + +@dataclass +class AgentConfig: + """Agent配置类,包含创建和管理Agent所需的所有参数""" + + # 基础参数 + bot_id: str + api_key: Optional[str] = None + model_name: str = "qwen3-next" + model_server: Optional[str] = None + language: Optional[str] = "jp" + + # 配置参数 + system_prompt: Optional[str] = None + mcp_settings: Optional[List[Dict]] = None + robot_type: Optional[str] = "general_agent" + generate_cfg: Optional[Dict] = None + enable_thinking: bool = True + + # 上下文参数 + project_dir: Optional[str] = None + user_identifier: Optional[str] = None + session_id: Optional[str] = None + + # 响应控制参数 + stream: bool = False + tool_response: bool = True + preamble_text: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """转换为字典格式,用于传递给需要**kwargs的函数""" + return { + 'bot_id': self.bot_id, + 'api_key': self.api_key, + 'model_name': self.model_name, + 'model_server': self.model_server, + 'language': self.language, + 'system_prompt': self.system_prompt, + 'mcp_settings': self.mcp_settings, + 'robot_type': self.robot_type, + 'generate_cfg': self.generate_cfg, + 'enable_thinking': self.enable_thinking, + 'project_dir': self.project_dir, + 'user_identifier': self.user_identifier, + 'session_id': self.session_id, + 'stream': self.stream, + 'tool_response': self.tool_response, + 'preamble_text': self.preamble_text + } + + @classmethod + def from_v1_request(cls, request, api_key: str, project_dir: Optional[str] = None, generate_cfg: Optional[Dict] = None): + """从v1请求创建配置""" + return cls( + bot_id=request.bot_id, + api_key=api_key, + model_name=request.model, + model_server=request.model_server, + language=request.language, + system_prompt=request.system_prompt, + mcp_settings=request.mcp_settings, + robot_type=request.robot_type, + user_identifier=request.user_identifier, + session_id=request.session_id, + enable_thinking=request.enable_thinking, + project_dir=project_dir, + stream=request.stream, + tool_response=request.tool_response, + generate_cfg=generate_cfg + ) + + @classmethod + def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None): + """从v2请求创建配置""" + return cls( + bot_id=request.bot_id, + api_key=bot_config.get("api_key"), + model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"), + model_server=bot_config.get("model_server", ""), + language=request.language or bot_config.get("language", "zh"), + system_prompt=bot_config.get("system_prompt"), + mcp_settings=bot_config.get("mcp_settings", []), + robot_type=bot_config.get("robot_type", "general_agent"), + user_identifier=request.user_identifier, + session_id=request.session_id, + enable_thinking=request.enable_thinking, + project_dir=project_dir, + stream=request.stream, + tool_response=request.tool_response, + generate_cfg={} # v2接口不传递额外的generate_cfg + ) \ No newline at end of file diff --git a/utils/api_models.py b/utils/api_models.py index 6656506..ec61836 100644 --- a/utils/api_models.py +++ b/utils/api_models.py @@ -53,6 +53,7 @@ class ChatRequest(BaseModel): robot_type: Optional[str] = "general_agent" user_identifier: Optional[str] = "" session_id: Optional[str] = None + enable_thinking: Optional[bool] = False class ChatRequestV2(BaseModel): @@ -63,6 +64,7 @@ class ChatRequestV2(BaseModel): language: Optional[str] = "zh" user_identifier: Optional[str] = "" session_id: Optional[str] = None + enable_thinking: Optional[bool] = False class FileProcessRequest(BaseModel): diff --git a/utils/fastapi_utils.py b/utils/fastapi_utils.py index 1fcec31..f5de94a 100644 --- a/utils/fastapi_utils.py +++ b/utils/fastapi_utils.py @@ -11,6 +11,7 @@ import logging from langchain_core.messages import HumanMessage, AIMessage, SystemMessage from langchain.chat_models import init_chat_model from utils.settings import MASTERKEY, BACKEND_HOST +from utils.agent_config import AgentConfig USER = "user" ASSISTANT = "assistant" @@ -500,9 +501,10 @@ def get_preamble_text(language: str, system_prompt: str): if preamble_matches: # 提取preamble内容 preamble_content = preamble_matches[0].strip() - # 从system_prompt中删除preamble代码块 - cleaned_system_prompt = re.sub(preamble_pattern, '', system_prompt, flags=re.DOTALL) - return preamble_content, cleaned_system_prompt + if preamble_content: + # 从system_prompt中删除preamble代码块 + cleaned_system_prompt = re.sub(preamble_pattern, '', system_prompt, flags=re.DOTALL) + return preamble_content, cleaned_system_prompt # 如果没有找到preamble代码块,使用默认的preamble选择 if language == "jp": @@ -559,12 +561,12 @@ def get_preamble_text(language: str, system_prompt: str): return default_preamble, system_prompt # 返回默认preamble和原始system_prompt -async def call_preamble_llm(chat_history: str, last_message: str, preamble_choices_text: str, language: str, model_name: str, api_key: str, model_server: str) -> str: +async def call_preamble_llm(messages: list, config: AgentConfig) -> str: """调用大语言模型处理guideline分析 Args: - chat_history: 聊天历史记录 - guidelines_text: 指导原则文本 + messages: 消息列表 + preamble_choices_text: 指导原则文本 model_name: 模型名称 api_key: API密钥 model_server: 模型服务器地址 @@ -579,6 +581,14 @@ async def call_preamble_llm(chat_history: str, last_message: str, preamble_choic except Exception as e: logger.error(f"Error reading guideline prompt template: {e}") return "" + + api_key = config.api_key + model_name = config.model_name + model_server = config.model_server + language = config.language + preamble_choices_text = config.preamble_text + last_message = get_user_last_message_content(messages) + chat_history = format_messages_to_chat_history(messages) # 替换模板中的占位符 system_prompt = preamble_template.replace('{preamble_choices_text}', preamble_choices_text).replace('{chat_history}', chat_history).replace('{last_message}', last_message).replace('{language}', get_language_text(language))