172 lines
5.6 KiB
Python
172 lines
5.6 KiB
Python
# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
"""A sqlite database assistant implemented by assistant"""
|
||
|
||
import argparse
|
||
import asyncio
|
||
import json
|
||
import os
|
||
from typing import Dict, List, Optional, Union
|
||
|
||
from qwen_agent.agents import Assistant
|
||
from qwen_agent.llm.oai import TextChatAtOAI
|
||
from qwen_agent.llm.schema import ASSISTANT, FUNCTION, Message
|
||
from qwen_agent.utils.output_beautify import typewriter_print
|
||
|
||
ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), "resource")
|
||
|
||
|
||
|
||
|
||
def read_mcp_settings():
|
||
with open("./mcp/mcp_settings.json", "r") as f:
|
||
mcp_settings_json = json.load(f)
|
||
return mcp_settings_json
|
||
|
||
|
||
def read_mcp_settings_with_project_restriction(project_data_dir: str):
|
||
"""读取MCP配置并添加项目目录限制"""
|
||
with open("./mcp/mcp_settings.json", "r") as f:
|
||
mcp_settings_json = json.load(f)
|
||
|
||
# 为json-reader添加项目目录限制
|
||
for server_config in mcp_settings_json:
|
||
if "mcpServers" in server_config:
|
||
for server_name, server_info in server_config["mcpServers"].items():
|
||
if server_name == "json-reader":
|
||
# 添加环境变量来传递项目目录限制
|
||
if "env" not in server_info:
|
||
server_info["env"] = {}
|
||
server_info["env"]["PROJECT_DATA_DIR"] = project_data_dir
|
||
server_info["env"]["PROJECT_ID"] = project_data_dir.split("/")[-2] if "/" in project_data_dir else "default"
|
||
break
|
||
|
||
return mcp_settings_json
|
||
|
||
|
||
def read_system_prompt():
|
||
"""读取通用的无状态系统prompt"""
|
||
with open("./agent_prompt.txt", "r", encoding="utf-8") as f:
|
||
return f.read().strip()
|
||
|
||
|
||
def init_agent_service():
|
||
"""默认初始化函数,保持向后兼容"""
|
||
return init_agent_service_universal("qwen3-next")
|
||
|
||
|
||
|
||
|
||
def init_agent_service_with_project(project_id: str, project_data_dir: str, model_name: str = "qwen3-next"):
|
||
"""支持项目目录的agent初始化函数 - 保持向后兼容"""
|
||
# 读取通用的系统prompt(无状态)
|
||
system = read_system_prompt()
|
||
|
||
# 读取MCP工具配置
|
||
tools = read_mcp_settings_with_project_restriction(project_data_dir)
|
||
|
||
# 创建默认的LLM配置(可以通过update_agent_llm动态更新)
|
||
llm_config = {
|
||
"model": model_name,
|
||
"model_server": "https://openrouter.ai/api/v1", # 默认服务器
|
||
"api_key": "default-key" # 默认密钥,实际使用时需要通过API传入
|
||
}
|
||
|
||
# 创建LLM实例
|
||
llm_instance = TextChatAtOAI(llm_config)
|
||
|
||
bot = Assistant(
|
||
llm=llm_instance, # 使用默认LLM初始化,可通过update_agent_llm动态更新
|
||
name=f"数据库助手-{project_id}",
|
||
description=f"项目 {project_id} 数据库查询",
|
||
system_message=system,
|
||
function_list=tools,
|
||
)
|
||
|
||
return bot
|
||
|
||
|
||
def init_agent_service_universal():
|
||
"""创建无状态的通用助手实例(使用默认LLM,可动态切换)"""
|
||
# 读取通用的系统prompt(无状态)
|
||
system = read_system_prompt()
|
||
|
||
# 读取基础的MCP工具配置(不包含项目限制)
|
||
tools = read_mcp_settings()
|
||
|
||
# 创建默认的LLM配置(可以通过update_agent_llm动态更新)
|
||
llm_config = {
|
||
"model": "qwen3-next", # 默认模型
|
||
"model_server": "https://openrouter.ai/api/v1", # 默认服务器
|
||
"api_key": "default-key" # 默认密钥,实际使用时需要通过API传入
|
||
}
|
||
|
||
# 创建LLM实例
|
||
llm_instance = TextChatAtOAI(llm_config)
|
||
|
||
bot = Assistant(
|
||
llm=llm_instance, # 使用默认LLM初始化,可通过update_agent_llm动态更新
|
||
name="通用数据检索助手",
|
||
description="无状态通用数据检索助手",
|
||
system_message=system,
|
||
function_list=tools,
|
||
)
|
||
|
||
return bot
|
||
|
||
|
||
def update_agent_llm(agent, model_name: str, api_key: str = None, model_server: str = None,generate_cfg: Dict = None):
|
||
"""动态更新助手实例的LLM,支持从接口传入参数"""
|
||
|
||
# 获取基础配置
|
||
llm_config = {
|
||
"model": model_name,
|
||
"api_key": api_key,
|
||
"model_server": model_server,
|
||
"generate_cfg": generate_cfg if generate_cfg else {}
|
||
}
|
||
|
||
# 创建LLM实例,确保不是字典
|
||
#if "llm_class" in llm_config:
|
||
# llm_instance = llm_config.get("llm_class", TextChatAtOAI)(llm_config)
|
||
#else:
|
||
# 使用默认的 TextChatAtOAI 类
|
||
llm_instance = TextChatAtOAI(llm_config)
|
||
|
||
# 动态设置LLM
|
||
agent.llm = llm_instance
|
||
|
||
return agent
|
||
|
||
|
||
def test(query="数据库里有几张表"):
|
||
# Define the agent - 使用通用初始化
|
||
bot = init_agent_service_universal()
|
||
|
||
# Chat
|
||
messages = []
|
||
|
||
messages.append({"role": "user", "content": query})
|
||
|
||
responses = []
|
||
for response in bot.run(messages):
|
||
responses.append(response)
|
||
|
||
# 只输出最终结果,不显示中间过程
|
||
if responses:
|
||
final_response = responses[-1][-1] # 取最后一个响应作为最终结果
|
||
print("Answer:", final_response["content"])
|
||
|