288 lines
12 KiB
Python
288 lines
12 KiB
Python
# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
import copy
|
||
import json
|
||
import logging
|
||
import os
|
||
import time
|
||
from typing import Dict, Iterator, List, Literal, Optional, Union
|
||
|
||
from qwen_agent.agents import Assistant
|
||
from qwen_agent.llm.schema import ASSISTANT, FUNCTION, Message
|
||
from qwen_agent.llm.oai import TextChatAtOAI
|
||
from qwen_agent.tools import BaseTool
|
||
from agent.custom_mcp_manager import CustomMCPManager
|
||
import logging
|
||
|
||
logger = logging.getLogger('app')
|
||
# 设置工具日志记录器
|
||
tool_logger = logging.getLogger('app')
|
||
|
||
class ModifiedAssistant(Assistant):
|
||
"""
|
||
修改后的 Assistant 子类,改变循环判断逻辑:
|
||
- 原始逻辑:如果没有使用工具,立即退出循环
|
||
- 修改后逻辑:如果没有使用工具,调用模型判断回答是否完整,如果不完整则继续循环
|
||
"""
|
||
|
||
def _is_retryable_error(self, error: Exception) -> bool:
|
||
"""判断错误是否可重试
|
||
|
||
Args:
|
||
error: 异常对象
|
||
|
||
Returns:
|
||
bool: 是否可重试
|
||
"""
|
||
error_str = str(error).lower()
|
||
retryable_indicators = [
|
||
'502', '500', '503', '504', # HTTP错误代码
|
||
'internal server error', # 内部服务器错误
|
||
'timeout', # 超时
|
||
'connection', # 连接错误
|
||
'network', # 网络错误
|
||
'rate', # 速率限制和相关错误
|
||
'quota', # 配额限制
|
||
'service unavailable', # 服务不可用
|
||
'provider returned error', # Provider错误
|
||
'model service error', # 模型服务错误
|
||
'temporary', # 临时错误
|
||
'retry' # 明确提示重试
|
||
]
|
||
return any(indicator in error_str for indicator in retryable_indicators)
|
||
|
||
def _call_tool(self, tool_name: str, tool_args: Union[str, dict] = '{}', **kwargs) -> str:
|
||
"""重写工具调用方法,添加调试信息"""
|
||
if tool_name not in self.function_map:
|
||
error_msg = f'Tool {tool_name} does not exist. Available tools: {list(self.function_map.keys())}'
|
||
tool_logger.error(error_msg)
|
||
return error_msg
|
||
|
||
tool = self.function_map[tool_name]
|
||
|
||
try:
|
||
tool_logger.info(f"开始调用工具: {tool_name} {tool_args}")
|
||
start_time = time.time()
|
||
|
||
# 调用父类的_call_tool方法
|
||
tool_result = super()._call_tool(tool_name, tool_args, **kwargs)
|
||
|
||
end_time = time.time()
|
||
tool_logger.info(f"工具 {tool_name} 执行完成,耗时: {end_time - start_time:.2f}秒 结果长度: {len(tool_result) if tool_result else 0}")
|
||
|
||
# 打印部分结果内容(避免过长)
|
||
if tool_result and len(tool_result) > 0:
|
||
preview = tool_result[:200] if len(tool_result) > 200 else tool_result
|
||
tool_logger.debug(f"工具 {tool_name} 结果预览: {preview}...")
|
||
|
||
return tool_result
|
||
|
||
except Exception as ex:
|
||
end_time = time.time()
|
||
tool_logger.error(f"工具调用异常,耗时: {end_time - start_time:.2f}秒 异常类型: {type(ex).__name__} {str(ex)}")
|
||
|
||
# 打印完整的堆栈跟踪
|
||
import traceback
|
||
tool_logger.error(f"堆栈跟踪:\n{traceback.format_exc()}")
|
||
|
||
# 返回详细的错误信息
|
||
error_message = f'An error occurred when calling tool {tool_name}: {type(ex).__name__}: {str(ex)}'
|
||
return error_message
|
||
|
||
def _init_tool(self, tool: Union[str, Dict, BaseTool]):
|
||
"""重写工具初始化方法,使用CustomMCPManager处理MCP服务器配置"""
|
||
if isinstance(tool, BaseTool):
|
||
# 处理BaseTool实例
|
||
tool_name = tool.name
|
||
if tool_name in self.function_map:
|
||
tool_logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list')
|
||
self.function_map[tool_name] = tool
|
||
elif isinstance(tool, dict) and 'mcpServers' in tool:
|
||
# 使用CustomMCPManager处理MCP服务器配置,支持headers
|
||
tools = CustomMCPManager().initConfig(tool)
|
||
for tool in tools:
|
||
tool_name = tool.name
|
||
if tool_name in self.function_map:
|
||
tool_logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list')
|
||
self.function_map[tool_name] = tool
|
||
else:
|
||
# 调用父类的处理方法
|
||
super()._init_tool(tool)
|
||
|
||
def _call_llm_with_retry(self, messages: List[Message], functions=None, extra_generate_cfg=None, max_retries: int = 5) -> Iterator:
|
||
"""带重试机制的LLM调用
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
functions: 函数列表
|
||
extra_generate_cfg: 额外生成配置
|
||
max_retries: 最大重试次数
|
||
|
||
Returns:
|
||
LLM响应流
|
||
|
||
Raises:
|
||
Exception: 重试次数耗尽后重新抛出原始异常
|
||
"""
|
||
for attempt in range(max_retries):
|
||
try:
|
||
return self._call_llm(messages=messages, functions=functions, extra_generate_cfg=extra_generate_cfg)
|
||
except Exception as e:
|
||
# 检查是否为可重试的错误
|
||
if self._is_retryable_error(e) and attempt < max_retries - 1:
|
||
delay = 2 ** attempt # 指数退避: 1s, 2s, 4s
|
||
tool_logger.warning(f"LLM调用失败 (尝试 {attempt + 1}/{max_retries}),{delay}秒后重试: {str(e)}")
|
||
time.sleep(delay)
|
||
continue
|
||
else:
|
||
# 不可重试的错误或已达到最大重试次数
|
||
if attempt > 0:
|
||
tool_logger.error(f"LLM调用重试失败,已达到最大重试次数 {max_retries}")
|
||
raise
|
||
|
||
def _run(self, messages: List[Message], lang: Literal['en', 'zh', 'ja'] = 'en', **kwargs) -> Iterator[List[Message]]:
|
||
|
||
message_list = copy.deepcopy(messages)
|
||
response = []
|
||
|
||
# 保持原有的最大调用次数限制
|
||
total_num_llm_calls_available = self.MAX_LLM_CALL_PER_RUN if hasattr(self, 'MAX_LLM_CALL_PER_RUN') else 10
|
||
num_llm_calls_available = total_num_llm_calls_available
|
||
while num_llm_calls_available > 0:
|
||
num_llm_calls_available -= 1
|
||
extra_generate_cfg = {'lang': lang}
|
||
if kwargs.get('seed') is not None:
|
||
extra_generate_cfg['seed'] = kwargs['seed']
|
||
|
||
output_stream = self._call_llm_with_retry(messages=message_list,
|
||
functions=[func.function for func in self.function_map.values()],
|
||
extra_generate_cfg=extra_generate_cfg)
|
||
output: List[Message] = []
|
||
for output in output_stream:
|
||
if output:
|
||
yield response + output
|
||
|
||
if output:
|
||
response.extend(output)
|
||
message_list.extend(output)
|
||
|
||
# 处理工具调用
|
||
used_any_tool = False
|
||
for out in output:
|
||
use_tool, tool_name, tool_args, _ = self._detect_tool(out)
|
||
if use_tool:
|
||
tool_result = self._call_tool(tool_name, tool_args, messages=message_list, **kwargs)
|
||
|
||
# 验证工具结果
|
||
if not tool_result:
|
||
tool_logger.warning(f"工具 {tool_name} 返回空结果")
|
||
tool_result = f"Tool {tool_name} completed execution but returned empty result"
|
||
elif tool_result.startswith('An error occurred when calling tool') or tool_result.startswith('工具调用失败'):
|
||
tool_logger.error(f"工具 {tool_name} 调用失败: {tool_result}")
|
||
|
||
fn_msg = Message(role=FUNCTION,
|
||
name=tool_name,
|
||
content=tool_result,
|
||
extra={'function_id': out.extra.get('function_id', '1')})
|
||
message_list.append(fn_msg)
|
||
response.append(fn_msg)
|
||
yield response
|
||
used_any_tool = True
|
||
|
||
# 如果使用了工具,继续循环
|
||
if not used_any_tool:
|
||
break
|
||
|
||
# 检查是否因为调用次数用完而退出循环
|
||
if num_llm_calls_available == 0:
|
||
# 根据语言选择错误消息
|
||
if lang == 'zh':
|
||
error_message = "工具调用超出限制"
|
||
elif lang == 'ja':
|
||
error_message = "ツール呼び出しが制限を超えました。"
|
||
else:
|
||
error_message = "Tool calls exceeded limit"
|
||
tool_logger.error(error_message)
|
||
|
||
error_msg = Message(
|
||
role=ASSISTANT,
|
||
content=error_message,
|
||
)
|
||
response.append(error_msg)
|
||
|
||
yield response
|
||
|
||
|
||
# Utility functions
|
||
def read_system_prompt():
|
||
"""读取通用的无状态系统prompt"""
|
||
with open("./prompt/system_prompt_default.md", "r", encoding="utf-8") as f:
|
||
return f.read().strip()
|
||
|
||
|
||
def read_mcp_settings():
|
||
"""读取MCP工具配置"""
|
||
with open("./mcp/mcp_settings.json", "r") as f:
|
||
mcp_settings_json = json.load(f)
|
||
return mcp_settings_json
|
||
|
||
|
||
def update_agent_llm(agent, model_name: str, api_key: str = None, model_server: str = None, generate_cfg: Dict = None):
|
||
"""动态更新助手实例的LLM和配置,支持从接口传入参数"""
|
||
# 获取基础配置
|
||
llm_config = {
|
||
"model": model_name,
|
||
"api_key": api_key,
|
||
"model_server": model_server,
|
||
"generate_cfg": generate_cfg if generate_cfg else {}
|
||
}
|
||
|
||
# 创建LLM实例
|
||
llm_instance = TextChatAtOAI(llm_config)
|
||
|
||
# 动态设置LLM
|
||
agent.llm = llm_instance
|
||
return agent
|
||
|
||
|
||
# 向后兼容:保持原有的初始化函数接口
|
||
def init_modified_agent_service_with_files(rag_cfg=None,
|
||
model_name="qwen3-next", api_key=None,
|
||
model_server=None, generate_cfg=None,
|
||
system_prompt=None, mcp=None):
|
||
"""创建支持预加载文件的修改版助手实例"""
|
||
system = system_prompt if system_prompt else read_system_prompt()
|
||
tools = mcp if mcp else read_mcp_settings()
|
||
|
||
llm_config = {
|
||
"model": model_name,
|
||
"api_key": api_key,
|
||
"model_server": model_server,
|
||
"generate_cfg": generate_cfg if generate_cfg else {}
|
||
}
|
||
|
||
# 创建LLM实例
|
||
llm_instance = TextChatAtOAI(llm_config)
|
||
|
||
bot = ModifiedAssistant(
|
||
llm=llm_instance,
|
||
name="修改版数据检索助手",
|
||
description="基于智能判断循环终止的助手",
|
||
system_message=system,
|
||
function_list=tools,
|
||
)
|
||
|
||
return bot
|