qwen_agent/skills/autoload/support/rag-retrieve/rag_retrieve_server.py
2026-06-05 14:49:54 +08:00

368 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
RAG retrieval MCP server.
Calls the local RAG API to retrieve documents.
"""
import asyncio
import hashlib
import json
import re
import sys
import os
from typing import Any, Dict, List
try:
import requests
except ImportError:
print("Error: requests module is required. Please install it with: pip install requests")
sys.exit(1)
from mcp_common import (
create_error_response,
create_success_response,
create_initialize_response,
create_ping_response,
create_tools_list_response,
load_tools_from_json,
handle_mcp_streaming
)
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
MASTERKEY = os.getenv("MASTERKEY", "master")
# Citation instruction prefixes injected into tool results
DOCUMENT_CITATION_INSTRUCTIONS = """<CITATION_INSTRUCTIONS>
When using the retrieved knowledge below, you MUST add XML citation tags for factual claims.
## Document Knowledge
Format: `<CITATION file="file_uuid" filename="name.pdf" page=3 />`
- Use `file` attribute with the UUID from document markers
- Use `filename` attribute with the actual filename from document markers
- Use `page` attribute (singular) with the page number
- `page` MUST be 0-based and must match the `pages:` values shown in the learned knowledge context
## Web Page Knowledge
Format: `<CITATION url="https://example.com/page" />`
- Use `url` attribute with the web page URL from the source metadata
- Do not use `file`, `filename`, or `page` attributes for web sources
- If content is grounded in a web source, prefer a web citation with `url` over a file citation
## Placement Rules
- Citations MUST appear IMMEDIATELY AFTER the paragraph or bullet list that uses the knowledge
- NEVER collect all citations and place them at the end of your response
- Limit to 1-2 citations per paragraph/bullet list
- If your answer uses learned knowledge, you MUST generate at least 1 `<CITATION ... />` in the response
</CITATION_INSTRUCTIONS>
"""
TABLE_CITATION_INSTRUCTIONS = """<CITATION_INSTRUCTIONS>
When using the retrieved table knowledge below, you MUST add XML citation tags for factual claims.
Format: `<CITATION file="file_id" filename="name.xlsx" sheet=1 rows=[2, 4] />`
- Parse `__src`: `F1S2R5` = file_ref F1, sheet 2, row 5
- Look up file_id in `file_ref_table`
- Combine same-sheet rows into one citation: `rows=[2, 4, 6]`
- MANDATORY: Create SEPARATE citation for EACH (file, sheet) combination
- NEVER put <CITATION> on the same line as a bullet point or table row
- Citations MUST be on separate lines AFTER the complete list/table
- NEVER include the `__src` column in your response - it is internal metadata only
- Citations MUST appear IMMEDIATELY AFTER the paragraph or bullet list that uses the knowledge
- NEVER collect all citations and place them at the end of your response
</CITATION_INSTRUCTIONS>
"""
def rag_retrieve(query: str, top_k: int = 100, trace_id: str = "") -> Dict[str, Any]:
"""Call the RAG retrieval API."""
try:
bot_id = ""
if len(sys.argv) > 1:
bot_id = sys.argv[1]
url = f"{BACKEND_HOST}/v1/rag_retrieve/{bot_id}"
if not url:
return {
"content": [
{
"type": "text",
"text": "Error: RAG API URL not provided. Please provide URL as command line argument."
}
]
}
# Get the master key and generate the authentication token
masterkey = MASTERKEY
token_input = f"{masterkey}:{bot_id}"
auth_token = hashlib.md5(token_input.encode()).hexdigest()
headers = {
"content-type": "application/json",
"authorization": f"Bearer {auth_token}"
}
if trace_id:
headers["X-Request-ID"] = trace_id
data = {
"query": query,
"top_k": top_k
}
# Send the POST request
response = requests.post(url, json=data, headers=headers, timeout=30)
if response.status_code != 200:
return {
"content": [
{
"type": "text",
"text": f"Error: RAG API returned status code {response.status_code}. Response: {response.text}"
}
]
}
# Parse the response
try:
response_data = response.json()
except json.JSONDecodeError as e:
return {
"content": [
{
"type": "text",
"text": f"Error: Failed to parse API response as JSON. Error: {str(e)}, Raw response: {response.text}"
}
]
}
# Extract the markdown field
if "markdown" in response_data:
markdown_content = response_data["markdown"]
return {
"content": [
{
"type": "text",
"text": DOCUMENT_CITATION_INSTRUCTIONS + markdown_content
}
]
}
else:
return {
"content": [
{
"type": "text",
"text": f"Error: 'markdown' field not found in API response. Response: {json.dumps(response_data, indent=2, ensure_ascii=False)}"
}
]
}
except requests.exceptions.RequestException as e:
return {
"content": [
{
"type": "text",
"text": f"Error: Failed to connect to RAG API. {str(e)}"
}
]
}
except Exception as e:
return {
"content": [
{
"type": "text",
"text": f"Error: {str(e)}"
}
]
}
def table_rag_retrieve(query: str, trace_id: str = "") -> Dict[str, Any]:
"""Call the Table RAG retrieval API."""
try:
bot_id = ""
if len(sys.argv) > 1:
bot_id = sys.argv[1]
url = f"{BACKEND_HOST}/v1/table_rag_retrieve/{bot_id}"
masterkey = MASTERKEY
token_input = f"{masterkey}:{bot_id}"
auth_token = hashlib.md5(token_input.encode()).hexdigest()
headers = {
"content-type": "application/json",
"authorization": f"Bearer {auth_token}"
}
if trace_id:
headers["X-Request-ID"] = trace_id
data = {
"query": query,
}
response = requests.post(url, json=data, headers=headers, timeout=300)
if response.status_code != 200:
return {
"content": [
{
"type": "text",
"text": f"Error: Table RAG API returned status code {response.status_code}. Response: {response.text}"
}
]
}
try:
response_data = response.json()
except json.JSONDecodeError as e:
return {
"content": [
{
"type": "text",
"text": f"Error: Failed to parse API response as JSON. Error: {str(e)}, Raw response: {response.text}"
}
]
}
if "markdown" in response_data:
markdown_content = response_data["markdown"]
if re.search(r"^no excel files found", markdown_content, re.IGNORECASE):
rag_result = rag_retrieve(query, trace_id=trace_id)
content = rag_result.get("content", [])
if content and content[0].get("type") == "text":
content[0]["text"] = "No table_rag_retrieve results were found. The content below is the fallback result from rag_retrieve\n\n" + content[0]["text"]
return rag_result
return {
"content": [
{
"type": "text",
"text": TABLE_CITATION_INSTRUCTIONS + markdown_content
}
]
}
else:
return {
"content": [
{
"type": "text",
"text": f"Error: 'markdown' field not found in API response. Response: {json.dumps(response_data, indent=2, ensure_ascii=False)}"
}
]
}
except requests.exceptions.RequestException as e:
return {
"content": [
{
"type": "text",
"text": f"Error: Failed to connect to Table RAG API. {str(e)}"
}
]
}
except Exception as e:
return {
"content": [
{
"type": "text",
"text": f"Error: {str(e)}"
}
]
}
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""Handle MCP request"""
try:
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
if method == "initialize":
return create_initialize_response(request_id, "rag-retrieve")
elif method == "ping":
return create_ping_response(request_id)
elif method == "tools/list":
# Load tool definitions from the JSON file
tools = load_tools_from_json("rag_retrieve_tools.json")
if not tools:
# If the JSON file does not exist, use the default definition
tools = [
{
"name": "rag_retrieve",
"description": "Call the RAG retrieval API to retrieve relevant documents based on the query. Returns markdown-formatted results containing the relevant content.",
"inputSchema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Retrieval query content"
}
},
"required": ["query"]
}
}
]
return create_tools_list_response(request_id, tools)
elif method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
meta = params.get("_meta") or params.get("meta") or {}
trace_id = meta.get("trace_id", "") if isinstance(meta, dict) else ""
if tool_name == "rag_retrieve":
query = arguments.get("query", "")
top_k = arguments.get("top_k", 100)
if not query:
return create_success_response(request_id, {
"content": [{
"type": "text",
"text": "Error: missing required parameter 'query'. Please call this tool again with a non-empty 'query' argument describing what you want to retrieve."
}]
})
result = rag_retrieve(query, top_k, trace_id)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
elif tool_name == "table_rag_retrieve":
query = arguments.get("query", "")
if not query:
return create_success_response(request_id, {
"content": [{
"type": "text",
"text": "Error: missing required parameter 'query'. Please call this tool again with a non-empty 'query' argument describing what you want to retrieve."
}]
})
result = table_rag_retrieve(query, trace_id)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
else:
return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}")
else:
return create_error_response(request_id, -32601, f"Unknown method: {method}")
except Exception as e:
return create_error_response(request.get("id"), -32603, f"Internal error: {str(e)}")
async def main():
"""Main entry point."""
await handle_mcp_streaming(handle_request)
if __name__ == "__main__":
asyncio.run(main())