343 lines
13 KiB
Python
343 lines
13 KiB
Python
# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
"""A sqlite database assistant implemented by assistant"""
|
||
|
||
import argparse
|
||
import asyncio
|
||
import copy
|
||
import json
|
||
import os
|
||
from typing import Dict, List, Optional, Union
|
||
|
||
from qwen_agent.agents import Assistant
|
||
from qwen_agent.gui import WebUI
|
||
from qwen_agent.llm.oai import TextChatAtOAI
|
||
from qwen_agent.llm.schema import ASSISTANT, FUNCTION, Message
|
||
from qwen_agent.utils.output_beautify import typewriter_print
|
||
|
||
ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), "resource")
|
||
|
||
|
||
class GPT4OChat(TextChatAtOAI):
|
||
"""自定义 GPT-4o 聊天类,修复 tool_call_id 问题"""
|
||
|
||
def convert_messages_to_dicts(self, messages: List[Message]) -> List[dict]:
|
||
# 使用父类方法进行基础转换
|
||
messages = super().convert_messages_to_dicts(messages)
|
||
|
||
# 应用修复后的消息转换
|
||
messages = self._fixed_conv_qwen_agent_messages_to_oai(messages)
|
||
|
||
return messages
|
||
|
||
@staticmethod
|
||
def _fixed_conv_qwen_agent_messages_to_oai(messages: List[Union[Message, Dict]]):
|
||
"""修复后的消息转换方法,确保 tool 消息包含 tool_call_id 字段"""
|
||
new_messages = []
|
||
i = 0
|
||
|
||
while i < len(messages):
|
||
msg = messages[i]
|
||
|
||
if msg['role'] == ASSISTANT:
|
||
# 处理 assistant 消息
|
||
assistant_msg = {'role': 'assistant'}
|
||
|
||
# 设置 content
|
||
content = msg.get('content', '')
|
||
if isinstance(content, (list, dict)):
|
||
assistant_msg['content'] = json.dumps(content, ensure_ascii=False)
|
||
elif content is None:
|
||
assistant_msg['content'] = ''
|
||
else:
|
||
assistant_msg['content'] = content
|
||
|
||
# 设置 reasoning_content
|
||
if msg.get('reasoning_content'):
|
||
assistant_msg['reasoning_content'] = msg['reasoning_content']
|
||
|
||
# 检查是否需要构造 tool_calls
|
||
has_tool_call = False
|
||
tool_calls = []
|
||
|
||
# 情况1:当前消息有 function_call
|
||
if msg.get('function_call'):
|
||
has_tool_call = True
|
||
tool_calls.append({
|
||
'id': msg.get('extra', {}).get('function_id', '1'),
|
||
'type': 'function',
|
||
'function': {
|
||
'name': msg['function_call']['name'],
|
||
'arguments': msg['function_call']['arguments']
|
||
}
|
||
})
|
||
|
||
# 注意:不再为孤立的 tool 消息构造虚假的 tool_call
|
||
|
||
if has_tool_call:
|
||
assistant_msg['tool_calls'] = tool_calls
|
||
new_messages.append(assistant_msg)
|
||
|
||
# 检查后续是否有对应的 tool 消息
|
||
if i + 1 < len(messages) and messages[i + 1]['role'] == 'tool':
|
||
tool_msg = copy.deepcopy(messages[i + 1])
|
||
# 确保 tool_call_id 匹配
|
||
tool_msg['tool_call_id'] = tool_calls[0]['id']
|
||
# 移除多余字段
|
||
for field in ['id', 'extra', 'function_call']:
|
||
if field in tool_msg:
|
||
del tool_msg[field]
|
||
# 确保 content 有效且为字符串
|
||
content = tool_msg.get('content', '')
|
||
if isinstance(content, (list, dict)):
|
||
tool_msg['content'] = json.dumps(content, ensure_ascii=False)
|
||
elif content is None:
|
||
tool_msg['content'] = ''
|
||
new_messages.append(tool_msg)
|
||
i += 2
|
||
else:
|
||
i += 1
|
||
else:
|
||
new_messages.append(assistant_msg)
|
||
i += 1
|
||
|
||
elif msg['role'] == 'tool':
|
||
# 孤立的 tool 消息,转换为 assistant + user 消息序列
|
||
# 首先添加一个包含工具结果的 assistant 消息
|
||
assistant_result = {'role': 'assistant'}
|
||
content = msg.get('content', '')
|
||
if isinstance(content, (list, dict)):
|
||
content = json.dumps(content, ensure_ascii=False)
|
||
assistant_result['content'] = f"工具查询结果: {content}"
|
||
new_messages.append(assistant_result)
|
||
|
||
# 然后添加一个 user 消息来继续对话
|
||
new_messages.append({'role': 'user', 'content': '请继续分析以上结果'})
|
||
i += 1
|
||
|
||
else:
|
||
# 处理其他角色消息
|
||
new_msg = copy.deepcopy(msg)
|
||
|
||
# 确保 content 有效且为字符串
|
||
content = new_msg.get('content', '')
|
||
if isinstance(content, (list, dict)):
|
||
new_msg['content'] = json.dumps(content, ensure_ascii=False)
|
||
elif content is None:
|
||
new_msg['content'] = ''
|
||
|
||
new_messages.append(new_msg)
|
||
i += 1
|
||
|
||
return new_messages
|
||
|
||
|
||
def read_mcp_settings():
|
||
with open("./mcp/mcp_settings.json", "r") as f:
|
||
mcp_settings_json = json.load(f)
|
||
return mcp_settings_json
|
||
|
||
|
||
def read_system_prompt():
|
||
with open("./agent_prompt.txt", "r", encoding="utf-8") as f:
|
||
return f.read().strip()
|
||
|
||
|
||
def init_agent_service():
|
||
llm_cfg = {
|
||
"llama-33": {
|
||
"model": "gbase-llama-33",
|
||
"model_server": "http://llmapi:9009/v1",
|
||
"api_key": "any",
|
||
},
|
||
"gpt-oss-120b": {
|
||
"model": "openai/gpt-oss-120b",
|
||
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
|
||
"api_key": "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",
|
||
|
||
"generate_cfg": {
|
||
"use_raw_api": True, # GPT-OSS true ,Qwen false
|
||
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
|
||
}
|
||
},
|
||
|
||
"claude-3.7": {
|
||
"model": "claude-3-7-sonnet-20250219",
|
||
"model_server": "https://one.felo.me/v1",
|
||
"api_key": "sk-9gtHriq7C3jAvepq5dA0092a5cC24a54Aa83FbC99cB88b21-2",
|
||
"generate_cfg": {
|
||
"use_raw_api": True, # GPT-OSS true ,Qwen false
|
||
},
|
||
},
|
||
"gpt-4o": {
|
||
"model": "gpt-4o",
|
||
"model_server": "https://one-dev.felo.me/v1",
|
||
"api_key": "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4",
|
||
"generate_cfg": {
|
||
"use_raw_api": True, # 启用 raw_api 但使用自定义类修复 tool_call_id 问题
|
||
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
|
||
},
|
||
},
|
||
"Gpt-4o-back": {
|
||
"model_type": "oai", # 使用 oai 类型以便使用自定义类
|
||
"model": "gpt-4o",
|
||
"model_server": "https://one-dev.felo.me/v1",
|
||
"api_key": "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4",
|
||
"generate_cfg": {
|
||
"use_raw_api": True, # 启用 raw_api 但使用自定义类修复 tool_call_id 问题
|
||
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
|
||
},
|
||
# 使用自定义的 GPT4OChat 类
|
||
"llm_class": GPT4OChat,
|
||
},
|
||
|
||
"glm-45": {
|
||
"model_server": "https://open.bigmodel.cn/api/paas/v4",
|
||
"api_key": "0c9cbaca9d2bbf864990f1e1decdf340.dXRMsZCHTUbPQ0rm",
|
||
"model": "glm-4.5",
|
||
"generate_cfg": {
|
||
"use_raw_api": True, # GPT-OSS true ,Qwen false
|
||
},
|
||
},
|
||
"qwen3-next": {
|
||
"model": "qwen/qwen3-next-80b-a3b-instruct",
|
||
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
|
||
"api_key": "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",
|
||
|
||
},
|
||
"deepresearch": {
|
||
"model": "alibaba/tongyi-deepresearch-30b-a3b",
|
||
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
|
||
"api_key": "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",
|
||
},
|
||
|
||
"qwen3-coder":{
|
||
"model": "Qwen/Qwen3-Coder-30B-A3B-Instruct",
|
||
"model_server": "https://api-inference.modelscope.cn/v1", # base_url, also known as api_base
|
||
"api_key": "ms-92027446-2787-4fd6-af01-f002459ec556",
|
||
},
|
||
"openrouter-gpt4o":{
|
||
"model": "openai/gpt-4o",
|
||
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
|
||
"api_key": "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",
|
||
"generate_cfg": {
|
||
"use_raw_api": True, # GPT-OSS true ,Qwen false
|
||
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
|
||
},
|
||
}
|
||
}
|
||
system = read_system_prompt()
|
||
|
||
# 暂时禁用 MCP 工具以测试 GPT-4o
|
||
tools = read_mcp_settings()
|
||
# 使用自定义的 GPT-4o 配置
|
||
llm_instance = llm_cfg["qwen3-next"]
|
||
if "llm_class" in llm_instance:
|
||
llm_instance = llm_instance.get("llm_class", TextChatAtOAI)(llm_instance)
|
||
|
||
bot = Assistant(
|
||
llm=llm_instance, # 使用自定义的 GPT-4o 实例
|
||
name="数据库助手",
|
||
description="数据库查询",
|
||
system_message=system,
|
||
function_list=tools,
|
||
)
|
||
|
||
return bot
|
||
|
||
|
||
def test(query="数据库里有几张表"):
|
||
# Define the agent
|
||
bot = init_agent_service()
|
||
|
||
# Chat
|
||
messages = []
|
||
|
||
messages.append({"role": "user", "content": query})
|
||
|
||
responses = []
|
||
for response in bot.run(messages):
|
||
responses.append(response)
|
||
|
||
# 只输出最终结果,不显示中间过程
|
||
if responses:
|
||
final_response = responses[-1][-1] # 取最后一个响应作为最终结果
|
||
print("Answer:", final_response["content"])
|
||
|
||
|
||
def app_tui():
|
||
# Define the agent
|
||
bot = init_agent_service()
|
||
|
||
# Chat
|
||
messages = []
|
||
while True:
|
||
# Query example: 数据库里有几张表
|
||
query = input("user question: ")
|
||
# File example: resource/poem.pdf
|
||
file = input("file url (press enter if no file): ").strip()
|
||
if not query:
|
||
print("user question cannot be empty!")
|
||
continue
|
||
if not file:
|
||
messages.append({"role": "user", "content": query})
|
||
else:
|
||
messages.append(
|
||
{"role": "user", "content": [{"text": query}, {"file": file}]}
|
||
)
|
||
|
||
response = []
|
||
for response in bot.run(messages):
|
||
print("bot response:", response)
|
||
messages.extend(response)
|
||
|
||
|
||
def app_gui():
|
||
# Define the agent
|
||
bot = init_agent_service()
|
||
chatbot_config = {
|
||
"prompt.suggestions": [
|
||
"数据库里有几张表",
|
||
"创建一个学生表包括学生的姓名、年龄",
|
||
"增加一个学生名字叫韩梅梅,今年6岁",
|
||
]
|
||
}
|
||
WebUI(
|
||
bot,
|
||
chatbot_config=chatbot_config,
|
||
).run()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(description="数据库助手")
|
||
parser.add_argument(
|
||
"--query", type=str, default="数据库里有几张表", help="用户问题"
|
||
)
|
||
parser.add_argument(
|
||
"--mode",
|
||
type=str,
|
||
choices=["test", "tui", "gui"],
|
||
default="test",
|
||
help="运行模式",
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
if args.mode == "test":
|
||
test(args.query)
|
||
elif args.mode == "tui":
|
||
app_tui()
|
||
elif args.mode == "gui":
|
||
app_gui()
|