#!/usr/bin/env python3 """ RAG检索MCP服务器 调用本地RAG API进行文档检索 """ import asyncio import hashlib import json import sys import os from typing import Any, Dict, List try: import requests except ImportError: print("Error: requests module is required. Please install it with: pip install requests") sys.exit(1) from mcp_common import ( create_error_response, create_success_response, create_initialize_response, create_ping_response, create_tools_list_response, load_tools_from_json, handle_mcp_streaming ) BACKEND_HOST = os.getenv("BACKEND_HOST", "https://api-dev.gptbase.ai") MASTERKEY = os.getenv("MASTERKEY", "master") # Citation instruction prefixes injected into tool results DOCUMENT_CITATION_INSTRUCTIONS = """ When using the retrieved knowledge below, you MUST add XML citation tags for factual claims. ## Document Knowledge Format: `` - Use `file` attribute with the UUID from document markers - Use `filename` attribute with the actual filename from document markers - Use `page` attribute (singular) with the page number - `page` MUST be 0-based and must match the `pages:` values shown in the learned knowledge context ## Web Page Knowledge Format: `` - Use `url` attribute with the web page URL from the source metadata - Do not use `file`, `filename`, or `page` attributes for web sources - If content is grounded in a web source, prefer a web citation with `url` over a file citation ## Placement Rules - Citations MUST appear IMMEDIATELY AFTER the paragraph or bullet list that uses the knowledge - NEVER collect all citations and place them at the end of your response - Limit to 1-2 citations per paragraph/bullet list - If your answer uses learned knowledge, you MUST generate at least 1 `` in the response """ TABLE_CITATION_INSTRUCTIONS = """ When using the retrieved table knowledge below, you MUST add XML citation tags for factual claims. Format: `` - Parse `__src`: `F1S2R5` = file_ref F1, sheet 2, row 5 - Look up file_id in `file_ref_table` - Combine same-sheet rows into one citation: `rows=[2, 4, 6]` - MANDATORY: Create SEPARATE citation for EACH (file, sheet) combination - NEVER put on the same line as a bullet point or table row - Citations MUST be on separate lines AFTER the complete list/table - NEVER include the `__src` column in your response - it is internal metadata only - Citations MUST appear IMMEDIATELY AFTER the paragraph or bullet list that uses the knowledge - NEVER collect all citations and place them at the end of your response """ def rag_retrieve(query: str, top_k: int = 100) -> Dict[str, Any]: """调用RAG检索API""" try: bot_id = "" if len(sys.argv) > 1: bot_id = sys.argv[1] url = f"{BACKEND_HOST}/v1/rag_retrieve/{bot_id}" if not url: return { "content": [ { "type": "text", "text": "Error: RAG API URL not provided. Please provide URL as command line argument." } ] } # 获取masterkey并生成认证token masterkey = MASTERKEY token_input = f"{masterkey}:{bot_id}" auth_token = hashlib.md5(token_input.encode()).hexdigest() headers = { "content-type": "application/json", "authorization": f"Bearer {auth_token}" } data = { "query": query, "top_k": top_k } # 发送POST请求 response = requests.post(url, json=data, headers=headers, timeout=30) if response.status_code != 200: return { "content": [ { "type": "text", "text": f"Error: RAG API returned status code {response.status_code}. Response: {response.text}" } ] } # 解析响应 try: response_data = response.json() except json.JSONDecodeError as e: return { "content": [ { "type": "text", "text": f"Error: Failed to parse API response as JSON. Error: {str(e)}, Raw response: {response.text}" } ] } # 提取markdown字段 if "markdown" in response_data: markdown_content = response_data["markdown"] return { "content": [ { "type": "text", "text": DOCUMENT_CITATION_INSTRUCTIONS + markdown_content } ] } else: return { "content": [ { "type": "text", "text": f"Error: 'markdown' field not found in API response. Response: {json.dumps(response_data, indent=2, ensure_ascii=False)}" } ] } except requests.exceptions.RequestException as e: return { "content": [ { "type": "text", "text": f"Error: Failed to connect to RAG API. {str(e)}" } ] } except Exception as e: return { "content": [ { "type": "text", "text": f"Error: {str(e)}" } ] } def table_rag_retrieve(query: str) -> Dict[str, Any]: """调用Table RAG检索API""" try: bot_id = "" if len(sys.argv) > 1: bot_id = sys.argv[1] url = f"{BACKEND_HOST}/v1/table_rag_retrieve/{bot_id}" masterkey = MASTERKEY token_input = f"{masterkey}:{bot_id}" auth_token = hashlib.md5(token_input.encode()).hexdigest() headers = { "content-type": "application/json", "authorization": f"Bearer {auth_token}" } data = { "query": query, } response = requests.post(url, json=data, headers=headers, timeout=300) if response.status_code != 200: return { "content": [ { "type": "text", "text": f"Error: Table RAG API returned status code {response.status_code}. Response: {response.text}" } ] } try: response_data = response.json() except json.JSONDecodeError as e: return { "content": [ { "type": "text", "text": f"Error: Failed to parse API response as JSON. Error: {str(e)}, Raw response: {response.text}" } ] } if "markdown" in response_data: markdown_content = response_data["markdown"] return { "content": [ { "type": "text", "text": TABLE_CITATION_INSTRUCTIONS + markdown_content } ] } else: return { "content": [ { "type": "text", "text": f"Error: 'markdown' field not found in API response. Response: {json.dumps(response_data, indent=2, ensure_ascii=False)}" } ] } except requests.exceptions.RequestException as e: return { "content": [ { "type": "text", "text": f"Error: Failed to connect to Table RAG API. {str(e)}" } ] } except Exception as e: return { "content": [ { "type": "text", "text": f"Error: {str(e)}" } ] } async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]: """Handle MCP request""" try: method = request.get("method") params = request.get("params", {}) request_id = request.get("id") if method == "initialize": return create_initialize_response(request_id, "rag-retrieve") elif method == "ping": return create_ping_response(request_id) elif method == "tools/list": # 从 JSON 文件加载工具定义 tools = load_tools_from_json("rag_retrieve_tools.json") if not tools: # 如果 JSON 文件不存在,使用默认定义 tools = [ { "name": "rag_retrieve", "description": "调用RAG检索API,根据查询内容检索相关文档。返回包含相关内容的markdown格式结果。", "inputSchema": { "type": "object", "properties": { "query": { "type": "string", "description": "检索查询内容" } }, "required": ["query"] } } ] return create_tools_list_response(request_id, tools) elif method == "tools/call": tool_name = params.get("name") arguments = params.get("arguments", {}) if tool_name == "rag_retrieve": query = arguments.get("query", "") top_k = arguments.get("top_k", 100) if not query: return create_error_response(request_id, -32602, "Missing required parameter: query") result = rag_retrieve(query, top_k) return { "jsonrpc": "2.0", "id": request_id, "result": result } elif tool_name == "table_rag_retrieve": query = arguments.get("query", "") if not query: return create_error_response(request_id, -32602, "Missing required parameter: query") result = table_rag_retrieve(query) return { "jsonrpc": "2.0", "id": request_id, "result": result } else: return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}") else: return create_error_response(request_id, -32601, f"Unknown method: {method}") except Exception as e: return create_error_response(request.get("id"), -32603, f"Internal error: {str(e)}") async def main(): """Main entry point.""" await handle_mcp_streaming(handle_request) if __name__ == "__main__": asyncio.run(main())