qwen_agent/mcp/rag_retrieve_server.py
朱潮 6300eea61d refactor: 将 citation 详细提示词从 system prompt 移至 RAG tool result 按需注入
system prompt 中的 citation 规则(document/table/web 三类约80行)占用大量 token,
现将详细格式要求移到 rag_retrieve_server.py 中作为工具返回前缀按需注入,
system prompt 仅保留精简版通用 placement rules。

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 12:30:20 +08:00

345 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
RAG检索MCP服务器
调用本地RAG API进行文档检索
"""
import asyncio
import hashlib
import json
import sys
import os
from typing import Any, Dict, List
try:
import requests
except ImportError:
print("Error: requests module is required. Please install it with: pip install requests")
sys.exit(1)
from mcp_common import (
create_error_response,
create_success_response,
create_initialize_response,
create_ping_response,
create_tools_list_response,
load_tools_from_json,
handle_mcp_streaming
)
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
MASTERKEY = os.getenv("MASTERKEY", "master")
# Citation instruction prefixes injected into tool results
DOCUMENT_CITATION_INSTRUCTIONS = """<CITATION_INSTRUCTIONS>
When using the retrieved knowledge below, you MUST add XML citation tags for factual claims.
## Document Knowledge
Format: `<CITATION file="file_uuid" filename="name.pdf" page=3 />`
- Use `file` attribute with the UUID from document markers
- Use `filename` attribute with the actual filename from document markers
- Use `page` attribute (singular) with the page number
- `page` MUST be 0-based and must match the `pages:` values shown in the learned knowledge context
## Web Page Knowledge
Format: `<CITATION url="https://example.com/page" />`
- Use `url` attribute with the web page URL from the source metadata
- Do not use `file`, `filename`, or `page` attributes for web sources
- If content is grounded in a web source, prefer a web citation with `url` over a file citation
## Placement Rules
- Citations MUST appear IMMEDIATELY AFTER the paragraph or bullet list that uses the knowledge
- NEVER collect all citations and place them at the end of your response
- Limit to 1-2 citations per paragraph/bullet list
- If your answer uses learned knowledge, you MUST generate at least 1 `<CITATION ... />` in the response
</CITATION_INSTRUCTIONS>
"""
TABLE_CITATION_INSTRUCTIONS = """<CITATION_INSTRUCTIONS>
When using the retrieved table knowledge below, you MUST add XML citation tags for factual claims.
Format: `<CITATION file="file_id" filename="name.xlsx" sheet=1 rows=[2, 4] />`
- Parse `__src`: `F1S2R5` = file_ref F1, sheet 2, row 5
- Look up file_id in `file_ref_table`
- Combine same-sheet rows into one citation: `rows=[2, 4, 6]`
- MANDATORY: Create SEPARATE citation for EACH (file, sheet) combination
- NEVER put <CITATION> on the same line as a bullet point or table row
- Citations MUST be on separate lines AFTER the complete list/table
- NEVER include the `__src` column in your response - it is internal metadata only
- Citations MUST appear IMMEDIATELY AFTER the paragraph or bullet list that uses the knowledge
- NEVER collect all citations and place them at the end of your response
</CITATION_INSTRUCTIONS>
"""
def rag_retrieve(query: str, top_k: int = 100) -> Dict[str, Any]:
"""调用RAG检索API"""
try:
bot_id = ""
if len(sys.argv) > 1:
bot_id = sys.argv[1]
url = f"{BACKEND_HOST}/v1/rag_retrieve/{bot_id}"
if not url:
return {
"content": [
{
"type": "text",
"text": "Error: RAG API URL not provided. Please provide URL as command line argument."
}
]
}
# 获取masterkey并生成认证token
masterkey = MASTERKEY
token_input = f"{masterkey}:{bot_id}"
auth_token = hashlib.md5(token_input.encode()).hexdigest()
headers = {
"content-type": "application/json",
"authorization": f"Bearer {auth_token}"
}
data = {
"query": query,
"top_k": top_k
}
# 发送POST请求
response = requests.post(url, json=data, headers=headers, timeout=30)
if response.status_code != 200:
return {
"content": [
{
"type": "text",
"text": f"Error: RAG API returned status code {response.status_code}. Response: {response.text}"
}
]
}
# 解析响应
try:
response_data = response.json()
except json.JSONDecodeError as e:
return {
"content": [
{
"type": "text",
"text": f"Error: Failed to parse API response as JSON. Error: {str(e)}, Raw response: {response.text}"
}
]
}
# 提取markdown字段
if "markdown" in response_data:
markdown_content = response_data["markdown"]
return {
"content": [
{
"type": "text",
"text": DOCUMENT_CITATION_INSTRUCTIONS + markdown_content
}
]
}
else:
return {
"content": [
{
"type": "text",
"text": f"Error: 'markdown' field not found in API response. Response: {json.dumps(response_data, indent=2, ensure_ascii=False)}"
}
]
}
except requests.exceptions.RequestException as e:
return {
"content": [
{
"type": "text",
"text": f"Error: Failed to connect to RAG API. {str(e)}"
}
]
}
except Exception as e:
return {
"content": [
{
"type": "text",
"text": f"Error: {str(e)}"
}
]
}
def table_rag_retrieve(query: str) -> Dict[str, Any]:
"""调用Table RAG检索API"""
try:
bot_id = ""
if len(sys.argv) > 1:
bot_id = sys.argv[1]
url = f"{BACKEND_HOST}/v1/table_rag_retrieve/{bot_id}"
masterkey = MASTERKEY
token_input = f"{masterkey}:{bot_id}"
auth_token = hashlib.md5(token_input.encode()).hexdigest()
headers = {
"content-type": "application/json",
"authorization": f"Bearer {auth_token}"
}
data = {
"query": query,
}
response = requests.post(url, json=data, headers=headers, timeout=300)
if response.status_code != 200:
return {
"content": [
{
"type": "text",
"text": f"Error: Table RAG API returned status code {response.status_code}. Response: {response.text}"
}
]
}
try:
response_data = response.json()
except json.JSONDecodeError as e:
return {
"content": [
{
"type": "text",
"text": f"Error: Failed to parse API response as JSON. Error: {str(e)}, Raw response: {response.text}"
}
]
}
if "markdown" in response_data:
markdown_content = response_data["markdown"]
return {
"content": [
{
"type": "text",
"text": TABLE_CITATION_INSTRUCTIONS + markdown_content
}
]
}
else:
return {
"content": [
{
"type": "text",
"text": f"Error: 'markdown' field not found in API response. Response: {json.dumps(response_data, indent=2, ensure_ascii=False)}"
}
]
}
except requests.exceptions.RequestException as e:
return {
"content": [
{
"type": "text",
"text": f"Error: Failed to connect to Table RAG API. {str(e)}"
}
]
}
except Exception as e:
return {
"content": [
{
"type": "text",
"text": f"Error: {str(e)}"
}
]
}
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""Handle MCP request"""
try:
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
if method == "initialize":
return create_initialize_response(request_id, "rag-retrieve")
elif method == "ping":
return create_ping_response(request_id)
elif method == "tools/list":
# 从 JSON 文件加载工具定义
tools = load_tools_from_json("rag_retrieve_tools.json")
if not tools:
# 如果 JSON 文件不存在,使用默认定义
tools = [
{
"name": "rag_retrieve",
"description": "调用RAG检索API根据查询内容检索相关文档。返回包含相关内容的markdown格式结果。",
"inputSchema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "检索查询内容"
}
},
"required": ["query"]
}
}
]
return create_tools_list_response(request_id, tools)
elif method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
if tool_name == "rag_retrieve":
query = arguments.get("query", "")
top_k = arguments.get("top_k", 100)
if not query:
return create_error_response(request_id, -32602, "Missing required parameter: query")
result = rag_retrieve(query, top_k)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
elif tool_name == "table_rag_retrieve":
query = arguments.get("query", "")
if not query:
return create_error_response(request_id, -32602, "Missing required parameter: query")
result = table_rag_retrieve(query)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
else:
return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}")
else:
return create_error_response(request_id, -32601, f"Unknown method: {method}")
except Exception as e:
return create_error_response(request.get("id"), -32603, f"Internal error: {str(e)}")
async def main():
"""Main entry point."""
await handle_mcp_streaming(handle_request)
if __name__ == "__main__":
asyncio.run(main())