228 lines
7.4 KiB
Python
228 lines
7.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
RAG retrieval MCP server.
|
|
Calls the local RAG API to retrieve documents.
|
|
"""
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import json
|
|
import sys
|
|
import os
|
|
from typing import Any, Dict
|
|
|
|
try:
|
|
import requests
|
|
except ImportError:
|
|
print("Error: requests module is required. Please install it with: pip install requests")
|
|
sys.exit(1)
|
|
|
|
from mcp_common import (
|
|
create_error_response,
|
|
create_success_response,
|
|
create_initialize_response,
|
|
create_ping_response,
|
|
create_tools_list_response,
|
|
load_tools_from_json,
|
|
handle_mcp_streaming
|
|
)
|
|
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
|
|
MASTERKEY = os.getenv("MASTERKEY", "master")
|
|
|
|
# Citation instruction prefixes injected into tool results
|
|
DOCUMENT_CITATION_INSTRUCTIONS = """<CITATION_INSTRUCTIONS>
|
|
When using the retrieved knowledge below, you MUST add XML citation tags for factual claims.
|
|
|
|
## Document Knowledge
|
|
Format: `<CITATION file="file_uuid" filename="name.pdf" page=3 />`
|
|
- Use `file` attribute with the UUID from document markers
|
|
- Use `filename` attribute with the actual filename from document markers
|
|
- Use `page` attribute (singular) with the page number
|
|
- `page` MUST be 0-based and must match the `pages:` values shown in the learned knowledge context
|
|
|
|
## Web Page Knowledge
|
|
Format: `<CITATION url="https://example.com/page" />`
|
|
- Use `url` attribute with the web page URL from the source metadata
|
|
- Do not use `file`, `filename`, or `page` attributes for web sources
|
|
- If content is grounded in a web source, prefer a web citation with `url` over a file citation
|
|
|
|
## Placement Rules
|
|
- Citations MUST appear IMMEDIATELY AFTER the paragraph or bullet list that uses the knowledge
|
|
- NEVER collect all citations and place them at the end of your response
|
|
- Limit to 1-2 citations per paragraph/bullet list
|
|
- If your answer uses learned knowledge, you MUST generate at least 1 `<CITATION ... />` in the response
|
|
</CITATION_INSTRUCTIONS>
|
|
|
|
"""
|
|
|
|
def rag_retrieve(query: str, top_k: int = 100) -> Dict[str, Any]:
|
|
"""Call the RAG retrieval API."""
|
|
try:
|
|
bot_id = ""
|
|
if len(sys.argv) > 1:
|
|
bot_id = sys.argv[1]
|
|
|
|
url = f"{BACKEND_HOST}/v1/rag_retrieve/{bot_id}"
|
|
if not url:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Error: RAG API URL not provided. Please provide URL as command line argument."
|
|
}
|
|
]
|
|
}
|
|
|
|
masterkey = MASTERKEY
|
|
token_input = f"{masterkey}:{bot_id}"
|
|
auth_token = hashlib.md5(token_input.encode()).hexdigest()
|
|
|
|
headers = {
|
|
"content-type": "application/json",
|
|
"authorization": f"Bearer {auth_token}"
|
|
}
|
|
data = {
|
|
"query": query,
|
|
"top_k": top_k
|
|
}
|
|
|
|
response = requests.post(url, json=data, headers=headers, timeout=30)
|
|
|
|
if response.status_code != 200:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"Error: RAG API returned status code {response.status_code}. Response: {response.text}"
|
|
}
|
|
]
|
|
}
|
|
|
|
try:
|
|
response_data = response.json()
|
|
except json.JSONDecodeError as e:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"Error: Failed to parse API response as JSON. Error: {str(e)}, Raw response: {response.text}"
|
|
}
|
|
]
|
|
}
|
|
|
|
if "markdown" in response_data:
|
|
markdown_content = response_data["markdown"]
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": DOCUMENT_CITATION_INSTRUCTIONS + markdown_content
|
|
}
|
|
]
|
|
}
|
|
else:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"Error: 'markdown' field not found in API response. Response: {json.dumps(response_data, indent=2, ensure_ascii=False)}"
|
|
}
|
|
]
|
|
}
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"Error: Failed to connect to RAG API. {str(e)}"
|
|
}
|
|
]
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"Error: {str(e)}"
|
|
}
|
|
]
|
|
}
|
|
|
|
|
|
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Handle an MCP request."""
|
|
try:
|
|
method = request.get("method")
|
|
params = request.get("params", {})
|
|
request_id = request.get("id")
|
|
|
|
if method == "initialize":
|
|
return create_initialize_response(request_id, "rag-retrieve")
|
|
|
|
elif method == "ping":
|
|
return create_ping_response(request_id)
|
|
|
|
elif method == "tools/list":
|
|
tools = load_tools_from_json("rag_retrieve_tools.json")
|
|
if not tools:
|
|
tools = [
|
|
{
|
|
"name": "rag_retrieve",
|
|
"description": "Call the RAG retrieval API to retrieve relevant documents for the query and return markdown-formatted results.",
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "Retrieval query content"
|
|
}
|
|
},
|
|
"required": ["query"]
|
|
}
|
|
}
|
|
]
|
|
return create_tools_list_response(request_id, tools)
|
|
|
|
elif method == "tools/call":
|
|
tool_name = params.get("name")
|
|
arguments = params.get("arguments", {})
|
|
|
|
if tool_name == "rag_retrieve":
|
|
query = arguments.get("query", "")
|
|
top_k = arguments.get("top_k", 100)
|
|
|
|
if not query:
|
|
return create_success_response(request_id, {
|
|
"content": [{
|
|
"type": "text",
|
|
"text": "Error: missing required parameter 'query'. Please call this tool again with a non-empty 'query' argument describing what you want to retrieve."
|
|
}]
|
|
})
|
|
|
|
result = rag_retrieve(query, top_k)
|
|
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"id": request_id,
|
|
"result": result
|
|
}
|
|
|
|
else:
|
|
return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}")
|
|
|
|
else:
|
|
return create_error_response(request_id, -32601, f"Unknown method: {method}")
|
|
|
|
except Exception as e:
|
|
return create_error_response(request.get("id"), -32603, f"Internal error: {str(e)}")
|
|
|
|
|
|
async def main():
|
|
"""Main entry point."""
|
|
await handle_mcp_streaming(handle_request)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|