Merge branch 'master' into prod
This commit is contained in:
commit
1a0d520bd4
198
mcp/rag_retrieve_server.py
Normal file
198
mcp/rag_retrieve_server.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
RAG检索MCP服务器
|
||||||
|
调用本地RAG API进行文档检索
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
except ImportError:
|
||||||
|
print("Error: requests module is required. Please install it with: pip install requests")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
from mcp_common import (
|
||||||
|
create_error_response,
|
||||||
|
create_success_response,
|
||||||
|
create_initialize_response,
|
||||||
|
create_ping_response,
|
||||||
|
create_tools_list_response,
|
||||||
|
load_tools_from_json,
|
||||||
|
handle_mcp_streaming
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rag_retrieve(query: str, top_k: int = 50) -> Dict[str, Any]:
|
||||||
|
"""调用RAG检索API"""
|
||||||
|
try:
|
||||||
|
url = ""
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
url = sys.argv[1]
|
||||||
|
|
||||||
|
if not url:
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Error: RAG API URL not provided. Please provide URL as command line argument."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"content-type": "application/json"
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"query": query,
|
||||||
|
"top_k": top_k
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送POST请求
|
||||||
|
response = requests.post(url, json=data, headers=headers, timeout=30)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"Error: RAG API returned status code {response.status_code}. Response: {response.text}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 解析响应
|
||||||
|
try:
|
||||||
|
response_data = response.json()
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"Error: Failed to parse API response as JSON. Error: {str(e)}, Raw response: {response.text}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# 提取markdown字段
|
||||||
|
if "markdown" in response_data:
|
||||||
|
markdown_content = response_data["markdown"]
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": markdown_content
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"Error: 'markdown' field not found in API response. Response: {json.dumps(response_data, indent=2, ensure_ascii=False)}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"Error: Failed to connect to RAG API. {str(e)}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"Error: {str(e)}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Handle MCP request"""
|
||||||
|
try:
|
||||||
|
method = request.get("method")
|
||||||
|
params = request.get("params", {})
|
||||||
|
request_id = request.get("id")
|
||||||
|
|
||||||
|
if method == "initialize":
|
||||||
|
return create_initialize_response(request_id, "rag-retrieve")
|
||||||
|
|
||||||
|
elif method == "ping":
|
||||||
|
return create_ping_response(request_id)
|
||||||
|
|
||||||
|
elif method == "tools/list":
|
||||||
|
# 从 JSON 文件加载工具定义
|
||||||
|
tools = load_tools_from_json("rag_retrieve_tools.json")
|
||||||
|
if not tools:
|
||||||
|
# 如果 JSON 文件不存在,使用默认定义
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"name": "rag_retrieve",
|
||||||
|
"description": "调用RAG检索API,根据查询内容检索相关文档。返回包含相关内容的markdown格式结果。",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "检索查询内容"
|
||||||
|
},
|
||||||
|
"top_k": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "返回结果的最大数量,默认50",
|
||||||
|
"default": 50
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return create_tools_list_response(request_id, tools)
|
||||||
|
|
||||||
|
elif method == "tools/call":
|
||||||
|
tool_name = params.get("name")
|
||||||
|
arguments = params.get("arguments", {})
|
||||||
|
|
||||||
|
if tool_name == "rag_retrieve":
|
||||||
|
query = arguments.get("query", "")
|
||||||
|
top_k = arguments.get("top_k", 50)
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
return create_error_response(request_id, -32602, "Missing required parameter: query")
|
||||||
|
|
||||||
|
result = rag_retrieve(query, top_k)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
return create_error_response(request_id, -32601, f"Unknown method: {method}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return create_error_response(request.get("id"), -32603, f"Internal error: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main entry point."""
|
||||||
|
await handle_mcp_streaming(handle_request)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
21
mcp/rag_retrieve_tools.json
Normal file
21
mcp/rag_retrieve_tools.json
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "rag_retrieve",
|
||||||
|
"description": "调用RAG检索API,根据查询内容检索相关文档。返回包含相关内容的markdown格式结果。",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "检索查询内容"
|
||||||
|
},
|
||||||
|
"top_k": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "返回结果的最大数量,默认50",
|
||||||
|
"default": 50
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
@ -103,8 +103,8 @@ class FileLoadedAgentManager:
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
# 实现参数优先级逻辑:传入参数 > 项目配置 > 默认配置
|
# 实现参数优先级逻辑:传入参数 > 项目配置 > 默认配置
|
||||||
final_system_prompt = load_system_prompt(project_dir, language, system_prompt, robot_type)
|
final_system_prompt = load_system_prompt(project_dir, language, system_prompt, robot_type, unique_id)
|
||||||
final_mcp_settings = load_mcp_settings(project_dir, mcp_settings)
|
final_mcp_settings = load_mcp_settings(project_dir, mcp_settings, unique_id)
|
||||||
|
|
||||||
cache_key = self._get_cache_key(unique_id)
|
cache_key = self._get_cache_key(unique_id)
|
||||||
|
|
||||||
|
|||||||
@ -7,9 +7,18 @@ import json
|
|||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
def load_system_prompt(project_dir: str, language: str = None, system_prompt: str=None, robot_type: str = "agent") -> str:
|
def load_system_prompt(project_dir: str, language: str = None, system_prompt: str=None, robot_type: str = "agent", unique_id: str="") -> str:
|
||||||
|
# 获取语言显示名称
|
||||||
|
language_display_map = {
|
||||||
|
'zh': '中文',
|
||||||
|
'en': 'English',
|
||||||
|
'ja': '日本語',
|
||||||
|
'jp': '日本語'
|
||||||
|
}
|
||||||
|
language_display = language_display_map.get(language, language if language else 'English')
|
||||||
|
|
||||||
if robot_type == "agent":
|
if robot_type == "agent":
|
||||||
return system_prompt or ""
|
return system_prompt.replace("{language}", language_display).replace('{unique_id}', unique_id) or ""
|
||||||
if robot_type == "catalog_agent":
|
if robot_type == "catalog_agent":
|
||||||
"""
|
"""
|
||||||
优先使用项目目录的system_prompt,没有才使用默认的system_prompt_default.md
|
优先使用项目目录的system_prompt,没有才使用默认的system_prompt_default.md
|
||||||
@ -52,18 +61,9 @@ def load_system_prompt(project_dir: str, language: str = None, system_prompt: st
|
|||||||
with open(readme_path, "r", encoding="utf-8") as f:
|
with open(readme_path, "r", encoding="utf-8") as f:
|
||||||
readme = f.read().strip()
|
readme = f.read().strip()
|
||||||
|
|
||||||
# 获取语言显示名称
|
return system_prompt_default.replace("{readme}", str(readme)).replace("{language}", language_display).replace("{extra_prompt}", system_prompt or "").replace('{unique_id}', unique_id) or ""
|
||||||
language_display_map = {
|
|
||||||
'zh': '中文',
|
|
||||||
'en': 'English',
|
|
||||||
'ja': '日本語',
|
|
||||||
'jp': '日本語'
|
|
||||||
}
|
|
||||||
language_display = language_display_map.get(language, language if language else 'English')
|
|
||||||
|
|
||||||
return system_prompt_default.replace("{readme}", str(readme)).replace("{language}", language_display).replace("{extra_prompt}", system_prompt or "") or ""
|
|
||||||
else:
|
else:
|
||||||
return system_prompt or ""
|
return system_prompt.replace("{language}", language_display).replace('{unique_id}', unique_id) or ""
|
||||||
|
|
||||||
def get_available_prompt_languages() -> list:
|
def get_available_prompt_languages() -> list:
|
||||||
"""
|
"""
|
||||||
@ -85,7 +85,7 @@ def get_available_prompt_languages() -> list:
|
|||||||
return available_languages
|
return available_languages
|
||||||
|
|
||||||
|
|
||||||
def replace_mcp_placeholders(mcp_settings: List[Dict], dataset_dir: str) -> List[Dict]:
|
def replace_mcp_placeholders(mcp_settings: List[Dict], dataset_dir: str, unique_id: str) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
替换 MCP 配置中的占位符
|
替换 MCP 配置中的占位符
|
||||||
"""
|
"""
|
||||||
@ -98,21 +98,21 @@ def replace_mcp_placeholders(mcp_settings: List[Dict], dataset_dir: str) -> List
|
|||||||
for key, value in obj.items():
|
for key, value in obj.items():
|
||||||
if key == 'args' and isinstance(value, list):
|
if key == 'args' and isinstance(value, list):
|
||||||
# 特别处理 args 列表
|
# 特别处理 args 列表
|
||||||
obj[key] = [item.replace('{dataset_dir}', dataset_dir) if isinstance(item, str) else item
|
obj[key] = [item.replace('{dataset_dir}', dataset_dir).replace('{unique_id}', unique_id) if isinstance(item, str) else item
|
||||||
for item in value]
|
for item in value]
|
||||||
elif isinstance(value, (dict, list)):
|
elif isinstance(value, (dict, list)):
|
||||||
obj[key] = replace_placeholders_in_obj(value)
|
obj[key] = replace_placeholders_in_obj(value)
|
||||||
elif isinstance(value, str):
|
elif isinstance(value, str):
|
||||||
obj[key] = value.replace('{dataset_dir}', dataset_dir)
|
obj[key] = value.replace('{dataset_dir}', dataset_dir).replace('{unique_id}', unique_id)
|
||||||
elif isinstance(obj, list):
|
elif isinstance(obj, list):
|
||||||
return [replace_placeholders_in_obj(item) if isinstance(item, (dict, list)) else
|
return [replace_placeholders_in_obj(item) if isinstance(item, (dict, list)) else
|
||||||
item.replace('{dataset_dir}', dataset_dir) if isinstance(item, str) else item
|
item.replace('{dataset_dir}', dataset_dir).replace('{unique_id}', unique_id) if isinstance(item, str) else item
|
||||||
for item in obj]
|
for item in obj]
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
return replace_placeholders_in_obj(mcp_settings)
|
return replace_placeholders_in_obj(mcp_settings)
|
||||||
|
|
||||||
def load_mcp_settings(project_dir: str, mcp_settings: list=None) -> List[Dict]:
|
def load_mcp_settings(project_dir: str, mcp_settings: list=None, unique_id: str="") -> List[Dict]:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
优先使用项目目录的mcp_settings.json,没有才使用默认的mcp/mcp_settings.json
|
优先使用项目目录的mcp_settings.json,没有才使用默认的mcp/mcp_settings.json
|
||||||
@ -164,7 +164,7 @@ def load_mcp_settings(project_dir: str, mcp_settings: list=None) -> List[Dict]:
|
|||||||
# 计算 dataset_dir 用于替换 MCP 配置中的占位符
|
# 计算 dataset_dir 用于替换 MCP 配置中的占位符
|
||||||
dataset_dir = os.path.join(project_dir, "dataset")
|
dataset_dir = os.path.join(project_dir, "dataset")
|
||||||
# 替换 MCP 配置中的 {dataset_dir} 占位符
|
# 替换 MCP 配置中的 {dataset_dir} 占位符
|
||||||
mcp_settings = replace_mcp_placeholders(mcp_settings, dataset_dir)
|
mcp_settings = replace_mcp_placeholders(mcp_settings, dataset_dir, unique_id)
|
||||||
return mcp_settings
|
return mcp_settings
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user