qwen_agent/skills/onprem/rag-retrieve-only/rag_retrieve_server.py
2026-06-05 14:49:54 +08:00

228 lines
7.4 KiB
Python

#!/usr/bin/env python3
"""
RAG retrieval MCP server.
Calls the local RAG API to retrieve documents.
"""
import asyncio
import hashlib
import json
import sys
import os
from typing import Any, Dict
try:
import requests
except ImportError:
print("Error: requests module is required. Please install it with: pip install requests")
sys.exit(1)
from mcp_common import (
create_error_response,
create_success_response,
create_initialize_response,
create_ping_response,
create_tools_list_response,
load_tools_from_json,
handle_mcp_streaming
)
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
MASTERKEY = os.getenv("MASTERKEY", "master")
# Citation instruction prefixes injected into tool results
DOCUMENT_CITATION_INSTRUCTIONS = """<CITATION_INSTRUCTIONS>
When using the retrieved knowledge below, you MUST add XML citation tags for factual claims.
## Document Knowledge
Format: `<CITATION file="file_uuid" filename="name.pdf" page=3 />`
- Use `file` attribute with the UUID from document markers
- Use `filename` attribute with the actual filename from document markers
- Use `page` attribute (singular) with the page number
- `page` MUST be 0-based and must match the `pages:` values shown in the learned knowledge context
## Web Page Knowledge
Format: `<CITATION url="https://example.com/page" />`
- Use `url` attribute with the web page URL from the source metadata
- Do not use `file`, `filename`, or `page` attributes for web sources
- If content is grounded in a web source, prefer a web citation with `url` over a file citation
## Placement Rules
- Citations MUST appear IMMEDIATELY AFTER the paragraph or bullet list that uses the knowledge
- NEVER collect all citations and place them at the end of your response
- Limit to 1-2 citations per paragraph/bullet list
- If your answer uses learned knowledge, you MUST generate at least 1 `<CITATION ... />` in the response
</CITATION_INSTRUCTIONS>
"""
def rag_retrieve(query: str, top_k: int = 100) -> Dict[str, Any]:
"""Call the RAG retrieval API."""
try:
bot_id = ""
if len(sys.argv) > 1:
bot_id = sys.argv[1]
url = f"{BACKEND_HOST}/v1/rag_retrieve/{bot_id}"
if not url:
return {
"content": [
{
"type": "text",
"text": "Error: RAG API URL not provided. Please provide URL as command line argument."
}
]
}
masterkey = MASTERKEY
token_input = f"{masterkey}:{bot_id}"
auth_token = hashlib.md5(token_input.encode()).hexdigest()
headers = {
"content-type": "application/json",
"authorization": f"Bearer {auth_token}"
}
data = {
"query": query,
"top_k": top_k
}
response = requests.post(url, json=data, headers=headers, timeout=30)
if response.status_code != 200:
return {
"content": [
{
"type": "text",
"text": f"Error: RAG API returned status code {response.status_code}. Response: {response.text}"
}
]
}
try:
response_data = response.json()
except json.JSONDecodeError as e:
return {
"content": [
{
"type": "text",
"text": f"Error: Failed to parse API response as JSON. Error: {str(e)}, Raw response: {response.text}"
}
]
}
if "markdown" in response_data:
markdown_content = response_data["markdown"]
return {
"content": [
{
"type": "text",
"text": DOCUMENT_CITATION_INSTRUCTIONS + markdown_content
}
]
}
else:
return {
"content": [
{
"type": "text",
"text": f"Error: 'markdown' field not found in API response. Response: {json.dumps(response_data, indent=2, ensure_ascii=False)}"
}
]
}
except requests.exceptions.RequestException as e:
return {
"content": [
{
"type": "text",
"text": f"Error: Failed to connect to RAG API. {str(e)}"
}
]
}
except Exception as e:
return {
"content": [
{
"type": "text",
"text": f"Error: {str(e)}"
}
]
}
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""Handle an MCP request."""
try:
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
if method == "initialize":
return create_initialize_response(request_id, "rag-retrieve")
elif method == "ping":
return create_ping_response(request_id)
elif method == "tools/list":
tools = load_tools_from_json("rag_retrieve_tools.json")
if not tools:
tools = [
{
"name": "rag_retrieve",
"description": "Call the RAG retrieval API to retrieve relevant documents for the query and return markdown-formatted results.",
"inputSchema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Retrieval query content"
}
},
"required": ["query"]
}
}
]
return create_tools_list_response(request_id, tools)
elif method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
if tool_name == "rag_retrieve":
query = arguments.get("query", "")
top_k = arguments.get("top_k", 100)
if not query:
return create_success_response(request_id, {
"content": [{
"type": "text",
"text": "Error: missing required parameter 'query'. Please call this tool again with a non-empty 'query' argument describing what you want to retrieve."
}]
})
result = rag_retrieve(query, top_k)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
else:
return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}")
else:
return create_error_response(request_id, -32601, f"Unknown method: {method}")
except Exception as e:
return create_error_response(request.get("id"), -32603, f"Internal error: {str(e)}")
async def main():
"""Main entry point."""
await handle_mcp_streaming(handle_request)
if __name__ == "__main__":
asyncio.run(main())