add deep_agent
This commit is contained in:
parent
eb17dff54a
commit
720db80ae9
15
.vscode/launch.json
vendored
Normal file
15
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "FastAPI: Debug",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "fastapi_app.py",
|
||||
"console": "integratedTerminal",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"python": "${workspaceFolder}/.venv/bin/python"
|
||||
}
|
||||
]
|
||||
}
|
||||
9
.vscode/settings.json
vendored
Normal file
9
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"python.languageServer": "Pylance",
|
||||
"python.analysis.indexing": true,
|
||||
"python.analysis.autoSearchPaths": true,
|
||||
"python.analysis.diagnosticMode": "workspace",
|
||||
"python.analysis.extraPaths": [
|
||||
"${workspaceFolder}/.venv"
|
||||
]
|
||||
}
|
||||
27
CLAUDE.md
27
CLAUDE.md
@ -1,3 +1,28 @@
|
||||
# python环境
|
||||
本项目的python环境是基于 poetry创建的,如果需要运行 py文件,需要执行poetry run python xxx.py 来执行。
|
||||
|
||||
启动脚本:
|
||||
```
|
||||
poetry run uvicorn fastapi_app:app --host 0.0.0.0 --port 8001
|
||||
```
|
||||
测试脚本:
|
||||
```
|
||||
curl --request POST \
|
||||
--url http://localhost:8001/api/v2/chat/completions \
|
||||
--header 'authorization: Bearer a21c99620a8ef61d69563afe05ccce89' \
|
||||
--header 'content-type: application/json' \
|
||||
--header 'x-trace-id: 123123123' \
|
||||
--data '{
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "咖啡多少钱一杯"
|
||||
}
|
||||
],
|
||||
"stream": true,
|
||||
"model": "whatever",
|
||||
"language": "ja",
|
||||
"bot_id": "63069654-7750-409d-9a58-a0960d899a20",
|
||||
"tool_response": true,
|
||||
"user_identifier": "及川"
|
||||
}'
|
||||
```
|
||||
|
||||
@ -1,502 +0,0 @@
|
||||
# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import datetime
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
from qwen_agent.tools.base import BaseTool
|
||||
|
||||
|
||||
class CustomMCPManager:
|
||||
_instance = None # Private class variable to store the unique instance
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(CustomMCPManager, cls).__new__(cls, *args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, 'clients'): # The singleton should only be inited once
|
||||
"""Set a new event loop in a separate thread"""
|
||||
try:
|
||||
import mcp # noqa
|
||||
except ImportError as e:
|
||||
raise ImportError('Could not import mcp. Please install mcp with `pip install -U mcp`.') from e
|
||||
|
||||
load_dotenv() # Load environment variables from .env file
|
||||
self.clients: dict = {}
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.loop_thread = threading.Thread(target=self.start_loop, daemon=True)
|
||||
self.loop_thread.start()
|
||||
|
||||
# A fallback way to terminate MCP tool processes after Qwen-Agent exits
|
||||
self.processes = []
|
||||
self.monkey_patch_mcp_create_platform_compatible_process()
|
||||
|
||||
def monkey_patch_mcp_create_platform_compatible_process(self):
|
||||
try:
|
||||
import mcp.client.stdio
|
||||
target = mcp.client.stdio._create_platform_compatible_process
|
||||
except (ModuleNotFoundError, AttributeError) as e:
|
||||
raise ImportError('Qwen-Agent needs to monkey patch MCP for process cleanup. '
|
||||
'Please upgrade MCP to a higher version with `pip install -U mcp`.') from e
|
||||
|
||||
async def _monkey_patched_create_platform_compatible_process(*args, **kwargs):
|
||||
process = await target(*args, **kwargs)
|
||||
self.processes.append(process)
|
||||
return process
|
||||
|
||||
mcp.client.stdio._create_platform_compatible_process = _monkey_patched_create_platform_compatible_process
|
||||
|
||||
def start_loop(self):
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
# Set a global exception handler to silently handle cross-task exceptions from MCP SSE connections
|
||||
def exception_handler(loop, context):
|
||||
exception = context.get('exception')
|
||||
if exception:
|
||||
# Silently handle cross-task exceptions from MCP SSE connections
|
||||
if (isinstance(exception, RuntimeError) and
|
||||
'Attempted to exit cancel scope in a different task' in str(exception)):
|
||||
return # Silently ignore this type of exception
|
||||
if (isinstance(exception, BaseExceptionGroup) and # noqa
|
||||
'Attempted to exit cancel scope in a different task' in str(exception)): # noqa
|
||||
return # Silently ignore this type of exception
|
||||
|
||||
# Other exceptions are handled normally
|
||||
loop.default_exception_handler(context)
|
||||
|
||||
self.loop.set_exception_handler(exception_handler)
|
||||
self.loop.run_forever()
|
||||
|
||||
def is_valid_mcp_servers(self, config: dict):
|
||||
"""Example of mcp servers configuration:
|
||||
{
|
||||
"mcpServers": {
|
||||
"memory": {
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-memory"]
|
||||
},
|
||||
"filesystem": {
|
||||
"command": "npx",
|
||||
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/allowed/files"]
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# Check if the top-level key "mcpServers" exists and its value is a dictionary
|
||||
if not isinstance(config, dict) or 'mcpServers' not in config or not isinstance(config['mcpServers'], dict):
|
||||
return False
|
||||
mcp_servers = config['mcpServers']
|
||||
# Check each sub-item under "mcpServers"
|
||||
for key in mcp_servers:
|
||||
server = mcp_servers[key]
|
||||
# Each sub-item must be a dictionary
|
||||
if not isinstance(server, dict):
|
||||
return False
|
||||
if 'command' in server:
|
||||
# "command" must be a string
|
||||
if not isinstance(server['command'], str):
|
||||
return False
|
||||
# "args" must be a list
|
||||
if 'args' not in server or not isinstance(server['args'], list):
|
||||
return False
|
||||
if 'url' in server:
|
||||
# "url" must be a string
|
||||
if not isinstance(server['url'], str):
|
||||
return False
|
||||
# "headers" must be a dictionary
|
||||
if 'headers' in server and not isinstance(server['headers'], dict):
|
||||
return False
|
||||
# If the "env" key exists, it must be a dictionary
|
||||
if 'env' in server and not isinstance(server['env'], dict):
|
||||
return False
|
||||
return True
|
||||
|
||||
def initConfig(self, config: Dict):
|
||||
if not self.is_valid_mcp_servers(config):
|
||||
raise ValueError('Config of mcpservers is not valid')
|
||||
logger.info(f'Initializing MCP tools from mcp servers: {list(config["mcpServers"].keys())}')
|
||||
# Submit coroutine to the event loop and wait for the result
|
||||
future = asyncio.run_coroutine_threadsafe(self.init_config_async(config), self.loop)
|
||||
try:
|
||||
result = future.result() # You can specify a timeout if desired
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.info(f'Failed in initializing MCP tools: {e}')
|
||||
raise e
|
||||
|
||||
async def init_config_async(self, config: Dict):
|
||||
tools: list = []
|
||||
mcp_servers = config['mcpServers']
|
||||
|
||||
# 并发连接所有MCP服务器
|
||||
connection_tasks = []
|
||||
for server_name in mcp_servers:
|
||||
client = CustomMCPClient()
|
||||
server = mcp_servers[server_name]
|
||||
# 创建连接任务
|
||||
task = self._connect_and_store_client(client, server_name, server)
|
||||
connection_tasks.append(task)
|
||||
|
||||
# 并发执行所有连接任务
|
||||
connected_clients = await asyncio.gather(*connection_tasks, return_exceptions=True)
|
||||
|
||||
# 处理连接结果并为每个客户端创建工具
|
||||
for result in connected_clients:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f'Failed to connect MCP server: {result}')
|
||||
continue
|
||||
|
||||
client, server_name = result
|
||||
client_id = client.client_id
|
||||
for tool in client.tools:
|
||||
"""MCP tool example:
|
||||
{
|
||||
"name": "read_query",
|
||||
"description": "Execute a SELECT query on the SQLite database",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "SELECT SQL query to execute"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
"""
|
||||
parameters = tool.inputSchema
|
||||
# The required field in inputSchema may be empty and needs to be initialized.
|
||||
if 'required' not in parameters:
|
||||
parameters['required'] = []
|
||||
# Remove keys from parameters that do not conform to the standard OpenAI schema
|
||||
# Check if the required fields exist
|
||||
required_fields = {'type', 'properties', 'required'}
|
||||
missing_fields = required_fields - parameters.keys()
|
||||
if missing_fields:
|
||||
raise ValueError(f'Missing required fields in schema: {missing_fields}')
|
||||
|
||||
# Keep only the necessary fields
|
||||
cleaned_parameters = {
|
||||
'type': parameters['type'],
|
||||
'properties': parameters['properties'],
|
||||
'required': parameters['required']
|
||||
}
|
||||
register_name = server_name + '-' + tool.name
|
||||
agent_tool = self.create_tool_class(register_name=register_name,
|
||||
register_client_id=client_id,
|
||||
tool_name=tool.name,
|
||||
tool_desc=tool.description,
|
||||
tool_parameters=cleaned_parameters)
|
||||
tools.append(agent_tool)
|
||||
|
||||
if client.resources:
|
||||
"""MCP resource example:
|
||||
{
|
||||
uri: string; // Unique identifier for the resource
|
||||
name: string; // Human-readable name
|
||||
description?: string; // Optional description
|
||||
mimeType?: string; // Optional MIME type
|
||||
}
|
||||
"""
|
||||
# List resources
|
||||
list_resources_tool_name = server_name + '-' + 'list_resources'
|
||||
list_resources_params = {'type': 'object', 'properties': {}, 'required': []}
|
||||
list_resources_agent_tool = self.create_tool_class(
|
||||
register_name=list_resources_tool_name,
|
||||
register_client_id=client_id,
|
||||
tool_name='list_resources',
|
||||
tool_desc='Servers expose a list of concrete resources through this tool. '
|
||||
'By invoking it, you can discover the available resources and obtain resource templates, which help clients understand how to construct valid URIs. '
|
||||
'These URI formats will be used as input parameters for the read_resource function. ',
|
||||
tool_parameters=list_resources_params)
|
||||
tools.append(list_resources_agent_tool)
|
||||
|
||||
# Read resource
|
||||
resources_template_str = '' # Check if there are resource templates
|
||||
try:
|
||||
list_resource_templates = await client.session.list_resource_templates(
|
||||
) # Check if the server has resources tesmplate
|
||||
if list_resource_templates.resourceTemplates:
|
||||
resources_template_str = '\n'.join(
|
||||
str(template) for template in list_resource_templates.resourceTemplates)
|
||||
|
||||
except Exception as e:
|
||||
logger.info(f'Failed in listing MCP resource templates: {e}')
|
||||
|
||||
read_resource_tool_name = server_name + '-' + 'read_resource'
|
||||
read_resource_params = {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'uri': {
|
||||
'type': 'string',
|
||||
'description': 'The URI identifying the specific resource to access'
|
||||
}
|
||||
},
|
||||
'required': ['uri']
|
||||
}
|
||||
original_tool_desc = 'Request to access a resource provided by a connected MCP server. Resources represent data sources that can be used as context, such as files, API responses, or system information.'
|
||||
if resources_template_str:
|
||||
tool_desc = original_tool_desc + '\nResource Templates:\n' + resources_template_str
|
||||
else:
|
||||
tool_desc = original_tool_desc
|
||||
read_resource_agent_tool = self.create_tool_class(register_name=read_resource_tool_name,
|
||||
register_client_id=client_id,
|
||||
tool_name='read_resource',
|
||||
tool_desc=tool_desc,
|
||||
tool_parameters=read_resource_params)
|
||||
tools.append(read_resource_agent_tool)
|
||||
|
||||
return tools
|
||||
|
||||
async def _connect_and_store_client(self, client, server_name, server):
|
||||
"""辅助方法:连接MCP服务器并存储客户端"""
|
||||
try:
|
||||
await client.connection_server(mcp_server_name=server_name,
|
||||
mcp_server=server) # Attempt to connect to the server
|
||||
|
||||
client_id = server_name + '_' + str(
|
||||
uuid.uuid4()) # To allow the same server name be used across different running agents
|
||||
client.client_id = client_id # Ensure client_id is set on the client instance
|
||||
self.clients[client_id] = client # Add to clients dict after successful connection
|
||||
return client, server_name
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to connect MCP server {server_name}: {e}')
|
||||
raise e
|
||||
|
||||
def create_tool_class(self, register_name, register_client_id, tool_name, tool_desc, tool_parameters):
|
||||
|
||||
class ToolClass(BaseTool):
|
||||
name = register_name
|
||||
description = tool_desc
|
||||
parameters = tool_parameters
|
||||
client_id = register_client_id
|
||||
|
||||
def call(self, params: Union[str, dict], **kwargs) -> str:
|
||||
tool_args = json.loads(params)
|
||||
# Submit coroutine to the event loop and wait for the result
|
||||
manager = CustomMCPManager()
|
||||
client = manager.clients[self.client_id]
|
||||
future = asyncio.run_coroutine_threadsafe(client.execute_function(tool_name, tool_args), manager.loop)
|
||||
try:
|
||||
result = future.result()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.info(f'Failed in executing MCP tool: {e}')
|
||||
raise e
|
||||
|
||||
ToolClass.__name__ = f'{register_name}_Class'
|
||||
return ToolClass()
|
||||
|
||||
def shutdown(self):
|
||||
futures = []
|
||||
for client_id in list(self.clients.keys()):
|
||||
client: CustomMCPClient = self.clients[client_id]
|
||||
future = asyncio.run_coroutine_threadsafe(client.cleanup(), self.loop)
|
||||
futures.append(future)
|
||||
del self.clients[client_id]
|
||||
time.sleep(1) # Wait for the graceful cleanups, otherwise fall back
|
||||
|
||||
# fallback
|
||||
if asyncio.all_tasks(self.loop):
|
||||
logger.info(
|
||||
'There are still tasks in `CustomMCPManager().loop`, force terminating the MCP tool processes. There may be some exceptions.'
|
||||
)
|
||||
for process in self.processes:
|
||||
try:
|
||||
process.terminate()
|
||||
except ProcessLookupError:
|
||||
pass # it's ok, the process may exit earlier
|
||||
|
||||
self.loop.call_soon_threadsafe(self.loop.stop)
|
||||
self.loop_thread.join()
|
||||
|
||||
|
||||
class CustomMCPClient:
|
||||
|
||||
def __init__(self):
|
||||
from mcp import ClientSession
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.tools: list = None
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self.resources: bool = False
|
||||
self._last_mcp_server_name = None
|
||||
self._last_mcp_server = None
|
||||
self.client_id = None # For replacing in MCPManager.clients
|
||||
|
||||
async def connection_server(self, mcp_server_name, mcp_server):
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
"""Connect to an MCP server and retrieve the available tools."""
|
||||
# Save parameters
|
||||
self._last_mcp_server_name = mcp_server_name
|
||||
self._last_mcp_server = mcp_server
|
||||
|
||||
try:
|
||||
if 'url' in mcp_server:
|
||||
url = mcp_server.get('url')
|
||||
sse_read_timeout = mcp_server.get('sse_read_timeout', 300)
|
||||
logger.info(f'{mcp_server_name} sse_read_timeout: {sse_read_timeout}s')
|
||||
if mcp_server.get('type', 'sse') == 'streamable-http':
|
||||
# streamable-http mode
|
||||
"""streamable-http mode mcp example:
|
||||
{"mcpServers": {
|
||||
"streamable-mcp-server": {
|
||||
"type": "streamable-http",
|
||||
"url":"http://0.0.0.0:8000/mcp"
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
headers = mcp_server.get('headers', {})
|
||||
self._streams_context = streamablehttp_client(
|
||||
url=url, headers=headers, sse_read_timeout=datetime.timedelta(seconds=sse_read_timeout))
|
||||
read_stream, write_stream, get_session_id = await self.exit_stack.enter_async_context(
|
||||
self._streams_context)
|
||||
self._session_context = ClientSession(read_stream, write_stream)
|
||||
self.session = await self.exit_stack.enter_async_context(self._session_context)
|
||||
else:
|
||||
# sse mode
|
||||
headers = mcp_server.get('headers', {'Accept': 'text/event-stream'})
|
||||
self._streams_context = sse_client(url, headers, sse_read_timeout=sse_read_timeout)
|
||||
streams = await self.exit_stack.enter_async_context(self._streams_context)
|
||||
self._session_context = ClientSession(*streams)
|
||||
self.session = await self.exit_stack.enter_async_context(self._session_context)
|
||||
else:
|
||||
server_params = StdioServerParameters(command=mcp_server['command'],
|
||||
args=mcp_server['args'],
|
||||
env=mcp_server.get('env', None))
|
||||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||||
self.stdio, self.write = stdio_transport
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
|
||||
logger.info(
|
||||
f'Initializing a MCP stdio_client, if this takes forever, please check the config of this mcp server: {mcp_server_name}'
|
||||
)
|
||||
|
||||
await self.session.initialize()
|
||||
list_tools = await self.session.list_tools()
|
||||
self.tools = list_tools.tools
|
||||
try:
|
||||
list_resources = await self.session.list_resources() # Check if the server has resources
|
||||
if list_resources.resources:
|
||||
self.resources = True
|
||||
except Exception:
|
||||
# logger.info(f"No list resources: {e}")
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f'Failed in connecting to MCP server: {e}')
|
||||
raise e
|
||||
|
||||
async def reconnect(self):
|
||||
# Create a new MCPClient and connect
|
||||
if self.client_id is None:
|
||||
raise RuntimeError(
|
||||
'Cannot reconnect: client_id is None. This usually means the client was not properly registered in MCPManager.'
|
||||
)
|
||||
new_client = CustomMCPClient()
|
||||
new_client.client_id = self.client_id
|
||||
await new_client.connection_server(self._last_mcp_server_name, self._last_mcp_server)
|
||||
return new_client
|
||||
|
||||
async def execute_function(self, tool_name, tool_args: dict):
|
||||
from mcp.types import TextResourceContents
|
||||
|
||||
# Check if session is alive
|
||||
try:
|
||||
await self.session.send_ping()
|
||||
except Exception as e:
|
||||
logger.info(f"Session is not alive, please increase 'sse_read_timeout' in the config, try reconnect: {e}")
|
||||
# Auto reconnect
|
||||
try:
|
||||
manager = CustomMCPManager()
|
||||
if self.client_id is not None:
|
||||
manager.clients[self.client_id] = await self.reconnect()
|
||||
return await manager.clients[self.client_id].execute_function(tool_name, tool_args)
|
||||
else:
|
||||
logger.info('Reconnect failed: client_id is None')
|
||||
return 'Session reconnect (client creation) exception: client_id is None'
|
||||
except Exception as e3:
|
||||
logger.info(f'Reconnect (client creation) exception type: {type(e3)}, value: {repr(e3)}')
|
||||
return f'Session reconnect (client creation) exception: {e3}'
|
||||
if tool_name == 'list_resources':
|
||||
try:
|
||||
list_resources = await self.session.list_resources()
|
||||
if list_resources.resources:
|
||||
resources_str = '\n\n'.join(str(resource) for resource in list_resources.resources)
|
||||
else:
|
||||
resources_str = 'No resources found'
|
||||
return resources_str
|
||||
except Exception as e:
|
||||
logger.info(f'No list resources: {e}')
|
||||
return f'Error: {e}'
|
||||
elif tool_name == 'read_resource':
|
||||
try:
|
||||
uri = tool_args.get('uri')
|
||||
if not uri:
|
||||
raise ValueError('URI is required for read_resource')
|
||||
read_resource = await self.session.read_resource(uri)
|
||||
texts = []
|
||||
for resource in read_resource.contents:
|
||||
if isinstance(resource, TextResourceContents):
|
||||
texts.append(resource.text)
|
||||
# if isinstance(resource, BlobResourceContents):
|
||||
# texts.append(resource.blob)
|
||||
if texts:
|
||||
return '\n\n'.join(texts)
|
||||
else:
|
||||
return 'Failed to read resource'
|
||||
except Exception as e:
|
||||
logger.info(f'Failed to read resource: {e}')
|
||||
return f'Error: {e}'
|
||||
else:
|
||||
response = await self.session.call_tool(tool_name, tool_args)
|
||||
texts = []
|
||||
for content in response.content:
|
||||
if content.type == 'text':
|
||||
texts.append(content.text)
|
||||
if texts:
|
||||
return '\n\n'.join(texts)
|
||||
else:
|
||||
return 'execute error'
|
||||
|
||||
async def cleanup(self):
|
||||
await self.exit_stack.aclose()
|
||||
|
||||
|
||||
def _cleanup_mcp(_sig_num=None, _frame=None):
|
||||
if CustomMCPManager._instance is None:
|
||||
return
|
||||
manager = CustomMCPManager()
|
||||
manager.shutdown()
|
||||
|
||||
|
||||
# Make sure all subprocesses are terminated even if killed abnormally
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
atexit.register(_cleanup_mcp)
|
||||
63
agent/deep_assistant.py
Normal file
63
agent/deep_assistant.py
Normal file
@ -0,0 +1,63 @@
|
||||
import json
|
||||
from langchain.chat_models import init_chat_model
|
||||
from deepagents import create_deep_agent
|
||||
from langchain.agents import create_agent
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from utils.fastapi_utils import detect_provider
|
||||
|
||||
# Utility functions
|
||||
def read_system_prompt():
|
||||
"""读取通用的无状态系统prompt"""
|
||||
with open("./prompt/system_prompt_default.md", "r", encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
def read_mcp_settings():
|
||||
"""读取MCP工具配置"""
|
||||
with open("./mcp/mcp_settings.json", "r") as f:
|
||||
mcp_settings_json = json.load(f)
|
||||
return mcp_settings_json
|
||||
|
||||
async def init_agent(model_name="qwen3-next", api_key=None,
|
||||
model_server=None, generate_cfg=None,
|
||||
system_prompt=None, mcp=None):
|
||||
system = system_prompt if system_prompt else read_system_prompt()
|
||||
mcp = mcp if mcp else read_mcp_settings()
|
||||
# 修改mcp[0]["mcpServers"]列表,把 type 字段改成 transport ,如果没有的话,就默认transport:stdio
|
||||
if mcp and len(mcp) > 0 and "mcpServers" in mcp[0]:
|
||||
for server_name, server_config in mcp[0]["mcpServers"].items():
|
||||
if isinstance(server_config, dict):
|
||||
if "type" in server_config and "transport" not in server_config:
|
||||
# 如果有 type 字段但没有 transport 字段,将 type 改为 transport
|
||||
type_value = server_config.pop("type")
|
||||
# 特殊处理:'streamable-http' 改为 'http'
|
||||
if type_value == "streamable-http":
|
||||
server_config["transport"] = "http"
|
||||
else:
|
||||
server_config["transport"] = type_value
|
||||
elif "transport" not in server_config:
|
||||
# 如果既没有 type 也没有 transport,添加默认的 transport: stdio
|
||||
server_config["transport"] = "stdio"
|
||||
|
||||
mcp_client = MultiServerMCPClient(mcp[0]["mcpServers"])
|
||||
mcp_tools = await mcp_client.get_tools()
|
||||
|
||||
# 检测或使用指定的提供商
|
||||
model_provider,base_url = detect_provider(model_name,model_server)
|
||||
|
||||
# 构建模型参数
|
||||
model_kwargs = {
|
||||
"model": model_name,
|
||||
"model_provider": model_provider,
|
||||
"temperature": 0.8,
|
||||
"base_url":base_url,
|
||||
"api_key":api_key
|
||||
}
|
||||
llm_instance = init_chat_model(**model_kwargs)
|
||||
|
||||
agent = create_agent(
|
||||
model=llm_instance,
|
||||
system_prompt=system,
|
||||
tools=mcp_tools
|
||||
)
|
||||
return agent
|
||||
@ -1,285 +0,0 @@
|
||||
# Copyright 2023
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""文件预加载助手管理器 - 管理基于unique_id的助手实例缓存"""
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from qwen_agent.agents import Assistant
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
from agent.modified_assistant import init_modified_agent_service_with_files, update_agent_llm
|
||||
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
||||
|
||||
|
||||
class FileLoadedAgentManager:
|
||||
"""文件预加载助手管理器
|
||||
|
||||
基于 unique_id 缓存助手实例,避免重复创建和文件解析
|
||||
"""
|
||||
|
||||
def __init__(self, max_cached_agents: int = 20):
|
||||
self.agents: Dict[str, Assistant] = {} # {cache_key: assistant_instance}
|
||||
self.unique_ids: Dict[str, str] = {} # {cache_key: unique_id}
|
||||
self.access_times: Dict[str, float] = {} # LRU 访问时间管理
|
||||
self.creation_times: Dict[str, float] = {} # 创建时间记录
|
||||
self.max_cached_agents = max_cached_agents
|
||||
self._creation_locks: Dict[str, asyncio.Lock] = {} # 防止并发创建相同agent的锁
|
||||
|
||||
def _get_cache_key(self, bot_id: str, model_name: str = None, api_key: str = None,
|
||||
model_server: str = None, generate_cfg: Dict = None,
|
||||
system_prompt: str = None, mcp_settings: List[Dict] = None) -> str:
|
||||
"""获取包含所有相关参数的哈希值作为缓存键
|
||||
|
||||
Args:
|
||||
bot_id: 机器人项目ID
|
||||
model_name: 模型名称
|
||||
api_key: API密钥
|
||||
model_server: 模型服务器地址
|
||||
generate_cfg: 生成配置
|
||||
system_prompt: 系统提示词
|
||||
mcp_settings: MCP设置列表
|
||||
|
||||
Returns:
|
||||
str: 缓存键的哈希值
|
||||
"""
|
||||
# 构建包含所有相关参数的字符串
|
||||
cache_data = {
|
||||
'bot_id': bot_id,
|
||||
'model_name': model_name or '',
|
||||
'api_key': api_key or '',
|
||||
'model_server': model_server or '',
|
||||
'generate_cfg': json.dumps(generate_cfg or {}, sort_keys=True),
|
||||
'system_prompt': system_prompt or '',
|
||||
'mcp_settings': json.dumps(mcp_settings or [], sort_keys=True)
|
||||
}
|
||||
|
||||
# 将字典转换为JSON字符串并计算哈希值
|
||||
cache_str = json.dumps(cache_data, sort_keys=True)
|
||||
return hashlib.md5(cache_str.encode('utf-8')).hexdigest()[:16]
|
||||
|
||||
def _update_access_time(self, cache_key: str):
|
||||
"""更新访问时间(LRU 管理)"""
|
||||
self.access_times[cache_key] = time.time()
|
||||
|
||||
def _cleanup_old_agents(self):
|
||||
"""清理旧的助手实例,基于 LRU 策略"""
|
||||
if len(self.agents) <= self.max_cached_agents:
|
||||
return
|
||||
|
||||
# 按 LRU 顺序排序,删除最久未访问的实例
|
||||
sorted_keys = sorted(self.access_times.keys(), key=lambda k: self.access_times[k])
|
||||
|
||||
keys_to_remove = sorted_keys[:-self.max_cached_agents]
|
||||
removed_count = 0
|
||||
|
||||
for cache_key in keys_to_remove:
|
||||
try:
|
||||
del self.agents[cache_key]
|
||||
del self.unique_ids[cache_key]
|
||||
del self.access_times[cache_key]
|
||||
del self.creation_times[cache_key]
|
||||
removed_count += 1
|
||||
logger.info(f"清理过期的助手实例缓存: {cache_key}")
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"已清理 {removed_count} 个过期的助手实例缓存")
|
||||
|
||||
async def get_or_create_agent(self,
|
||||
bot_id: str,
|
||||
project_dir: Optional[str],
|
||||
model_name: str = "qwen3-next",
|
||||
api_key: Optional[str] = None,
|
||||
model_server: Optional[str] = None,
|
||||
generate_cfg: Optional[Dict] = None,
|
||||
language: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
mcp_settings: Optional[List[Dict]] = None,
|
||||
robot_type: Optional[str] = "general_agent",
|
||||
user_identifier: Optional[str] = None) -> Assistant:
|
||||
"""获取或创建文件预加载的助手实例
|
||||
|
||||
Args:
|
||||
bot_id: 项目的唯一标识符
|
||||
project_dir: 项目目录路径,用于读取README.md,可以为None
|
||||
model_name: 模型名称
|
||||
api_key: API 密钥
|
||||
model_server: 模型服务器地址
|
||||
generate_cfg: 生成配置
|
||||
language: 语言代码,用于选择对应的系统提示词
|
||||
system_prompt: 可选的系统提示词,优先级高于项目配置
|
||||
mcp_settings: 可选的MCP设置,优先级高于项目配置
|
||||
robot_type: 机器人类型,取值 agent/catalog_agent
|
||||
user_identifier: 用户标识符
|
||||
|
||||
Returns:
|
||||
Assistant: 配置好的助手实例
|
||||
"""
|
||||
import os
|
||||
|
||||
# 使用异步加载配置文件(带缓存)
|
||||
final_system_prompt = await load_system_prompt_async(
|
||||
project_dir, language, system_prompt, robot_type, bot_id, user_identifier
|
||||
)
|
||||
final_mcp_settings = await load_mcp_settings_async(
|
||||
project_dir, mcp_settings, bot_id, robot_type
|
||||
)
|
||||
|
||||
cache_key = self._get_cache_key(bot_id, model_name, api_key, model_server,
|
||||
generate_cfg, final_system_prompt, final_mcp_settings)
|
||||
|
||||
# 使用异步锁防止并发创建相同的agent
|
||||
creation_lock = self._creation_locks.setdefault(cache_key, asyncio.Lock())
|
||||
async with creation_lock:
|
||||
# 再次检查是否已存在该助手实例(获取锁后可能有其他请求已创建)
|
||||
if cache_key in self.agents:
|
||||
self._update_access_time(cache_key)
|
||||
agent = self.agents[cache_key]
|
||||
|
||||
# 动态更新 LLM 配置和系统设置(如果参数有变化)
|
||||
update_agent_llm(agent, model_name, api_key, model_server, generate_cfg)
|
||||
|
||||
logger.info(f"复用现有的助手实例缓存: {cache_key} (bot_id: {bot_id})")
|
||||
return agent
|
||||
|
||||
# 清理过期实例
|
||||
self._cleanup_old_agents()
|
||||
|
||||
# 创建新的助手实例,预加载文件
|
||||
logger.info(f"创建新的助手实例缓存: {cache_key}, bot_id: {bot_id}")
|
||||
current_time = time.time()
|
||||
|
||||
agent = init_modified_agent_service_with_files(
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
model_server=model_server,
|
||||
generate_cfg=generate_cfg,
|
||||
system_prompt=final_system_prompt,
|
||||
mcp=final_mcp_settings
|
||||
)
|
||||
|
||||
# 缓存实例
|
||||
self.agents[cache_key] = agent
|
||||
self.unique_ids[cache_key] = bot_id
|
||||
self.access_times[cache_key] = current_time
|
||||
self.creation_times[cache_key] = current_time
|
||||
|
||||
# 清理创建锁
|
||||
self._creation_locks.pop(cache_key, None)
|
||||
|
||||
logger.info(f"助手实例缓存创建完成: {cache_key}")
|
||||
return agent
|
||||
|
||||
def get_cache_stats(self) -> Dict:
|
||||
"""获取缓存统计信息"""
|
||||
current_time = time.time()
|
||||
stats = {
|
||||
"total_cached_agents": len(self.agents),
|
||||
"max_cached_agents": self.max_cached_agents,
|
||||
"agents": {}
|
||||
}
|
||||
|
||||
for cache_key, agent in self.agents.items():
|
||||
stats["agents"][cache_key] = {
|
||||
"unique_id": self.unique_ids.get(cache_key, "unknown"),
|
||||
"created_at": self.creation_times.get(cache_key, 0),
|
||||
"last_accessed": self.access_times.get(cache_key, 0),
|
||||
"age_seconds": int(current_time - self.creation_times.get(cache_key, current_time)),
|
||||
"idle_seconds": int(current_time - self.access_times.get(cache_key, current_time))
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def clear_cache(self) -> int:
|
||||
"""清空所有缓存
|
||||
|
||||
Returns:
|
||||
int: 清理的实例数量
|
||||
"""
|
||||
cache_count = len(self.agents)
|
||||
|
||||
self.agents.clear()
|
||||
self.unique_ids.clear()
|
||||
self.access_times.clear()
|
||||
self.creation_times.clear()
|
||||
|
||||
logger.info(f"已清空所有助手实例缓存,共清理 {cache_count} 个实例")
|
||||
return cache_count
|
||||
|
||||
def remove_cache_by_unique_id(self, unique_id: str) -> int:
|
||||
"""根据 unique_id 移除所有相关的缓存
|
||||
|
||||
由于缓存key现在包含 system_prompt 和 mcp_settings,
|
||||
一个 unique_id 可能对应多个缓存实例。
|
||||
|
||||
Args:
|
||||
unique_id: 项目的唯一标识符
|
||||
|
||||
Returns:
|
||||
int: 成功移除的实例数量
|
||||
"""
|
||||
keys_to_remove = []
|
||||
|
||||
# 找到所有匹配的 unique_id 的缓存键
|
||||
for cache_key, stored_unique_id in self.unique_ids.items():
|
||||
if stored_unique_id == unique_id:
|
||||
keys_to_remove.append(cache_key)
|
||||
|
||||
# 移除找到的缓存
|
||||
removed_count = 0
|
||||
for cache_key in keys_to_remove:
|
||||
try:
|
||||
del self.agents[cache_key]
|
||||
del self.unique_ids[cache_key]
|
||||
del self.access_times[cache_key]
|
||||
del self.creation_times[cache_key]
|
||||
self._creation_locks.pop(cache_key, None) # 清理创建锁
|
||||
removed_count += 1
|
||||
logger.info(f"已移除助手实例缓存: {cache_key} (unique_id: {unique_id})")
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"已移除 unique_id={unique_id} 的 {removed_count} 个助手实例缓存")
|
||||
else:
|
||||
logger.warning(f"未找到 unique_id={unique_id} 的缓存实例")
|
||||
|
||||
return removed_count
|
||||
|
||||
|
||||
# 全局文件预加载助手管理器实例
|
||||
_global_agent_manager: Optional[FileLoadedAgentManager] = None
|
||||
|
||||
|
||||
def get_global_agent_manager() -> FileLoadedAgentManager:
|
||||
"""获取全局文件预加载助手管理器实例"""
|
||||
global _global_agent_manager
|
||||
if _global_agent_manager is None:
|
||||
_global_agent_manager = FileLoadedAgentManager()
|
||||
return _global_agent_manager
|
||||
|
||||
|
||||
def init_global_agent_manager(max_cached_agents: int = 20):
|
||||
"""初始化全局文件预加载助手管理器"""
|
||||
global _global_agent_manager
|
||||
_global_agent_manager = FileLoadedAgentManager(max_cached_agents)
|
||||
return _global_agent_manager
|
||||
@ -1,287 +0,0 @@
|
||||
# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Iterator, List, Literal, Optional, Union
|
||||
|
||||
from qwen_agent.agents import Assistant
|
||||
from qwen_agent.llm.schema import ASSISTANT, FUNCTION, Message
|
||||
from qwen_agent.llm.oai import TextChatAtOAI
|
||||
from qwen_agent.tools import BaseTool
|
||||
from agent.custom_mcp_manager import CustomMCPManager
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
# 设置工具日志记录器
|
||||
tool_logger = logging.getLogger('app')
|
||||
|
||||
class ModifiedAssistant(Assistant):
|
||||
"""
|
||||
修改后的 Assistant 子类,改变循环判断逻辑:
|
||||
- 原始逻辑:如果没有使用工具,立即退出循环
|
||||
- 修改后逻辑:如果没有使用工具,调用模型判断回答是否完整,如果不完整则继续循环
|
||||
"""
|
||||
|
||||
def _is_retryable_error(self, error: Exception) -> bool:
|
||||
"""判断错误是否可重试
|
||||
|
||||
Args:
|
||||
error: 异常对象
|
||||
|
||||
Returns:
|
||||
bool: 是否可重试
|
||||
"""
|
||||
error_str = str(error).lower()
|
||||
retryable_indicators = [
|
||||
'502', '500', '503', '504', # HTTP错误代码
|
||||
'internal server error', # 内部服务器错误
|
||||
'timeout', # 超时
|
||||
'connection', # 连接错误
|
||||
'network', # 网络错误
|
||||
'rate', # 速率限制和相关错误
|
||||
'quota', # 配额限制
|
||||
'service unavailable', # 服务不可用
|
||||
'provider returned error', # Provider错误
|
||||
'model service error', # 模型服务错误
|
||||
'temporary', # 临时错误
|
||||
'retry' # 明确提示重试
|
||||
]
|
||||
return any(indicator in error_str for indicator in retryable_indicators)
|
||||
|
||||
def _call_tool(self, tool_name: str, tool_args: Union[str, dict] = '{}', **kwargs) -> str:
|
||||
"""重写工具调用方法,添加调试信息"""
|
||||
if tool_name not in self.function_map:
|
||||
error_msg = f'Tool {tool_name} does not exist. Available tools: {list(self.function_map.keys())}'
|
||||
tool_logger.error(error_msg)
|
||||
return error_msg
|
||||
|
||||
tool = self.function_map[tool_name]
|
||||
|
||||
try:
|
||||
tool_logger.info(f"开始调用工具: {tool_name} {tool_args}")
|
||||
start_time = time.time()
|
||||
|
||||
# 调用父类的_call_tool方法
|
||||
tool_result = super()._call_tool(tool_name, tool_args, **kwargs)
|
||||
|
||||
end_time = time.time()
|
||||
tool_logger.info(f"工具 {tool_name} 执行完成,耗时: {end_time - start_time:.2f}秒 结果长度: {len(tool_result) if tool_result else 0}")
|
||||
|
||||
# 打印部分结果内容(避免过长)
|
||||
if tool_result and len(tool_result) > 0:
|
||||
preview = tool_result[:200] if len(tool_result) > 200 else tool_result
|
||||
tool_logger.debug(f"工具 {tool_name} 结果预览: {preview}...")
|
||||
|
||||
return tool_result
|
||||
|
||||
except Exception as ex:
|
||||
end_time = time.time()
|
||||
tool_logger.error(f"工具调用异常,耗时: {end_time - start_time:.2f}秒 异常类型: {type(ex).__name__} {str(ex)}")
|
||||
|
||||
# 打印完整的堆栈跟踪
|
||||
import traceback
|
||||
tool_logger.error(f"堆栈跟踪:\n{traceback.format_exc()}")
|
||||
|
||||
# 返回详细的错误信息
|
||||
error_message = f'An error occurred when calling tool {tool_name}: {type(ex).__name__}: {str(ex)}'
|
||||
return error_message
|
||||
|
||||
def _init_tool(self, tool: Union[str, Dict, BaseTool]):
|
||||
"""重写工具初始化方法,使用CustomMCPManager处理MCP服务器配置"""
|
||||
if isinstance(tool, BaseTool):
|
||||
# 处理BaseTool实例
|
||||
tool_name = tool.name
|
||||
if tool_name in self.function_map:
|
||||
tool_logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list')
|
||||
self.function_map[tool_name] = tool
|
||||
elif isinstance(tool, dict) and 'mcpServers' in tool:
|
||||
# 使用CustomMCPManager处理MCP服务器配置,支持headers
|
||||
tools = CustomMCPManager().initConfig(tool)
|
||||
for tool in tools:
|
||||
tool_name = tool.name
|
||||
if tool_name in self.function_map:
|
||||
tool_logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list')
|
||||
self.function_map[tool_name] = tool
|
||||
else:
|
||||
# 调用父类的处理方法
|
||||
super()._init_tool(tool)
|
||||
|
||||
def _call_llm_with_retry(self, messages: List[Message], functions=None, extra_generate_cfg=None, max_retries: int = 5) -> Iterator:
|
||||
"""带重试机制的LLM调用
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
functions: 函数列表
|
||||
extra_generate_cfg: 额外生成配置
|
||||
max_retries: 最大重试次数
|
||||
|
||||
Returns:
|
||||
LLM响应流
|
||||
|
||||
Raises:
|
||||
Exception: 重试次数耗尽后重新抛出原始异常
|
||||
"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return self._call_llm(messages=messages, functions=functions, extra_generate_cfg=extra_generate_cfg)
|
||||
except Exception as e:
|
||||
# 检查是否为可重试的错误
|
||||
if self._is_retryable_error(e) and attempt < max_retries - 1:
|
||||
delay = 2 ** attempt # 指数退避: 1s, 2s, 4s
|
||||
tool_logger.warning(f"LLM调用失败 (尝试 {attempt + 1}/{max_retries}),{delay}秒后重试: {str(e)}")
|
||||
time.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
# 不可重试的错误或已达到最大重试次数
|
||||
if attempt > 0:
|
||||
tool_logger.error(f"LLM调用重试失败,已达到最大重试次数 {max_retries}")
|
||||
raise
|
||||
|
||||
def _run(self, messages: List[Message], lang: Literal['en', 'zh', 'ja'] = 'en', **kwargs) -> Iterator[List[Message]]:
|
||||
|
||||
message_list = copy.deepcopy(messages)
|
||||
response = []
|
||||
|
||||
# 保持原有的最大调用次数限制
|
||||
total_num_llm_calls_available = self.MAX_LLM_CALL_PER_RUN if hasattr(self, 'MAX_LLM_CALL_PER_RUN') else 100
|
||||
num_llm_calls_available = total_num_llm_calls_available
|
||||
while num_llm_calls_available > 0:
|
||||
num_llm_calls_available -= 1
|
||||
extra_generate_cfg = {'lang': lang}
|
||||
if kwargs.get('seed') is not None:
|
||||
extra_generate_cfg['seed'] = kwargs['seed']
|
||||
|
||||
output_stream = self._call_llm_with_retry(messages=message_list,
|
||||
functions=[func.function for func in self.function_map.values()],
|
||||
extra_generate_cfg=extra_generate_cfg)
|
||||
output: List[Message] = []
|
||||
for output in output_stream:
|
||||
if output:
|
||||
yield response + output
|
||||
|
||||
if output:
|
||||
response.extend(output)
|
||||
message_list.extend(output)
|
||||
|
||||
# 处理工具调用
|
||||
used_any_tool = False
|
||||
for out in output:
|
||||
use_tool, tool_name, tool_args, _ = self._detect_tool(out)
|
||||
if use_tool:
|
||||
tool_result = self._call_tool(tool_name, tool_args, messages=message_list, **kwargs)
|
||||
|
||||
# 验证工具结果
|
||||
if not tool_result:
|
||||
tool_logger.warning(f"工具 {tool_name} 返回空结果")
|
||||
tool_result = f"Tool {tool_name} completed execution but returned empty result"
|
||||
elif tool_result.startswith('An error occurred when calling tool') or tool_result.startswith('工具调用失败'):
|
||||
tool_logger.error(f"工具 {tool_name} 调用失败: {tool_result}")
|
||||
|
||||
fn_msg = Message(role=FUNCTION,
|
||||
name=tool_name,
|
||||
content=tool_result,
|
||||
extra={'function_id': out.extra.get('function_id', '1')})
|
||||
message_list.append(fn_msg)
|
||||
response.append(fn_msg)
|
||||
yield response
|
||||
used_any_tool = True
|
||||
|
||||
# 如果使用了工具,继续循环
|
||||
if not used_any_tool:
|
||||
break
|
||||
|
||||
# 检查是否因为调用次数用完而退出循环
|
||||
if num_llm_calls_available == 0:
|
||||
# 根据语言选择错误消息
|
||||
if lang == 'zh':
|
||||
error_message = "工具调用超出限制"
|
||||
elif lang == 'ja':
|
||||
error_message = "ツール呼び出しが制限を超えました。"
|
||||
else:
|
||||
error_message = "Tool calls exceeded limit"
|
||||
tool_logger.error(error_message)
|
||||
|
||||
error_msg = Message(
|
||||
role=ASSISTANT,
|
||||
content=error_message,
|
||||
)
|
||||
response.append(error_msg)
|
||||
|
||||
yield response
|
||||
|
||||
|
||||
# Utility functions
|
||||
def read_system_prompt():
|
||||
"""读取通用的无状态系统prompt"""
|
||||
with open("./prompt/system_prompt_default.md", "r", encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
def read_mcp_settings():
|
||||
"""读取MCP工具配置"""
|
||||
with open("./mcp/mcp_settings.json", "r") as f:
|
||||
mcp_settings_json = json.load(f)
|
||||
return mcp_settings_json
|
||||
|
||||
|
||||
def update_agent_llm(agent, model_name: str, api_key: str = None, model_server: str = None, generate_cfg: Dict = None):
|
||||
"""动态更新助手实例的LLM和配置,支持从接口传入参数"""
|
||||
# 获取基础配置
|
||||
llm_config = {
|
||||
"model": model_name,
|
||||
"api_key": api_key,
|
||||
"model_server": model_server,
|
||||
"generate_cfg": generate_cfg if generate_cfg else {}
|
||||
}
|
||||
|
||||
# 创建LLM实例
|
||||
llm_instance = TextChatAtOAI(llm_config)
|
||||
|
||||
# 动态设置LLM
|
||||
agent.llm = llm_instance
|
||||
return agent
|
||||
|
||||
|
||||
# 向后兼容:保持原有的初始化函数接口
|
||||
def init_modified_agent_service_with_files(rag_cfg=None,
|
||||
model_name="qwen3-next", api_key=None,
|
||||
model_server=None, generate_cfg=None,
|
||||
system_prompt=None, mcp=None):
|
||||
"""创建支持预加载文件的修改版助手实例"""
|
||||
system = system_prompt if system_prompt else read_system_prompt()
|
||||
tools = mcp if mcp else read_mcp_settings()
|
||||
|
||||
llm_config = {
|
||||
"model": model_name,
|
||||
"api_key": api_key,
|
||||
"model_server": model_server,
|
||||
"generate_cfg": generate_cfg if generate_cfg else {}
|
||||
}
|
||||
|
||||
# 创建LLM实例
|
||||
llm_instance = TextChatAtOAI(llm_config)
|
||||
|
||||
bot = ModifiedAssistant(
|
||||
llm=llm_instance,
|
||||
name="修改版数据检索助手",
|
||||
description="基于智能判断循环终止的助手",
|
||||
system_message=system,
|
||||
function_list=tools,
|
||||
)
|
||||
|
||||
return bot
|
||||
@ -18,17 +18,13 @@ import hashlib
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List, Optional
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
|
||||
from qwen_agent.agents import Assistant
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
from agent.modified_assistant import init_modified_agent_service_with_files, update_agent_llm
|
||||
from agent.deep_assistant import init_agent
|
||||
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
||||
|
||||
|
||||
@ -131,7 +127,7 @@ class ShardedAgentManager:
|
||||
system_prompt: Optional[str] = None,
|
||||
mcp_settings: Optional[List[Dict]] = None,
|
||||
robot_type: Optional[str] = "general_agent",
|
||||
user_identifier: Optional[str] = None) -> Assistant:
|
||||
user_identifier: Optional[str] = None):
|
||||
"""获取或创建文件预加载的助手实例"""
|
||||
|
||||
# 更新请求统计
|
||||
@ -159,10 +155,6 @@ class ShardedAgentManager:
|
||||
if cache_key in shard['agents']:
|
||||
self._update_access_time(shard, cache_key)
|
||||
agent = shard['agents'][cache_key]
|
||||
|
||||
# 动态更新 LLM 配置和系统设置
|
||||
update_agent_llm(agent, model_name, api_key, model_server, generate_cfg)
|
||||
|
||||
# 更新缓存命中统计
|
||||
with self._stats_lock:
|
||||
self._global_stats['cache_hits'] += 1
|
||||
@ -184,7 +176,6 @@ class ShardedAgentManager:
|
||||
if cache_key in shard['agents']:
|
||||
self._update_access_time(shard, cache_key)
|
||||
agent = shard['agents'][cache_key]
|
||||
update_agent_llm(agent, model_name, api_key, model_server, generate_cfg)
|
||||
|
||||
with self._stats_lock:
|
||||
self._global_stats['cache_hits'] += 1
|
||||
@ -199,7 +190,7 @@ class ShardedAgentManager:
|
||||
logger.info(f"分片创建新的助手实例缓存: {cache_key}, bot_id: {bot_id}, shard: {shard_index}")
|
||||
current_time = time.time()
|
||||
|
||||
agent = init_modified_agent_service_with_files(
|
||||
agent = await init_agent(
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
model_server=model_server,
|
||||
|
||||
1371
poetry.lock
generated
1371
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -6,12 +6,11 @@ authors = [
|
||||
{name = "朱潮",email = "zhuchaowe@users.noreply.github.com"}
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12.0"
|
||||
requires-python = ">=3.12,<4.0"
|
||||
dependencies = [
|
||||
"fastapi==0.116.1",
|
||||
"uvicorn==0.35.0",
|
||||
"requests==2.32.5",
|
||||
"qwen-agent[mcp,rag]==0.0.31",
|
||||
"pydantic==2.10.5",
|
||||
"python-dateutil==2.8.2",
|
||||
"torch==2.2.0",
|
||||
@ -27,6 +26,9 @@ dependencies = [
|
||||
"chardet>=5.0.0",
|
||||
"psutil (>=7.1.3,<8.0.0)",
|
||||
"uvloop (>=0.22.1,<0.23.0)",
|
||||
"deepagents (>=0.3.0,<0.4.0)",
|
||||
"langchain-mcp-adapters (>=0.2.1,<0.3.0)",
|
||||
"langchain-openai (>=1.1.1,<2.0.0)",
|
||||
]
|
||||
|
||||
[tool.poetry.requires-plugins]
|
||||
|
||||
233
requirements.txt
233
requirements.txt
@ -1,116 +1,117 @@
|
||||
aiofiles==25.1.0 ; python_version >= "3.12"
|
||||
aiohappyeyeballs==2.6.1 ; python_version >= "3.12"
|
||||
aiohttp==3.13.1 ; python_version >= "3.12"
|
||||
aiosignal==1.4.0 ; python_version >= "3.12"
|
||||
annotated-types==0.7.0 ; python_version >= "3.12"
|
||||
anyio==4.11.0 ; python_version >= "3.12"
|
||||
attrs==25.4.0 ; python_version >= "3.12"
|
||||
beautifulsoup4==4.14.2 ; python_full_version >= "3.12.0"
|
||||
certifi==2025.10.5 ; python_version >= "3.12"
|
||||
cffi==2.0.0 ; python_version >= "3.12" and platform_python_implementation != "PyPy"
|
||||
chardet==5.2.0 ; python_version >= "3.12"
|
||||
charset-normalizer==3.4.4 ; python_version >= "3.12"
|
||||
click==8.3.0 ; python_version >= "3.12"
|
||||
colorama==0.4.6 ; python_version >= "3.12" and platform_system == "Windows"
|
||||
cryptography==46.0.3 ; python_version >= "3.12"
|
||||
dashscope==1.24.6 ; python_full_version >= "3.12.0"
|
||||
distro==1.9.0 ; python_version >= "3.12"
|
||||
dotenv==0.9.9 ; python_full_version >= "3.12.0"
|
||||
et-xmlfile==2.0.0 ; python_version >= "3.12"
|
||||
eval-type-backport==0.2.2 ; python_version >= "3.12"
|
||||
fastapi==0.116.1 ; python_version >= "3.12"
|
||||
filelock==3.20.0 ; python_version >= "3.12"
|
||||
frozenlist==1.8.0 ; python_version >= "3.12"
|
||||
fsspec==2025.9.0 ; python_version >= "3.12"
|
||||
h11==0.16.0 ; python_version >= "3.12"
|
||||
hf-xet==1.1.10 ; python_version >= "3.12" and (platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "arm64" or platform_machine == "aarch64")
|
||||
httpcore==1.0.9 ; python_version >= "3.12"
|
||||
httpx-sse==0.4.3 ; python_version >= "3.12"
|
||||
httpx==0.28.1 ; python_version >= "3.12"
|
||||
huey==2.5.3 ; python_full_version >= "3.12.0"
|
||||
huggingface-hub==0.35.3 ; python_full_version >= "3.12.0"
|
||||
idna==3.11 ; python_version >= "3.12"
|
||||
jieba==0.42.1 ; python_full_version >= "3.12.0"
|
||||
jinja2==3.1.6 ; python_version >= "3.12"
|
||||
jiter==0.11.1 ; python_version >= "3.12"
|
||||
joblib==1.5.2 ; python_version >= "3.12"
|
||||
json5==0.12.1 ; python_full_version >= "3.12.0"
|
||||
jsonlines==4.0.0 ; python_version >= "3.12"
|
||||
jsonschema-specifications==2025.9.1 ; python_version >= "3.12"
|
||||
jsonschema==4.25.1 ; python_version >= "3.12"
|
||||
lxml==6.0.2 ; python_version >= "3.12"
|
||||
markupsafe==3.0.3 ; python_version >= "3.12"
|
||||
mcp==1.12.4 ; python_version >= "3.12"
|
||||
mpmath==1.3.0 ; python_full_version >= "3.12.0"
|
||||
multidict==6.7.0 ; python_version >= "3.12"
|
||||
networkx==3.5 ; python_version >= "3.12"
|
||||
numpy==1.26.4 ; python_version >= "3.12"
|
||||
nvidia-cublas-cu12==12.1.3.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
|
||||
nvidia-cuda-cupti-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
|
||||
nvidia-cuda-runtime-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
|
||||
nvidia-cudnn-cu12==8.9.2.26 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
|
||||
nvidia-cufft-cu12==11.0.2.54 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
|
||||
nvidia-curand-cu12==10.3.2.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
|
||||
nvidia-cusolver-cu12==11.4.5.107 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
|
||||
nvidia-cusparse-cu12==12.1.0.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
|
||||
nvidia-nccl-cu12==2.19.3 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
|
||||
nvidia-nvjitlink-cu12==12.9.86 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
|
||||
nvidia-nvtx-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.12"
|
||||
openai==2.5.0 ; python_version >= "3.12"
|
||||
openpyxl==3.1.5 ; python_version >= "3.12"
|
||||
packaging==25.0 ; python_version >= "3.12"
|
||||
pandas==2.3.3 ; python_version >= "3.12"
|
||||
pdfminer-six==20250506 ; python_version >= "3.12"
|
||||
pdfplumber==0.11.7 ; python_version >= "3.12"
|
||||
pillow==12.0.0 ; python_version >= "3.12"
|
||||
propcache==0.4.1 ; python_version >= "3.12"
|
||||
psutil==7.1.3 ; python_version >= "3.12"
|
||||
pycparser==2.23 ; platform_python_implementation != "PyPy" and implementation_name != "PyPy" and python_version >= "3.12"
|
||||
pydantic-core==2.27.2 ; python_version >= "3.12"
|
||||
pydantic-settings==2.11.0 ; python_version >= "3.12"
|
||||
pydantic==2.10.5 ; python_version >= "3.12"
|
||||
pypdfium2==4.30.0 ; python_version >= "3.12"
|
||||
python-dateutil==2.8.2 ; python_version >= "3.12"
|
||||
python-docx==1.2.0 ; python_version >= "3.12"
|
||||
python-dotenv==1.1.1 ; python_version >= "3.12"
|
||||
python-multipart==0.0.20 ; python_version >= "3.12"
|
||||
python-pptx==1.0.2 ; python_version >= "3.12"
|
||||
pytz==2025.2 ; python_full_version >= "3.12.0"
|
||||
pywin32==311 ; python_full_version >= "3.12.0" and sys_platform == "win32"
|
||||
pyyaml==6.0.3 ; python_version >= "3.12"
|
||||
qwen-agent==0.0.31 ; python_full_version >= "3.12.0"
|
||||
rank-bm25==0.2.2 ; python_full_version >= "3.12.0"
|
||||
referencing==0.37.0 ; python_version >= "3.12"
|
||||
regex==2025.9.18 ; python_version >= "3.12"
|
||||
requests==2.32.5 ; python_version >= "3.12"
|
||||
rpds-py==0.27.1 ; python_version >= "3.12"
|
||||
safetensors==0.6.2 ; python_version >= "3.12"
|
||||
scikit-learn==1.7.2 ; python_version >= "3.12"
|
||||
scipy==1.16.2 ; python_version >= "3.12"
|
||||
sentence-transformers==5.1.1 ; python_version >= "3.12"
|
||||
six==1.17.0 ; python_version >= "3.12"
|
||||
sniffio==1.3.1 ; python_version >= "3.12"
|
||||
snowballstemmer==3.0.1 ; python_version >= "3.12"
|
||||
soupsieve==2.8 ; python_version >= "3.12"
|
||||
sse-starlette==3.0.2 ; python_version >= "3.12"
|
||||
starlette==0.47.3 ; python_version >= "3.12"
|
||||
sympy==1.14.0 ; python_version >= "3.12"
|
||||
tabulate==0.9.0 ; python_version >= "3.12"
|
||||
threadpoolctl==3.6.0 ; python_version >= "3.12"
|
||||
tiktoken==0.12.0 ; python_version >= "3.12"
|
||||
tokenizers==0.22.1 ; python_version >= "3.12"
|
||||
torch==2.2.0 ; python_full_version >= "3.12.0"
|
||||
tqdm==4.67.1 ; python_version >= "3.12"
|
||||
transformers==4.57.1 ; python_full_version >= "3.12.0"
|
||||
triton==2.2.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_full_version >= "3.12.0"
|
||||
typing-extensions==4.15.0 ; python_version >= "3.12"
|
||||
typing-inspection==0.4.2 ; python_version >= "3.12"
|
||||
tzdata==2025.2 ; python_version >= "3.12"
|
||||
urllib3==2.5.0 ; python_version >= "3.12"
|
||||
uvicorn==0.35.0 ; python_version >= "3.12"
|
||||
uvloop==0.22.1 ; python_full_version >= "3.12.0"
|
||||
websocket-client==1.9.0 ; python_version >= "3.12"
|
||||
xlrd==2.0.2 ; python_version >= "3.12"
|
||||
xlsxwriter==3.2.9 ; python_version >= "3.12"
|
||||
yarl==1.22.0 ; python_version >= "3.12"
|
||||
aiofiles==25.1.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
aiohappyeyeballs==2.6.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
aiohttp==3.13.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
aiosignal==1.4.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
annotated-types==0.7.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
anthropic==0.75.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
anyio==4.11.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
attrs==25.4.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
bracex==2.6 ; python_version >= "3.12" and python_version < "4.0"
|
||||
certifi==2025.10.5 ; python_version >= "3.12" and python_version < "4.0"
|
||||
chardet==5.2.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
charset-normalizer==3.4.4 ; python_version >= "3.12" and python_version < "4.0"
|
||||
click==8.3.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
colorama==0.4.6 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Windows"
|
||||
deepagents==0.3.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
distro==1.9.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
docstring-parser==0.17.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
et-xmlfile==2.0.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
fastapi==0.116.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
filelock==3.20.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
frozenlist==1.8.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
fsspec==2025.9.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
h11==0.16.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
hf-xet==1.1.10 ; python_version >= "3.12" and python_version < "4.0" and (platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "arm64" or platform_machine == "aarch64")
|
||||
httpcore==1.0.9 ; python_version >= "3.12" and python_version < "4.0"
|
||||
httpx-sse==0.4.3 ; python_version >= "3.12" and python_version < "4.0"
|
||||
httpx==0.28.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
huey==2.5.3 ; python_version >= "3.12" and python_version < "4.0"
|
||||
huggingface-hub==0.35.3 ; python_version >= "3.12" and python_version < "4.0"
|
||||
idna==3.11 ; python_version >= "3.12" and python_version < "4.0"
|
||||
jinja2==3.1.6 ; python_version >= "3.12" and python_version < "4.0"
|
||||
jiter==0.11.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
joblib==1.5.2 ; python_version >= "3.12" and python_version < "4.0"
|
||||
jsonpatch==1.33 ; python_version >= "3.12" and python_version < "4.0"
|
||||
jsonpointer==3.0.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
jsonschema-specifications==2025.9.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
jsonschema==4.25.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
langchain-anthropic==1.2.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
langchain-core==1.1.3 ; python_version >= "3.12" and python_version < "4.0"
|
||||
langchain-mcp-adapters==0.2.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
langchain-openai==1.1.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
langchain==1.1.3 ; python_version >= "3.12" and python_version < "4.0"
|
||||
langgraph-checkpoint==3.0.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
langgraph-prebuilt==1.0.5 ; python_version >= "3.12" and python_version < "4.0"
|
||||
langgraph-sdk==0.2.15 ; python_version >= "3.12" and python_version < "4.0"
|
||||
langgraph==1.0.4 ; python_version >= "3.12" and python_version < "4.0"
|
||||
langsmith==0.4.59 ; python_version >= "3.12" and python_version < "4.0"
|
||||
markupsafe==3.0.3 ; python_version >= "3.12" and python_version < "4.0"
|
||||
mcp==1.12.4 ; python_version >= "3.12" and python_version < "4.0"
|
||||
mpmath==1.3.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
multidict==6.7.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
networkx==3.5 ; python_version >= "3.12" and python_version < "4.0"
|
||||
numpy==1.26.4 ; python_version >= "3.12" and python_version < "4.0"
|
||||
nvidia-cublas-cu12==12.1.3.1 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-cuda-cupti-cu12==12.1.105 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-cuda-runtime-cu12==12.1.105 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-cudnn-cu12==8.9.2.26 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-cufft-cu12==11.0.2.54 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-curand-cu12==10.3.2.106 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-cusolver-cu12==11.4.5.107 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-cusparse-cu12==12.1.0.106 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-nccl-cu12==2.19.3 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-nvjitlink-cu12==12.9.86 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
nvidia-nvtx-cu12==12.1.105 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
openai==2.5.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
openpyxl==3.1.5 ; python_version >= "3.12" and python_version < "4.0"
|
||||
orjson==3.11.5 ; python_version >= "3.12" and python_version < "4.0"
|
||||
ormsgpack==1.12.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
packaging==25.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
pandas==2.3.3 ; python_version >= "3.12" and python_version < "4.0"
|
||||
pillow==12.0.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
propcache==0.4.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
psutil==7.1.3 ; python_version >= "3.12" and python_version < "4.0"
|
||||
pydantic-core==2.27.2 ; python_version >= "3.12" and python_version < "4.0"
|
||||
pydantic-settings==2.11.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
pydantic==2.10.5 ; python_version >= "3.12" and python_version < "4.0"
|
||||
python-dateutil==2.8.2 ; python_version >= "3.12" and python_version < "4.0"
|
||||
python-dotenv==1.1.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
python-multipart==0.0.20 ; python_version >= "3.12" and python_version < "4.0"
|
||||
pytz==2025.2 ; python_version >= "3.12" and python_version < "4.0"
|
||||
pywin32==311 ; python_version >= "3.12" and python_version < "4.0" and sys_platform == "win32"
|
||||
pyyaml==6.0.3 ; python_version >= "3.12" and python_version < "4.0"
|
||||
referencing==0.37.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
regex==2025.9.18 ; python_version >= "3.12" and python_version < "4.0"
|
||||
requests-toolbelt==1.0.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
requests==2.32.5 ; python_version >= "3.12" and python_version < "4.0"
|
||||
rpds-py==0.27.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
safetensors==0.6.2 ; python_version >= "3.12" and python_version < "4.0"
|
||||
scikit-learn==1.7.2 ; python_version >= "3.12" and python_version < "4.0"
|
||||
scipy==1.16.2 ; python_version >= "3.12" and python_version < "4.0"
|
||||
sentence-transformers==5.1.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
six==1.17.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
sniffio==1.3.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
sse-starlette==3.0.2 ; python_version >= "3.12" and python_version < "4.0"
|
||||
starlette==0.47.3 ; python_version >= "3.12" and python_version < "4.0"
|
||||
sympy==1.14.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
tenacity==9.1.2 ; python_version >= "3.12" and python_version < "4.0"
|
||||
threadpoolctl==3.6.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
tiktoken==0.12.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
tokenizers==0.22.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
torch==2.2.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
tqdm==4.67.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
transformers==4.57.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
triton==2.2.0 ; python_version >= "3.12" and python_version < "4.0" and platform_system == "Linux" and platform_machine == "x86_64"
|
||||
typing-extensions==4.15.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
typing-inspection==0.4.2 ; python_version >= "3.12" and python_version < "4.0"
|
||||
tzdata==2025.2 ; python_version >= "3.12" and python_version < "4.0"
|
||||
urllib3==2.5.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
uuid-utils==0.12.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
uvicorn==0.35.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
uvloop==0.22.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
wcmatch==10.1 ; python_version >= "3.12" and python_version < "4.0"
|
||||
xlrd==2.0.2 ; python_version >= "3.12" and python_version < "4.0"
|
||||
xxhash==3.6.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
yarl==1.22.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
zstandard==0.25.0 ; python_version >= "3.12" and python_version < "4.0"
|
||||
|
||||
@ -17,9 +17,10 @@ from agent.prompt_loader import load_guideline_prompt
|
||||
from utils.fastapi_utils import (
|
||||
process_messages, extract_block_from_system_prompt, format_messages_to_chat_history,
|
||||
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
|
||||
_get_optimal_batch_size, process_guideline, get_content_from_messages, call_preamble_llm, get_preamble_text, get_language_text,
|
||||
process_guideline, call_preamble_llm, get_preamble_text, get_language_text,
|
||||
create_stream_chunk
|
||||
)
|
||||
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@ -180,9 +181,17 @@ Action: Provide concise, friendly, and personified natural responses.
|
||||
all_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理结果
|
||||
agent = all_results[0] if len(all_results) >0 else None # agent创建的结果
|
||||
agent = all_results[0] if len(all_results) > 0 else None # agent创建的结果
|
||||
|
||||
guideline_reasoning = all_results[1] if len(all_results) >1 else ""
|
||||
# 检查agent是否为异常对象
|
||||
if isinstance(agent, Exception):
|
||||
logger.error(f"Error creating agent: {agent}")
|
||||
raise agent
|
||||
|
||||
guideline_reasoning = all_results[1] if len(all_results) > 1 else ""
|
||||
if isinstance(guideline_reasoning, Exception):
|
||||
logger.error(f"Error in guideline processing: {guideline_reasoning}")
|
||||
guideline_reasoning = ""
|
||||
if guideline_prompt or guideline_reasoning:
|
||||
logger.info("Guideline Prompt: %s, Reasoning: %s",
|
||||
guideline_prompt.replace('\n', '\\n') if guideline_prompt else "None",
|
||||
@ -243,7 +252,7 @@ async def enhanced_generate_stream_response(
|
||||
preamble_text = await preamble_task
|
||||
# 只有当preamble_text不为空且不为"<empty>"时才输出
|
||||
if preamble_text and preamble_text.strip() and preamble_text != "<empty>":
|
||||
preamble_content = get_content_from_messages([{"role": "preamble","content": preamble_text + "\n"}], tool_response=tool_response)
|
||||
preamble_content = f"[PREAMBLE]\n{preamble_text}\n"
|
||||
chunk_data = create_stream_chunk(f"chatcmpl-preamble", model_name, preamble_content)
|
||||
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
|
||||
logger.info(f"Stream mode: Generated preamble text ({len(preamble_text)} chars)")
|
||||
@ -257,7 +266,7 @@ async def enhanced_generate_stream_response(
|
||||
|
||||
# 立即发送guideline_reasoning
|
||||
if guideline_reasoning:
|
||||
guideline_content = get_content_from_messages([{"role": "assistant","reasoning_content": guideline_reasoning+ "\n"}], tool_response=tool_response)
|
||||
guideline_content = f"[THINK]\n{guideline_reasoning}\n"
|
||||
chunk_data = create_stream_chunk(f"chatcmpl-guideline", model_name, guideline_content)
|
||||
yield f"data: {json.dumps(chunk_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
@ -270,20 +279,41 @@ async def enhanced_generate_stream_response(
|
||||
else:
|
||||
final_messages = append_assistant_last_message(final_messages, f"\n\nlanguage:{get_language_text(language)}")
|
||||
|
||||
logger.debug(f"Final messages for agent (showing first 2): {final_messages[:2]}")
|
||||
|
||||
# 第三阶段:agent响应流式传输
|
||||
logger.info(f"Starting agent stream response")
|
||||
accumulated_content = ""
|
||||
chunk_id = 0
|
||||
for response in agent.run(messages=final_messages):
|
||||
previous_content = accumulated_content
|
||||
accumulated_content = get_content_from_messages(response, tool_response=tool_response)
|
||||
|
||||
# 计算新增的内容
|
||||
if accumulated_content.startswith(previous_content):
|
||||
new_content = accumulated_content[len(previous_content):]
|
||||
else:
|
||||
new_content = accumulated_content
|
||||
previous_content = ""
|
||||
message_tag = ""
|
||||
function_name = ""
|
||||
tool_args = ""
|
||||
async for msg,metadata in agent.astream({"messages": final_messages}, stream_mode="messages"):
|
||||
new_content = ""
|
||||
if isinstance(msg, AIMessageChunk):
|
||||
# 判断是否有工具调用
|
||||
if msg.tool_call_chunks: # 检查工具调用块
|
||||
if message_tag != "TOOL_CALL":
|
||||
message_tag = "TOOL_CALL"
|
||||
if msg.tool_call_chunks[0]["name"]:
|
||||
function_name = msg.tool_call_chunks[0]["name"]
|
||||
if msg.tool_call_chunks[0]["args"]:
|
||||
tool_args += msg.tool_call_chunks[0]["args"]
|
||||
elif len(msg.content)>0:
|
||||
if message_tag != "ANSWER":
|
||||
message_tag = "ANSWER"
|
||||
new_content = f"[{message_tag}]\n{msg.text}"
|
||||
elif message_tag == "ANSWER":
|
||||
new_content = msg.text
|
||||
elif message_tag == "TOOL_CALL" and \
|
||||
(
|
||||
("finish_reason" in msg.response_metadata and msg.response_metadata["finish_reason"] == "tool_calls") or \
|
||||
("stop_reason" in msg.response_metadata and msg.response_metadata["stop_reason"] == "tool_use")
|
||||
):
|
||||
new_content = f"[{message_tag}] {function_name}\n{tool_args}"
|
||||
message_tag = "TOOL_CALL"
|
||||
elif isinstance(msg, ToolMessage) and len(msg.content)>0:
|
||||
message_tag = "TOOL_RESPONSE"
|
||||
new_content = f"[{message_tag}] {msg.name}\n{msg.text}"
|
||||
|
||||
# 只有当有新内容时才发送chunk
|
||||
if new_content:
|
||||
@ -379,36 +409,44 @@ async def create_agent_and_generate_response(
|
||||
|
||||
# 准备最终的消息
|
||||
final_messages = messages.copy()
|
||||
pre_message_list = []
|
||||
if guideline_reasoning:
|
||||
# 用###分割guideline_reasoning,取最后一段作为Guidelines
|
||||
guidelines_text = guideline_reasoning.split('###')[-1].strip() if guideline_reasoning else ""
|
||||
final_messages = append_assistant_last_message(final_messages, f"language:{get_language_text(language)}\n\nGuidelines:\n{guidelines_text}\n I will follow these guidelines step by step.")
|
||||
pre_message_list.append({"role": "assistant","reasoning_content": guideline_reasoning+ "\n"})
|
||||
else:
|
||||
final_messages = append_assistant_last_message(final_messages, f"\n\nlanguage:{get_language_text(language)}")
|
||||
|
||||
# 非流式响应
|
||||
agent_responses = agent.run_nonstream(final_messages)
|
||||
final_responses = pre_message_list + agent_responses
|
||||
if final_responses and len(final_responses) > 0:
|
||||
# 使用 get_content_from_messages 处理响应,支持 tool_response 参数
|
||||
content = get_content_from_messages(final_responses, tool_response=tool_response)
|
||||
agent_responses = await agent.ainvoke({"messages": final_messages})
|
||||
append_messages = agent_responses["messages"][len(final_messages):]
|
||||
# agent_responses = agent.run_nonstream(final_messages)
|
||||
response_text = ""
|
||||
if guideline_reasoning:
|
||||
response_text += "[THINK]\n"+guideline_reasoning+ "\n"
|
||||
for msg in append_messages:
|
||||
if isinstance(msg,AIMessage):
|
||||
if len(msg.text)>0:
|
||||
response_text += "[ANSWER]\n"+msg.text+ "\n"
|
||||
if len(msg.tool_calls)>0:
|
||||
response_text += "".join([f"[TOOL_CALL] {tool['name']}\n{json.dumps(tool["args"]) if isinstance(tool["args"],dict) else tool["args"]}\n" for tool in msg.tool_calls])
|
||||
elif isinstance(msg,ToolMessage) and tool_response:
|
||||
response_text += f"[TOOL_RESPONSE] {msg.name}\n{msg.text}\n"
|
||||
|
||||
if len(response_text) > 0:
|
||||
# 构造OpenAI格式的响应
|
||||
return ChatResponse(
|
||||
choices=[{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": content
|
||||
"content": response_text
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
usage={
|
||||
"prompt_tokens": sum(len(msg.get("content", "")) for msg in messages),
|
||||
"completion_tokens": len(content),
|
||||
"total_tokens": sum(len(msg.get("content", "")) for msg in messages) + len(content)
|
||||
"completion_tokens": len(response_text),
|
||||
"total_tokens": sum(len(msg.get("content", "")) for msg in messages) + len(response_text)
|
||||
}
|
||||
)
|
||||
else:
|
||||
|
||||
@ -16,7 +16,7 @@ try:
|
||||
except ImportError:
|
||||
def apply_optimization_profile(profile):
|
||||
return {"profile": profile, "status": "system_optimizer not available"}
|
||||
from utils.fastapi_utils import get_content_from_messages
|
||||
|
||||
from embedding import get_model_manager
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
@ -22,7 +22,6 @@ from .dataset_manager import (
|
||||
)
|
||||
|
||||
from .project_manager import (
|
||||
get_content_from_messages,
|
||||
generate_project_readme,
|
||||
save_project_readme,
|
||||
get_project_status,
|
||||
@ -31,20 +30,7 @@ from .project_manager import (
|
||||
get_project_stats
|
||||
)
|
||||
|
||||
# Import agent management modules
|
||||
# Note: These have been moved to agent package
|
||||
# from .file_loaded_agent_manager import (
|
||||
# get_global_agent_manager,
|
||||
# init_global_agent_manager
|
||||
# )
|
||||
|
||||
# Import optimized modules
|
||||
# Note: These have been moved to agent package
|
||||
# from .sharded_agent_manager import (
|
||||
# ShardedAgentManager,
|
||||
# get_global_sharded_agent_manager,
|
||||
# init_global_sharded_agent_manager
|
||||
# )
|
||||
|
||||
from .connection_pool import (
|
||||
HTTPConnectionPool,
|
||||
@ -153,7 +139,6 @@ __all__ = [
|
||||
'remove_dataset_directory_by_key',
|
||||
|
||||
# project_manager
|
||||
'get_content_from_messages',
|
||||
'generate_project_readme',
|
||||
'save_project_readme',
|
||||
'get_project_status',
|
||||
@ -161,10 +146,6 @@ __all__ = [
|
||||
'list_projects',
|
||||
'get_project_stats',
|
||||
|
||||
# file_loaded_agent_manager (moved to agent package)
|
||||
# 'get_global_agent_manager',
|
||||
# 'init_global_agent_manager',
|
||||
|
||||
# agent_pool
|
||||
'AgentPool',
|
||||
'get_agent_pool',
|
||||
|
||||
@ -6,10 +6,14 @@ import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Dict, Optional, Union, Any
|
||||
import aiohttp
|
||||
from qwen_agent.llm.schema import ASSISTANT, FUNCTION
|
||||
from qwen_agent.llm.oai import TextChatAtOAI
|
||||
from fastapi import HTTPException
|
||||
import logging
|
||||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
|
||||
from langchain.chat_models import init_chat_model
|
||||
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
TOOL = "tool"
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
@ -19,6 +23,16 @@ thread_pool = ThreadPoolExecutor(max_workers=10)
|
||||
# 创建并发信号量,限制同时进行的API调用数量
|
||||
api_semaphore = asyncio.Semaphore(8) # 最多同时进行8个API调用
|
||||
|
||||
def detect_provider(model_name,model_server):
|
||||
"""根据模型名称检测提供商类型"""
|
||||
model_name_lower = model_name.lower()
|
||||
if any(claude_model in model_name_lower for claude_model in ["claude", "anthropic"]):
|
||||
return "anthropic",model_server.replace("/v1","")
|
||||
elif any(openai_model in model_name_lower for openai_model in ["gpt", "openai", "o1"]):
|
||||
return "openai",model_server
|
||||
else:
|
||||
# 默认使用 openai 兼容格式
|
||||
return "openai",model_server
|
||||
|
||||
def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extension: str) -> tuple[str, int]:
|
||||
"""
|
||||
@ -93,50 +107,50 @@ def create_stream_chunk(chunk_id: str, model_name: str, content: str = None, fin
|
||||
}
|
||||
return chunk_data
|
||||
|
||||
def get_content_from_messages(messages: List[dict], tool_response: bool = True) -> str:
|
||||
"""Extract content from qwen-agent messages with special formatting"""
|
||||
full_text = ''
|
||||
content = []
|
||||
TOOL_CALL_S = '[TOOL_CALL]'
|
||||
TOOL_RESULT_S = '[TOOL_RESPONSE]'
|
||||
THOUGHT_S = '[THINK]'
|
||||
ANSWER_S = '[ANSWER]'
|
||||
PREAMBLE_S = '[PREAMBLE]'
|
||||
# def get_content_from_messages(messages: List[dict], tool_response: bool = True) -> str:
|
||||
# """Extract content from qwen-agent messages with special formatting"""
|
||||
# full_text = ''
|
||||
# content = []
|
||||
# TOOL_CALL_S = '[TOOL_CALL]'
|
||||
# TOOL_RESULT_S = '[TOOL_RESPONSE]'
|
||||
# THOUGHT_S = '[THINK]'
|
||||
# ANSWER_S = '[ANSWER]'
|
||||
# PREAMBLE_S = '[PREAMBLE]'
|
||||
|
||||
for msg in messages:
|
||||
if msg['role'] == ASSISTANT:
|
||||
if msg.get('reasoning_content'):
|
||||
assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
|
||||
content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
|
||||
if msg.get('content'):
|
||||
assert isinstance(msg['content'], str), 'Now only supports text messages'
|
||||
# 过滤掉流式输出中的不完整 tool_call 文本
|
||||
content_text = msg["content"]
|
||||
# for msg in messages:
|
||||
# if msg['role'] == ASSISTANT:
|
||||
# if msg.get('reasoning_content'):
|
||||
# assert isinstance(msg['reasoning_content'], str), 'Now only supports text messages'
|
||||
# content.append(f'{THOUGHT_S}\n{msg["reasoning_content"]}')
|
||||
# if msg.get('content'):
|
||||
# assert isinstance(msg['content'], str), 'Now only supports text messages'
|
||||
# # 过滤掉流式输出中的不完整 tool_call 文本
|
||||
# content_text = msg["content"]
|
||||
|
||||
# 使用正则表达式替换不完整的 tool_call 模式为空字符串
|
||||
# # 使用正则表达式替换不完整的 tool_call 模式为空字符串
|
||||
|
||||
# 匹配并替换不完整的 tool_call 模式
|
||||
content_text = re.sub(r'<t?o?o?l?_?c?a?l?l?$', '', content_text)
|
||||
# 只有在处理后内容不为空时才添加
|
||||
if content_text.strip():
|
||||
content.append(f'{ANSWER_S}\n{content_text}')
|
||||
if msg.get('function_call'):
|
||||
content_text = msg["function_call"]["arguments"]
|
||||
content_text = re.sub(r'}\n<\/?t?o?o?l?_?c?a?l?l?$', '', content_text)
|
||||
if content_text.strip():
|
||||
content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{content_text}')
|
||||
elif msg['role'] == FUNCTION:
|
||||
if tool_response:
|
||||
content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
|
||||
elif msg['role'] == "preamble":
|
||||
content.append(f'{PREAMBLE_S}\n{msg["content"]}')
|
||||
else:
|
||||
raise TypeError
|
||||
# # 匹配并替换不完整的 tool_call 模式
|
||||
# content_text = re.sub(r'<t?o?o?l?_?c?a?l?l?$', '', content_text)
|
||||
# # 只有在处理后内容不为空时才添加
|
||||
# if content_text.strip():
|
||||
# content.append(f'{ANSWER_S}\n{content_text}')
|
||||
# if msg.get('function_call'):
|
||||
# content_text = msg["function_call"]["arguments"]
|
||||
# content_text = re.sub(r'}\n<\/?t?o?o?l?_?c?a?l?l?$', '', content_text)
|
||||
# if content_text.strip():
|
||||
# content.append(f'{TOOL_CALL_S} {msg["function_call"]["name"]}\n{content_text}')
|
||||
# elif msg['role'] == FUNCTION:
|
||||
# if tool_response:
|
||||
# content.append(f'{TOOL_RESULT_S} {msg["name"]}\n{msg["content"]}')
|
||||
# elif msg['role'] == "preamble":
|
||||
# content.append(f'{PREAMBLE_S}\n{msg["content"]}')
|
||||
# else:
|
||||
# raise TypeError
|
||||
|
||||
if content:
|
||||
full_text = '\n'.join(content)
|
||||
# if content:
|
||||
# full_text = '\n'.join(content)
|
||||
|
||||
return full_text
|
||||
# return full_text
|
||||
|
||||
|
||||
def process_messages(messages: List[Dict], language: Optional[str] = None) -> List[Dict[str, str]]:
|
||||
@ -156,13 +170,13 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
|
||||
processed_messages = []
|
||||
|
||||
# 收集所有ASSISTANT消息的索引
|
||||
assistant_indices = [i for i, msg in enumerate(messages) if msg.role == "assistant"]
|
||||
assistant_indices = [i for i, msg in enumerate(messages) if msg.role == ASSISTANT]
|
||||
total_assistant_messages = len(assistant_indices)
|
||||
cutoff_point = max(0, total_assistant_messages - 5)
|
||||
|
||||
# 处理每条消息
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role == "assistant":
|
||||
if msg.role == ASSISTANT:
|
||||
# 确定当前ASSISTANT消息在所有ASSISTANT消息中的位置(从0开始)
|
||||
assistant_position = assistant_indices.index(i)
|
||||
|
||||
@ -236,7 +250,7 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
|
||||
assistant_content = ""
|
||||
function_calls = []
|
||||
tool_responses = []
|
||||
|
||||
tool_id = ""
|
||||
for i in range(0, len(parts)):
|
||||
if i % 2 == 0: # 文本内容
|
||||
text = parts[i].strip()
|
||||
@ -259,8 +273,10 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
|
||||
break
|
||||
|
||||
if should_include:
|
||||
# 将 TOOL_RESPONSE 包装成 tool_result 消息,紧跟对应的 tool_use
|
||||
final_messages.append({
|
||||
"role": FUNCTION,
|
||||
"role": TOOL,
|
||||
"tool_call_id": tool_id, # 与前面 tool_use 的 id 保持一致
|
||||
"name": function_name,
|
||||
"content": response_content
|
||||
})
|
||||
@ -279,13 +295,17 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
|
||||
break
|
||||
|
||||
if should_include:
|
||||
tool_id = f"tool_id_{i}"
|
||||
final_messages.append({
|
||||
"role": ASSISTANT,
|
||||
"content": "",
|
||||
"function_call": {
|
||||
"name": function_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
"tool_calls": [{
|
||||
"id":tool_id,
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
}]
|
||||
})
|
||||
elif current_tag != "THINK" and current_tag != "PREAMBLE":
|
||||
final_messages.append({
|
||||
@ -297,7 +317,6 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
|
||||
else:
|
||||
# 非 assistant 消息或不包含 [TOOL_RESPONSE] 的消息直接添加
|
||||
final_messages.append(msg)
|
||||
print(final_messages)
|
||||
return final_messages
|
||||
|
||||
|
||||
@ -317,16 +336,19 @@ def format_messages_to_chat_history(messages: List[Dict[str, str]]) -> str:
|
||||
for message in messages:
|
||||
role = message.get('role', '')
|
||||
content = message.get('content', '')
|
||||
if role == 'user':
|
||||
name = message.get('name', '')
|
||||
if role == USER:
|
||||
chat_history.append(f"user: {content}")
|
||||
elif role == FUNCTION:
|
||||
chat_history.append(f"function_response: {content}")
|
||||
elif role == TOOL:
|
||||
chat_history.append(f"{name} response: {content}")
|
||||
elif role == ASSISTANT:
|
||||
if len(content) >0:
|
||||
chat_history.append(f"assistant: {content}")
|
||||
if message.get('function_call'):
|
||||
chat_history.append(f"function_call: {message.get('function_call').get('name')} ")
|
||||
chat_history.append(f"{message.get('function_call').get('arguments')}")
|
||||
if message.get('tool_calls'):
|
||||
for tool_call in message.get('tool_calls'):
|
||||
function_name = tool_call.get('function').get('name')
|
||||
arguments = tool_call.get('function').get('arguments')
|
||||
chat_history.append(f"{function_name} call: {arguments}")
|
||||
|
||||
recent_chat_history = chat_history[-15:] if len(chat_history) > 15 else chat_history
|
||||
print(f"recent_chat_history:{recent_chat_history}")
|
||||
@ -411,30 +433,44 @@ async def fetch_bot_config(bot_id: str) -> Dict[str, Any]:
|
||||
)
|
||||
|
||||
|
||||
def _sync_call_llm(llm_config, messages) -> str:
|
||||
"""同步调用LLM的辅助函数,在线程池中执行"""
|
||||
llm_instance = TextChatAtOAI(llm_config)
|
||||
async def _sync_call_llm(llm_config, messages) -> str:
|
||||
"""同步调用LLM的辅助函数,在线程池中执行 - 使用LangChain"""
|
||||
try:
|
||||
# 设置stream=False来获取非流式响应
|
||||
response = llm_instance.chat(messages=messages, stream=False)
|
||||
# 创建LangChain LLM实例
|
||||
model_name = llm_config.get('model')
|
||||
model_server = llm_config.get('model_server')
|
||||
api_key = llm_config.get('api_key')
|
||||
# 检测或使用指定的提供商
|
||||
model_provider,base_url = detect_provider(model_name,model_server)
|
||||
|
||||
# 处理响应
|
||||
if isinstance(response, list) and response:
|
||||
# 如果返回的是Message列表,提取内容
|
||||
if hasattr(response[0], 'content'):
|
||||
return response[0].content
|
||||
elif isinstance(response[0], dict) and 'content' in response[0]:
|
||||
return response[0]['content']
|
||||
# 构建模型参数
|
||||
model_kwargs = {
|
||||
"model": model_name,
|
||||
"model_provider": model_provider,
|
||||
"temperature": 0.8,
|
||||
"base_url":base_url,
|
||||
"api_key":api_key
|
||||
}
|
||||
llm_instance = init_chat_model(**model_kwargs)
|
||||
|
||||
# 如果是字符串,直接返回
|
||||
if isinstance(response, str):
|
||||
return response
|
||||
# 转换消息格式为LangChain格式
|
||||
langchain_messages = []
|
||||
for msg in messages:
|
||||
if msg['role'] == 'system':
|
||||
langchain_messages.append(SystemMessage(content=msg['content']))
|
||||
elif msg['role'] == 'user':
|
||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
elif msg['role'] == 'assistant':
|
||||
langchain_messages.append(AIMessage(content=msg['content']))
|
||||
|
||||
# 处理其他类型
|
||||
return str(response) if response else ""
|
||||
# 调用LangChain模型
|
||||
response = await llm_instance.ainvoke(langchain_messages)
|
||||
|
||||
# 返回响应内容
|
||||
return response.content if response.content else ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling guideline LLM: {e}")
|
||||
logger.error(f"Error calling guideline LLM with LangChain: {e}")
|
||||
return ""
|
||||
|
||||
def get_language_text(language: str):
|
||||
@ -550,9 +586,8 @@ async def call_preamble_llm(chat_history: str, last_message: str, preamble_choic
|
||||
try:
|
||||
# 使用信号量控制并发API调用数量
|
||||
async with api_semaphore:
|
||||
# 使用线程池执行同步HTTP调用,避免阻塞事件循环
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(thread_pool, _sync_call_llm, llm_config, messages)
|
||||
# 直接调用异步LLM函数
|
||||
response = await _sync_call_llm(llm_config, messages)
|
||||
|
||||
# 从响应中提取 ```json 和 ``` 包裹的内容
|
||||
json_pattern = r'```json\s*\n(.*?)\n```'
|
||||
@ -605,9 +640,8 @@ async def call_guideline_llm(chat_history: str, guidelines_prompt: str, model_na
|
||||
try:
|
||||
# 使用信号量控制并发API调用数量
|
||||
async with api_semaphore:
|
||||
# 使用线程池执行同步HTTP调用,避免阻塞事件循环
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(thread_pool, _sync_call_llm, llm_config, messages)
|
||||
# 直接调用异步LLM函数
|
||||
response = await _sync_call_llm(llm_config, messages)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -15,15 +15,6 @@ logger = logging.getLogger('app')
|
||||
from utils.file_utils import get_document_preview, load_processed_files_log
|
||||
|
||||
|
||||
def get_content_from_messages(messages: List[dict]) -> str:
|
||||
"""Extract content from messages list"""
|
||||
content = ""
|
||||
for message in messages:
|
||||
if message.get("role") == "user":
|
||||
content += message.get("content", "")
|
||||
return content
|
||||
|
||||
|
||||
def generate_directory_tree(project_dir: str, unique_id: str, max_depth: int = 3) -> str:
|
||||
"""Generate dataset directory tree structure for the project"""
|
||||
def _build_tree(path: str, prefix: str = "", is_last: bool = True, depth: int = 0) -> List[str]:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user