catalog-agent/gbase_agent.py
2025-10-17 16:16:41 +08:00

213 lines
7.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2023 The Qwen team, Alibaba Group. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A sqlite database assistant implemented by assistant"""
import argparse
import asyncio
import json
import os
from typing import Dict, List, Optional, Union
from qwen_agent.agents import Assistant
from qwen_agent.llm.oai import TextChatAtOAI
from qwen_agent.llm.schema import ASSISTANT, FUNCTION, Message
from qwen_agent.utils.output_beautify import typewriter_print
ROOT_RESOURCE = os.path.join(os.path.dirname(__file__), "resource")
def read_mcp_settings():
with open("./mcp/mcp_settings.json", "r") as f:
mcp_settings_json = json.load(f)
return mcp_settings_json
def read_mcp_settings_with_project_restriction(project_data_dir: str):
"""读取MCP配置并添加项目目录限制"""
with open("./mcp/mcp_settings.json", "r") as f:
mcp_settings_json = json.load(f)
# 为json-reader添加项目目录限制
for server_config in mcp_settings_json:
if "mcpServers" in server_config:
for server_name, server_info in server_config["mcpServers"].items():
if server_name == "json-reader":
# 添加环境变量来传递项目目录限制
if "env" not in server_info:
server_info["env"] = {}
server_info["env"]["PROJECT_DATA_DIR"] = project_data_dir
server_info["env"]["PROJECT_ID"] = project_data_dir.split("/")[-2] if "/" in project_data_dir else "default"
break
return mcp_settings_json
def read_system_prompt():
"""读取通用的无状态系统prompt"""
with open("./system_prompt.md", "r", encoding="utf-8") as f:
return f.read().strip()
def init_agent_service():
"""默认初始化函数,保持向后兼容"""
return init_agent_service_universal("qwen3-next")
def init_agent_service_with_project(project_id: str, project_data_dir: str, model_name: str = "qwen3-next"):
"""支持项目目录的agent初始化函数 - 保持向后兼容"""
# 读取通用的系统prompt无状态
system = read_system_prompt()
# 读取MCP工具配置
tools = read_mcp_settings_with_project_restriction(project_data_dir)
# 创建默认的LLM配置可以通过update_agent_llm动态更新
llm_config = {
"model": model_name,
"model_server": "https://openrouter.ai/api/v1", # 默认服务器
"api_key": "default-key" # 默认密钥实际使用时需要通过API传入
}
# 创建LLM实例
llm_instance = TextChatAtOAI(llm_config)
bot = Assistant(
llm=llm_instance, # 使用默认LLM初始化可通过update_agent_llm动态更新
name=f"数据库助手-{project_id}",
description=f"项目 {project_id} 数据库查询",
system_message=system,
function_list=tools,
)
return bot
def init_agent_service_universal():
"""创建无状态的通用助手实例使用默认LLM可动态切换"""
return init_agent_service_with_files(files=None)
def init_agent_service_with_files(rag_cfg: Optional[Dict] = None,
model_name: str = "qwen3-next", api_key: Optional[str] = None,
model_server: Optional[str] = None, generate_cfg: Optional[Dict] = None,
system_prompt: Optional[str] = None, mcp: Optional[List[Dict]] = None):
"""创建支持预加载文件的助手实例
Args:
files: 预加载的文件路径列表
rag_cfg: RAG配置参数
model_name: 模型名称
api_key: API 密钥
model_server: 模型服务器地址
generate_cfg: 生成配置
system_prompt: 系统提示词,如果未提供则使用本地提示词
mcp: MCP配置如果未提供则使用本地mcp_settings.json文件
"""
# 使用传入的system_prompt或读取本地通用的系统prompt
system = system_prompt if system_prompt else read_system_prompt()
# 使用传入的mcp配置或读取基础的MCP工具配置不包含项目限制
tools = mcp if mcp else read_mcp_settings()
# 创建LLM配置使用传入的参数
llm_config = {
"model": model_name,
"api_key": api_key,
"model_server": model_server,
"generate_cfg": generate_cfg if generate_cfg else {}
}
# 创建LLM实例
llm_instance = TextChatAtOAI(llm_config)
# 配置RAG参数以优化大量文件处理
default_rag_cfg = {
'max_ref_token': 8000, # 增加引用token限制
'parser_page_size': 1000, # 增加解析页面大小
'rag_keygen_strategy': 'SplitQueryThenGenKeyword', # 使用关键词生成策略
'rag_searchers': ['keyword_search', 'front_page_search'] # 混合搜索策略
}
# 合并用户提供的RAG配置
final_rag_cfg = {**default_rag_cfg, **(rag_cfg or {})}
bot = Assistant(
llm=llm_instance, # 使用默认LLM初始化可通过update_agent_llm动态更新
name="数据检索助手",
description="支持预加载文件的数据检索助手",
system_message=system,
function_list=tools,
#files=files, # 预加载文件列表
#rag_cfg=final_rag_cfg, # RAG配置
)
return bot
def update_agent_llm(agent, model_name: str, api_key: str = None, model_server: str = None, generate_cfg: Dict = None, system_prompt: str = None, mcp_settings: List[Dict] = None):
"""动态更新助手实例的LLM和配置支持从接口传入参数"""
# 获取基础配置
llm_config = {
"model": model_name,
"api_key": api_key,
"model_server": model_server,
"generate_cfg": generate_cfg if generate_cfg else {}
}
# 创建LLM实例确保不是字典
#if "llm_class" in llm_config:
# llm_instance = llm_config.get("llm_class", TextChatAtOAI)(llm_config)
#else:
# 使用默认的 TextChatAtOAI 类
llm_instance = TextChatAtOAI(llm_config)
# 动态设置LLM
agent.llm = llm_instance
# 更新系统消息(如果提供)
if system_prompt:
agent.system_message = system_prompt
# 更新MCP设置如果提供
if mcp_settings:
agent.function_list = mcp_settings
return agent
def test(query="数据库里有几张表"):
# Define the agent - 使用通用初始化
bot = init_agent_service_universal()
# Chat
messages = []
messages.append({"role": "user", "content": query})
responses = []
for response in bot.run(messages):
responses.append(response)
# 只输出最终结果,不显示中间过程
if responses:
final_response = responses[-1][-1] # 取最后一个响应作为最终结果
print("Answer:", final_response["content"])