448 lines
14 KiB
Python
448 lines
14 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
语义搜索MCP服务器
|
||
基于embedding向量进行语义相似度搜索
|
||
参考multi_keyword_search_server.py的实现方式
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import os
|
||
import pickle
|
||
import sys
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
import numpy as np
|
||
from sentence_transformers import SentenceTransformer, util
|
||
|
||
# 延迟加载模型
|
||
embedder = None
|
||
|
||
def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'):
|
||
"""获取模型实例(延迟加载)
|
||
|
||
Args:
|
||
model_name_or_path (str): 模型名称或本地路径
|
||
- 可以是 HuggingFace 模型名称
|
||
- 可以是本地模型路径
|
||
"""
|
||
global embedder
|
||
if embedder is None:
|
||
# 优先使用本地模型路径
|
||
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
|
||
|
||
# 从环境变量获取设备配置,默认为 CPU
|
||
device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu')
|
||
if device not in ['cpu', 'cuda', 'mps']:
|
||
print(f"警告: 不支持的设备类型 '{device}',使用默认 CPU")
|
||
device = 'cpu'
|
||
|
||
# 检查本地模型是否存在
|
||
if os.path.exists(local_model_path):
|
||
print(f"使用本地模型: {local_model_path}")
|
||
embedder = SentenceTransformer(local_model_path, device=device)
|
||
else:
|
||
print(f"本地模型不存在,使用HuggingFace模型: {model_name_or_path}")
|
||
embedder = SentenceTransformer(model_name_or_path, device=device)
|
||
|
||
return embedder
|
||
|
||
|
||
def validate_file_path(file_path: str, allowed_dir: str) -> str:
|
||
"""验证文件路径是否在允许的目录内"""
|
||
# 转换为绝对路径
|
||
if not os.path.isabs(file_path):
|
||
file_path = os.path.abspath(file_path)
|
||
|
||
allowed_dir = os.path.abspath(allowed_dir)
|
||
|
||
# 检查路径是否在允许的目录内
|
||
if not file_path.startswith(allowed_dir):
|
||
raise ValueError(f"Access denied: path {file_path} is not within allowed directory {allowed_dir}")
|
||
|
||
# 检查路径遍历攻击
|
||
if ".." in file_path:
|
||
raise ValueError(f"Access denied: path traversal attack detected")
|
||
|
||
return file_path
|
||
|
||
|
||
def get_allowed_directory():
|
||
"""获取允许访问的目录"""
|
||
# 从环境变量读取项目数据目录
|
||
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
|
||
return os.path.abspath(project_dir)
|
||
|
||
|
||
def load_tools_from_json() -> List[Dict[str, Any]]:
|
||
"""从 JSON 文件加载工具定义"""
|
||
try:
|
||
tools_file = os.path.join(os.path.dirname(__file__), "tools", "semantic_search_tools.json")
|
||
if os.path.exists(tools_file):
|
||
with open(tools_file, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
else:
|
||
# 如果 JSON 文件不存在,使用默认定义
|
||
return []
|
||
except Exception as e:
|
||
print(f"Warning: Unable to load tool definition JSON file: {str(e)}")
|
||
return []
|
||
|
||
|
||
def semantic_search(query: str, embeddings_file: str, top_k: int = 20) -> Dict[str, Any]:
|
||
"""执行语义搜索"""
|
||
if not query.strip():
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": "Error: Query cannot be empty"
|
||
}
|
||
]
|
||
}
|
||
|
||
# 处理项目目录限制
|
||
project_data_dir = get_allowed_directory()
|
||
|
||
# 验证embeddings文件路径
|
||
try:
|
||
# 解析相对路径
|
||
if not os.path.isabs(embeddings_file):
|
||
# 移除 projects/ 前缀(如果存在)
|
||
clean_path = embeddings_file
|
||
if clean_path.startswith('projects/'):
|
||
clean_path = clean_path[9:] # 移除 'projects/' 前缀
|
||
elif clean_path.startswith('./projects/'):
|
||
clean_path = clean_path[11:] # 移除 './projects/' 前缀
|
||
|
||
# 尝试在项目目录中查找文件
|
||
full_path = os.path.join(project_data_dir, clean_path.lstrip('./'))
|
||
if not os.path.exists(full_path):
|
||
# 如果直接路径不存在,尝试递归查找
|
||
found = find_file_in_project(clean_path, project_data_dir)
|
||
if found:
|
||
embeddings_file = found
|
||
else:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Error: embeddings file {embeddings_file} not found in project directory {project_data_dir}"
|
||
}
|
||
]
|
||
}
|
||
else:
|
||
embeddings_file = full_path
|
||
else:
|
||
if not embeddings_file.startswith(project_data_dir):
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Error: embeddings file path must be within project directory {project_data_dir}"
|
||
}
|
||
]
|
||
}
|
||
if not os.path.exists(embeddings_file):
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Error: embeddings file {embeddings_file} does not exist"
|
||
}
|
||
]
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Error: embeddings file path validation failed - {str(e)}"
|
||
}
|
||
]
|
||
}
|
||
|
||
try:
|
||
# 加载嵌入数据
|
||
with open(embeddings_file, 'rb') as f:
|
||
embedding_data = pickle.load(f)
|
||
|
||
# 兼容新旧数据结构
|
||
if 'chunks' in embedding_data:
|
||
# 新的数据结构(使用chunks)
|
||
sentences = embedding_data['chunks']
|
||
sentence_embeddings = embedding_data['embeddings']
|
||
# 从embedding_data中获取模型路径(如果有的话)
|
||
model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
|
||
model = get_model(model_path)
|
||
else:
|
||
# 旧的数据结构(使用sentences)
|
||
sentences = embedding_data['sentences']
|
||
sentence_embeddings = embedding_data['embeddings']
|
||
model = get_model()
|
||
|
||
# 编码查询
|
||
query_embedding = model.encode(query, convert_to_tensor=True)
|
||
|
||
# 计算相似度
|
||
cos_scores = util.cos_sim(query_embedding, sentence_embeddings)[0]
|
||
|
||
# 获取top_k结果
|
||
top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k]
|
||
|
||
# 格式化结果
|
||
results = []
|
||
for i, idx in enumerate(top_results):
|
||
sentence = sentences[idx]
|
||
score = cos_scores[idx].item()
|
||
results.append({
|
||
'rank': i + 1,
|
||
'content': sentence,
|
||
'similarity_score': score,
|
||
'file_path': embeddings_file
|
||
})
|
||
|
||
if not results:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": "No matching results found"
|
||
}
|
||
]
|
||
}
|
||
|
||
# 格式化输出
|
||
formatted_output = "\n".join([
|
||
f"#{result['rank']} [similarity:{result['similarity_score']:.4f}]: {result['content']}"
|
||
for result in results
|
||
])
|
||
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": formatted_output
|
||
}
|
||
]
|
||
}
|
||
|
||
except FileNotFoundError:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Error: embeddings file {embeddings_file} not found"
|
||
}
|
||
]
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Search error: {str(e)}"
|
||
}
|
||
]
|
||
}
|
||
|
||
|
||
def find_file_in_project(filename: str, project_dir: str) -> Optional[str]:
|
||
"""在项目目录中递归查找文件"""
|
||
for root, dirs, files in os.walk(project_dir):
|
||
if filename in files:
|
||
return os.path.join(root, filename)
|
||
return None
|
||
|
||
|
||
def get_model_info() -> Dict[str, Any]:
|
||
"""获取当前模型信息"""
|
||
try:
|
||
# 检查本地模型路径
|
||
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
|
||
|
||
if os.path.exists(local_model_path):
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"✅ 使用本地模型: {local_model_path}\n"
|
||
f"模型状态: 已加载\n"
|
||
f"设备: CPU\n"
|
||
f"说明: 避免从HuggingFace下载,提高响应速度"
|
||
}
|
||
]
|
||
}
|
||
else:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"⚠️ 本地模型不存在: {local_model_path}\n"
|
||
f"将使用HuggingFace模型: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2\n"
|
||
f"建议: 下载模型到本地以提高响应速度\n"
|
||
f"设备: CPU"
|
||
}
|
||
]
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"❌ 获取模型信息失败: {str(e)}"
|
||
}
|
||
]
|
||
}
|
||
|
||
|
||
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Handle MCP request"""
|
||
try:
|
||
method = request.get("method")
|
||
params = request.get("params", {})
|
||
request_id = request.get("id")
|
||
|
||
if method == "initialize":
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": request_id,
|
||
"result": {
|
||
"protocolVersion": "2024-11-05",
|
||
"capabilities": {
|
||
"tools": {}
|
||
},
|
||
"serverInfo": {
|
||
"name": "semantic-search",
|
||
"version": "1.0.0"
|
||
}
|
||
}
|
||
}
|
||
|
||
elif method == "ping":
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": request_id,
|
||
"result": {
|
||
"pong": True
|
||
}
|
||
}
|
||
|
||
elif method == "tools/list":
|
||
# 从 JSON 文件加载工具定义
|
||
tools = load_tools_from_json()
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": request_id,
|
||
"result": {
|
||
"tools": tools
|
||
}
|
||
}
|
||
|
||
elif method == "tools/call":
|
||
tool_name = params.get("name")
|
||
arguments = params.get("arguments", {})
|
||
|
||
if tool_name == "semantic_search":
|
||
query = arguments.get("query", "")
|
||
embeddings_file = arguments.get("embeddings_file", "")
|
||
top_k = arguments.get("top_k", 20)
|
||
|
||
result = semantic_search(query, embeddings_file, top_k)
|
||
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": request_id,
|
||
"result": result
|
||
}
|
||
|
||
elif tool_name == "get_model_info":
|
||
result = get_model_info()
|
||
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": request_id,
|
||
"result": result
|
||
}
|
||
|
||
else:
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": request_id,
|
||
"error": {
|
||
"code": -32601,
|
||
"message": f"Unknown tool: {tool_name}"
|
||
}
|
||
}
|
||
|
||
else:
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": request_id,
|
||
"error": {
|
||
"code": -32601,
|
||
"message": f"Unknown method: {method}"
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": request.get("id"),
|
||
"error": {
|
||
"code": -32603,
|
||
"message": f"Internal error: {str(e)}"
|
||
}
|
||
}
|
||
|
||
|
||
async def main():
|
||
"""Main entry point."""
|
||
try:
|
||
while True:
|
||
# Read from stdin
|
||
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
|
||
if not line:
|
||
break
|
||
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
|
||
try:
|
||
request = json.loads(line)
|
||
response = await handle_request(request)
|
||
|
||
# Write to stdout
|
||
sys.stdout.write(json.dumps(response) + "\n")
|
||
sys.stdout.flush()
|
||
|
||
except json.JSONDecodeError:
|
||
error_response = {
|
||
"jsonrpc": "2.0",
|
||
"error": {
|
||
"code": -32700,
|
||
"message": "Parse error"
|
||
}
|
||
}
|
||
sys.stdout.write(json.dumps(error_response) + "\n")
|
||
sys.stdout.flush()
|
||
|
||
except Exception as e:
|
||
error_response = {
|
||
"jsonrpc": "2.0",
|
||
"error": {
|
||
"code": -32603,
|
||
"message": f"Internal error: {str(e)}"
|
||
}
|
||
}
|
||
sys.stdout.write(json.dumps(error_response) + "\n")
|
||
sys.stdout.flush()
|
||
|
||
except KeyboardInterrupt:
|
||
pass
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|