64 lines
2.5 KiB
Python
64 lines
2.5 KiB
Python
import json
|
||
from langchain.chat_models import init_chat_model
|
||
from deepagents import create_deep_agent
|
||
from langchain.agents import create_agent
|
||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||
from utils.fastapi_utils import detect_provider
|
||
|
||
# Utility functions
|
||
def read_system_prompt():
|
||
"""读取通用的无状态系统prompt"""
|
||
with open("./prompt/system_prompt_default.md", "r", encoding="utf-8") as f:
|
||
return f.read().strip()
|
||
|
||
|
||
def read_mcp_settings():
|
||
"""读取MCP工具配置"""
|
||
with open("./mcp/mcp_settings.json", "r") as f:
|
||
mcp_settings_json = json.load(f)
|
||
return mcp_settings_json
|
||
|
||
async def init_agent(model_name="qwen3-next", api_key=None,
|
||
model_server=None, generate_cfg=None,
|
||
system_prompt=None, mcp=None):
|
||
system = system_prompt if system_prompt else read_system_prompt()
|
||
mcp = mcp if mcp else read_mcp_settings()
|
||
# 修改mcp[0]["mcpServers"]列表,把 type 字段改成 transport ,如果没有的话,就默认transport:stdio
|
||
if mcp and len(mcp) > 0 and "mcpServers" in mcp[0]:
|
||
for server_name, server_config in mcp[0]["mcpServers"].items():
|
||
if isinstance(server_config, dict):
|
||
if "type" in server_config and "transport" not in server_config:
|
||
# 如果有 type 字段但没有 transport 字段,将 type 改为 transport
|
||
type_value = server_config.pop("type")
|
||
# 特殊处理:'streamable-http' 改为 'http'
|
||
if type_value == "streamable-http":
|
||
server_config["transport"] = "http"
|
||
else:
|
||
server_config["transport"] = type_value
|
||
elif "transport" not in server_config:
|
||
# 如果既没有 type 也没有 transport,添加默认的 transport: stdio
|
||
server_config["transport"] = "stdio"
|
||
|
||
mcp_client = MultiServerMCPClient(mcp[0]["mcpServers"])
|
||
mcp_tools = await mcp_client.get_tools()
|
||
|
||
# 检测或使用指定的提供商
|
||
model_provider,base_url = detect_provider(model_name,model_server)
|
||
|
||
# 构建模型参数
|
||
model_kwargs = {
|
||
"model": model_name,
|
||
"model_provider": model_provider,
|
||
"temperature": 0.8,
|
||
"base_url":base_url,
|
||
"api_key":api_key
|
||
}
|
||
llm_instance = init_chat_model(**model_kwargs)
|
||
|
||
agent = create_agent(
|
||
model=llm_instance,
|
||
system_prompt=system,
|
||
tools=mcp_tools
|
||
)
|
||
return agent
|