#!/usr/bin/env python3 """ 语义搜索MCP服务器 基于embedding向量进行语义相似度搜索 参考multi_keyword_search_server.py的实现方式 """ import asyncio import json import os import pickle import sys from typing import Any, Dict, List, Optional, Union import numpy as np from sentence_transformers import SentenceTransformer, util from mcp_common import ( get_allowed_directory, load_tools_from_json, resolve_file_path, find_file_in_project, create_error_response, create_success_response, create_initialize_response, create_ping_response, create_tools_list_response, handle_mcp_streaming ) import requests def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: int = 20) -> Dict[str, Any]: """执行语义搜索,通过调用 FastAPI 接口""" # 处理查询输入 if isinstance(queries, str): queries = [queries] # 验证查询列表 if not queries or not any(q.strip() for q in queries): return { "content": [ { "type": "text", "text": "Error: Queries cannot be empty" } ] } # 过滤空查询 queries = [q.strip() for q in queries if q.strip()] try: # FastAPI 服务地址 fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001') api_endpoint = f"{fastapi_url}/api/v1/semantic-search" # 处理每个查询 all_results = [] resolved_embeddings_file = resolve_file_path(embeddings_file) for query in queries: # 调用 FastAPI 接口 request_data = { "embedding_file": resolved_embeddings_file, "query": query, "top_k": top_k, "min_score": 0.0 } response = requests.post( api_endpoint, json=request_data, timeout=30 ) if response.status_code == 200: result_data = response.json() if result_data.get("success"): for res in result_data.get("results", []): all_results.append({ 'query': query, 'rank': res["rank"], 'content': res["content"], 'similarity_score': res["score"], 'file_path': embeddings_file }) else: print(f"搜索失败: {result_data.get('error', '未知错误')}") else: print(f"API 调用失败: {response.status_code} - {response.text}") if not all_results: return { "content": [ { "type": "text", "text": "No matching results found" } ] } # 按相似度分数排序所有结果 all_results.sort(key=lambda x: x['similarity_score'], reverse=True) # 格式化输出 formatted_lines = [] formatted_lines.append(f"Found {len(all_results)} results for {len(queries)} queries:") formatted_lines.append("") for i, result in enumerate(all_results): formatted_lines.append(f"#{i+1} [query: '{result['query']}'] [similarity:{result['similarity_score']:.4f}]: {result['content']}") formatted_output = "\n".join(formatted_lines) return { "content": [ { "type": "text", "text": formatted_output } ] } except requests.exceptions.RequestException as e: return { "content": [ { "type": "text", "text": f"API request failed: {str(e)}" } ] } except Exception as e: return { "content": [ { "type": "text", "text": f"Search error: {str(e)}" } ] } async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]: """Handle MCP request""" try: method = request.get("method") params = request.get("params", {}) request_id = request.get("id") if method == "initialize": return create_initialize_response(request_id, "semantic-search") elif method == "ping": return create_ping_response(request_id) elif method == "tools/list": # 从 JSON 文件加载工具定义 tools = load_tools_from_json("semantic_search_tools.json") return create_tools_list_response(request_id, tools) elif method == "tools/call": tool_name = params.get("name") arguments = params.get("arguments", {}) if tool_name == "semantic_search": queries = arguments.get("queries", []) # 兼容旧的query参数 if not queries and "query" in arguments: queries = arguments.get("query", "") embeddings_file = arguments.get("embeddings_file", "") top_k = arguments.get("top_k", 20) result = semantic_search(queries, embeddings_file, top_k) return { "jsonrpc": "2.0", "id": request_id, "result": result } else: return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}") else: return create_error_response(request_id, -32601, f"Unknown method: {method}") except Exception as e: return create_error_response(request.get("id"), -32603, f"Internal error: {str(e)}") async def main(): """Main entry point.""" await handle_mcp_streaming(handle_request) if __name__ == "__main__": asyncio.run(main())