catalog-agent/gbase_agent.py
2025-10-06 19:51:39 +08:00

343 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A sqlite database assistant implemented by assistant"""
import argparse
import asyncio
import copy
import json
import os
from typing import Dict, List, Optional, Union
from qwen_agent.agents import Assistant
from qwen_agent.gui import WebUI
from qwen_agent.llm.oai import TextChatAtOAI
from qwen_agent.llm.schema import ASSISTANT, FUNCTION, Message
from qwen_agent.utils.output_beautify import typewriter_print
ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), "resource")
class GPT4OChat(TextChatAtOAI):
"""自定义 GPT-4o 聊天类,修复 tool_call_id 问题"""
def convert_messages_to_dicts(self, messages: List[Message]) -> List[dict]:
# 使用父类方法进行基础转换
messages = super().convert_messages_to_dicts(messages)
# 应用修复后的消息转换
messages = self._fixed_conv_qwen_agent_messages_to_oai(messages)
return messages
@staticmethod
def _fixed_conv_qwen_agent_messages_to_oai(messages: List[Union[Message, Dict]]):
"""修复后的消息转换方法,确保 tool 消息包含 tool_call_id 字段"""
new_messages = []
i = 0
while i < len(messages):
msg = messages[i]
if msg['role'] == ASSISTANT:
# 处理 assistant 消息
assistant_msg = {'role': 'assistant'}
# 设置 content
content = msg.get('content', '')
if isinstance(content, (list, dict)):
assistant_msg['content'] = json.dumps(content, ensure_ascii=False)
elif content is None:
assistant_msg['content'] = ''
else:
assistant_msg['content'] = content
# 设置 reasoning_content
if msg.get('reasoning_content'):
assistant_msg['reasoning_content'] = msg['reasoning_content']
# 检查是否需要构造 tool_calls
has_tool_call = False
tool_calls = []
# 情况1当前消息有 function_call
if msg.get('function_call'):
has_tool_call = True
tool_calls.append({
'id': msg.get('extra', {}).get('function_id', '1'),
'type': 'function',
'function': {
'name': msg['function_call']['name'],
'arguments': msg['function_call']['arguments']
}
})
# 注意:不再为孤立的 tool 消息构造虚假的 tool_call
if has_tool_call:
assistant_msg['tool_calls'] = tool_calls
new_messages.append(assistant_msg)
# 检查后续是否有对应的 tool 消息
if i + 1 < len(messages) and messages[i + 1]['role'] == 'tool':
tool_msg = copy.deepcopy(messages[i + 1])
# 确保 tool_call_id 匹配
tool_msg['tool_call_id'] = tool_calls[0]['id']
# 移除多余字段
for field in ['id', 'extra', 'function_call']:
if field in tool_msg:
del tool_msg[field]
# 确保 content 有效且为字符串
content = tool_msg.get('content', '')
if isinstance(content, (list, dict)):
tool_msg['content'] = json.dumps(content, ensure_ascii=False)
elif content is None:
tool_msg['content'] = ''
new_messages.append(tool_msg)
i += 2
else:
i += 1
else:
new_messages.append(assistant_msg)
i += 1
elif msg['role'] == 'tool':
# 孤立的 tool 消息,转换为 assistant + user 消息序列
# 首先添加一个包含工具结果的 assistant 消息
assistant_result = {'role': 'assistant'}
content = msg.get('content', '')
if isinstance(content, (list, dict)):
content = json.dumps(content, ensure_ascii=False)
assistant_result['content'] = f"工具查询结果: {content}"
new_messages.append(assistant_result)
# 然后添加一个 user 消息来继续对话
new_messages.append({'role': 'user', 'content': '请继续分析以上结果'})
i += 1
else:
# 处理其他角色消息
new_msg = copy.deepcopy(msg)
# 确保 content 有效且为字符串
content = new_msg.get('content', '')
if isinstance(content, (list, dict)):
new_msg['content'] = json.dumps(content, ensure_ascii=False)
elif content is None:
new_msg['content'] = ''
new_messages.append(new_msg)
i += 1
return new_messages
def read_mcp_settings():
with open("./mcp/mcp_settings.json", "r") as f:
mcp_settings_json = json.load(f)
return mcp_settings_json
def read_system_prompt():
with open("./agent_prompt.txt", "r", encoding="utf-8") as f:
return f.read().strip()
def init_agent_service():
llm_cfg = {
"llama-33": {
"model": "gbase-llama-33",
"model_server": "http://llmapi:9009/v1",
"api_key": "any",
},
"gpt-oss-120b": {
"model": "openai/gpt-oss-120b",
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
"api_key": "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",
"generate_cfg": {
"use_raw_api": True, # GPT-OSS true ,Qwen false
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
}
},
"claude-3.7": {
"model": "claude-3-7-sonnet-20250219",
"model_server": "https://one.felo.me/v1",
"api_key": "sk-9gtHriq7C3jAvepq5dA0092a5cC24a54Aa83FbC99cB88b21-2",
"generate_cfg": {
"use_raw_api": True, # GPT-OSS true ,Qwen false
},
},
"gpt-4o": {
"model": "gpt-4o",
"model_server": "https://one-dev.felo.me/v1",
"api_key": "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4",
"generate_cfg": {
"use_raw_api": True, # 启用 raw_api 但使用自定义类修复 tool_call_id 问题
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
},
},
"Gpt-4o-back": {
"model_type": "oai", # 使用 oai 类型以便使用自定义类
"model": "gpt-4o",
"model_server": "https://one-dev.felo.me/v1",
"api_key": "sk-hsKClH0Z695EkK5fDdB2Ec2fE13f4fC1B627BdBb8e554b5b-4",
"generate_cfg": {
"use_raw_api": True, # 启用 raw_api 但使用自定义类修复 tool_call_id 问题
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
},
# 使用自定义的 GPT4OChat 类
"llm_class": GPT4OChat,
},
"glm-45": {
"model_server": "https://open.bigmodel.cn/api/paas/v4",
"api_key": "0c9cbaca9d2bbf864990f1e1decdf340.dXRMsZCHTUbPQ0rm",
"model": "glm-4.5",
"generate_cfg": {
"use_raw_api": True, # GPT-OSS true ,Qwen false
},
},
"qwen3-next": {
"model": "qwen/qwen3-next-80b-a3b-instruct",
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
"api_key": "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",
},
"deepresearch": {
"model": "alibaba/tongyi-deepresearch-30b-a3b",
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
"api_key": "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",
},
"qwen3-coder":{
"model": "Qwen/Qwen3-Coder-30B-A3B-Instruct",
"model_server": "https://api-inference.modelscope.cn/v1", # base_url, also known as api_base
"api_key": "ms-92027446-2787-4fd6-af01-f002459ec556",
},
"openrouter-gpt4o":{
"model": "openai/gpt-4o",
"model_server": "https://openrouter.ai/api/v1", # base_url, also known as api_base
"api_key": "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",
"generate_cfg": {
"use_raw_api": True, # GPT-OSS true ,Qwen false
"fncall_prompt_type": "nous", # 使用 nous 风格的函数调用提示
},
}
}
system = read_system_prompt()
# 暂时禁用 MCP 工具以测试 GPT-4o
tools = read_mcp_settings()
# 使用自定义的 GPT-4o 配置
llm_instance = llm_cfg["qwen3-next"]
if "llm_class" in llm_instance:
llm_instance = llm_instance.get("llm_class", TextChatAtOAI)(llm_instance)
bot = Assistant(
llm=llm_instance, # 使用自定义的 GPT-4o 实例
name="数据库助手",
description="数据库查询",
system_message=system,
function_list=tools,
)
return bot
def test(query="数据库里有几张表"):
# Define the agent
bot = init_agent_service()
# Chat
messages = []
messages.append({"role": "user", "content": query})
responses = []
for response in bot.run(messages):
responses.append(response)
# 只输出最终结果,不显示中间过程
if responses:
final_response = responses[-1][-1] # 取最后一个响应作为最终结果
print("Answer:", final_response["content"])
def app_tui():
# Define the agent
bot = init_agent_service()
# Chat
messages = []
while True:
# Query example: 数据库里有几张表
query = input("user question: ")
# File example: resource/poem.pdf
file = input("file url (press enter if no file): ").strip()
if not query:
print("user question cannot be empty")
continue
if not file:
messages.append({"role": "user", "content": query})
else:
messages.append(
{"role": "user", "content": [{"text": query}, {"file": file}]}
)
response = []
for response in bot.run(messages):
print("bot response:", response)
messages.extend(response)
def app_gui():
# Define the agent
bot = init_agent_service()
chatbot_config = {
"prompt.suggestions": [
"数据库里有几张表",
"创建一个学生表包括学生的姓名、年龄",
"增加一个学生名字叫韩梅梅今年6岁",
]
}
WebUI(
bot,
chatbot_config=chatbot_config,
).run()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="数据库助手")
parser.add_argument(
"--query", type=str, default="数据库里有几张表", help="用户问题"
)
parser.add_argument(
"--mode",
type=str,
choices=["test", "tui", "gui"],
default="test",
help="运行模式",
)
args = parser.parse_args()
if args.mode == "test":
test(args.query)
elif args.mode == "tui":
app_tui()
elif args.mode == "gui":
app_gui()