213 lines
7.3 KiB
Python
213 lines
7.3 KiB
Python
# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
"""A sqlite database assistant implemented by assistant"""
|
||
|
||
import argparse
|
||
import asyncio
|
||
import json
|
||
import os
|
||
from typing import Dict, List, Optional, Union
|
||
|
||
from qwen_agent.agents import Assistant
|
||
from qwen_agent.llm.oai import TextChatAtOAI
|
||
from qwen_agent.llm.schema import ASSISTANT, FUNCTION, Message
|
||
from qwen_agent.utils.output_beautify import typewriter_print
|
||
|
||
ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), "resource")
|
||
|
||
|
||
|
||
|
||
def read_mcp_settings():
|
||
with open("./mcp/mcp_settings.json", "r") as f:
|
||
mcp_settings_json = json.load(f)
|
||
return mcp_settings_json
|
||
|
||
|
||
def read_mcp_settings_with_project_restriction(project_data_dir: str):
|
||
"""读取MCP配置并添加项目目录限制"""
|
||
with open("./mcp/mcp_settings.json", "r") as f:
|
||
mcp_settings_json = json.load(f)
|
||
|
||
# 为json-reader添加项目目录限制
|
||
for server_config in mcp_settings_json:
|
||
if "mcpServers" in server_config:
|
||
for server_name, server_info in server_config["mcpServers"].items():
|
||
if server_name == "json-reader":
|
||
# 添加环境变量来传递项目目录限制
|
||
if "env" not in server_info:
|
||
server_info["env"] = {}
|
||
server_info["env"]["PROJECT_DATA_DIR"] = project_data_dir
|
||
server_info["env"]["PROJECT_ID"] = project_data_dir.split("/")[-2] if "/" in project_data_dir else "default"
|
||
break
|
||
|
||
return mcp_settings_json
|
||
|
||
|
||
def read_system_prompt():
|
||
"""读取通用的无状态系统prompt"""
|
||
with open("./system_prompt.md", "r", encoding="utf-8") as f:
|
||
return f.read().strip()
|
||
|
||
|
||
def init_agent_service():
|
||
"""默认初始化函数,保持向后兼容"""
|
||
return init_agent_service_universal("qwen3-next")
|
||
|
||
|
||
|
||
|
||
def init_agent_service_with_project(project_id: str, project_data_dir: str, model_name: str = "qwen3-next"):
|
||
"""支持项目目录的agent初始化函数 - 保持向后兼容"""
|
||
# 读取通用的系统prompt(无状态)
|
||
system = read_system_prompt()
|
||
|
||
# 读取MCP工具配置
|
||
tools = read_mcp_settings_with_project_restriction(project_data_dir)
|
||
|
||
# 创建默认的LLM配置(可以通过update_agent_llm动态更新)
|
||
llm_config = {
|
||
"model": model_name,
|
||
"model_server": "https://openrouter.ai/api/v1", # 默认服务器
|
||
"api_key": "default-key" # 默认密钥,实际使用时需要通过API传入
|
||
}
|
||
|
||
# 创建LLM实例
|
||
llm_instance = TextChatAtOAI(llm_config)
|
||
|
||
bot = Assistant(
|
||
llm=llm_instance, # 使用默认LLM初始化,可通过update_agent_llm动态更新
|
||
name=f"数据库助手-{project_id}",
|
||
description=f"项目 {project_id} 数据库查询",
|
||
system_message=system,
|
||
function_list=tools,
|
||
)
|
||
|
||
return bot
|
||
|
||
|
||
def init_agent_service_universal():
|
||
"""创建无状态的通用助手实例(使用默认LLM,可动态切换)"""
|
||
return init_agent_service_with_files(files=None)
|
||
|
||
|
||
def init_agent_service_with_files(rag_cfg: Optional[Dict] = None,
|
||
model_name: str = "qwen3-next", api_key: Optional[str] = None,
|
||
model_server: Optional[str] = None, generate_cfg: Optional[Dict] = None,
|
||
system_prompt: Optional[str] = None, mcp: Optional[List[Dict]] = None):
|
||
"""创建支持预加载文件的助手实例
|
||
|
||
Args:
|
||
files: 预加载的文件路径列表
|
||
rag_cfg: RAG配置参数
|
||
model_name: 模型名称
|
||
api_key: API 密钥
|
||
model_server: 模型服务器地址
|
||
generate_cfg: 生成配置
|
||
system_prompt: 系统提示词,如果未提供则使用本地提示词
|
||
mcp: MCP配置,如果未提供则使用本地mcp_settings.json文件
|
||
"""
|
||
# 使用传入的system_prompt或读取本地通用的系统prompt
|
||
system = system_prompt if system_prompt else read_system_prompt()
|
||
|
||
# 使用传入的mcp配置或读取基础的MCP工具配置(不包含项目限制)
|
||
tools = mcp if mcp else read_mcp_settings()
|
||
|
||
# 创建LLM配置,使用传入的参数
|
||
llm_config = {
|
||
"model": model_name,
|
||
"api_key": api_key,
|
||
"model_server": model_server,
|
||
"generate_cfg": generate_cfg if generate_cfg else {}
|
||
}
|
||
|
||
# 创建LLM实例
|
||
llm_instance = TextChatAtOAI(llm_config)
|
||
|
||
# 配置RAG参数以优化大量文件处理
|
||
default_rag_cfg = {
|
||
'max_ref_token': 8000, # 增加引用token限制
|
||
'parser_page_size': 1000, # 增加解析页面大小
|
||
'rag_keygen_strategy': 'SplitQueryThenGenKeyword', # 使用关键词生成策略
|
||
'rag_searchers': ['keyword_search', 'front_page_search'] # 混合搜索策略
|
||
}
|
||
|
||
# 合并用户提供的RAG配置
|
||
final_rag_cfg = {**default_rag_cfg, **(rag_cfg or {})}
|
||
|
||
bot = Assistant(
|
||
llm=llm_instance, # 使用默认LLM初始化,可通过update_agent_llm动态更新
|
||
name="数据检索助手",
|
||
description="支持预加载文件的数据检索助手",
|
||
system_message=system,
|
||
function_list=tools,
|
||
#files=files, # 预加载文件列表
|
||
#rag_cfg=final_rag_cfg, # RAG配置
|
||
)
|
||
|
||
return bot
|
||
|
||
|
||
def update_agent_llm(agent, model_name: str, api_key: str = None, model_server: str = None, generate_cfg: Dict = None, system_prompt: str = None, mcp_settings: List[Dict] = None):
|
||
"""动态更新助手实例的LLM和配置,支持从接口传入参数"""
|
||
|
||
# 获取基础配置
|
||
llm_config = {
|
||
"model": model_name,
|
||
"api_key": api_key,
|
||
"model_server": model_server,
|
||
"generate_cfg": generate_cfg if generate_cfg else {}
|
||
}
|
||
|
||
# 创建LLM实例,确保不是字典
|
||
#if "llm_class" in llm_config:
|
||
# llm_instance = llm_config.get("llm_class", TextChatAtOAI)(llm_config)
|
||
#else:
|
||
# 使用默认的 TextChatAtOAI 类
|
||
llm_instance = TextChatAtOAI(llm_config)
|
||
|
||
# 动态设置LLM
|
||
agent.llm = llm_instance
|
||
|
||
# 更新系统消息(如果提供)
|
||
if system_prompt:
|
||
agent.system_message = system_prompt
|
||
|
||
# 更新MCP设置(如果提供)
|
||
if mcp_settings:
|
||
agent.function_list = mcp_settings
|
||
|
||
return agent
|
||
|
||
|
||
def test(query="数据库里有几张表"):
|
||
# Define the agent - 使用通用初始化
|
||
bot = init_agent_service_universal()
|
||
|
||
# Chat
|
||
messages = []
|
||
|
||
messages.append({"role": "user", "content": query})
|
||
|
||
responses = []
|
||
for response in bot.run(messages):
|
||
responses.append(response)
|
||
|
||
# 只输出最终结果,不显示中间过程
|
||
if responses:
|
||
final_response = responses[-1][-1] # 取最后一个响应作为最终结果
|
||
print("Answer:", final_response["content"])
|
||
|