#!/usr/bin/env python3 """ 语义搜索MCP服务器 基于embedding向量进行语义相似度搜索 参考multi_keyword_search_server.py的实现方式 """ import asyncio import json import os import pickle import sys from typing import Any, Dict, List, Optional import numpy as np from sentence_transformers import SentenceTransformer, util from mcp_common import ( get_allowed_directory, load_tools_from_json, resolve_file_path, find_file_in_project, create_error_response, create_success_response, create_initialize_response, create_ping_response, create_tools_list_response, handle_mcp_streaming ) # 延迟加载模型 embedder = None def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'): """获取模型实例(延迟加载) Args: model_name_or_path (str): 模型名称或本地路径 - 可以是 HuggingFace 模型名称 - 可以是本地模型路径 """ global embedder if embedder is None: # 优先使用本地模型路径 local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2" # 从环境变量获取设备配置,默认为 CPU device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu') if device not in ['cpu', 'cuda', 'mps']: print(f"警告: 不支持的设备类型 '{device}',使用默认 CPU") device = 'cpu' # 检查本地模型是否存在 if os.path.exists(local_model_path): print(f"使用本地模型: {local_model_path}") embedder = SentenceTransformer(local_model_path, device=device) else: print(f"本地模型不存在,使用HuggingFace模型: {model_name_or_path}") embedder = SentenceTransformer(model_name_or_path, device=device) return embedder def semantic_search(query: str, embeddings_file: str, top_k: int = 20) -> Dict[str, Any]: """执行语义搜索""" if not query.strip(): return { "content": [ { "type": "text", "text": "Error: Query cannot be empty" } ] } # 验证embeddings文件路径 try: # 解析文件路径,支持 folder/document.txt 和 document.txt 格式 resolved_embeddings_file = resolve_file_path(embeddings_file) # 加载嵌入数据 with open(resolved_embeddings_file, 'rb') as f: embedding_data = pickle.load(f) # 兼容新旧数据结构 if 'chunks' in embedding_data: # 新的数据结构(使用chunks) sentences = embedding_data['chunks'] sentence_embeddings = embedding_data['embeddings'] # 从embedding_data中获取模型路径(如果有的话) model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') model = get_model(model_path) else: # 旧的数据结构(使用sentences) sentences = embedding_data['sentences'] sentence_embeddings = embedding_data['embeddings'] model = get_model() # 编码查询 query_embedding = model.encode(query, convert_to_tensor=True) # 计算相似度 cos_scores = util.cos_sim(query_embedding, sentence_embeddings)[0] # 获取top_k结果 top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k] # 格式化结果 results = [] for i, idx in enumerate(top_results): sentence = sentences[idx] score = cos_scores[idx].item() results.append({ 'rank': i + 1, 'content': sentence, 'similarity_score': score, 'file_path': embeddings_file }) if not results: return { "content": [ { "type": "text", "text": "No matching results found" } ] } # 格式化输出 formatted_output = "\n".join([ f"#{result['rank']} [similarity:{result['similarity_score']:.4f}]: {result['content']}" for result in results ]) return { "content": [ { "type": "text", "text": formatted_output } ] } except FileNotFoundError: return { "content": [ { "type": "text", "text": f"Error: embeddings file {embeddings_file} not found" } ] } except Exception as e: return { "content": [ { "type": "text", "text": f"Search error: {str(e)}" } ] } def get_model_info() -> Dict[str, Any]: """获取当前模型信息""" try: # 检查本地模型路径 local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2" if os.path.exists(local_model_path): return { "content": [ { "type": "text", "text": f"✅ 使用本地模型: {local_model_path}\n" f"模型状态: 已加载\n" f"设备: CPU\n" f"说明: 避免从HuggingFace下载,提高响应速度" } ] } else: return { "content": [ { "type": "text", "text": f"⚠️ 本地模型不存在: {local_model_path}\n" f"将使用HuggingFace模型: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2\n" f"建议: 下载模型到本地以提高响应速度\n" f"设备: CPU" } ] } except Exception as e: return { "content": [ { "type": "text", "text": f"❌ 获取模型信息失败: {str(e)}" } ] } async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]: """Handle MCP request""" try: method = request.get("method") params = request.get("params", {}) request_id = request.get("id") if method == "initialize": return create_initialize_response(request_id, "semantic-search") elif method == "ping": return create_ping_response(request_id) elif method == "tools/list": # 从 JSON 文件加载工具定义 tools = load_tools_from_json("semantic_search_tools.json") return create_tools_list_response(request_id, tools) elif method == "tools/call": tool_name = params.get("name") arguments = params.get("arguments", {}) if tool_name == "semantic_search": query = arguments.get("query", "") embeddings_file = arguments.get("embeddings_file", "") top_k = arguments.get("top_k", 20) result = semantic_search(query, embeddings_file, top_k) return { "jsonrpc": "2.0", "id": request_id, "result": result } elif tool_name == "get_model_info": result = get_model_info() return { "jsonrpc": "2.0", "id": request_id, "result": result } else: return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}") else: return create_error_response(request_id, -32601, f"Unknown method: {method}") except Exception as e: return create_error_response(request.get("id"), -32603, f"Internal error: {str(e)}") async def main(): """Main entry point.""" await handle_mcp_streaming(handle_request) if __name__ == "__main__": asyncio.run(main())