add deep_agent

This commit is contained in:
朱潮 2025-12-12 18:41:52 +08:00
parent eb17dff54a
commit 720db80ae9
16 changed files with 1164 additions and 1969 deletions

15
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,15 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "FastAPI: Debug",
"type": "python",
"request": "launch",
"program": "fastapi_app.py",
"console": "integratedTerminal",
"cwd": "${workspaceFolder}",
"envFile": "${workspaceFolder}/.env",
"python": "${workspaceFolder}/.venv/bin/python"
}
]
}

9
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,9 @@
{
"python.languageServer": "Pylance",
"python.analysis.indexing": true,
"python.analysis.autoSearchPaths": true,
"python.analysis.diagnosticMode": "workspace",
"python.analysis.extraPaths": [
"${workspaceFolder}/.venv"
]
}

View File

@ -1,3 +1,28 @@
# python环境
本项目的python环境是基于 poetry创建的如果需要运行 py文件需要执行poetry run python xxx.py 来执行。
启动脚本:
```
poetry run uvicorn fastapi_app:app --host 0.0.0.0 --port 8001
```
测试脚本:
```
curl --request POST \
--url http://localhost:8001/api/v2/chat/completions \
--header 'authorization: Bearer a21c99620a8ef61d69563afe05ccce89' \
--header 'content-type: application/json' \
--header 'x-trace-id: 123123123' \
--data '{
"messages": [
{
"role": "user",
"content": "咖啡多少钱一杯"
}
],
"stream": true,
"model": "whatever",
"language": "ja",
"bot_id": "63069654-7750-409d-9a58-a0960d899a20",
"tool_response": true,
"user_identifier": "及川"
}'
```

View File

@ -1,502 +0,0 @@
# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import atexit
import datetime
import json
import threading
import time
import uuid
from contextlib import AsyncExitStack
from typing import Dict, Optional, Union
from dotenv import load_dotenv
import logging
logger = logging.getLogger('app')
from qwen_agent.tools.base import BaseTool
class CustomMCPManager:
_instance = None # Private class variable to store the unique instance
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(CustomMCPManager, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self):
if not hasattr(self, 'clients'): # The singleton should only be inited once
"""Set a new event loop in a separate thread"""
try:
import mcp # noqa
except ImportError as e:
raise ImportError('Could not import mcp. Please install mcp with `pip install -U mcp`.') from e
load_dotenv() # Load environment variables from .env file
self.clients: dict = {}
self.loop = asyncio.new_event_loop()
self.loop_thread = threading.Thread(target=self.start_loop, daemon=True)
self.loop_thread.start()
# A fallback way to terminate MCP tool processes after Qwen-Agent exits
self.processes = []
self.monkey_patch_mcp_create_platform_compatible_process()
def monkey_patch_mcp_create_platform_compatible_process(self):
try:
import mcp.client.stdio
target = mcp.client.stdio._create_platform_compatible_process
except (ModuleNotFoundError, AttributeError) as e:
raise ImportError('Qwen-Agent needs to monkey patch MCP for process cleanup. '
'Please upgrade MCP to a higher version with `pip install -U mcp`.') from e
async def _monkey_patched_create_platform_compatible_process(*args, **kwargs):
process = await target(*args, **kwargs)
self.processes.append(process)
return process
mcp.client.stdio._create_platform_compatible_process = _monkey_patched_create_platform_compatible_process
def start_loop(self):
asyncio.set_event_loop(self.loop)
# Set a global exception handler to silently handle cross-task exceptions from MCP SSE connections
def exception_handler(loop, context):
exception = context.get('exception')
if exception:
# Silently handle cross-task exceptions from MCP SSE connections
if (isinstance(exception, RuntimeError) and
'Attempted to exit cancel scope in a different task' in str(exception)):
return # Silently ignore this type of exception
if (isinstance(exception, BaseExceptionGroup) and # noqa
'Attempted to exit cancel scope in a different task' in str(exception)): # noqa
return # Silently ignore this type of exception
# Other exceptions are handled normally
loop.default_exception_handler(context)
self.loop.set_exception_handler(exception_handler)
self.loop.run_forever()
def is_valid_mcp_servers(self, config: dict):
"""Example of mcp servers configuration:
{
"mcpServers": {
"memory": {
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-memory"]
},
"filesystem": {
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/allowed/files"]
}
}
}
"""
# Check if the top-level key "mcpServers" exists and its value is a dictionary
if not isinstance(config, dict) or 'mcpServers' not in config or not isinstance(config['mcpServers'], dict):
return False
mcp_servers = config['mcpServers']
# Check each sub-item under "mcpServers"
for key in mcp_servers:
server = mcp_servers[key]
# Each sub-item must be a dictionary
if not isinstance(server, dict):
return False
if 'command' in server:
# "command" must be a string
if not isinstance(server['command'], str):
return False
# "args" must be a list
if 'args' not in server or not isinstance(server['args'], list):
return False
if 'url' in server:
# "url" must be a string
if not isinstance(server['url'], str):
return False
# "headers" must be a dictionary
if 'headers' in server and not isinstance(server['headers'], dict):
return False
# If the "env" key exists, it must be a dictionary
if 'env' in server and not isinstance(server['env'], dict):
return False
return True
def initConfig(self, config: Dict):
if not self.is_valid_mcp_servers(config):
raise ValueError('Config of mcpservers is not valid')
logger.info(f'Initializing MCP tools from mcp servers: {list(config["mcpServers"].keys())}')
# Submit coroutine to the event loop and wait for the result
future = asyncio.run_coroutine_threadsafe(self.init_config_async(config), self.loop)
try:
result = future.result() # You can specify a timeout if desired
return result
except Exception as e:
logger.info(f'Failed in initializing MCP tools: {e}')
raise e
async def init_config_async(self, config: Dict):
tools: list = []
mcp_servers = config['mcpServers']
# 并发连接所有MCP服务器
connection_tasks = []
for server_name in mcp_servers:
client = CustomMCPClient()
server = mcp_servers[server_name]
# 创建连接任务
task = self._connect_and_store_client(client, server_name, server)
connection_tasks.append(task)
# 并发执行所有连接任务
connected_clients = await asyncio.gather(*connection_tasks, return_exceptions=True)
# 处理连接结果并为每个客户端创建工具
for result in connected_clients:
if isinstance(result, Exception):
logger.error(f'Failed to connect MCP server: {result}')
continue
client, server_name = result
client_id = client.client_id
for tool in client.tools:
"""MCP tool example:
{
"name": "read_query",
"description": "Execute a SELECT query on the SQLite database",
"inputSchema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SELECT SQL query to execute"
}
},
"required": ["query"]
}
"""
parameters = tool.inputSchema
# The required field in inputSchema may be empty and needs to be initialized.
if 'required' not in parameters:
parameters['required'] = []
# Remove keys from parameters that do not conform to the standard OpenAI schema
# Check if the required fields exist
required_fields = {'type', 'properties', 'required'}
missing_fields = required_fields - parameters.keys()
if missing_fields:
raise ValueError(f'Missing required fields in schema: {missing_fields}')
# Keep only the necessary fields
cleaned_parameters = {
'type': parameters['type'],
'properties': parameters['properties'],
'required': parameters['required']
}
register_name = server_name + '-' + tool.name
agent_tool = self.create_tool_class(register_name=register_name,
register_client_id=client_id,
tool_name=tool.name,
tool_desc=tool.description,
tool_parameters=cleaned_parameters)
tools.append(agent_tool)
if client.resources:
"""MCP resource example:
{
uri: string; // Unique identifier for the resource
name: string; // Human-readable name
description?: string; // Optional description
mimeType?: string; // Optional MIME type
}
"""
# List resources
list_resources_tool_name = server_name + '-' + 'list_resources'
list_resources_params = {'type': 'object', 'properties': {}, 'required': []}
list_resources_agent_tool = self.create_tool_class(
register_name=list_resources_tool_name,
register_client_id=client_id,
tool_name='list_resources',
tool_desc='Servers expose a list of concrete resources through this tool. '
'By invoking it, you can discover the available resources and obtain resource templates, which help clients understand how to construct valid URIs. '
'These URI formats will be used as input parameters for the read_resource function. ',
tool_parameters=list_resources_params)
tools.append(list_resources_agent_tool)
# Read resource
resources_template_str = '' # Check if there are resource templates
try:
list_resource_templates = await client.session.list_resource_templates(
) # Check if the server has resources tesmplate
if list_resource_templates.resourceTemplates:
resources_template_str = '\n'.join(
str(template) for template in list_resource_templates.resourceTemplates)
except Exception as e:
logger.info(f'Failed in listing MCP resource templates: {e}')
read_resource_tool_name = server_name + '-' + 'read_resource'
read_resource_params = {
'type': 'object',
'properties': {
'uri': {
'type': 'string',
'description': 'The URI identifying the specific resource to access'
}
},
'required': ['uri']
}
original_tool_desc = 'Request to access a resource provided by a connected MCP server. Resources represent data sources that can be used as context, such as files, API responses, or system information.'
if resources_template_str:
tool_desc = original_tool_desc + '\nResource Templates:\n' + resources_template_str
else:
tool_desc = original_tool_desc
read_resource_agent_tool = self.create_tool_class(register_name=read_resource_tool_name,
register_client_id=client_id,
tool_name='read_resource',
tool_desc=tool_desc,
tool_parameters=read_resource_params)
tools.append(read_resource_agent_tool)
return tools
async def _connect_and_store_client(self, client, server_name, server):
"""辅助方法连接MCP服务器并存储客户端"""
try:
await client.connection_server(mcp_server_name=server_name,
mcp_server=server) # Attempt to connect to the server
client_id = server_name + '_' + str(
uuid.uuid4()) # To allow the same server name be used across different running agents
client.client_id = client_id # Ensure client_id is set on the client instance
self.clients[client_id] = client # Add to clients dict after successful connection
return client, server_name
except Exception as e:
logger.error(f'Failed to connect MCP server {server_name}: {e}')
raise e
def create_tool_class(self, register_name, register_client_id, tool_name, tool_desc, tool_parameters):
class ToolClass(BaseTool):
name = register_name
description = tool_desc
parameters = tool_parameters
client_id = register_client_id
def call(self, params: Union[str, dict], **kwargs) -> str:
tool_args = json.loads(params)
# Submit coroutine to the event loop and wait for the result
manager = CustomMCPManager()
client = manager.clients[self.client_id]
future = asyncio.run_coroutine_threadsafe(client.execute_function(tool_name, tool_args), manager.loop)
try:
result = future.result()
return result
except Exception as e:
logger.info(f'Failed in executing MCP tool: {e}')
raise e
ToolClass.__name__ = f'{register_name}_Class'
return ToolClass()
def shutdown(self):
futures = []
for client_id in list(self.clients.keys()):
client: CustomMCPClient = self.clients[client_id]
future = asyncio.run_coroutine_threadsafe(client.cleanup(), self.loop)
futures.append(future)
del self.clients[client_id]
time.sleep(1) # Wait for the graceful cleanups, otherwise fall back
# fallback
if asyncio.all_tasks(self.loop):
logger.info(
'There are still tasks in `CustomMCPManager().loop`, force terminating the MCP tool processes. There may be some exceptions.'
)
for process in self.processes:
try:
process.terminate()
except ProcessLookupError:
pass # it's ok, the process may exit earlier
self.loop.call_soon_threadsafe(self.loop.stop)
self.loop_thread.join()
class CustomMCPClient:
def __init__(self):
from mcp import ClientSession
self.session: Optional[ClientSession] = None
self.tools: list = None
self.exit_stack = AsyncExitStack()
self.resources: bool = False
self._last_mcp_server_name = None
self._last_mcp_server = None
self.client_id = None # For replacing in MCPManager.clients
async def connection_server(self, mcp_server_name, mcp_server):
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
"""Connect to an MCP server and retrieve the available tools."""
# Save parameters
self._last_mcp_server_name = mcp_server_name
self._last_mcp_server = mcp_server
try:
if 'url' in mcp_server:
url = mcp_server.get('url')
sse_read_timeout = mcp_server.get('sse_read_timeout', 300)
logger.info(f'{mcp_server_name} sse_read_timeout: {sse_read_timeout}s')
if mcp_server.get('type', 'sse') == 'streamable-http':
# streamable-http mode
"""streamable-http mode mcp example:
{"mcpServers": {
"streamable-mcp-server": {
"type": "streamable-http",
"url":"http://0.0.0.0:8000/mcp"
}
}
}
"""
headers = mcp_server.get('headers', {})
self._streams_context = streamablehttp_client(
url=url, headers=headers, sse_read_timeout=datetime.timedelta(seconds=sse_read_timeout))
read_stream, write_stream, get_session_id = await self.exit_stack.enter_async_context(
self._streams_context)
self._session_context = ClientSession(read_stream, write_stream)
self.session = await self.exit_stack.enter_async_context(self._session_context)
else:
# sse mode
headers = mcp_server.get('headers', {'Accept': 'text/event-stream'})
self._streams_context = sse_client(url, headers, sse_read_timeout=sse_read_timeout)
streams = await self.exit_stack.enter_async_context(self._streams_context)
self._session_context = ClientSession(*streams)
self.session = await self.exit_stack.enter_async_context(self._session_context)
else:
server_params = StdioServerParameters(command=mcp_server['command'],
args=mcp_server['args'],
env=mcp_server.get('env', None))
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
logger.info(
f'Initializing a MCP stdio_client, if this takes forever, please check the config of this mcp server: {mcp_server_name}'
)
await self.session.initialize()
list_tools = await self.session.list_tools()
self.tools = list_tools.tools
try:
list_resources = await self.session.list_resources() # Check if the server has resources
if list_resources.resources:
self.resources = True
except Exception:
# logger.info(f"No list resources: {e}")
pass
except Exception as e:
logger.warning(f'Failed in connecting to MCP server: {e}')
raise e
async def reconnect(self):
# Create a new MCPClient and connect
if self.client_id is None:
raise RuntimeError(
'Cannot reconnect: client_id is None. This usually means the client was not properly registered in MCPManager.'
)
new_client = CustomMCPClient()
new_client.client_id = self.client_id
await new_client.connection_server(self._last_mcp_server_name, self._last_mcp_server)
return new_client
async def execute_function(self, tool_name, tool_args: dict):
from mcp.types import TextResourceContents
# Check if session is alive
try:
await self.session.send_ping()
except Exception as e:
logger.info(f"Session is not alive, please increase 'sse_read_timeout' in the config, try reconnect: {e}")
# Auto reconnect
try:
manager = CustomMCPManager()
if self.client_id is not None:
manager.clients[self.client_id] = await self.reconnect()
return await manager.clients[self.client_id].execute_function(tool_name, tool_args)
else:
logger.info('Reconnect failed: client_id is None')
return 'Session reconnect (client creation) exception: client_id is None'
except Exception as e3:
logger.info(f'Reconnect (client creation) exception type: {type(e3)}, value: {repr(e3)}')
return f'Session reconnect (client creation) exception: {e3}'
if tool_name == 'list_resources':
try:
list_resources = await self.session.list_resources()
if list_resources.resources:
resources_str = '\n\n'.join(str(resource) for resource in list_resources.resources)
else:
resources_str = 'No resources found'
return resources_str
except Exception as e:
logger.info(f'No list resources: {e}')
return f'Error: {e}'
elif tool_name == 'read_resource':
try:
uri = tool_args.get('uri')
if not uri:
raise ValueError('URI is required for read_resource')
read_resource = await self.session.read_resource(uri)
texts = []
for resource in read_resource.contents:
if isinstance(resource, TextResourceContents):
texts.append(resource.text)
# if isinstance(resource, BlobResourceContents):
# texts.append(resource.blob)
if texts:
return '\n\n'.join(texts)
else:
return 'Failed to read resource'
except Exception as e:
logger.info(f'Failed to read resource: {e}')
return f'Error: {e}'
else:
response = await self.session.call_tool(tool_name, tool_args)
texts = []
for content in response.content:
if content.type == 'text':
texts.append(content.text)
if texts:
return '\n\n'.join(texts)
else:
return 'execute error'
async def cleanup(self):
await self.exit_stack.aclose()
def _cleanup_mcp(_sig_num=None, _frame=None):
if CustomMCPManager._instance is None:
return
manager = CustomMCPManager()
manager.shutdown()
# Make sure all subprocesses are terminated even if killed abnormally
if threading.current_thread() is threading.main_thread():
atexit.register(_cleanup_mcp)

63
agent/deep_assistant.py Normal file
View File

@ -0,0 +1,63 @@
import json
from langchain.chat_models import init_chat_model
from deepagents import create_deep_agent
from langchain.agents import create_agent
from langchain_mcp_adapters.client import MultiServerMCPClient
from utils.fastapi_utils import detect_provider
# Utility functions
def read_system_prompt():
"""读取通用的无状态系统prompt"""
with open("./prompt/system_prompt_default.md", "r", encoding="utf-8") as f:
return f.read().strip()
def read_mcp_settings():
"""读取MCP工具配置"""
with open("./mcp/mcp_settings.json", "r") as f:
mcp_settings_json = json.load(f)
return mcp_settings_json
async def init_agent(model_name="qwen3-next", api_key=None,
model_server=None, generate_cfg=None,
system_prompt=None, mcp=None):
system = system_prompt if system_prompt else read_system_prompt()
mcp = mcp if mcp else read_mcp_settings()
# 修改mcp[0]["mcpServers"]列表,把 type 字段改成 transport 如果没有的话就默认transport:stdio
if mcp and len(mcp) > 0 and "mcpServers" in mcp[0]:
for server_name, server_config in mcp[0]["mcpServers"].items():
if isinstance(server_config, dict):
if "type" in server_config and "transport" not in server_config:
# 如果有 type 字段但没有 transport 字段,将 type 改为 transport
type_value = server_config.pop("type")
# 特殊处理:'streamable-http' 改为 'http'
if type_value == "streamable-http":
server_config["transport"] = "http"
else:
server_config["transport"] = type_value
elif "transport" not in server_config:
# 如果既没有 type 也没有 transport添加默认的 transport: stdio
server_config["transport"] = "stdio"
mcp_client = MultiServerMCPClient(mcp[0]["mcpServers"])
mcp_tools = await mcp_client.get_tools()
# 检测或使用指定的提供商
model_provider,base_url = detect_provider(model_name,model_server)
# 构建模型参数
model_kwargs = {
"model": model_name,
"model_provider": model_provider,
"temperature": 0.8,
"base_url":base_url,
"api_key":api_key
}
llm_instance = init_chat_model(**model_kwargs)
agent = create_agent(
model=llm_instance,
system_prompt=system,
tools=mcp_tools
)
return agent

View File

@ -1,285 +0,0 @@
# Copyright 2023
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""文件预加载助手管理器 - 管理基于unique_id的助手实例缓存"""
import hashlib
import time
import json
import asyncio
from typing import Dict, List, Optional
from qwen_agent.agents import Assistant
import logging
logger = logging.getLogger('app')
from agent.modified_assistant import init_modified_agent_service_with_files, update_agent_llm
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
class FileLoadedAgentManager:
"""文件预加载助手管理器
基于 unique_id 缓存助手实例避免重复创建和文件解析
"""
def __init__(self, max_cached_agents: int = 20):
self.agents: Dict[str, Assistant] = {} # {cache_key: assistant_instance}
self.unique_ids: Dict[str, str] = {} # {cache_key: unique_id}
self.access_times: Dict[str, float] = {} # LRU 访问时间管理
self.creation_times: Dict[str, float] = {} # 创建时间记录
self.max_cached_agents = max_cached_agents
self._creation_locks: Dict[str, asyncio.Lock] = {} # 防止并发创建相同agent的锁
def _get_cache_key(self, bot_id: str, model_name: str = None, api_key: str = None,
model_server: str = None, generate_cfg: Dict = None,
system_prompt: str = None, mcp_settings: List[Dict] = None) -> str:
"""获取包含所有相关参数的哈希值作为缓存键
Args:
bot_id: 机器人项目ID
model_name: 模型名称
api_key: API密钥
model_server: 模型服务器地址
generate_cfg: 生成配置
system_prompt: 系统提示词
mcp_settings: MCP设置列表
Returns:
str: 缓存键的哈希值
"""
# 构建包含所有相关参数的字符串
cache_data = {
'bot_id': bot_id,
'model_name': model_name or '',
'api_key': api_key or '',
'model_server': model_server or '',
'generate_cfg': json.dumps(generate_cfg or {}, sort_keys=True),
'system_prompt': system_prompt or '',
'mcp_settings': json.dumps(mcp_settings or [], sort_keys=True)
}
# 将字典转换为JSON字符串并计算哈希值
cache_str = json.dumps(cache_data, sort_keys=True)
return hashlib.md5(cache_str.encode('utf-8')).hexdigest()[:16]
def _update_access_time(self, cache_key: str):
"""更新访问时间LRU 管理)"""
self.access_times[cache_key] = time.time()
def _cleanup_old_agents(self):
"""清理旧的助手实例,基于 LRU 策略"""
if len(self.agents) <= self.max_cached_agents:
return
# 按 LRU 顺序排序,删除最久未访问的实例
sorted_keys = sorted(self.access_times.keys(), key=lambda k: self.access_times[k])
keys_to_remove = sorted_keys[:-self.max_cached_agents]
removed_count = 0
for cache_key in keys_to_remove:
try:
del self.agents[cache_key]
del self.unique_ids[cache_key]
del self.access_times[cache_key]
del self.creation_times[cache_key]
removed_count += 1
logger.info(f"清理过期的助手实例缓存: {cache_key}")
except KeyError:
continue
if removed_count > 0:
logger.info(f"已清理 {removed_count} 个过期的助手实例缓存")
async def get_or_create_agent(self,
bot_id: str,
project_dir: Optional[str],
model_name: str = "qwen3-next",
api_key: Optional[str] = None,
model_server: Optional[str] = None,
generate_cfg: Optional[Dict] = None,
language: Optional[str] = None,
system_prompt: Optional[str] = None,
mcp_settings: Optional[List[Dict]] = None,
robot_type: Optional[str] = "general_agent",
user_identifier: Optional[str] = None) -> Assistant:
"""获取或创建文件预加载的助手实例
Args:
bot_id: 项目的唯一标识符
project_dir: 项目目录路径用于读取README.md可以为None
model_name: 模型名称
api_key: API 密钥
model_server: 模型服务器地址
generate_cfg: 生成配置
language: 语言代码用于选择对应的系统提示词
system_prompt: 可选的系统提示词优先级高于项目配置
mcp_settings: 可选的MCP设置优先级高于项目配置
robot_type: 机器人类型取值 agent/catalog_agent
user_identifier: 用户标识符
Returns:
Assistant: 配置好的助手实例
"""
import os
# 使用异步加载配置文件(带缓存)
final_system_prompt = await load_system_prompt_async(
project_dir, language, system_prompt, robot_type, bot_id, user_identifier
)
final_mcp_settings = await load_mcp_settings_async(
project_dir, mcp_settings, bot_id, robot_type
)
cache_key = self._get_cache_key(bot_id, model_name, api_key, model_server,
generate_cfg, final_system_prompt, final_mcp_settings)
# 使用异步锁防止并发创建相同的agent
creation_lock = self._creation_locks.setdefault(cache_key, asyncio.Lock())
async with creation_lock:
# 再次检查是否已存在该助手实例(获取锁后可能有其他请求已创建)
if cache_key in self.agents:
self._update_access_time(cache_key)
agent = self.agents[cache_key]
# 动态更新 LLM 配置和系统设置(如果参数有变化)
update_agent_llm(agent, model_name, api_key, model_server, generate_cfg)
logger.info(f"复用现有的助手实例缓存: {cache_key} (bot_id: {bot_id})")
return agent
# 清理过期实例
self._cleanup_old_agents()
# 创建新的助手实例,预加载文件
logger.info(f"创建新的助手实例缓存: {cache_key}, bot_id: {bot_id}")
current_time = time.time()
agent = init_modified_agent_service_with_files(
model_name=model_name,
api_key=api_key,
model_server=model_server,
generate_cfg=generate_cfg,
system_prompt=final_system_prompt,
mcp=final_mcp_settings
)
# 缓存实例
self.agents[cache_key] = agent
self.unique_ids[cache_key] = bot_id
self.access_times[cache_key] = current_time
self.creation_times[cache_key] = current_time
# 清理创建锁
self._creation_locks.pop(cache_key, None)
logger.info(f"助手实例缓存创建完成: {cache_key}")
return agent
def get_cache_stats(self) -> Dict:
"""获取缓存统计信息"""
current_time = time.time()
stats = {
"total_cached_agents": len(self.agents),
"max_cached_agents": self.max_cached_agents,
"agents": {}
}
for cache_key, agent in self.agents.items():
stats["agents"][cache_key] = {
"unique_id": self.unique_ids.get(cache_key, "unknown"),
"created_at": self.creation_times.get(cache_key, 0),
"last_accessed": self.access_times.get(cache_key, 0),
"age_seconds": int(current_time - self.creation_times.get(cache_key, current_time)),
"idle_seconds": int(current_time - self.access_times.get(cache_key, current_time))
}
return stats
def clear_cache(self) -> int:
"""清空所有缓存
Returns:
int: 清理的实例数量
"""
cache_count = len(self.agents)
self.agents.clear()
self.unique_ids.clear()
self.access_times.clear()
self.creation_times.clear()
logger.info(f"已清空所有助手实例缓存,共清理 {cache_count} 个实例")
return cache_count
def remove_cache_by_unique_id(self, unique_id: str) -> int:
"""根据 unique_id 移除所有相关的缓存
由于缓存key现在包含 system_prompt mcp_settings
一个 unique_id 可能对应多个缓存实例
Args:
unique_id: 项目的唯一标识符
Returns:
int: 成功移除的实例数量
"""
keys_to_remove = []
# 找到所有匹配的 unique_id 的缓存键
for cache_key, stored_unique_id in self.unique_ids.items():
if stored_unique_id == unique_id:
keys_to_remove.append(cache_key)
# 移除找到的缓存
removed_count = 0
for cache_key in keys_to_remove:
try:
del self.agents[cache_key]
del self.unique_ids[cache_key]
del self.access_times[cache_key]
del self.creation_times[cache_key]
self._creation_locks.pop(cache_key, None) # 清理创建锁
removed_count += 1
logger.info(f"已移除助手实例缓存: {cache_key} (unique_id: {unique_id})")
except KeyError:
continue
if removed_count > 0:
logger.info(f"已移除 unique_id={unique_id}{removed_count} 个助手实例缓存")
else:
logger.warning(f"未找到 unique_id={unique_id} 的缓存实例")
return removed_count
# 全局文件预加载助手管理器实例
_global_agent_manager: Optional[FileLoadedAgentManager] = None
def get_global_agent_manager() -> FileLoadedAgentManager:
"""获取全局文件预加载助手管理器实例"""
global _global_agent_manager
if _global_agent_manager is None:
_global_agent_manager = FileLoadedAgentManager()
return _global_agent_manager
def init_global_agent_manager(max_cached_agents: int = 20):
"""初始化全局文件预加载助手管理器"""
global _global_agent_manager
_global_agent_manager = FileLoadedAgentManager(max_cached_agents)
return _global_agent_manager

View File

@ -1,287 +0,0 @@
# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import logging
import os
import time
from typing import Dict, Iterator, List, Literal, Optional, Union
from qwen_agent.agents import Assistant
from qwen_agent.llm.schema import ASSISTANT, FUNCTION, Message
from qwen_agent.llm.oai import TextChatAtOAI
from qwen_agent.tools import BaseTool
from agent.custom_mcp_manager import CustomMCPManager
import logging
logger = logging.getLogger('app')
# 设置工具日志记录器
tool_logger = logging.getLogger('app')
class ModifiedAssistant(Assistant):
"""
修改后的 Assistant 子类改变循环判断逻辑
- 原始逻辑如果没有使用工具立即退出循环
- 修改后逻辑如果没有使用工具调用模型判断回答是否完整如果不完整则继续循环
"""
def _is_retryable_error(self, error: Exception) -> bool:
"""判断错误是否可重试
Args:
error: 异常对象
Returns:
bool: 是否可重试
"""
error_str = str(error).lower()
retryable_indicators = [
'502', '500', '503', '504', # HTTP错误代码
'internal server error', # 内部服务器错误
'timeout', # 超时
'connection', # 连接错误
'network', # 网络错误
'rate', # 速率限制和相关错误
'quota', # 配额限制
'service unavailable', # 服务不可用
'provider returned error', # Provider错误
'model service error', # 模型服务错误
'temporary', # 临时错误
'retry' # 明确提示重试
]
return any(indicator in error_str for indicator in retryable_indicators)
def _call_tool(self, tool_name: str, tool_args: Union[str, dict] = '{}', **kwargs) -> str:
"""重写工具调用方法,添加调试信息"""
if tool_name not in self.function_map:
error_msg = f'Tool {tool_name} does not exist. Available tools: {list(self.function_map.keys())}'
tool_logger.error(error_msg)
return error_msg
tool = self.function_map[tool_name]
try:
tool_logger.info(f"开始调用工具: {tool_name} {tool_args}")
start_time = time.time()
# 调用父类的_call_tool方法
tool_result = super()._call_tool(tool_name, tool_args, **kwargs)
end_time = time.time()
tool_logger.info(f"工具 {tool_name} 执行完成,耗时: {end_time - start_time:.2f}秒 结果长度: {len(tool_result) if tool_result else 0}")
# 打印部分结果内容(避免过长)
if tool_result and len(tool_result) > 0:
preview = tool_result[:200] if len(tool_result) > 200 else tool_result
tool_logger.debug(f"工具 {tool_name} 结果预览: {preview}...")
return tool_result
except Exception as ex:
end_time = time.time()
tool_logger.error(f"工具调用异常,耗时: {end_time - start_time:.2f}秒 异常类型: {type(ex).__name__} {str(ex)}")
# 打印完整的堆栈跟踪
import traceback
tool_logger.error(f"堆栈跟踪:\n{traceback.format_exc()}")
# 返回详细的错误信息
error_message = f'An error occurred when calling tool {tool_name}: {type(ex).__name__}: {str(ex)}'
return error_message
def _init_tool(self, tool: Union[str, Dict, BaseTool]):
"""重写工具初始化方法使用CustomMCPManager处理MCP服务器配置"""
if isinstance(tool, BaseTool):
# 处理BaseTool实例
tool_name = tool.name
if tool_name in self.function_map:
tool_logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list')
self.function_map[tool_name] = tool
elif isinstance(tool, dict) and 'mcpServers' in tool:
# 使用CustomMCPManager处理MCP服务器配置支持headers
tools = CustomMCPManager().initConfig(tool)
for tool in tools:
tool_name = tool.name
if tool_name in self.function_map:
tool_logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list')
self.function_map[tool_name] = tool
else:
# 调用父类的处理方法
super()._init_tool(tool)
def _call_llm_with_retry(self, messages: List[Message], functions=None, extra_generate_cfg=None, max_retries: int = 5) -> Iterator:
"""带重试机制的LLM调用
Args:
messages: 消息列表
functions: 函数列表
extra_generate_cfg: 额外生成配置
max_retries: 最大重试次数
Returns:
LLM响应流
Raises:
Exception: 重试次数耗尽后重新抛出原始异常
"""
for attempt in range(max_retries):
try:
return self._call_llm(messages=messages, functions=functions, extra_generate_cfg=extra_generate_cfg)
except Exception as e:
# 检查是否为可重试的错误
if self._is_retryable_error(e) and attempt < max_retries - 1:
delay = 2 ** attempt # 指数退避: 1s, 2s, 4s
tool_logger.warning(f"LLM调用失败 (尝试 {attempt + 1}/{max_retries}){delay}秒后重试: {str(e)}")
time.sleep(delay)
continue
else:
# 不可重试的错误或已达到最大重试次数
if attempt > 0:
tool_logger.error(f"LLM调用重试失败已达到最大重试次数 {max_retries}")
raise
def _run(self, messages: List[Message], lang: Literal['en', 'zh', 'ja'] = 'en', **kwargs) -> Iterator[List[Message]]:
message_list = copy.deepcopy(messages)
response = []
# 保持原有的最大调用次数限制
total_num_llm_calls_available = self.MAX_LLM_CALL_PER_RUN if hasattr(self, 'MAX_LLM_CALL_PER_RUN') else 100
num_llm_calls_available = total_num_llm_calls_available
while num_llm_calls_available > 0:
num_llm_calls_available -= 1
extra_generate_cfg = {'lang': lang}
if kwargs.get('seed') is not None:
extra_generate_cfg['seed'] = kwargs['seed']
output_stream = self._call_llm_with_retry(messages=message_list,
functions=[func.function for func in self.function_map.values()],
extra_generate_cfg=extra_generate_cfg)
output: List[Message] = []
for output in output_stream:
if output:
yield response + output
if output:
response.extend(output)
message_list.extend(output)
# 处理工具调用
used_any_tool = False
for out in output:
use_tool, tool_name, tool_args, _ = self._detect_tool(out)
if use_tool:
tool_result = self._call_tool(tool_name, tool_args, messages=message_list, **kwargs)
# 验证工具结果
if not tool_result:
tool_logger.warning(f"工具 {tool_name} 返回空结果")
tool_result = f"Tool {tool_name} completed execution but returned empty result"
elif tool_result.startswith('An error occurred when calling tool') or tool_result.startswith('工具调用失败'):
tool_logger.error(f"工具 {tool_name} 调用失败: {tool_result}")
fn_msg = Message(role=FUNCTION,
name=tool_name,
content=tool_result,
extra={'function_id': out.extra.get('function_id', '1')})
message_list.append(fn_msg)
response.append(fn_msg)
yield response
used_any_tool = True
# 如果使用了工具,继续循环
if not used_any_tool:
break
# 检查是否因为调用次数用完而退出循环
if num_llm_calls_available == 0:
# 根据语言选择错误消息
if lang == 'zh':
error_message = "工具调用超出限制"
elif lang == 'ja':
error_message = "ツール呼び出しが制限を超えました。"
else:
error_message = "Tool calls exceeded limit"
tool_logger.error(error_message)
error_msg = Message(
role=ASSISTANT,
content=error_message,
)
response.append(error_msg)
yield response
# Utility functions
def read_system_prompt():
"""读取通用的无状态系统prompt"""
with open("./prompt/system_prompt_default.md", "r", encoding="utf-8") as f:
return f.read().strip()
def read_mcp_settings():
"""读取MCP工具配置"""
with open("./mcp/mcp_settings.json", "r") as f:
mcp_settings_json = json.load(f)
return mcp_settings_json
def update_agent_llm(agent, model_name: str, api_key: str = None, model_server: str = None, generate_cfg: Dict = None):
"""动态更新助手实例的LLM和配置支持从接口传入参数"""
# 获取基础配置
llm_config = {
"model": model_name,
"api_key": api_key,
"model_server": model_server,
"generate_cfg": generate_cfg if generate_cfg else {}
}
# 创建LLM实例
llm_instance = TextChatAtOAI(llm_config)
# 动态设置LLM
agent.llm = llm_instance
return agent
# 向后兼容:保持原有的初始化函数接口
def init_modified_agent_service_with_files(rag_cfg=None,
model_name="qwen3-next", api_key=None,
model_server=None, generate_cfg=None,
system_prompt=None, mcp=None):
"""创建支持预加载文件的修改版助手实例"""
system = system_prompt if system_prompt else read_system_prompt()
tools = mcp if mcp else read_mcp_settings()
llm_config = {
"model": model_name,
"api_key": api_key,
"model_server": model_server,
"generate_cfg": generate_cfg if generate_cfg else {}
}
# 创建LLM实例
llm_instance = TextChatAtOAI(llm_config)
bot = ModifiedAssistant(
llm=llm_instance,
name="修改版数据检索助手",
description="基于智能判断循环终止的助手",
system_message=system,
function_list=tools,
)
return bot

View File

@ -18,17 +18,13 @@ import hashlib
import time
import json
import asyncio
from typing import Dict, List, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional
import threading
from collections import defaultdict
from qwen_agent.agents import Assistant
import logging
logger = logging.getLogger('app')
from agent.modified_assistant import init_modified_agent_service_with_files, update_agent_llm
from agent.deep_assistant import init_agent
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
@ -131,7 +127,7 @@ class ShardedAgentManager:
system_prompt: Optional[str] = None,
mcp_settings: Optional[List[Dict]] = None,
robot_type: Optional[str] = "general_agent",
user_identifier: Optional[str] = None) -> Assistant:
user_identifier: Optional[str] = None):
"""获取或创建文件预加载的助手实例"""
# 更新请求统计
@ -159,10 +155,6 @@ class ShardedAgentManager:
if cache_key in shard['agents']:
self._update_access_time(shard, cache_key)
agent = shard['agents'][cache_key]
# 动态更新 LLM 配置和系统设置
update_agent_llm(agent, model_name, api_key, model_server, generate_cfg)
# 更新缓存命中统计
with self._stats_lock:
self._global_stats['cache_hits'] += 1
@ -184,7 +176,6 @@ class ShardedAgentManager:
if cache_key in shard['agents']:
self._update_access_time(shard, cache_key)
agent = shard['agents'][cache_key]
update_agent_llm(agent, model_name, api_key, model_server, generate_cfg)
with self._stats_lock:
self._global_stats['cache_hits'] += 1
@ -199,7 +190,7 @@ class ShardedAgentManager:
logger.info(f"分片创建新的助手实例缓存: {cache_key}, bot_id: {bot_id}, shard: {shard_index}")
current_time = time.time()
agent = init_modified_agent_service_with_files(
agent = await init_agent(
model_name=model_name,
api_key=api_key,
model_server=model_server,

1371
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -6,12 +6,11 @@ authors = [
{name = "朱潮",email = "zhuchaowe@users.noreply.github.com"}
]
readme = "README.md"
requires-python = ">=3.12.0"
requires-python = ">=3.12,<4.0"
dependencies = [
"fastapi==0.116.1",
"uvicorn==0.35.0",
"requests==2.32.5",
"qwen-agent[mcp,rag]==0.0.31",
"pydantic==2.10.5",
"python-dateutil==2.8.2",
"torch==2.2.0",
@ -27,6 +26,9 @@ dependencies = [
"chardet>=5.0.0",
"psutil (>=7.1.3,<8.0.0)",
"uvloop (>=0.22.1,<0.23.0)",
"deepagents (>=0.3.0,<0.4.0)",
"langchain-mcp-adapters (>=0.2.1,<0.3.0)",
"langchain-openai (>=1.1.1,<2.0.0)",
]
[tool.poetry.requires-plugins]

View File

@ -1,116 +1,117 @@
aiofiles==25.1.0 ; python_version >= "3.12"
aiohappyeyeballs==2.6.1 ; python_version >= "3.12"
aiohttp==3.13.1 ; python_version >= "3.12"
aiosignal==1.4.0 ; python_version >= "3.12"
annotated-types==0.7.0 ; python_version >= "3.12"
anyio==4.11.0 ; python_version >= "3.12"
attrs==25.4.0 ; python_version >= "3.12"
beautifulsoup4==4.14.2 ; python_full_version >= "3.12.0"
certifi==2025.10.5 ; python_version >= "3.12"
cffi==2.0.0 ; python_version >= "3.12" and platform_python_implementation != "PyPy"
chardet==5.2.0 ; python_version >= "3.12"
charset-normalizer==3.4.4 ; python_version >= "3.12"
click==8.3.0 ; python_version >= "3.12"
colorama==0.4.6 ; python_version >= "3.12" and platform_system == "Windows"
cryptography==46.0.3 ; python_version >= "3.12"
dashscope==1.24.6 ; python_full_version >= "3.12.0"
distro==1.9.0 ; python_version >= "3.12"
dotenv==0.9.9 ; python_full_version >= "3.12.0"
et-xmlfile==2.0.0 ; python_version >= "3.12"
eval-type-backport==0.2.2 ; python_version >= "3.12"
fastapi==0.116.1 ; python_version >= "3.12"
filelock==3.20.0 ; python_version >= "3.12"
frozenlist==1.8.0 ; python_version >= "3.12"
fsspec==2025.9.0 ; python_version >= "3.12"
h11==0.16.0 ; python_version >= "3.12"
hf-xet==1.1.10 ; python_version >= "3.12" and (platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "arm64" or platform_machine == "aarch64")
httpcore==1.0.9 ; python_version >= "3.12"
httpx-sse==0.4.3 ; python_version >= "3.12"
httpx==0.28.1 ; python_version >= "3.12"
huey==2.5.3 ; python_full_version >= "3.12.0"
huggingface-hub==0.35.3 ; python_full_version >= "3.12.0"
idna==3.11 ; python_version >= "3.12"
jieba==0.42.1 ; python_full_version >= "3.12.0"
jinja2==3.1.6 ; python_version >= "3.12"
jiter==0.11.1 ; python_version >= "3.12"
joblib==1.5.2 ; python_version >= "3.12"
json5==0.12.1 ; python_full_version >= "3.12.0"
jsonlines==4.0.0 ; python_version >= "3.12"
jsonschema-specifications==2025.9.1 ; python_version >= "3.12"
jsonschema==4.25.1 ; python_version >= "3.12"
lxml==6.0.2 ; python_version >= "3.12"
markupsafe==3.0.3 ; python_version >= "3.12"
mcp==1.12.4 ; python_version >= "3.12"
mpmath==1.3.0 ; python_full_version >= "3.12.0"
multidict==6.7.0 ; python_version >= "3.12"
networkx==3.5 ; python_version >= "3.12"
numpy==1.26.4 ; python_version >= "3.12"
nvidia-cublas-cu12==12.1.3.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
nvidia-cuda-cupti-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
nvidia-cuda-runtime-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
nvidia-cudnn-cu12==8.9.2.26 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
nvidia-cufft-cu12==11.0.2.54 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
nvidia-curand-cu12==10.3.2.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
nvidia-cusolver-cu12==11.4.5.107 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
nvidia-cusparse-cu12==12.1.0.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
nvidia-nccl-cu12==2.19.3 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
nvidia-nvjitlink-cu12==12.9.86 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
nvidia-nvtx-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
openai==2.5.0 ; python_version >= "3.12"
openpyxl==3.1.5 ; python_version >= "3.12"
packaging==25.0 ; python_version >= "3.12"
pandas==2.3.3 ; python_version >= "3.12"
pdfminer-six==20250506 ; python_version >= "3.12"
pdfplumber==0.11.7 ; python_version >= "3.12"
pillow==12.0.0 ; python_version >= "3.12"
propcache==0.4.1 ; python_version >= "3.12"
psutil==7.1.3 ; python_version >= "3.12"
pycparser==2.23 ; platform_python_implementation != "PyPy" and implementation_name != "PyPy" and python_version >= "3.12"
pydantic-core==2.27.2 ; python_version >= "3.12"
pydantic-settings==2.11.0 ; python_version >= "3.12"
pydantic==2.10.5 ; python_version >= "3.12"
pypdfium2==4.30.0 ; python_version >= "3.12"
python-dateutil==2.8.2 ; python_version >= "3.12"
python-docx==1.2.0 ; python_version >= "3.12"
python-dotenv==1.1.1 ; python_version >= "3.12"
python-multipart==0.0.20 ; python_version >= "3.12"
python-pptx==1.0.2 ; python_version >= "3.12"
pytz==2025.2 ; python_full_version >= "3.12.0"
pywin32==311 ; python_full_version >= "3.12.0" and sys_platform == "win32"
pyyaml==6.0.3 ; python_version >= "3.12"
qwen-agent==0.0.31 ; python_full_version >= "3.12.0"
rank-bm25==0.2.2 ; python_full_version >= "3.12.0"
referencing==0.37.0 ; python_version >= "3.12"
regex==2025.9.18 ; python_version >= "3.12"
requests==2.32.5 ; python_version >= "3.12"
rpds-py==0.27.1 ; python_version >= "3.12"
safetensors==0.6.2 ; python_version >= "3.12"
scikit-learn==1.7.2 ; python_version >= "3.12"
scipy==1.16.2 ; python_version >= "3.12"
sentence-transformers==5.1.1 ; python_version >= "3.12"
six==1.17.0 ; python_version >= "3.12"
sniffio==1.3.1 ; python_version >= "3.12"
snowballstemmer==3.0.1 ; python_version >= "3.12"
soupsieve==2.8 ; python_version >= "3.12"
sse-starlette==3.0.2 ; python_version >= "3.12"
starlette==0.47.3 ; python_version >= "3.12"
sympy==1.14.0 ; python_version >= "3.12"
tabulate==0.9.0 ; python_version >= "3.12"
threadpoolctl==3.6.0 ; python_version >= "3.12"
tiktoken==0.12.0 ; python_version >= "3.12"
tokenizers==0.22.1 ; python_version >= "3.12"
torch==2.2.0 ; python_full_version >= "3.12.0"
tqdm==4.67.1 ; python_version >= "3.12"
transformers==4.57.1 ; python_full_version >= "3.12.0"
triton==2.2.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.12.0"
typing-extensions==4.15.0 ; python_version >= "3.12"
typing-inspection==0.4.2 ; python_version >= "3.12"
tzdata==2025.2 ; python_version >= "3.12"
urllib3==2.5.0 ; python_version >= "3.12"
uvicorn==0.35.0 ; python_version >= "3.12"
uvloop==0.22.1 ; python_full_version >= "3.12.0"
websocket-client==1.9.0 ; python_version >= "3.12"
xlrd==2.0.2 ; python_version >= "3.12"
xlsxwriter==3.2.9 ; python_version >= "3.12"
yarl==1.22.0 ; python_version >= "3.12"
aiofiles==25.1.0 ; python_version >= "3.12" and python_version < "4.0"
aiohappyeyeballs==2.6.1 ; python_version >= "3.12" and python_version < "4.0"
aiohttp==3.13.1 ; python_version >= "3.12" and python_version < "4.0"
aiosignal==1.4.0 ; python_version >= "3.12" and python_version < "4.0"
annotated-types==0.7.0 ; python_version >= "3.12" and python_version < "4.0"
anthropic==0.75.0 ; python_version >= "3.12" and python_version < "4.0"
anyio==4.11.0 ; python_version >= "3.12" and python_version < "4.0"
attrs==25.4.0 ; python_version >= "3.12" and python_version < "4.0"
bracex==2.6 ; python_version >= "3.12" and python_version < "4.0"
certifi==2025.10.5 ; python_version >= "3.12" and python_version < "4.0"
chardet==5.2.0 ; python_version >= "3.12" and python_version < "4.0"
charset-normalizer==3.4.4 ; python_version >= "3.12" and python_version < "4.0"
click==8.3.0 ; python_version >= "3.12" and python_version < "4.0"
colorama==0.4.6 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Windows"
deepagents==0.3.0 ; python_version >= "3.12" and python_version < "4.0"
distro==1.9.0 ; python_version >= "3.12" and python_version < "4.0"
docstring-parser==0.17.0 ; python_version >= "3.12" and python_version < "4.0"
et-xmlfile==2.0.0 ; python_version >= "3.12" and python_version < "4.0"
fastapi==0.116.1 ; python_version >= "3.12" and python_version < "4.0"
filelock==3.20.0 ; python_version >= "3.12" and python_version < "4.0"
frozenlist==1.8.0 ; python_version >= "3.12" and python_version < "4.0"
fsspec==2025.9.0 ; python_version >= "3.12" and python_version < "4.0"
h11==0.16.0 ; python_version >= "3.12" and python_version < "4.0"
hf-xet==1.1.10 ; python_version >= "3.12" and python_version < "4.0" and (platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "arm64" or platform_machine == "aarch64")
httpcore==1.0.9 ; python_version >= "3.12" and python_version < "4.0"
httpx-sse==0.4.3 ; python_version >= "3.12" and python_version < "4.0"
httpx==0.28.1 ; python_version >= "3.12" and python_version < "4.0"
huey==2.5.3 ; python_version >= "3.12" and python_version < "4.0"
huggingface-hub==0.35.3 ; python_version >= "3.12" and python_version < "4.0"
idna==3.11 ; python_version >= "3.12" and python_version < "4.0"
jinja2==3.1.6 ; python_version >= "3.12" and python_version < "4.0"
jiter==0.11.1 ; python_version >= "3.12" and python_version < "4.0"
joblib==1.5.2 ; python_version >= "3.12" and python_version < "4.0"
jsonpatch==1.33 ; python_version >= "3.12" and python_version < "4.0"
jsonpointer==3.0.0 ; python_version >= "3.12" and python_version < "4.0"
jsonschema-specifications==2025.9.1 ; python_version >= "3.12" and python_version < "4.0"
jsonschema==4.25.1 ; python_version >= "3.12" and python_version < "4.0"
langchain-anthropic==1.2.0 ; python_version >= "3.12" and python_version < "4.0"
langchain-core==1.1.3 ; python_version >= "3.12" and python_version < "4.0"
langchain-mcp-adapters==0.2.1 ; python_version >= "3.12" and python_version < "4.0"
langchain-openai==1.1.1 ; python_version >= "3.12" and python_version < "4.0"
langchain==1.1.3 ; python_version >= "3.12" and python_version < "4.0"
langgraph-checkpoint==3.0.1 ; python_version >= "3.12" and python_version < "4.0"
langgraph-prebuilt==1.0.5 ; python_version >= "3.12" and python_version < "4.0"
langgraph-sdk==0.2.15 ; python_version >= "3.12" and python_version < "4.0"
langgraph==1.0.4 ; python_version >= "3.12" and python_version < "4.0"
langsmith==0.4.59 ; python_version >= "3.12" and python_version < "4.0"
markupsafe==3.0.3 ; python_version >= "3.12" and python_version < "4.0"
mcp==1.12.4 ; python_version >= "3.12" and python_version < "4.0"
mpmath==1.3.0 ; python_version >= "3.12" and python_version < "4.0"
multidict==6.7.0 ; python_version >= "3.12" and python_version < "4.0"
networkx==3.5 ; python_version >= "3.12" and python_version < "4.0"
numpy==1.26.4 ; python_version >= "3.12" and python_version < "4.0"
nvidia-cublas-cu12==12.1.3.1 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-cupti-cu12==12.1.105 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-nvrtc-cu12==12.1.105 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cuda-runtime-cu12==12.1.105 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cudnn-cu12==8.9.2.26 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cufft-cu12==11.0.2.54 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-curand-cu12==10.3.2.106 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusolver-cu12==11.4.5.107 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-cusparse-cu12==12.1.0.106 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nccl-cu12==2.19.3 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nvjitlink-cu12==12.9.86 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
nvidia-nvtx-cu12==12.1.105 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
openai==2.5.0 ; python_version >= "3.12" and python_version < "4.0"
openpyxl==3.1.5 ; python_version >= "3.12" and python_version < "4.0"
orjson==3.11.5 ; python_version >= "3.12" and python_version < "4.0"
ormsgpack==1.12.0 ; python_version >= "3.12" and python_version < "4.0"
packaging==25.0 ; python_version >= "3.12" and python_version < "4.0"
pandas==2.3.3 ; python_version >= "3.12" and python_version < "4.0"
pillow==12.0.0 ; python_version >= "3.12" and python_version < "4.0"
propcache==0.4.1 ; python_version >= "3.12" and python_version < "4.0"
psutil==7.1.3 ; python_version >= "3.12" and python_version < "4.0"
pydantic-core==2.27.2 ; python_version >= "3.12" and python_version < "4.0"
pydantic-settings==2.11.0 ; python_version >= "3.12" and python_version < "4.0"
pydantic==2.10.5 ; python_version >= "3.12" and python_version < "4.0"
python-dateutil==2.8.2 ; python_version >= "3.12" and python_version < "4.0"
python-dotenv==1.1.1 ; python_version >= "3.12" and python_version < "4.0"
python-multipart==0.0.20 ; python_version >= "3.12" and python_version < "4.0"
pytz==2025.2 ; python_version >= "3.12" and python_version < "4.0"
pywin32==311 ; python_version >= "3.12" and python_version < "4.0" and sys_platform == "win32"
pyyaml==6.0.3 ; python_version >= "3.12" and python_version < "4.0"
referencing==0.37.0 ; python_version >= "3.12" and python_version < "4.0"
regex==2025.9.18 ; python_version >= "3.12" and python_version < "4.0"
requests-toolbelt==1.0.0 ; python_version >= "3.12" and python_version < "4.0"
requests==2.32.5 ; python_version >= "3.12" and python_version < "4.0"
rpds-py==0.27.1 ; python_version >= "3.12" and python_version < "4.0"
safetensors==0.6.2 ; python_version >= "3.12" and python_version < "4.0"
scikit-learn==1.7.2 ; python_version >= "3.12" and python_version < "4.0"
scipy==1.16.2 ; python_version >= "3.12" and python_version < "4.0"
sentence-transformers==5.1.1 ; python_version >= "3.12" and python_version < "4.0"
six==1.17.0 ; python_version >= "3.12" and python_version < "4.0"
sniffio==1.3.1 ; python_version >= "3.12" and python_version < "4.0"
sse-starlette==3.0.2 ; python_version >= "3.12" and python_version < "4.0"
starlette==0.47.3 ; python_version >= "3.12" and python_version < "4.0"
sympy==1.14.0 ; python_version >= "3.12" and python_version < "4.0"
tenacity==9.1.2 ; python_version >= "3.12" and python_version < "4.0"
threadpoolctl==3.6.0 ; python_version >= "3.12" and python_version < "4.0"
tiktoken==0.12.0 ; python_version >= "3.12" and python_version < "4.0"
tokenizers==0.22.1 ; python_version >= "3.12" and python_version < "4.0"
torch==2.2.0 ; python_version >= "3.12" and python_version < "4.0"
tqdm==4.67.1 ; python_version >= "3.12" and python_version < "4.0"
transformers==4.57.1 ; python_version >= "3.12" and python_version < "4.0"
triton==2.2.0 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
typing-extensions==4.15.0 ; python_version >= "3.12" and python_version < "4.0"
typing-inspection==0.4.2 ; python_version >= "3.12" and python_version < "4.0"
tzdata==2025.2 ; python_version >= "3.12" and python_version < "4.0"
urllib3==2.5.0 ; python_version >= "3.12" and python_version < "4.0"
uuid-utils==0.12.0 ; python_version >= "3.12" and python_version < "4.0"
uvicorn==0.35.0 ; python_version >= "3.12" and python_version < "4.0"
uvloop==0.22.1 ; python_version >= "3.12" and python_version < "4.0"
wcmatch==10.1 ; python_version >= "3.12" and python_version < "4.0"
xlrd==2.0.2 ; python_version >= "3.12" and python_version < "4.0"
xxhash==3.6.0 ; python_version >= "3.12" and python_version < "4.0"
yarl==1.22.0 ; python_version >= "3.12" and python_version < "4.0"
zstandard==0.25.0 ; python_version >= "3.12" and python_version < "4.0"

View File

@ -17,9 +17,10 @@ from agent.prompt_loader import load_guideline_prompt
from utils.fastapi_utils import (
process_messages, extract_block_from_system_prompt, format_messages_to_chat_history,
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
_get_optimal_batch_size, process_guideline, get_content_from_messages, call_preamble_llm, get_preamble_text, get_language_text,
process_guideline, call_preamble_llm, get_preamble_text, get_language_text,
create_stream_chunk
)
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
router = APIRouter()
@ -180,9 +181,17 @@ Action: Provide concise, friendly, and personified natural responses.
all_results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果
agent = all_results[0] if len(all_results) >0 else None # agent创建的结果
agent = all_results[0] if len(all_results) > 0 else None # agent创建的结果
guideline_reasoning = all_results[1] if len(all_results) >1 else ""
# 检查agent是否为异常对象
if isinstance(agent, Exception):
logger.error(f"Error creating agent: {agent}")
raise agent
guideline_reasoning = all_results[1] if len(all_results) > 1 else ""
if isinstance(guideline_reasoning, Exception):
logger.error(f"Error in guideline processing: {guideline_reasoning}")
guideline_reasoning = ""
if guideline_prompt or guideline_reasoning:
logger.info("Guideline Prompt: %s, Reasoning: %s",
guideline_prompt.replace('\n', '\\n') if guideline_prompt else "None",
@ -243,7 +252,7 @@ async def enhanced_generate_stream_response(
preamble_text = await preamble_task
# 只有当preamble_text不为空且不为"<empty>"时才输出
if preamble_text and preamble_text.strip() and preamble_text != "<empty>":
preamble_content = get_content_from_messages([{"role": "preamble","content": preamble_text + "\n"}], tool_response=tool_response)
preamble_content = f"[PREAMBLE]\n{preamble_text}\n"
chunk_data = create_stream_chunk(f"chatcmpl-preamble", model_name, preamble_content)
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
logger.info(f"Stream mode: Generated preamble text ({len(preamble_text)} chars)")
@ -257,7 +266,7 @@ async def enhanced_generate_stream_response(
# 立即发送guideline_reasoning
if guideline_reasoning:
guideline_content = get_content_from_messages([{"role": "assistant","reasoning_content": guideline_reasoning+ "\n"}], tool_response=tool_response)
guideline_content = f"[THINK]\n{guideline_reasoning}\n"
chunk_data = create_stream_chunk(f"chatcmpl-guideline", model_name, guideline_content)
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
@ -270,20 +279,41 @@ async def enhanced_generate_stream_response(
else:
final_messages = append_assistant_last_message(final_messages, f"\n\nlanguage:{get_language_text(language)}")
logger.debug(f"Final messages for agent (showing first 2): {final_messages[:2]}")
# 第三阶段agent响应流式传输
logger.info(f"Starting agent stream response")
accumulated_content = ""
chunk_id = 0
for response in agent.run(messages=final_messages):
previous_content = accumulated_content
accumulated_content = get_content_from_messages(response, tool_response=tool_response)
# 计算新增的内容
if accumulated_content.startswith(previous_content):
new_content = accumulated_content[len(previous_content):]
else:
new_content = accumulated_content
previous_content = ""
message_tag = ""
function_name = ""
tool_args = ""
async for msg,metadata in agent.astream({"messages": final_messages}, stream_mode="messages"):
new_content = ""
if isinstance(msg, AIMessageChunk):
# 判断是否有工具调用
if msg.tool_call_chunks: # 检查工具调用块
if message_tag != "TOOL_CALL":
message_tag = "TOOL_CALL"
if msg.tool_call_chunks[0]["name"]:
function_name = msg.tool_call_chunks[0]["name"]
if msg.tool_call_chunks[0]["args"]:
tool_args += msg.tool_call_chunks[0]["args"]
elif len(msg.content)>0:
if message_tag != "ANSWER":
message_tag = "ANSWER"
new_content = f"[{message_tag}]\n{msg.text}"
elif message_tag == "ANSWER":
new_content = msg.text
elif message_tag == "TOOL_CALL" and \
(
("finish_reason" in msg.response_metadata and msg.response_metadata["finish_reason"] == "tool_calls") or \
("stop_reason" in msg.response_metadata and msg.response_metadata["stop_reason"] == "tool_use")
):
new_content = f"[{message_tag}] {function_name}\n{tool_args}"
message_tag = "TOOL_CALL"
elif isinstance(msg, ToolMessage) and len(msg.content)>0:
message_tag = "TOOL_RESPONSE"
new_content = f"[{message_tag}] {msg.name}\n{msg.text}"
# 只有当有新内容时才发送chunk
if new_content:
@ -379,36 +409,44 @@ async def create_agent_and_generate_response(
# 准备最终的消息
final_messages = messages.copy()
pre_message_list = []
if guideline_reasoning:
# 用###分割guideline_reasoning取最后一段作为Guidelines
guidelines_text = guideline_reasoning.split('###')[-1].strip() if guideline_reasoning else ""
final_messages = append_assistant_last_message(final_messages, f"language:{get_language_text(language)}\n\nGuidelines:\n{guidelines_text}\n I will follow these guidelines step by step.")
pre_message_list.append({"role": "assistant","reasoning_content": guideline_reasoning+ "\n"})
else:
final_messages = append_assistant_last_message(final_messages, f"\n\nlanguage:{get_language_text(language)}")
# 非流式响应
agent_responses = agent.run_nonstream(final_messages)
final_responses = pre_message_list + agent_responses
if final_responses and len(final_responses) > 0:
# 使用 get_content_from_messages 处理响应,支持 tool_response 参数
content = get_content_from_messages(final_responses, tool_response=tool_response)
agent_responses = await agent.ainvoke({"messages": final_messages})
append_messages = agent_responses["messages"][len(final_messages):]
# agent_responses = agent.run_nonstream(final_messages)
response_text = ""
if guideline_reasoning:
response_text += "[THINK]\n"+guideline_reasoning+ "\n"
for msg in append_messages:
if isinstance(msg,AIMessage):
if len(msg.text)>0:
response_text += "[ANSWER]\n"+msg.text+ "\n"
if len(msg.tool_calls)>0:
response_text += "".join([f"[TOOL_CALL] {tool['name']}\n{json.dumps(tool["args"]) if isinstance(tool["args"],dict) else tool["args"]}\n" for tool in msg.tool_calls])
elif isinstance(msg,ToolMessage) and tool_response:
response_text += f"[TOOL_RESPONSE] {msg.name}\n{msg.text}\n"
if len(response_text) > 0:
# 构造OpenAI格式的响应
return ChatResponse(
choices=[{
"index": 0,
"message": {
"role": "assistant",
"content": content
"content": response_text
},
"finish_reason": "stop"
}],
usage={
"prompt_tokens": sum(len(msg.get("content", "")) for msg in messages),
"completion_tokens": len(content),
"total_tokens": sum(len(msg.get("content", "")) for msg in messages) + len(content)
"completion_tokens": len(response_text),
"total_tokens": sum(len(msg.get("content", "")) for msg in messages) + len(response_text)
}
)
else:

View File

@ -16,7 +16,7 @@ try:
except ImportError:
def apply_optimization_profile(profile):
return {"profile": profile, "status": "system_optimizer not available"}
from utils.fastapi_utils import get_content_from_messages
from embedding import get_model_manager
from pydantic import BaseModel
import logging

View File

@ -22,7 +22,6 @@ from .dataset_manager import (
)
from .project_manager import (
get_content_from_messages,
generate_project_readme,
save_project_readme,
get_project_status,
@ -31,20 +30,7 @@ from .project_manager import (
get_project_stats
)
# Import agent management modules
# Note: These have been moved to agent package
# from .file_loaded_agent_manager import (
# get_global_agent_manager,
# init_global_agent_manager
# )
# Import optimized modules
# Note: These have been moved to agent package
# from .sharded_agent_manager import (
# ShardedAgentManager,
# get_global_sharded_agent_manager,
# init_global_sharded_agent_manager
# )
from .connection_pool import (
HTTPConnectionPool,
@ -153,7 +139,6 @@ __all__ = [
'remove_dataset_directory_by_key',
# project_manager
'get_content_from_messages',
'generate_project_readme',
'save_project_readme',
'get_project_status',
@ -161,10 +146,6 @@ __all__ = [
'list_projects',
'get_project_stats',
# file_loaded_agent_manager (moved to agent package)
# 'get_global_agent_manager',
# 'init_global_agent_manager',
# agent_pool
'AgentPool',
'get_agent_pool',

View File

@ -6,10 +6,14 @@ import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict, Optional, Union, Any
import aiohttp
from qwen_agent.llm.schema import ASSISTANT, FUNCTION
from qwen_agent.llm.oai import TextChatAtOAI
from fastapi import HTTPException
import logging
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain.chat_models import init_chat_model
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
logger = logging.getLogger('app')
@ -19,6 +23,16 @@ thread_pool = ThreadPoolExecutor(max_workers=10)
# 创建并发信号量限制同时进行的API调用数量
api_semaphore = asyncio.Semaphore(8) # 最多同时进行8个API调用
def detect_provider(model_name,model_server):
"""根据模型名称检测提供商类型"""
model_name_lower = model_name.lower()
if any(claude_model in model_name_lower for claude_model in ["claude", "anthropic"]):
return "anthropic",model_server.replace("/v1","")
elif any(openai_model in model_name_lower for openai_model in ["gpt", "openai", "o1"]):
return "openai",model_server
else:
# 默认使用 openai 兼容格式
return "openai",model_server
def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extension: str) -> tuple[str, int]:
"""
@ -93,50 +107,50 @@ def create_stream_chunk(chunk_id: str, model_name: str, content: str = None, fin
}
return chunk_data
def get_content_from_messages(messages: List[dict], tool_response: bool = True) -> str:
"""Extract content from qwen-agent messages with special formatting"""
full_text = ''
content = []
TOOL_CALL_S = '[TOOL_CALL]'
TOOL_RESULT_S = '[TOOL_RESPONSE]'
THOUGHT_S = '[THINK]'
ANSWER_S = '[ANSWER]'
PREAMBLE_S = '[PREAMBLE]'
# def get_content_from_messages(messages: List[dict], tool_response: bool = True) -> str:
# """Extract content from qwen-agent messages with special formatting"""
# full_text = ''
# content = []
# TOOL_CALL_S = '[TOOL_CALL]'
# TOOL_RESULT_S = '[TOOL_RESPONSE]'
# THOUGHT_S = '[THINK]'
# ANSWER_S = '[ANSWER]'
# PREAMBLE_S = '[PREAMBLE]'
for msg in messages:
if msg['role'] == ASSISTANT:
if msg.get('reasoning_content'):
assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
if msg.get('content'):
assert isinstance(msg['content'], str), 'Now only supports text messages'
# 过滤掉流式输出中的不完整 tool_call 文本
content_text = msg["content"]
# for msg in messages:
# if msg['role'] == ASSISTANT:
# if msg.get('reasoning_content'):
# assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
# content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
# if msg.get('content'):
# assert isinstance(msg['content'], str), 'Now only supports text messages'
# # 过滤掉流式输出中的不完整 tool_call 文本
# content_text = msg["content"]
# 使用正则表达式替换不完整的 tool_call 模式为空字符串
# # 使用正则表达式替换不完整的 tool_call 模式为空字符串
# 匹配并替换不完整的 tool_call 模式
content_text = re.sub(r'<t?o?o?l?_?c?a?l?l?$', '', content_text)
# 只有在处理后内容不为空时才添加
if content_text.strip():
content.append(f'{ANSWER_S}\n{content_text}')
if msg.get('function_call'):
content_text = msg["function_call"]["arguments"]
content_text = re.sub(r'}\n<\/?t?o?o?l?_?c?a?l?l?$', '', content_text)
if content_text.strip():
content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{content_text}')
elif msg['role'] == FUNCTION:
if tool_response:
content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
elif msg['role'] == "preamble":
content.append(f'{PREAMBLE_S}\n{msg["content"]}')
else:
raise TypeError
# # 匹配并替换不完整的 tool_call 模式
# content_text = re.sub(r'<t?o?o?l?_?c?a?l?l?$', '', content_text)
# # 只有在处理后内容不为空时才添加
# if content_text.strip():
# content.append(f'{ANSWER_S}\n{content_text}')
# if msg.get('function_call'):
# content_text = msg["function_call"]["arguments"]
# content_text = re.sub(r'}\n<\/?t?o?o?l?_?c?a?l?l?$', '', content_text)
# if content_text.strip():
# content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{content_text}')
# elif msg['role'] == FUNCTION:
# if tool_response:
# content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
# elif msg['role'] == "preamble":
# content.append(f'{PREAMBLE_S}\n{msg["content"]}')
# else:
# raise TypeError
if content:
full_text = '\n'.join(content)
# if content:
# full_text = '\n'.join(content)
return full_text
# return full_text
def process_messages(messages: List[Dict], language: Optional[str] = None) -> List[Dict[str, str]]:
@ -156,13 +170,13 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
processed_messages = []
# 收集所有ASSISTANT消息的索引
assistant_indices = [i for i, msg in enumerate(messages) if msg.role == "assistant"]
assistant_indices = [i for i, msg in enumerate(messages) if msg.role == ASSISTANT]
total_assistant_messages = len(assistant_indices)
cutoff_point = max(0, total_assistant_messages - 5)
# 处理每条消息
for i, msg in enumerate(messages):
if msg.role == "assistant":
if msg.role == ASSISTANT:
# 确定当前ASSISTANT消息在所有ASSISTANT消息中的位置从0开始
assistant_position = assistant_indices.index(i)
@ -236,7 +250,7 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
assistant_content = ""
function_calls = []
tool_responses = []
tool_id = ""
for i in range(0, len(parts)):
if i % 2 == 0: # 文本内容
text = parts[i].strip()
@ -259,8 +273,10 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
break
if should_include:
# 将 TOOL_RESPONSE 包装成 tool_result 消息,紧跟对应的 tool_use
final_messages.append({
"role": FUNCTION,
"role": TOOL,
"tool_call_id": tool_id, # 与前面 tool_use 的 id 保持一致
"name": function_name,
"content": response_content
})
@ -279,13 +295,17 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
break
if should_include:
tool_id = f"tool_id_{i}"
final_messages.append({
"role": ASSISTANT,
"content": "",
"function_call": {
"name": function_name,
"arguments": arguments
}
"tool_calls": [{
"id":tool_id,
"function": {
"name": function_name,
"arguments": arguments
}
}]
})
elif current_tag != "THINK" and current_tag != "PREAMBLE":
final_messages.append({
@ -297,7 +317,6 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
else:
# 非 assistant 消息或不包含 [TOOL_RESPONSE] 的消息直接添加
final_messages.append(msg)
print(final_messages)
return final_messages
@ -317,16 +336,19 @@ def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str:
for message in messages:
role = message.get('role', '')
content = message.get('content', '')
if role == 'user':
name = message.get('name', '')
if role == USER:
chat_history.append(f"user: {content}")
elif role == FUNCTION:
chat_history.append(f"function_response: {content}")
elif role == TOOL:
chat_history.append(f"{name} response: {content}")
elif role == ASSISTANT:
if len(content) >0:
chat_history.append(f"assistant: {content}")
if message.get('function_call'):
chat_history.append(f"function_call: {message.get('function_call').get('name')} ")
chat_history.append(f"{message.get('function_call').get('arguments')}")
if message.get('tool_calls'):
for tool_call in message.get('tool_calls'):
function_name = tool_call.get('function').get('name')
arguments = tool_call.get('function').get('arguments')
chat_history.append(f"{function_name} call: {arguments}")
recent_chat_history = chat_history[-15:] if len(chat_history) > 15 else chat_history
print(f"recent_chat_history:{recent_chat_history}")
@ -411,30 +433,44 @@ async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
)
def _sync_call_llm(llm_config, messages) -> str:
"""同步调用LLM的辅助函数在线程池中执行"""
llm_instance = TextChatAtOAI(llm_config)
async def _sync_call_llm(llm_config, messages) -> str:
"""同步调用LLM的辅助函数在线程池中执行 - 使用LangChain"""
try:
# 设置stream=False来获取非流式响应
response = llm_instance.chat(messages=messages, stream=False)
# 创建LangChain LLM实例
model_name = llm_config.get('model')
model_server = llm_config.get('model_server')
api_key = llm_config.get('api_key')
# 检测或使用指定的提供商
model_provider,base_url = detect_provider(model_name,model_server)
# 处理响应
if isinstance(response, list) and response:
# 如果返回的是Message列表提取内容
if hasattr(response[0], 'content'):
return response[0].content
elif isinstance(response[0], dict) and 'content' in response[0]:
return response[0]['content']
# 构建模型参数
model_kwargs = {
"model": model_name,
"model_provider": model_provider,
"temperature": 0.8,
"base_url":base_url,
"api_key":api_key
}
llm_instance = init_chat_model(**model_kwargs)
# 如果是字符串,直接返回
if isinstance(response, str):
return response
# 转换消息格式为LangChain格式
langchain_messages = []
for msg in messages:
if msg['role'] == 'system':
langchain_messages.append(SystemMessage(content=msg['content']))
elif msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
# 处理其他类型
return str(response) if response else ""
# 调用LangChain模型
response = await llm_instance.ainvoke(langchain_messages)
# 返回响应内容
return response.content if response.content else ""
except Exception as e:
logger.error(f"Error calling guideline LLM: {e}")
logger.error(f"Error calling guideline LLM with LangChain: {e}")
return ""
def get_language_text(language: str):
@ -550,9 +586,8 @@ async def call_preamble_llm(chat_history: str, last_message: str, preamble_choic
try:
# 使用信号量控制并发API调用数量
async with api_semaphore:
# 使用线程池执行同步HTTP调用避免阻塞事件循环
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(thread_pool, _sync_call_llm, llm_config, messages)
# 直接调用异步LLM函数
response = await _sync_call_llm(llm_config, messages)
# 从响应中提取 ```json 和 ``` 包裹的内容
json_pattern = r'```json\s*\n(.*?)\n```'
@ -605,9 +640,8 @@ async def call_guideline_llm(chat_history: str, guidelines_prompt: str, model_na
try:
# 使用信号量控制并发API调用数量
async with api_semaphore:
# 使用线程池执行同步HTTP调用避免阻塞事件循环
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(thread_pool, _sync_call_llm, llm_config, messages)
# 直接调用异步LLM函数
response = await _sync_call_llm(llm_config, messages)
return response
except Exception as e:

View File

@ -15,15 +15,6 @@ logger = logging.getLogger('app')
from utils.file_utils import get_document_preview, load_processed_files_log
def get_content_from_messages(messages: List[dict]) -> str:
"""Extract content from messages list"""
content = ""
for message in messages:
if message.get("role") == "user":
content += message.get("content", "")
return content
def generate_directory_tree(project_dir: str, unique_id: str, max_depth: int = 3) -> str:
"""Generate dataset directory tree structure for the project"""
def _build_tree(path: str, prefix: str = "", is_last: bool = True, depth: int = 0) -> List[str]: