add table_rag
This commit is contained in:
parent
4adf62afb7
commit
c1d2d48979
@ -36,7 +36,7 @@ def rag_retrieve(query: str, top_k: int = 100) -> Dict[str, Any]:
|
||||
if len(sys.argv) > 1:
|
||||
bot_id = sys.argv[1]
|
||||
|
||||
url = f"{BACKEND_HOST}/v1/table_rag_retrieve/{bot_id}"
|
||||
url = f"{BACKEND_HOST}/v1/rag_retrieve/{bot_id}"
|
||||
if not url:
|
||||
return {
|
||||
"content": [
|
||||
@ -128,6 +128,91 @@ def rag_retrieve(query: str, top_k: int = 100) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def table_rag_retrieve(query: str) -> Dict[str, Any]:
|
||||
"""调用Table RAG检索API"""
|
||||
try:
|
||||
bot_id = ""
|
||||
if len(sys.argv) > 1:
|
||||
bot_id = sys.argv[1]
|
||||
|
||||
url = f"{BACKEND_HOST}/v1/table_rag_retrieve/{bot_id}"
|
||||
|
||||
masterkey = MASTERKEY
|
||||
token_input = f"{masterkey}:{bot_id}"
|
||||
auth_token = hashlib.md5(token_input.encode()).hexdigest()
|
||||
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
"authorization": f"Bearer {auth_token}"
|
||||
}
|
||||
data = {
|
||||
"query": query,
|
||||
}
|
||||
|
||||
response = requests.post(url, json=data, headers=headers, timeout=60)
|
||||
|
||||
if response.status_code != 200:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Error: Table RAG API returned status code {response.status_code}. Response: {response.text}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
try:
|
||||
response_data = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Error: Failed to parse API response as JSON. Error: {str(e)}, Raw response: {response.text}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
if "markdown" in response_data:
|
||||
markdown_content = response_data["markdown"]
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": markdown_content
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Error: 'markdown' field not found in API response. Response: {json.dumps(response_data, indent=2, ensure_ascii=False)}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Error: Failed to connect to Table RAG API. {str(e)}"
|
||||
}
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Error: {str(e)}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Handle MCP request"""
|
||||
try:
|
||||
@ -183,6 +268,20 @@ async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"result": result
|
||||
}
|
||||
|
||||
elif tool_name == "table_rag_retrieve":
|
||||
query = arguments.get("query", "")
|
||||
|
||||
if not query:
|
||||
return create_error_response(request_id, -32602, "Missing required parameter: query")
|
||||
|
||||
result = table_rag_retrieve(query)
|
||||
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": result
|
||||
}
|
||||
|
||||
else:
|
||||
return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}")
|
||||
|
||||
|
||||
@ -17,5 +17,19 @@
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "table_rag_retrieve",
|
||||
"description": "Call Table RAG retrieval API to retrieve relevant data from Excel/spreadsheet files in the knowledge base. Returns markdown format results containing table data analysis.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Retrieval query content for table data"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user