#!/usr/bin/env python3 """ 语义搜索MCP服务器 基于embedding向量进行语义相似度搜索 参考multi_keyword_search_server.py的实现方式 """ import asyncio import json import os import pickle import sys from typing import Any, Dict, List, Optional import numpy as np from sentence_transformers import SentenceTransformer, util # 延迟加载模型 embedder = None def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'): """获取模型实例(延迟加载) Args: model_name_or_path (str): 模型名称或本地路径 - 可以是 HuggingFace 模型名称 - 可以是本地模型路径 """ global embedder if embedder is None: # 优先使用本地模型路径 local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2" # 检查本地模型是否存在 if os.path.exists(local_model_path): print(f"使用本地模型: {local_model_path}") embedder = SentenceTransformer(local_model_path, device='cpu') else: print(f"本地模型不存在,使用HuggingFace模型: {model_name_or_path}") embedder = SentenceTransformer(model_name_or_path, device='cpu') return embedder def validate_file_path(file_path: str, allowed_dir: str) -> str: """验证文件路径是否在允许的目录内""" # 转换为绝对路径 if not os.path.isabs(file_path): file_path = os.path.abspath(file_path) allowed_dir = os.path.abspath(allowed_dir) # 检查路径是否在允许的目录内 if not file_path.startswith(allowed_dir): raise ValueError(f"访问被拒绝: 路径 {file_path} 不在允许的目录 {allowed_dir} 内") # 检查路径遍历攻击 if ".." in file_path: raise ValueError(f"访问被拒绝: 检测到路径遍历攻击尝试") return file_path def get_allowed_directory(): """获取允许访问的目录""" # 从环境变量读取项目数据目录 project_dir = os.getenv("PROJECT_DATA_DIR", "./projects") return os.path.abspath(project_dir) def semantic_search(query: str, embeddings_file: str, top_k: int = 20) -> Dict[str, Any]: """执行语义搜索""" if not query.strip(): return { "content": [ { "type": "text", "text": "错误:查询不能为空" } ] } # 处理项目目录限制 project_data_dir = get_allowed_directory() # 验证embeddings文件路径 try: # 解析相对路径 if not os.path.isabs(embeddings_file): # 移除 projects/ 前缀(如果存在) clean_path = embeddings_file if clean_path.startswith('projects/'): clean_path = clean_path[9:] # 移除 'projects/' 前缀 elif clean_path.startswith('./projects/'): clean_path = clean_path[11:] # 移除 './projects/' 前缀 # 尝试在项目目录中查找文件 full_path = os.path.join(project_data_dir, clean_path.lstrip('./')) if not os.path.exists(full_path): # 如果直接路径不存在,尝试递归查找 found = find_file_in_project(clean_path, project_data_dir) if found: embeddings_file = found else: return { "content": [ { "type": "text", "text": f"错误:在项目目录 {project_data_dir} 中未找到embeddings文件 {embeddings_file}" } ] } else: embeddings_file = full_path else: if not embeddings_file.startswith(project_data_dir): return { "content": [ { "type": "text", "text": f"错误:embeddings文件路径必须在项目目录 {project_data_dir} 内" } ] } if not os.path.exists(embeddings_file): return { "content": [ { "type": "text", "text": f"错误:embeddings文件 {embeddings_file} 不存在" } ] } except Exception as e: return { "content": [ { "type": "text", "text": f"错误:embeddings文件路径验证失败 - {str(e)}" } ] } try: # 加载嵌入数据 with open(embeddings_file, 'rb') as f: embedding_data = pickle.load(f) # 兼容新旧数据结构 if 'chunks' in embedding_data: # 新的数据结构(使用chunks) sentences = embedding_data['chunks'] sentence_embeddings = embedding_data['embeddings'] # 从embedding_data中获取模型路径(如果有的话) model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') model = get_model(model_path) else: # 旧的数据结构(使用sentences) sentences = embedding_data['sentences'] sentence_embeddings = embedding_data['embeddings'] model = get_model() # 编码查询 query_embedding = model.encode(query, convert_to_tensor=True) # 计算相似度 cos_scores = util.cos_sim(query_embedding, sentence_embeddings)[0] # 获取top_k结果 top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k] # 格式化结果 results = [] for i, idx in enumerate(top_results): sentence = sentences[idx] score = cos_scores[idx].item() results.append({ 'rank': i + 1, 'content': sentence, 'similarity_score': score, 'file_path': embeddings_file }) if not results: return { "content": [ { "type": "text", "text": "未找到匹配的结果" } ] } # 格式化输出 formatted_output = "\n".join([ f"#{result['rank']} [similarity:{result['similarity_score']:.4f}]: {result['content']}" for result in results ]) return { "content": [ { "type": "text", "text": formatted_output } ] } except FileNotFoundError: return { "content": [ { "type": "text", "text": f"错误:找不到embeddings文件 {embeddings_file}" } ] } except Exception as e: return { "content": [ { "type": "text", "text": f"搜索时出错:{str(e)}" } ] } def find_file_in_project(filename: str, project_dir: str) -> Optional[str]: """在项目目录中递归查找文件""" for root, dirs, files in os.walk(project_dir): if filename in files: return os.path.join(root, filename) return None def get_model_info() -> Dict[str, Any]: """获取当前模型信息""" try: # 检查本地模型路径 local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2" if os.path.exists(local_model_path): return { "content": [ { "type": "text", "text": f"✅ 使用本地模型: {local_model_path}\n" f"模型状态: 已加载\n" f"设备: CPU\n" f"说明: 避免从HuggingFace下载,提高响应速度" } ] } else: return { "content": [ { "type": "text", "text": f"⚠️ 本地模型不存在: {local_model_path}\n" f"将使用HuggingFace模型: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2\n" f"建议: 下载模型到本地以提高响应速度\n" f"设备: CPU" } ] } except Exception as e: return { "content": [ { "type": "text", "text": f"❌ 获取模型信息失败: {str(e)}" } ] } async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]: """Handle MCP request""" try: method = request.get("method") params = request.get("params", {}) request_id = request.get("id") if method == "initialize": return { "jsonrpc": "2.0", "id": request_id, "result": { "protocolVersion": "2024-11-05", "capabilities": { "tools": {} }, "serverInfo": { "name": "semantic-search", "version": "1.0.0" } } } elif method == "ping": return { "jsonrpc": "2.0", "id": request_id, "result": { "pong": True } } elif method == "tools/list": return { "jsonrpc": "2.0", "id": request_id, "result": { "tools": [ { "name": "semantic_search", "description": "语义搜索工具,基于向量嵌入进行相似度搜索。格式:#[排名] [相似度分数]: [匹配内容]", "inputSchema": { "type": "object", "properties": { "query": { "type": "string", "description": "搜索查询文本" }, "embeddings_file": { "type": "string", "description": "embeddings pickle文件路径" }, "top_k": { "type": "integer", "description": "返回结果的最大数量,默认20", "default": 20 } }, "required": ["query", "embeddings_file"] } }, { "name": "get_model_info", "description": "获取当前使用的模型信息,包括模型路径、加载状态等", "inputSchema": { "type": "object", "properties": {}, "required": [] } } ] } } elif method == "tools/call": tool_name = params.get("name") arguments = params.get("arguments", {}) if tool_name == "semantic_search": query = arguments.get("query", "") embeddings_file = arguments.get("embeddings_file", "") top_k = arguments.get("top_k", 20) result = semantic_search(query, embeddings_file, top_k) return { "jsonrpc": "2.0", "id": request_id, "result": result } elif tool_name == "get_model_info": result = get_model_info() return { "jsonrpc": "2.0", "id": request_id, "result": result } else: return { "jsonrpc": "2.0", "id": request_id, "error": { "code": -32601, "message": f"Unknown tool: {tool_name}" } } else: return { "jsonrpc": "2.0", "id": request_id, "error": { "code": -32601, "message": f"Unknown method: {method}" } } except Exception as e: return { "jsonrpc": "2.0", "id": request.get("id"), "error": { "code": -32603, "message": f"Internal error: {str(e)}" } } async def main(): """Main entry point.""" try: while True: # Read from stdin line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline) if not line: break line = line.strip() if not line: continue try: request = json.loads(line) response = await handle_request(request) # Write to stdout sys.stdout.write(json.dumps(response) + "\n") sys.stdout.flush() except json.JSONDecodeError: error_response = { "jsonrpc": "2.0", "error": { "code": -32700, "message": "Parse error" } } sys.stdout.write(json.dumps(error_response) + "\n") sys.stdout.flush() except Exception as e: error_response = { "jsonrpc": "2.0", "error": { "code": -32603, "message": f"Internal error: {str(e)}" } } sys.stdout.write(json.dumps(error_response) + "\n") sys.stdout.flush() except KeyboardInterrupt: pass if __name__ == "__main__": asyncio.run(main())