qwen_agent/modified_assistant.py
2025-10-22 00:45:32 +08:00

207 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import os
import time
from typing import Dict, Iterator, List, Literal, Optional, Union
from qwen_agent.agents import Assistant
from qwen_agent.llm.schema import ASSISTANT, FUNCTION, Message
from qwen_agent.llm.oai import TextChatAtOAI
class ModifiedAssistant(Assistant):
"""
修改后的 Assistant 子类,改变循环判断逻辑:
- 原始逻辑:如果没有使用工具,立即退出循环
- 修改后逻辑:如果没有使用工具,调用模型判断回答是否完整,如果不完整则继续循环
"""
def _is_retryable_error(self, error: Exception) -> bool:
"""判断错误是否可重试
Args:
error: 异常对象
Returns:
bool: 是否可重试
"""
error_str = str(error).lower()
retryable_indicators = [
'502', '500', '503', '504', # HTTP错误代码
'internal server error', # 内部服务器错误
'timeout', # 超时
'connection', # 连接错误
'network', # 网络错误
'rate', # 速率限制和相关错误
'quota', # 配额限制
'service unavailable', # 服务不可用
'provider returned error', # Provider错误
'model service error', # 模型服务错误
'temporary', # 临时错误
'retry' # 明确提示重试
]
return any(indicator in error_str for indicator in retryable_indicators)
def _call_llm_with_retry(self, messages: List[Message], functions=None, extra_generate_cfg=None, max_retries: int = 5) -> Iterator:
"""带重试机制的LLM调用
Args:
messages: 消息列表
functions: 函数列表
extra_generate_cfg: 额外生成配置
max_retries: 最大重试次数
Returns:
LLM响应流
Raises:
Exception: 重试次数耗尽后重新抛出原始异常
"""
for attempt in range(max_retries):
try:
return self._call_llm(messages=messages, functions=functions, extra_generate_cfg=extra_generate_cfg)
except Exception as e:
# 检查是否为可重试的错误
if self._is_retryable_error(e) and attempt < max_retries - 1:
delay = 2 ** attempt # 指数退避: 1s, 2s, 4s
print(f"LLM调用失败 (尝试 {attempt + 1}/{max_retries}){delay}秒后重试: {str(e)}")
time.sleep(delay)
continue
else:
# 不可重试的错误或已达到最大重试次数
if attempt > 0:
print(f"LLM调用重试失败已达到最大重试次数 {max_retries}")
raise
def _run(self, messages: List[Message], lang: Literal['en', 'zh', 'ja'] = 'en', **kwargs) -> Iterator[List[Message]]:
message_list = copy.deepcopy(messages)
response = []
# 保持原有的最大调用次数限制
total_num_llm_calls_available = self.MAX_LLM_CALL_PER_RUN if hasattr(self, 'MAX_LLM_CALL_PER_RUN') else 10
num_llm_calls_available = total_num_llm_calls_available
while num_llm_calls_available > 0:
num_llm_calls_available -= 1
extra_generate_cfg = {'lang': lang}
if kwargs.get('seed') is not None:
extra_generate_cfg['seed'] = kwargs['seed']
output_stream = self._call_llm_with_retry(messages=message_list,
functions=[func.function for func in self.function_map.values()],
extra_generate_cfg=extra_generate_cfg)
output: List[Message] = []
for output in output_stream:
if output:
yield response + output
if output:
response.extend(output)
message_list.extend(output)
# 处理工具调用
used_any_tool = False
for out in output:
use_tool, tool_name, tool_args, _ = self._detect_tool(out)
print(out,lang, use_tool, tool_name, tool_args)
if use_tool:
tool_result = self._call_tool(tool_name, tool_args, messages=message_list, **kwargs)
fn_msg = Message(role=FUNCTION,
name=tool_name,
content=tool_result,
extra={'function_id': out.extra.get('function_id', '1')})
message_list.append(fn_msg)
response.append(fn_msg)
yield response
used_any_tool = True
# 如果使用了工具,继续循环
if not used_any_tool:
break
yield response
# Utility functions
def read_system_prompt():
"""读取通用的无状态系统prompt"""
with open("./prompt/system_prompt_default.md", "r", encoding="utf-8") as f:
return f.read().strip()
def read_mcp_settings():
"""读取MCP工具配置"""
with open("./mcp/mcp_settings.json", "r") as f:
mcp_settings_json = json.load(f)
return mcp_settings_json
def update_agent_llm(agent, model_name: str, api_key: str = None, model_server: str = None, generate_cfg: Dict = None, system_prompt: str = None, mcp_settings: List[Dict] = None):
"""动态更新助手实例的LLM和配置支持从接口传入参数"""
# 获取基础配置
llm_config = {
"model": model_name,
"api_key": api_key,
"model_server": model_server,
"generate_cfg": generate_cfg if generate_cfg else {}
}
# 创建LLM实例
llm_instance = TextChatAtOAI(llm_config)
# 动态设置LLM
agent.llm = llm_instance
# 更新系统消息(如果提供)
if system_prompt:
agent.system_message = system_prompt
# 更新MCP设置如果提供
if mcp_settings:
agent.function_list = mcp_settings
return agent
# 向后兼容:保持原有的初始化函数接口
def init_modified_agent_service_with_files(rag_cfg=None,
model_name="qwen3-next", api_key=None,
model_server=None, generate_cfg=None,
system_prompt=None, mcp=None):
"""创建支持预加载文件的修改版助手实例"""
system = system_prompt if system_prompt else read_system_prompt()
tools = mcp if mcp else read_mcp_settings()
llm_config = {
"model": model_name,
"api_key": api_key,
"model_server": model_server,
"generate_cfg": generate_cfg if generate_cfg else {}
}
# 创建LLM实例
llm_instance = TextChatAtOAI(llm_config)
bot = ModifiedAssistant(
llm=llm_instance,
name="修改版数据检索助手",
description="基于智能判断循环终止的助手",
system_message=system,
function_list=tools,
)
return bot