diff --git a/mcp/rag_retrieve_server.py b/mcp/rag_retrieve_server.py index 4ee994d..f1ac5df 100644 --- a/mcp/rag_retrieve_server.py +++ b/mcp/rag_retrieve_server.py @@ -36,7 +36,7 @@ def rag_retrieve(query: str, top_k: int = 100) -> Dict[str, Any]: if len(sys.argv) > 1: bot_id = sys.argv[1] - url = f"{BACKEND_HOST}/v1/table_rag_retrieve/{bot_id}" + url = f"{BACKEND_HOST}/v1/rag_retrieve/{bot_id}" if not url: return { "content": [ @@ -128,6 +128,91 @@ def rag_retrieve(query: str, top_k: int = 100) -> Dict[str, Any]: } +def table_rag_retrieve(query: str) -> Dict[str, Any]: + """调用Table RAG检索API""" + try: + bot_id = "" + if len(sys.argv) > 1: + bot_id = sys.argv[1] + + url = f"{BACKEND_HOST}/v1/table_rag_retrieve/{bot_id}" + + masterkey = MASTERKEY + token_input = f"{masterkey}:{bot_id}" + auth_token = hashlib.md5(token_input.encode()).hexdigest() + + headers = { + "content-type": "application/json", + "authorization": f"Bearer {auth_token}" + } + data = { + "query": query, + } + + response = requests.post(url, json=data, headers=headers, timeout=60) + + if response.status_code != 200: + return { + "content": [ + { + "type": "text", + "text": f"Error: Table RAG API returned status code {response.status_code}. Response: {response.text}" + } + ] + } + + try: + response_data = response.json() + except json.JSONDecodeError as e: + return { + "content": [ + { + "type": "text", + "text": f"Error: Failed to parse API response as JSON. Error: {str(e)}, Raw response: {response.text}" + } + ] + } + + if "markdown" in response_data: + markdown_content = response_data["markdown"] + return { + "content": [ + { + "type": "text", + "text": markdown_content + } + ] + } + else: + return { + "content": [ + { + "type": "text", + "text": f"Error: 'markdown' field not found in API response. Response: {json.dumps(response_data, indent=2, ensure_ascii=False)}" + } + ] + } + + except requests.exceptions.RequestException as e: + return { + "content": [ + { + "type": "text", + "text": f"Error: Failed to connect to Table RAG API. {str(e)}" + } + ] + } + except Exception as e: + return { + "content": [ + { + "type": "text", + "text": f"Error: {str(e)}" + } + ] + } + + async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]: """Handle MCP request""" try: @@ -182,6 +267,20 @@ async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]: "id": request_id, "result": result } + + elif tool_name == "table_rag_retrieve": + query = arguments.get("query", "") + + if not query: + return create_error_response(request_id, -32602, "Missing required parameter: query") + + result = table_rag_retrieve(query) + + return { + "jsonrpc": "2.0", + "id": request_id, + "result": result + } else: return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}") diff --git a/mcp/tools/rag_retrieve_tools.json b/mcp/tools/rag_retrieve_tools.json index 0cde52a..5cb271d 100644 --- a/mcp/tools/rag_retrieve_tools.json +++ b/mcp/tools/rag_retrieve_tools.json @@ -17,5 +17,19 @@ }, "required": ["query"] } + }, + { + "name": "table_rag_retrieve", + "description": "Call Table RAG retrieval API to retrieve relevant data from Excel/spreadsheet files in the knowledge base. Returns markdown format results containing table data analysis.", + "inputSchema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Retrieval query content for table data" + } + }, + "required": ["query"] + } } ]