249 lines
7.6 KiB
Python
249 lines
7.6 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
语义搜索MCP服务器
|
||
基于embedding向量进行语义相似度搜索
|
||
参考multi_keyword_search_server.py的实现方式
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import os
|
||
import pickle
|
||
import sys
|
||
from typing import Any, Dict, List, Optional, Union
|
||
|
||
import numpy as np
|
||
from mcp_common import (
|
||
get_allowed_directory,
|
||
load_tools_from_json,
|
||
resolve_file_path,
|
||
find_file_in_project,
|
||
create_error_response,
|
||
create_success_response,
|
||
create_initialize_response,
|
||
create_ping_response,
|
||
create_tools_list_response,
|
||
handle_mcp_streaming
|
||
)
|
||
from utils.settings import FASTAPI_URL
|
||
|
||
import requests
|
||
|
||
|
||
def encode_query_via_api(query: str, fastapi_url: str = None) -> np.ndarray:
|
||
"""通过API编码单个查询"""
|
||
if not fastapi_url:
|
||
fastapi_url = FASTAPI_URL
|
||
|
||
api_endpoint = f"{fastapi_url}/api/v1/embedding/encode"
|
||
|
||
try:
|
||
# 调用编码接口
|
||
request_data = {
|
||
"texts": [query],
|
||
"batch_size": 1
|
||
}
|
||
|
||
response = requests.post(
|
||
api_endpoint,
|
||
json=request_data,
|
||
timeout=30
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
result_data = response.json()
|
||
if result_data.get("success"):
|
||
embeddings_list = result_data.get("embeddings", [])
|
||
if embeddings_list:
|
||
return np.array(embeddings_list[0])
|
||
|
||
print(f"API编码失败: {response.status_code} - {response.text}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"API编码异常: {e}")
|
||
return None
|
||
|
||
|
||
def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: int = 20) -> Dict[str, Any]:
|
||
"""执行语义搜索,直接读取本地embedding文件并计算相似度"""
|
||
# 处理查询输入
|
||
if isinstance(queries, str):
|
||
queries = [queries]
|
||
|
||
# 验证查询列表
|
||
if not queries or not any(q.strip() for q in queries):
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": "Error: Queries cannot be empty"
|
||
}
|
||
]
|
||
}
|
||
|
||
# 过滤空查询
|
||
queries = [q.strip() for q in queries if q.strip()]
|
||
|
||
try:
|
||
# 解析embedding文件路径
|
||
resolved_embeddings_file = resolve_file_path(embeddings_file)
|
||
|
||
# 读取embedding文件
|
||
with open(resolved_embeddings_file, 'rb') as f:
|
||
embedding_data = pickle.load(f)
|
||
|
||
# 兼容新旧数据结构
|
||
if 'chunks' in embedding_data:
|
||
# 新的数据结构(使用chunks)
|
||
chunks = embedding_data['chunks']
|
||
chunk_embeddings = embedding_data['embeddings']
|
||
chunking_strategy = embedding_data.get('chunking_strategy', 'unknown')
|
||
else:
|
||
# 旧的数据结构(使用sentences)
|
||
chunks = embedding_data['sentences']
|
||
chunk_embeddings = embedding_data['embeddings']
|
||
chunking_strategy = 'line'
|
||
|
||
all_results = []
|
||
|
||
# 处理每个查询
|
||
for query in queries:
|
||
# 使用API编码查询
|
||
print(f"正在为查询编码: {query}")
|
||
query_embedding = encode_query_via_api(query)
|
||
|
||
if query_embedding is None:
|
||
print(f"查询编码失败: {query}")
|
||
continue
|
||
|
||
# 计算相似度
|
||
if len(chunk_embeddings.shape) > 1:
|
||
cos_scores = np.dot(chunk_embeddings, query_embedding) / (
|
||
np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_embedding) + 1e-8
|
||
)
|
||
else:
|
||
cos_scores = np.array([0.0] * len(chunks))
|
||
|
||
# 获取top_k结果
|
||
top_indices = np.argsort(-cos_scores)[:top_k]
|
||
|
||
for rank, idx in enumerate(top_indices):
|
||
score = cos_scores[idx]
|
||
if score > 0: # 只包含有一定相关性的结果
|
||
all_results.append({
|
||
'query': query,
|
||
'rank': rank + 1,
|
||
'content': chunks[idx],
|
||
'similarity_score': float(score),
|
||
'file_path': embeddings_file
|
||
})
|
||
|
||
if not all_results:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": "No matching results found"
|
||
}
|
||
]
|
||
}
|
||
|
||
# 按相似度分数排序所有结果
|
||
all_results.sort(key=lambda x: x['similarity_score'], reverse=True)
|
||
|
||
# 格式化输出
|
||
formatted_lines = []
|
||
formatted_lines.append(f"Found {len(all_results)} results for {len(queries)} queries:")
|
||
formatted_lines.append("")
|
||
|
||
for i, result in enumerate(all_results):
|
||
formatted_lines.append(f"#{i+1} [query: '{result['query']}'] [similarity:{result['similarity_score']:.4f}]: {result['content']}")
|
||
|
||
formatted_output = "\n".join(formatted_lines)
|
||
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": formatted_output
|
||
}
|
||
]
|
||
}
|
||
|
||
except FileNotFoundError:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Error: Embeddings file not found: {embeddings_file}"
|
||
}
|
||
]
|
||
}
|
||
except Exception as e:
|
||
return {
|
||
"content": [
|
||
{
|
||
"type": "text",
|
||
"text": f"Search error: {str(e)}"
|
||
}
|
||
]
|
||
}
|
||
|
||
|
||
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""Handle MCP request"""
|
||
try:
|
||
method = request.get("method")
|
||
params = request.get("params", {})
|
||
request_id = request.get("id")
|
||
|
||
if method == "initialize":
|
||
return create_initialize_response(request_id, "semantic-search")
|
||
|
||
elif method == "ping":
|
||
return create_ping_response(request_id)
|
||
|
||
elif method == "tools/list":
|
||
# 从 JSON 文件加载工具定义
|
||
tools = load_tools_from_json("semantic_search_tools.json")
|
||
return create_tools_list_response(request_id, tools)
|
||
|
||
elif method == "tools/call":
|
||
tool_name = params.get("name")
|
||
arguments = params.get("arguments", {})
|
||
|
||
if tool_name == "semantic_search":
|
||
queries = arguments.get("queries", [])
|
||
# 兼容旧的query参数
|
||
if not queries and "query" in arguments:
|
||
queries = arguments.get("query", "")
|
||
embeddings_file = arguments.get("embeddings_file", "")
|
||
top_k = arguments.get("top_k", 20)
|
||
|
||
result = semantic_search(queries, embeddings_file, top_k)
|
||
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": request_id,
|
||
"result": result
|
||
}
|
||
|
||
else:
|
||
return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}")
|
||
|
||
else:
|
||
return create_error_response(request_id, -32601, f"Unknown method: {method}")
|
||
|
||
except Exception as e:
|
||
return create_error_response(request.get("id"), -32603, f"Internal error: {str(e)}")
|
||
|
||
|
||
async def main():
|
||
"""Main entry point."""
|
||
await handle_mcp_streaming(handle_request)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|