catalog-agent/gbase_agent.py
2025-10-07 14:35:07 +08:00

236 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A sqlite database assistant implemented by assistant"""
import argparse
import asyncio
import json
import os
from typing import Dict, List, Optional, Union
from qwen_agent.agents import Assistant
from qwen_agent.gui import WebUI
from qwen_agent.llm.oai import TextChatAtOAI
from qwen_agent.llm.schema import ASSISTANT, FUNCTION, Message
from qwen_agent.utils.output_beautify import typewriter_print
ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), "resource")
def read_mcp_settings():
with open("./mcp/mcp_settings.json", "r") as f:
mcp_settings_json = json.load(f)
return mcp_settings_json
def read_mcp_settings_with_project_restriction(project_data_dir: str):
"""读取MCP配置并添加项目目录限制"""
with open("./mcp/mcp_settings.json", "r") as f:
mcp_settings_json = json.load(f)
# 为json-reader添加项目目录限制
for server_config in mcp_settings_json:
if "mcpServers" in server_config:
for server_name, server_info in server_config["mcpServers"].items():
if server_name == "json-reader":
# 添加环境变量来传递项目目录限制
if "env" not in server_info:
server_info["env"] = {}
server_info["env"]["PROJECT_DATA_DIR"] = project_data_dir
server_info["env"]["PROJECT_ID"] = project_data_dir.split("/")[-2] if "/" in project_data_dir else "default"
break
return mcp_settings_json
def read_system_prompt():
"""读取通用的无状态系统prompt"""
with open("./agent_prompt.txt", "r", encoding="utf-8") as f:
return f.read().strip()
def init_agent_service():
"""默认初始化函数,保持向后兼容"""
return init_agent_service_universal("qwen3-next")
def init_agent_service_with_project(project_id: str, project_data_dir: str, model_name: str = "qwen3-next"):
"""支持项目目录的agent初始化函数 - 保持向后兼容"""
# 读取通用的系统prompt无状态
system = read_system_prompt()
# 读取MCP工具配置
tools = read_mcp_settings_with_project_restriction(project_data_dir)
# 创建默认的LLM配置可以通过update_agent_llm动态更新
llm_config = {
"model": model_name,
"model_server": "https://openrouter.ai/api/v1", # 默认服务器
"api_key": "default-key" # 默认密钥实际使用时需要通过API传入
}
# 创建LLM实例
llm_instance = TextChatAtOAI(llm_config)
bot = Assistant(
llm=llm_instance, # 使用默认LLM初始化可通过update_agent_llm动态更新
name=f"数据库助手-{project_id}",
description=f"项目 {project_id} 数据库查询",
system_message=system,
function_list=tools,
)
return bot
def init_agent_service_universal():
"""创建无状态的通用助手实例使用默认LLM可动态切换"""
# 读取通用的系统prompt无状态
system = read_system_prompt()
# 读取基础的MCP工具配置不包含项目限制
tools = read_mcp_settings()
# 创建默认的LLM配置可以通过update_agent_llm动态更新
llm_config = {
"model": "qwen3-next", # 默认模型
"model_server": "https://openrouter.ai/api/v1", # 默认服务器
"api_key": "default-key" # 默认密钥实际使用时需要通过API传入
}
# 创建LLM实例
llm_instance = TextChatAtOAI(llm_config)
bot = Assistant(
llm=llm_instance, # 使用默认LLM初始化可通过update_agent_llm动态更新
name="通用数据检索助手",
description="无状态通用数据检索助手",
system_message=system,
function_list=tools,
)
return bot
def update_agent_llm(agent, model_name: str, api_key: str = None, model_server: str = None):
"""动态更新助手实例的LLM支持从接口传入参数"""
# 获取基础配置
llm_config = {
"model": model_name,
"api_key": api_key,
"model_server": model_server
}
# 创建LLM实例确保不是字典
#if "llm_class" in llm_config:
# llm_instance = llm_config.get("llm_class", TextChatAtOAI)(llm_config)
#else:
# 使用默认的 TextChatAtOAI 类
llm_instance = TextChatAtOAI(llm_config)
# 动态设置LLM
agent.llm = llm_instance
return agent
def test(query="数据库里有几张表"):
# Define the agent - 使用通用初始化
bot = init_agent_service_universal()
# Chat
messages = []
messages.append({"role": "user", "content": query})
responses = []
for response in bot.run(messages):
responses.append(response)
# 只输出最终结果,不显示中间过程
if responses:
final_response = responses[-1][-1] # 取最后一个响应作为最终结果
print("Answer:", final_response["content"])
def app_tui():
# Define the agent - 使用通用初始化
bot = init_agent_service_universal()
# Chat
messages = []
while True:
# Query example: 数据库里有几张表
query = input("user question: ")
# File example: resource/poem.pdf
file = input("file url (press enter if no file): ").strip()
if not query:
print("user question cannot be empty")
continue
if not file:
messages.append({"role": "user", "content": query})
else:
messages.append(
{"role": "user", "content": [{"text": query}, {"file": file}]}
)
response = []
for response in bot.run(messages):
print("bot response:", response)
messages.extend(response)
def app_gui():
# Define the agent - 使用通用初始化
bot = init_agent_service_universal()
chatbot_config = {
"prompt.suggestions": [
"数据库里有几张表",
"创建一个学生表包括学生的姓名、年龄",
"增加一个学生名字叫韩梅梅今年6岁",
]
}
WebUI(
bot,
chatbot_config=chatbot_config,
).run()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="数据库助手")
parser.add_argument(
"--query", type=str, default="数据库里有几张表", help="用户问题"
)
parser.add_argument(
"--mode",
type=str,
choices=["test", "tui", "gui"],
default="test",
help="运行模式",
)
args = parser.parse_args()
if args.mode == "test":
test(args.query)
elif args.mode == "tui":
app_tui()
elif args.mode == "gui":
app_gui()