368 lines
13 KiB
Python
368 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
RAG retrieval MCP server.
|
|
Calls the local RAG API to retrieve documents.
|
|
"""
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import json
|
|
import re
|
|
import sys
|
|
import os
|
|
from typing import Any, Dict, List
|
|
|
|
try:
|
|
import requests
|
|
except ImportError:
|
|
print("Error: requests module is required. Please install it with: pip install requests")
|
|
sys.exit(1)
|
|
|
|
from mcp_common import (
|
|
create_error_response,
|
|
create_success_response,
|
|
create_initialize_response,
|
|
create_ping_response,
|
|
create_tools_list_response,
|
|
load_tools_from_json,
|
|
handle_mcp_streaming
|
|
)
|
|
BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai")
|
|
MASTERKEY = os.getenv("MASTERKEY", "master")
|
|
|
|
# Citation instruction prefixes injected into tool results
|
|
DOCUMENT_CITATION_INSTRUCTIONS = """<CITATION_INSTRUCTIONS>
|
|
When using the retrieved knowledge below, you MUST add XML citation tags for factual claims.
|
|
|
|
## Document Knowledge
|
|
Format: `<CITATION file="file_uuid" filename="name.pdf" page=3 />`
|
|
- Use `file` attribute with the UUID from document markers
|
|
- Use `filename` attribute with the actual filename from document markers
|
|
- Use `page` attribute (singular) with the page number
|
|
- `page` MUST be 0-based and must match the `pages:` values shown in the learned knowledge context
|
|
|
|
## Web Page Knowledge
|
|
Format: `<CITATION url="https://example.com/page" />`
|
|
- Use `url` attribute with the web page URL from the source metadata
|
|
- Do not use `file`, `filename`, or `page` attributes for web sources
|
|
- If content is grounded in a web source, prefer a web citation with `url` over a file citation
|
|
|
|
## Placement Rules
|
|
- Citations MUST appear IMMEDIATELY AFTER the paragraph or bullet list that uses the knowledge
|
|
- NEVER collect all citations and place them at the end of your response
|
|
- Limit to 1-2 citations per paragraph/bullet list
|
|
- If your answer uses learned knowledge, you MUST generate at least 1 `<CITATION ... />` in the response
|
|
</CITATION_INSTRUCTIONS>
|
|
|
|
"""
|
|
|
|
TABLE_CITATION_INSTRUCTIONS = """<CITATION_INSTRUCTIONS>
|
|
When using the retrieved table knowledge below, you MUST add XML citation tags for factual claims.
|
|
|
|
Format: `<CITATION file="file_id" filename="name.xlsx" sheet=1 rows=[2, 4] />`
|
|
- Parse `__src`: `F1S2R5` = file_ref F1, sheet 2, row 5
|
|
- Look up file_id in `file_ref_table`
|
|
- Combine same-sheet rows into one citation: `rows=[2, 4, 6]`
|
|
- MANDATORY: Create SEPARATE citation for EACH (file, sheet) combination
|
|
- NEVER put <CITATION> on the same line as a bullet point or table row
|
|
- Citations MUST be on separate lines AFTER the complete list/table
|
|
- NEVER include the `__src` column in your response - it is internal metadata only
|
|
- Citations MUST appear IMMEDIATELY AFTER the paragraph or bullet list that uses the knowledge
|
|
- NEVER collect all citations and place them at the end of your response
|
|
</CITATION_INSTRUCTIONS>
|
|
|
|
"""
|
|
|
|
def rag_retrieve(query: str, top_k: int = 100, trace_id: str = "") -> Dict[str, Any]:
|
|
"""Call the RAG retrieval API."""
|
|
try:
|
|
bot_id = ""
|
|
if len(sys.argv) > 1:
|
|
bot_id = sys.argv[1]
|
|
|
|
url = f"{BACKEND_HOST}/v1/rag_retrieve/{bot_id}"
|
|
if not url:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Error: RAG API URL not provided. Please provide URL as command line argument."
|
|
}
|
|
]
|
|
}
|
|
|
|
# Get the master key and generate the authentication token.
|
|
masterkey = MASTERKEY
|
|
token_input = f"{masterkey}:{bot_id}"
|
|
auth_token = hashlib.md5(token_input.encode()).hexdigest()
|
|
|
|
headers = {
|
|
"content-type": "application/json",
|
|
"authorization": f"Bearer {auth_token}"
|
|
}
|
|
if trace_id:
|
|
headers["X-Request-ID"] = trace_id
|
|
data = {
|
|
"query": query,
|
|
"top_k": top_k
|
|
}
|
|
|
|
# Send the POST request.
|
|
response = requests.post(url, json=data, headers=headers, timeout=30)
|
|
|
|
if response.status_code != 200:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"Error: RAG API returned status code {response.status_code}. Response: {response.text}"
|
|
}
|
|
]
|
|
}
|
|
|
|
# Parse the response.
|
|
try:
|
|
response_data = response.json()
|
|
except json.JSONDecodeError as e:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"Error: Failed to parse API response as JSON. Error: {str(e)}, Raw response: {response.text}"
|
|
}
|
|
]
|
|
}
|
|
|
|
# Extract the markdown field.
|
|
if "markdown" in response_data:
|
|
markdown_content = response_data["markdown"]
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": DOCUMENT_CITATION_INSTRUCTIONS + markdown_content
|
|
}
|
|
]
|
|
}
|
|
else:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"Error: 'markdown' field not found in API response. Response: {json.dumps(response_data, indent=2, ensure_ascii=False)}"
|
|
}
|
|
]
|
|
}
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"Error: Failed to connect to RAG API. {str(e)}"
|
|
}
|
|
]
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"Error: {str(e)}"
|
|
}
|
|
]
|
|
}
|
|
|
|
|
|
def table_rag_retrieve(query: str, trace_id: str = "") -> Dict[str, Any]:
|
|
"""Call the Table RAG retrieval API."""
|
|
try:
|
|
bot_id = ""
|
|
if len(sys.argv) > 1:
|
|
bot_id = sys.argv[1]
|
|
|
|
url = f"{BACKEND_HOST}/v1/table_rag_retrieve/{bot_id}"
|
|
|
|
masterkey = MASTERKEY
|
|
token_input = f"{masterkey}:{bot_id}"
|
|
auth_token = hashlib.md5(token_input.encode()).hexdigest()
|
|
|
|
headers = {
|
|
"content-type": "application/json",
|
|
"authorization": f"Bearer {auth_token}"
|
|
}
|
|
if trace_id:
|
|
headers["X-Request-ID"] = trace_id
|
|
data = {
|
|
"query": query,
|
|
}
|
|
|
|
response = requests.post(url, json=data, headers=headers, timeout=300)
|
|
|
|
if response.status_code != 200:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"Error: Table RAG API returned status code {response.status_code}. Response: {response.text}"
|
|
}
|
|
]
|
|
}
|
|
|
|
try:
|
|
response_data = response.json()
|
|
except json.JSONDecodeError as e:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"Error: Failed to parse API response as JSON. Error: {str(e)}, Raw response: {response.text}"
|
|
}
|
|
]
|
|
}
|
|
|
|
if "markdown" in response_data:
|
|
markdown_content = response_data["markdown"]
|
|
if re.search(r"^no excel files found", markdown_content, re.IGNORECASE):
|
|
rag_result = rag_retrieve(query, trace_id=trace_id)
|
|
content = rag_result.get("content", [])
|
|
if content and content[0].get("type") == "text":
|
|
content[0]["text"] = "No table_rag_retrieve results were found. The content below is the fallback result from rag_retrieve:\n\n" + content[0]["text"]
|
|
return rag_result
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": TABLE_CITATION_INSTRUCTIONS + markdown_content
|
|
}
|
|
]
|
|
}
|
|
else:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"Error: 'markdown' field not found in API response. Response: {json.dumps(response_data, indent=2, ensure_ascii=False)}"
|
|
}
|
|
]
|
|
}
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"Error: Failed to connect to Table RAG API. {str(e)}"
|
|
}
|
|
]
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"Error: {str(e)}"
|
|
}
|
|
]
|
|
}
|
|
|
|
|
|
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Handle an MCP request."""
|
|
try:
|
|
method = request.get("method")
|
|
params = request.get("params", {})
|
|
request_id = request.get("id")
|
|
|
|
if method == "initialize":
|
|
return create_initialize_response(request_id, "rag-retrieve")
|
|
|
|
elif method == "ping":
|
|
return create_ping_response(request_id)
|
|
|
|
elif method == "tools/list":
|
|
# Load tool definitions from the JSON file.
|
|
tools = load_tools_from_json("rag_retrieve_tools.json")
|
|
if not tools:
|
|
# If the JSON file does not exist, use the default definition.
|
|
tools = [
|
|
{
|
|
"name": "rag_retrieve",
|
|
"description": "Call the RAG retrieval API to retrieve relevant documents for the query and return markdown-formatted results.",
|
|
"inputSchema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "Retrieval query content"
|
|
}
|
|
},
|
|
"required": ["query"]
|
|
}
|
|
}
|
|
]
|
|
return create_tools_list_response(request_id, tools)
|
|
|
|
elif method == "tools/call":
|
|
tool_name = params.get("name")
|
|
arguments = params.get("arguments", {})
|
|
meta = params.get("_meta") or params.get("meta") or {}
|
|
trace_id = meta.get("trace_id", "") if isinstance(meta, dict) else ""
|
|
|
|
if tool_name == "rag_retrieve":
|
|
query = arguments.get("query", "")
|
|
top_k = arguments.get("top_k", 100)
|
|
|
|
if not query:
|
|
return create_success_response(request_id, {
|
|
"content": [{
|
|
"type": "text",
|
|
"text": "Error: missing required parameter 'query'. Please call this tool again with a non-empty 'query' argument describing what you want to retrieve."
|
|
}]
|
|
})
|
|
|
|
result = rag_retrieve(query, top_k, trace_id)
|
|
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"id": request_id,
|
|
"result": result
|
|
}
|
|
|
|
elif tool_name == "table_rag_retrieve":
|
|
query = arguments.get("query", "")
|
|
|
|
if not query:
|
|
return create_success_response(request_id, {
|
|
"content": [{
|
|
"type": "text",
|
|
"text": "Error: missing required parameter 'query'. Please call this tool again with a non-empty 'query' argument describing what you want to retrieve."
|
|
}]
|
|
})
|
|
|
|
result = table_rag_retrieve(query, trace_id)
|
|
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"id": request_id,
|
|
"result": result
|
|
}
|
|
|
|
else:
|
|
return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}")
|
|
|
|
else:
|
|
return create_error_response(request_id, -32601, f"Unknown method: {method}")
|
|
|
|
except Exception as e:
|
|
return create_error_response(request.get("id"), -32603, f"Internal error: {str(e)}")
|
|
|
|
|
|
async def main():
|
|
"""Main entry point."""
|
|
await handle_mcp_streaming(handle_request)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|