qwen_agent/mcp/semantic_search_server.py
2025-10-22 23:04:49 +08:00

268 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
语义搜索MCP服务器
基于embedding向量进行语义相似度搜索
参考multi_keyword_search_server.py的实现方式
"""
import asyncio
import json
import os
import pickle
import sys
from typing import Any, Dict, List, Optional
import numpy as np
from sentence_transformers import SentenceTransformer, util
from mcp_common import (
get_allowed_directory,
load_tools_from_json,
resolve_file_path,
find_file_in_project,
create_error_response,
create_success_response,
create_initialize_response,
create_ping_response,
create_tools_list_response,
handle_mcp_streaming
)
# 延迟加载模型
embedder = None
def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'):
"""获取模型实例(延迟加载)
Args:
model_name_or_path (str): 模型名称或本地路径
- 可以是 HuggingFace 模型名称
- 可以是本地模型路径
"""
global embedder
if embedder is None:
# 优先使用本地模型路径
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
# 从环境变量获取设备配置,默认为 CPU
device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu')
if device not in ['cpu', 'cuda', 'mps']:
print(f"警告: 不支持的设备类型 '{device}',使用默认 CPU")
device = 'cpu'
# 检查本地模型是否存在
if os.path.exists(local_model_path):
print(f"使用本地模型: {local_model_path}")
embedder = SentenceTransformer(local_model_path, device=device)
else:
print(f"本地模型不存在使用HuggingFace模型: {model_name_or_path}")
embedder = SentenceTransformer(model_name_or_path, device=device)
return embedder
def semantic_search(query: str, embeddings_file: str, top_k: int = 20) -> Dict[str, Any]:
"""执行语义搜索"""
if not query.strip():
return {
"content": [
{
"type": "text",
"text": "Error: Query cannot be empty"
}
]
}
# 验证embeddings文件路径
try:
# 解析文件路径,支持 folder/document.txt 和 document.txt 格式
resolved_embeddings_file = resolve_file_path(embeddings_file)
# 加载嵌入数据
with open(resolved_embeddings_file, 'rb') as f:
embedding_data = pickle.load(f)
# 兼容新旧数据结构
if 'chunks' in embedding_data:
# 新的数据结构使用chunks
sentences = embedding_data['chunks']
sentence_embeddings = embedding_data['embeddings']
# 从embedding_data中获取模型路径如果有的话
model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
model = get_model(model_path)
else:
# 旧的数据结构使用sentences
sentences = embedding_data['sentences']
sentence_embeddings = embedding_data['embeddings']
model = get_model()
# 编码查询
query_embedding = model.encode(query, convert_to_tensor=True)
# 计算相似度
cos_scores = util.cos_sim(query_embedding, sentence_embeddings)[0]
# 获取top_k结果
top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k]
# 格式化结果
results = []
for i, idx in enumerate(top_results):
sentence = sentences[idx]
score = cos_scores[idx].item()
results.append({
'rank': i + 1,
'content': sentence,
'similarity_score': score,
'file_path': embeddings_file
})
if not results:
return {
"content": [
{
"type": "text",
"text": "No matching results found"
}
]
}
# 格式化输出
formatted_output = "\n".join([
f"#{result['rank']} [similarity:{result['similarity_score']:.4f}]: {result['content']}"
for result in results
])
return {
"content": [
{
"type": "text",
"text": formatted_output
}
]
}
except FileNotFoundError:
return {
"content": [
{
"type": "text",
"text": f"Error: embeddings file {embeddings_file} not found"
}
]
}
except Exception as e:
return {
"content": [
{
"type": "text",
"text": f"Search error: {str(e)}"
}
]
}
def get_model_info() -> Dict[str, Any]:
"""获取当前模型信息"""
try:
# 检查本地模型路径
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
if os.path.exists(local_model_path):
return {
"content": [
{
"type": "text",
"text": f"✅ 使用本地模型: {local_model_path}\n"
f"模型状态: 已加载\n"
f"设备: CPU\n"
f"说明: 避免从HuggingFace下载提高响应速度"
}
]
}
else:
return {
"content": [
{
"type": "text",
"text": f"⚠️ 本地模型不存在: {local_model_path}\n"
f"将使用HuggingFace模型: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2\n"
f"建议: 下载模型到本地以提高响应速度\n"
f"设备: CPU"
}
]
}
except Exception as e:
return {
"content": [
{
"type": "text",
"text": f"❌ 获取模型信息失败: {str(e)}"
}
]
}
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""Handle MCP request"""
try:
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
if method == "initialize":
return create_initialize_response(request_id, "semantic-search")
elif method == "ping":
return create_ping_response(request_id)
elif method == "tools/list":
# 从 JSON 文件加载工具定义
tools = load_tools_from_json("semantic_search_tools.json")
return create_tools_list_response(request_id, tools)
elif method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
if tool_name == "semantic_search":
query = arguments.get("query", "")
embeddings_file = arguments.get("embeddings_file", "")
top_k = arguments.get("top_k", 20)
result = semantic_search(query, embeddings_file, top_k)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
elif tool_name == "get_model_info":
result = get_model_info()
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
else:
return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}")
else:
return create_error_response(request_id, -32601, f"Unknown method: {method}")
except Exception as e:
return create_error_response(request.get("id"), -32603, f"Internal error: {str(e)}")
async def main():
"""Main entry point."""
await handle_mcp_streaming(handle_request)
if __name__ == "__main__":
asyncio.run(main())