catalog-agent/mcp/semantic_search_server.py
2025-10-17 16:16:41 +08:00

458 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
语义搜索MCP服务器
基于embedding向量进行语义相似度搜索
参考multi_keyword_search_server.py的实现方式
"""
import asyncio
import json
import os
import pickle
import sys
from typing import Any, Dict, List, Optional
import numpy as np
from sentence_transformers import SentenceTransformer, util
# 延迟加载模型
embedder = None
def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'):
"""获取模型实例(延迟加载)
Args:
model_name_or_path (str): 模型名称或本地路径
- 可以是 HuggingFace 模型名称
- 可以是本地模型路径
"""
global embedder
if embedder is None:
# 优先使用本地模型路径
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
# 检查本地模型是否存在
if os.path.exists(local_model_path):
print(f"使用本地模型: {local_model_path}")
embedder = SentenceTransformer(local_model_path, device='cpu')
else:
print(f"本地模型不存在使用HuggingFace模型: {model_name_or_path}")
embedder = SentenceTransformer(model_name_or_path, device='cpu')
return embedder
def validate_file_path(file_path: str, allowed_dir: str) -> str:
"""验证文件路径是否在允许的目录内"""
# 转换为绝对路径
if not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
allowed_dir = os.path.abspath(allowed_dir)
# 检查路径是否在允许的目录内
if not file_path.startswith(allowed_dir):
raise ValueError(f"访问被拒绝: 路径 {file_path} 不在允许的目录 {allowed_dir}")
# 检查路径遍历攻击
if ".." in file_path:
raise ValueError(f"访问被拒绝: 检测到路径遍历攻击尝试")
return file_path
def get_allowed_directory():
"""获取允许访问的目录"""
# 从环境变量读取项目数据目录
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
return os.path.abspath(project_dir)
def semantic_search(query: str, embeddings_file: str, top_k: int = 20) -> Dict[str, Any]:
"""执行语义搜索"""
if not query.strip():
return {
"content": [
{
"type": "text",
"text": "错误:查询不能为空"
}
]
}
# 处理项目目录限制
project_data_dir = get_allowed_directory()
# 验证embeddings文件路径
try:
# 解析相对路径
if not os.path.isabs(embeddings_file):
# 移除 projects/ 前缀(如果存在)
clean_path = embeddings_file
if clean_path.startswith('projects/'):
clean_path = clean_path[9:] # 移除 'projects/' 前缀
elif clean_path.startswith('./projects/'):
clean_path = clean_path[11:] # 移除 './projects/' 前缀
# 尝试在项目目录中查找文件
full_path = os.path.join(project_data_dir, clean_path.lstrip('./'))
if not os.path.exists(full_path):
# 如果直接路径不存在,尝试递归查找
found = find_file_in_project(clean_path, project_data_dir)
if found:
embeddings_file = found
else:
return {
"content": [
{
"type": "text",
"text": f"错误:在项目目录 {project_data_dir} 中未找到embeddings文件 {embeddings_file}"
}
]
}
else:
embeddings_file = full_path
else:
if not embeddings_file.startswith(project_data_dir):
return {
"content": [
{
"type": "text",
"text": f"错误embeddings文件路径必须在项目目录 {project_data_dir}"
}
]
}
if not os.path.exists(embeddings_file):
return {
"content": [
{
"type": "text",
"text": f"错误embeddings文件 {embeddings_file} 不存在"
}
]
}
except Exception as e:
return {
"content": [
{
"type": "text",
"text": f"错误embeddings文件路径验证失败 - {str(e)}"
}
]
}
try:
# 加载嵌入数据
with open(embeddings_file, 'rb') as f:
embedding_data = pickle.load(f)
# 兼容新旧数据结构
if 'chunks' in embedding_data:
# 新的数据结构使用chunks
sentences = embedding_data['chunks']
sentence_embeddings = embedding_data['embeddings']
# 从embedding_data中获取模型路径如果有的话
model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
model = get_model(model_path)
else:
# 旧的数据结构使用sentences
sentences = embedding_data['sentences']
sentence_embeddings = embedding_data['embeddings']
model = get_model()
# 编码查询
query_embedding = model.encode(query, convert_to_tensor=True)
# 计算相似度
cos_scores = util.cos_sim(query_embedding, sentence_embeddings)[0]
# 获取top_k结果
top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k]
# 格式化结果
results = []
for i, idx in enumerate(top_results):
sentence = sentences[idx]
score = cos_scores[idx].item()
results.append({
'rank': i + 1,
'content': sentence,
'similarity_score': score,
'file_path': embeddings_file
})
if not results:
return {
"content": [
{
"type": "text",
"text": "未找到匹配的结果"
}
]
}
# 格式化输出
formatted_output = "\n".join([
f"#{result['rank']} [similarity:{result['similarity_score']:.4f}]: {result['content']}"
for result in results
])
return {
"content": [
{
"type": "text",
"text": formatted_output
}
]
}
except FileNotFoundError:
return {
"content": [
{
"type": "text",
"text": f"错误找不到embeddings文件 {embeddings_file}"
}
]
}
except Exception as e:
return {
"content": [
{
"type": "text",
"text": f"搜索时出错:{str(e)}"
}
]
}
def find_file_in_project(filename: str, project_dir: str) -> Optional[str]:
"""在项目目录中递归查找文件"""
for root, dirs, files in os.walk(project_dir):
if filename in files:
return os.path.join(root, filename)
return None
def get_model_info() -> Dict[str, Any]:
"""获取当前模型信息"""
try:
# 检查本地模型路径
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
if os.path.exists(local_model_path):
return {
"content": [
{
"type": "text",
"text": f"✅ 使用本地模型: {local_model_path}\n"
f"模型状态: 已加载\n"
f"设备: CPU\n"
f"说明: 避免从HuggingFace下载提高响应速度"
}
]
}
else:
return {
"content": [
{
"type": "text",
"text": f"⚠️ 本地模型不存在: {local_model_path}\n"
f"将使用HuggingFace模型: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2\n"
f"建议: 下载模型到本地以提高响应速度\n"
f"设备: CPU"
}
]
}
except Exception as e:
return {
"content": [
{
"type": "text",
"text": f"❌ 获取模型信息失败: {str(e)}"
}
]
}
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""Handle MCP request"""
try:
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
if method == "initialize":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": "semantic-search",
"version": "1.0.0"
}
}
}
elif method == "ping":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"pong": True
}
}
elif method == "tools/list":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"tools": [
{
"name": "semantic_search",
"description": "语义搜索工具,基于向量嵌入进行相似度搜索。格式:#[排名] [相似度分数]: [匹配内容]",
"inputSchema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "搜索查询文本"
},
"embeddings_file": {
"type": "string",
"description": "embeddings pickle文件路径"
},
"top_k": {
"type": "integer",
"description": "返回结果的最大数量默认20",
"default": 20
}
},
"required": ["query", "embeddings_file"]
}
},
{
"name": "get_model_info",
"description": "获取当前使用的模型信息,包括模型路径、加载状态等",
"inputSchema": {
"type": "object",
"properties": {},
"required": []
}
}
]
}
}
elif method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
if tool_name == "semantic_search":
query = arguments.get("query", "")
embeddings_file = arguments.get("embeddings_file", "")
top_k = arguments.get("top_k", 20)
result = semantic_search(query, embeddings_file, top_k)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
elif tool_name == "get_model_info":
result = get_model_info()
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32601,
"message": f"Unknown tool: {tool_name}"
}
}
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32601,
"message": f"Unknown method: {method}"
}
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"error": {
"code": -32603,
"message": f"Internal error: {str(e)}"
}
}
async def main():
"""Main entry point."""
try:
while True:
# Read from stdin
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
if not line:
break
line = line.strip()
if not line:
continue
try:
request = json.loads(line)
response = await handle_request(request)
# Write to stdout
sys.stdout.write(json.dumps(response) + "\n")
sys.stdout.flush()
except json.JSONDecodeError:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32700,
"message": "Parse error"
}
}
sys.stdout.write(json.dumps(error_response) + "\n")
sys.stdout.flush()
except Exception as e:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32603,
"message": f"Internal error: {str(e)}"
}
}
sys.stdout.write(json.dumps(error_response) + "\n")
sys.stdout.flush()
except KeyboardInterrupt:
pass
if __name__ == "__main__":
asyncio.run(main())