diff --git a/agent/agent_config.py b/agent/agent_config.py index 8351be8..d8693af 100644 --- a/agent/agent_config.py +++ b/agent/agent_config.py @@ -46,6 +46,9 @@ class AgentConfig: memori_semantic_search_top_k: int = 20 _mem0_context: Optional[str] = None # Mem0 召回的记忆上下文,供中间件间传递使用 + # 自定义 shell 环境变量 + shell_env: Optional[Dict[str, str]] = field(default_factory=dict) + # Checkpointer 会话历史 _session_history: Optional[List] = field(default_factory=list) # 从 checkpointer 读取的历史聊天记录 @@ -72,6 +75,7 @@ class AgentConfig: 'enable_memori': self.enable_memori, 'memori_semantic_search_top_k': self.memori_semantic_search_top_k, 'trace_id': self.trace_id, + 'shell_env': self.shell_env, } def safe_print(self): @@ -130,6 +134,7 @@ class AgentConfig: enable_memori=request.enable_memory, memori_semantic_search_top_k=getattr(request, 'memori_semantic_search_top_k', None) or MEM0_SEMANTIC_SEARCH_TOP_K, trace_id=trace_id, + shell_env=getattr(request, 'shell_env', None) or {}, ) # 在创建 config 时尽早准备 checkpoint 消息 @@ -198,6 +203,7 @@ class AgentConfig: enable_memori=enable_memori, memori_semantic_search_top_k=bot_config.get("memori_semantic_search_top_k", MEM0_SEMANTIC_SEARCH_TOP_K), trace_id=trace_id, + shell_env=getattr(request, 'shell_env', None) or bot_config.get("shell_env") or {}, ) # 在创建 config 时尽早准备 checkpoint 消息 diff --git a/agent/deep_assistant.py b/agent/deep_assistant.py index d1751ad..7dfb662 100644 --- a/agent/deep_assistant.py +++ b/agent/deep_assistant.py @@ -289,7 +289,8 @@ async def init_agent(config: AgentConfig): shell_env={ "ASSISTANT_ID": config.bot_id, "USER_IDENTIFIER": config.user_identifier, - "TRACE_ID": config.trace_id + "TRACE_ID": config.trace_id, + **(config.shell_env or {}), } ) diff --git a/routes/chat.py b/routes/chat.py index cc52310..f59e33a 100644 --- a/routes/chat.py +++ b/routes/chat.py @@ -482,7 +482,7 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] = project_dir = create_project_directory(request.dataset_ids, bot_id, request.skills) # 收集额外参数作为 generate_cfg - exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id', 'enable_thinking', 'skills', 'enable_memory', 'n'} + exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id', 'enable_thinking', 'skills', 'enable_memory', 'n', 'shell_env'} generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields} # 处理消息 messages = process_messages(request.messages, request.language) @@ -532,7 +532,7 @@ async def chat_warmup_v1(request: ChatRequest, authorization: Optional[str] = He project_dir = create_project_directory(request.dataset_ids, bot_id, request.skills) # 收集额外参数作为 generate_cfg - exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id', 'enable_thinking', 'skills', 'enable_memory', 'n'} + exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id', 'enable_thinking', 'skills', 'enable_memory', 'n', 'shell_env'} generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields} # 创建一个空的消息列表用于预热(实际消息不会在warmup中处理) @@ -636,7 +636,7 @@ async def chat_warmup_v2(request: ChatRequestV2, authorization: Optional[str] = messages = process_messages(empty_messages, request.language or "ja") # 收集额外参数作为 generate_cfg - exclude_fields = {'messages', 'stream', 'tool_response', 'bot_id', 'language', 'user_identifier', 'session_id', 'n', 'model', 'model_server', 'api_key'} + exclude_fields = {'messages', 'stream', 'tool_response', 'bot_id', 'language', 'user_identifier', 'session_id', 'n', 'model', 'model_server', 'api_key', 'shell_env'} generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields} # 从请求中提取 model/model_server/api_key,优先级高于 bot_config(排除 "whatever" 和空值) req_data = request.model_dump() @@ -743,7 +743,7 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st # 处理消息 messages = process_messages(request.messages, request.language) # 收集额外参数作为 generate_cfg - exclude_fields = {'messages', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings', 'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id', 'enable_thinking', 'skills', 'enable_memory', 'n', 'model', 'model_server', 'api_key'} + exclude_fields = {'messages', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings', 'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id', 'enable_thinking', 'skills', 'enable_memory', 'n', 'model', 'model_server', 'api_key', 'shell_env'} generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields} # 从请求中提取 model/model_server/api_key,优先级高于 bot_config(排除 "whatever" 和空值) req_data = request.model_dump() diff --git a/utils/api_models.py b/utils/api_models.py index 7a85947..386c6f9 100644 --- a/utils/api_models.py +++ b/utils/api_models.py @@ -55,6 +55,7 @@ class ChatRequest(BaseModel): enable_thinking: Optional[bool] = DEFAULT_THINKING_ENABLE skills: Optional[List[str]] = None enable_memory: Optional[bool] = False + shell_env: Optional[Dict[str, str]] = None model_config = ConfigDict(extra='allow') @@ -67,6 +68,7 @@ class ChatRequestV2(BaseModel): language: Optional[str] = "zh" user_identifier: Optional[str] = "" session_id: Optional[str] = None + shell_env: Optional[Dict[str, str]] = None model_config = ConfigDict(extra='allow')