#!/usr/bin/env python3 """ 语义搜索MCP服务器 基于embedding向量进行语义相似度搜索 参考multi_keyword_search_server.py的实现方式 """ import asyncio import json import os import pickle import sys from typing import Any, Dict, List, Optional, Union import numpy as np from mcp_common import ( get_allowed_directory, load_tools_from_json, resolve_file_path, find_file_in_project, create_error_response, create_success_response, create_initialize_response, create_ping_response, create_tools_list_response, handle_mcp_streaming ) import requests def encode_query_via_api(query: str, fastapi_url: str = None) -> np.ndarray: """通过API编码单个查询""" if not fastapi_url: fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001') api_endpoint = f"{fastapi_url}/api/v1/embedding/encode" try: # 调用编码接口 request_data = { "texts": [query], "batch_size": 1 } response = requests.post( api_endpoint, json=request_data, timeout=30 ) if response.status_code == 200: result_data = response.json() if result_data.get("success"): embeddings_list = result_data.get("embeddings", []) if embeddings_list: return np.array(embeddings_list[0]) print(f"API编码失败: {response.status_code} - {response.text}") return None except Exception as e: print(f"API编码异常: {e}") return None def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: int = 20) -> Dict[str, Any]: """执行语义搜索,直接读取本地embedding文件并计算相似度""" # 处理查询输入 if isinstance(queries, str): queries = [queries] # 验证查询列表 if not queries or not any(q.strip() for q in queries): return { "content": [ { "type": "text", "text": "Error: Queries cannot be empty" } ] } # 过滤空查询 queries = [q.strip() for q in queries if q.strip()] try: # 解析embedding文件路径 resolved_embeddings_file = resolve_file_path(embeddings_file) # 读取embedding文件 with open(resolved_embeddings_file, 'rb') as f: embedding_data = pickle.load(f) # 兼容新旧数据结构 if 'chunks' in embedding_data: # 新的数据结构(使用chunks) chunks = embedding_data['chunks'] chunk_embeddings = embedding_data['embeddings'] chunking_strategy = embedding_data.get('chunking_strategy', 'unknown') else: # 旧的数据结构(使用sentences) chunks = embedding_data['sentences'] chunk_embeddings = embedding_data['embeddings'] chunking_strategy = 'line' all_results = [] # 处理每个查询 for query in queries: # 使用API编码查询 print(f"正在为查询编码: {query}") query_embedding = encode_query_via_api(query) if query_embedding is None: print(f"查询编码失败: {query}") continue # 计算相似度 if len(chunk_embeddings.shape) > 1: cos_scores = np.dot(chunk_embeddings, query_embedding) / ( np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_embedding) + 1e-8 ) else: cos_scores = np.array([0.0] * len(chunks)) # 获取top_k结果 top_indices = np.argsort(-cos_scores)[:top_k] for rank, idx in enumerate(top_indices): score = cos_scores[idx] if score > 0: # 只包含有一定相关性的结果 all_results.append({ 'query': query, 'rank': rank + 1, 'content': chunks[idx], 'similarity_score': float(score), 'file_path': embeddings_file }) if not all_results: return { "content": [ { "type": "text", "text": "No matching results found" } ] } # 按相似度分数排序所有结果 all_results.sort(key=lambda x: x['similarity_score'], reverse=True) # 格式化输出 formatted_lines = [] formatted_lines.append(f"Found {len(all_results)} results for {len(queries)} queries:") formatted_lines.append("") for i, result in enumerate(all_results): formatted_lines.append(f"#{i+1} [query: '{result['query']}'] [similarity:{result['similarity_score']:.4f}]: {result['content']}") formatted_output = "\n".join(formatted_lines) return { "content": [ { "type": "text", "text": formatted_output } ] } except FileNotFoundError: return { "content": [ { "type": "text", "text": f"Error: Embeddings file not found: {embeddings_file}" } ] } except Exception as e: return { "content": [ { "type": "text", "text": f"Search error: {str(e)}" } ] } async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]: """Handle MCP request""" try: method = request.get("method") params = request.get("params", {}) request_id = request.get("id") if method == "initialize": return create_initialize_response(request_id, "semantic-search") elif method == "ping": return create_ping_response(request_id) elif method == "tools/list": # 从 JSON 文件加载工具定义 tools = load_tools_from_json("semantic_search_tools.json") return create_tools_list_response(request_id, tools) elif method == "tools/call": tool_name = params.get("name") arguments = params.get("arguments", {}) if tool_name == "semantic_search": queries = arguments.get("queries", []) # 兼容旧的query参数 if not queries and "query" in arguments: queries = arguments.get("query", "") embeddings_file = arguments.get("embeddings_file", "") top_k = arguments.get("top_k", 20) result = semantic_search(queries, embeddings_file, top_k) return { "jsonrpc": "2.0", "id": request_id, "result": result } else: return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}") else: return create_error_response(request_id, -32601, f"Unknown method: {method}") except Exception as e: return create_error_response(request.get("id"), -32603, f"Internal error: {str(e)}") async def main(): """Main entry point.""" await handle_mcp_streaming(handle_request) if __name__ == "__main__": asyncio.run(main())