system prompt 中的 citation 规则(document/table/web 三类约80行)占用大量 token, 现将详细格式要求移到 rag_retrieve_server.py 中作为工具返回前缀按需注入, system prompt 仅保留精简版通用 placement rules。 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
345 lines
11 KiB
Python
345 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
RAG检索MCP服务器
|
||
调用本地RAG API进行文档检索
|
||
"""
|
||
|
||
import asyncio
|
||
import hashlib
|
||
import json
|
||
import sys
|
||
import os
|
||
from typing import Any, Dict, List
|
||
|
||
try:
|
||
import requests
|
||
except ImportError:
|
||
print("Error: requests module is required. Please install it with: pip install requests")
|
||
sys.exit(1)
|
||
|
||
from mcp_common import (
|
||
create_error_response,
|
||
create_success_response,
|
||
create_initialize_response,
|
||
create_ping_response,
|
||
create_tools_list_response,
|
||
load_tools_from_json,
|
||
handle_mcp_streaming
|
||
)
|
||
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
|
||
MASTERKEY = os.getenv("MASTERKEY", "master")
|
||
|
||
# Citation instruction prefixes injected into tool results
|
||
DOCUMENT_CITATION_INSTRUCTIONS = """<CITATION_INSTRUCTIONS>
|
||
When using the retrieved knowledge below, you MUST add XML citation tags for factual claims.
|
||
|
||
## Document Knowledge
|
||
Format: `<CITATION file="file_uuid" filename="name.pdf" page=3 />`
|
||
- Use `file` attribute with the UUID from document markers
|
||
- Use `filename` attribute with the actual filename from document markers
|
||
- Use `page` attribute (singular) with the page number
|
||
- `page` MUST be 0-based and must match the `pages:` values shown in the learned knowledge context
|
||
|
||
## Web Page Knowledge
|
||
Format: `<CITATION url="https://example.com/page" />`
|
||
- Use `url` attribute with the web page URL from the source metadata
|
||
- Do not use `file`, `filename`, or `page` attributes for web sources
|
||
- If content is grounded in a web source, prefer a web citation with `url` over a file citation
|
||
|
||
## Placement Rules
|
||
- Citations MUST appear IMMEDIATELY AFTER the paragraph or bullet list that uses the knowledge
|
||
- NEVER collect all citations and place them at the end of your response
|
||
- Limit to 1-2 citations per paragraph/bullet list
|
||
- If your answer uses learned knowledge, you MUST generate at least 1 `<CITATION ... />` in the response
|
||
</CITATION_INSTRUCTIONS>
|
||
|
||
"""
|
||
|
||
TABLE_CITATION_INSTRUCTIONS = """<CITATION_INSTRUCTIONS>
|
||
When using the retrieved table knowledge below, you MUST add XML citation tags for factual claims.
|
||
|
||
Format: `<CITATION file="file_id" filename="name.xlsx" sheet=1 rows=[2, 4] />`
|
||
- Parse `__src`: `F1S2R5` = file_ref F1, sheet 2, row 5
|
||
- Look up file_id in `file_ref_table`
|
||
- Combine same-sheet rows into one citation: `rows=[2, 4, 6]`
|
||
- MANDATORY: Create SEPARATE citation for EACH (file, sheet) combination
|
||
- NEVER put <CITATION> on the same line as a bullet point or table row
|
||
- Citations MUST be on separate lines AFTER the complete list/table
|
||
- NEVER include the `__src` column in your response - it is internal metadata only
|
||
- Citations MUST appear IMMEDIATELY AFTER the paragraph or bullet list that uses the knowledge
|
||
- NEVER collect all citations and place them at the end of your response
|
||
</CITATION_INSTRUCTIONS>
|
||
|
||
"""
|
||
|
||
def rag_retrieve(query: str, top_k: int = 100) -> Dict[str, Any]:
|
||
"""调用RAG检索API"""
|
||
try:
|
||
bot_id = ""
|
||
if len(sys.argv) > 1:
|
||
bot_id = sys.argv[1]
|
||
|
||
url = f"{BACKEND_HOST}/v1/rag_retrieve/{bot_id}"
|
||
if not url:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": "Error: RAG API URL not provided. Please provide URL as command line argument."
|
||
}
|
||
]
|
||
}
|
||
|
||
# 获取masterkey并生成认证token
|
||
masterkey = MASTERKEY
|
||
token_input = f"{masterkey}:{bot_id}"
|
||
auth_token = hashlib.md5(token_input.encode()).hexdigest()
|
||
|
||
headers = {
|
||
"content-type": "application/json",
|
||
"authorization": f"Bearer {auth_token}"
|
||
}
|
||
data = {
|
||
"query": query,
|
||
"top_k": top_k
|
||
}
|
||
|
||
# 发送POST请求
|
||
response = requests.post(url, json=data, headers=headers, timeout=30)
|
||
|
||
if response.status_code != 200:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Error: RAG API returned status code {response.status_code}. Response: {response.text}"
|
||
}
|
||
]
|
||
}
|
||
|
||
# 解析响应
|
||
try:
|
||
response_data = response.json()
|
||
except json.JSONDecodeError as e:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Error: Failed to parse API response as JSON. Error: {str(e)}, Raw response: {response.text}"
|
||
}
|
||
]
|
||
}
|
||
|
||
# 提取markdown字段
|
||
if "markdown" in response_data:
|
||
markdown_content = response_data["markdown"]
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": DOCUMENT_CITATION_INSTRUCTIONS + markdown_content
|
||
}
|
||
]
|
||
}
|
||
else:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Error: 'markdown' field not found in API response. Response: {json.dumps(response_data, indent=2, ensure_ascii=False)}"
|
||
}
|
||
]
|
||
}
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Error: Failed to connect to RAG API. {str(e)}"
|
||
}
|
||
]
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Error: {str(e)}"
|
||
}
|
||
]
|
||
}
|
||
|
||
|
||
def table_rag_retrieve(query: str) -> Dict[str, Any]:
|
||
"""调用Table RAG检索API"""
|
||
try:
|
||
bot_id = ""
|
||
if len(sys.argv) > 1:
|
||
bot_id = sys.argv[1]
|
||
|
||
url = f"{BACKEND_HOST}/v1/table_rag_retrieve/{bot_id}"
|
||
|
||
masterkey = MASTERKEY
|
||
token_input = f"{masterkey}:{bot_id}"
|
||
auth_token = hashlib.md5(token_input.encode()).hexdigest()
|
||
|
||
headers = {
|
||
"content-type": "application/json",
|
||
"authorization": f"Bearer {auth_token}"
|
||
}
|
||
data = {
|
||
"query": query,
|
||
}
|
||
|
||
response = requests.post(url, json=data, headers=headers, timeout=300)
|
||
|
||
if response.status_code != 200:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Error: Table RAG API returned status code {response.status_code}. Response: {response.text}"
|
||
}
|
||
]
|
||
}
|
||
|
||
try:
|
||
response_data = response.json()
|
||
except json.JSONDecodeError as e:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Error: Failed to parse API response as JSON. Error: {str(e)}, Raw response: {response.text}"
|
||
}
|
||
]
|
||
}
|
||
|
||
if "markdown" in response_data:
|
||
markdown_content = response_data["markdown"]
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": TABLE_CITATION_INSTRUCTIONS + markdown_content
|
||
}
|
||
]
|
||
}
|
||
else:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Error: 'markdown' field not found in API response. Response: {json.dumps(response_data, indent=2, ensure_ascii=False)}"
|
||
}
|
||
]
|
||
}
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Error: Failed to connect to Table RAG API. {str(e)}"
|
||
}
|
||
]
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Error: {str(e)}"
|
||
}
|
||
]
|
||
}
|
||
|
||
|
||
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Handle MCP request"""
|
||
try:
|
||
method = request.get("method")
|
||
params = request.get("params", {})
|
||
request_id = request.get("id")
|
||
|
||
if method == "initialize":
|
||
return create_initialize_response(request_id, "rag-retrieve")
|
||
|
||
elif method == "ping":
|
||
return create_ping_response(request_id)
|
||
|
||
elif method == "tools/list":
|
||
# 从 JSON 文件加载工具定义
|
||
tools = load_tools_from_json("rag_retrieve_tools.json")
|
||
if not tools:
|
||
# 如果 JSON 文件不存在,使用默认定义
|
||
tools = [
|
||
{
|
||
"name": "rag_retrieve",
|
||
"description": "调用RAG检索API,根据查询内容检索相关文档。返回包含相关内容的markdown格式结果。",
|
||
"inputSchema": {
|
||
"type": "object",
|
||
"properties": {
|
||
"query": {
|
||
"type": "string",
|
||
"description": "检索查询内容"
|
||
}
|
||
},
|
||
"required": ["query"]
|
||
}
|
||
}
|
||
]
|
||
return create_tools_list_response(request_id, tools)
|
||
|
||
elif method == "tools/call":
|
||
tool_name = params.get("name")
|
||
arguments = params.get("arguments", {})
|
||
|
||
if tool_name == "rag_retrieve":
|
||
query = arguments.get("query", "")
|
||
top_k = arguments.get("top_k", 100)
|
||
|
||
if not query:
|
||
return create_error_response(request_id, -32602, "Missing required parameter: query")
|
||
|
||
result = rag_retrieve(query, top_k)
|
||
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": request_id,
|
||
"result": result
|
||
}
|
||
|
||
elif tool_name == "table_rag_retrieve":
|
||
query = arguments.get("query", "")
|
||
|
||
if not query:
|
||
return create_error_response(request_id, -32602, "Missing required parameter: query")
|
||
|
||
result = table_rag_retrieve(query)
|
||
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": request_id,
|
||
"result": result
|
||
}
|
||
|
||
else:
|
||
return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}")
|
||
|
||
else:
|
||
return create_error_response(request_id, -32601, f"Unknown method: {method}")
|
||
|
||
except Exception as e:
|
||
return create_error_response(request.get("id"), -32603, f"Internal error: {str(e)}")
|
||
|
||
|
||
async def main():
|
||
"""Main entry point."""
|
||
await handle_mcp_streaming(handle_request)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|