200 lines
6.2 KiB
Python
200 lines
6.2 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
语义搜索MCP服务器
|
|
基于embedding向量进行语义相似度搜索
|
|
参考multi_keyword_search_server.py的实现方式
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import pickle
|
|
import sys
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
import numpy as np
|
|
from sentence_transformers import SentenceTransformer, util
|
|
from mcp_common import (
|
|
get_allowed_directory,
|
|
load_tools_from_json,
|
|
resolve_file_path,
|
|
find_file_in_project,
|
|
create_error_response,
|
|
create_success_response,
|
|
create_initialize_response,
|
|
create_ping_response,
|
|
create_tools_list_response,
|
|
handle_mcp_streaming
|
|
)
|
|
|
|
import requests
|
|
|
|
|
|
def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: int = 20) -> Dict[str, Any]:
|
|
"""执行语义搜索,通过调用 FastAPI 接口"""
|
|
# 处理查询输入
|
|
if isinstance(queries, str):
|
|
queries = [queries]
|
|
|
|
# 验证查询列表
|
|
if not queries or not any(q.strip() for q in queries):
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Error: Queries cannot be empty"
|
|
}
|
|
]
|
|
}
|
|
|
|
# 过滤空查询
|
|
queries = [q.strip() for q in queries if q.strip()]
|
|
|
|
try:
|
|
# FastAPI 服务地址
|
|
fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
|
|
api_endpoint = f"{fastapi_url}/api/v1/semantic-search"
|
|
|
|
# 处理每个查询
|
|
all_results = []
|
|
resolved_embeddings_file = resolve_file_path(embeddings_file)
|
|
for query in queries:
|
|
# 调用 FastAPI 接口
|
|
request_data = {
|
|
"embedding_file": resolved_embeddings_file,
|
|
"query": query,
|
|
"top_k": top_k,
|
|
"min_score": 0.0
|
|
}
|
|
|
|
response = requests.post(
|
|
api_endpoint,
|
|
json=request_data,
|
|
timeout=30
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
result_data = response.json()
|
|
|
|
if result_data.get("success"):
|
|
for res in result_data.get("results", []):
|
|
all_results.append({
|
|
'query': query,
|
|
'rank': res["rank"],
|
|
'content': res["content"],
|
|
'similarity_score': res["score"],
|
|
'file_path': embeddings_file
|
|
})
|
|
else:
|
|
print(f"搜索失败: {result_data.get('error', '未知错误')}")
|
|
else:
|
|
print(f"API 调用失败: {response.status_code} - {response.text}")
|
|
|
|
if not all_results:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "No matching results found"
|
|
}
|
|
]
|
|
}
|
|
|
|
# 按相似度分数排序所有结果
|
|
all_results.sort(key=lambda x: x['similarity_score'], reverse=True)
|
|
|
|
# 格式化输出
|
|
formatted_lines = []
|
|
formatted_lines.append(f"Found {len(all_results)} results for {len(queries)} queries:")
|
|
formatted_lines.append("")
|
|
|
|
for i, result in enumerate(all_results):
|
|
formatted_lines.append(f"#{i+1} [query: '{result['query']}'] [similarity:{result['similarity_score']:.4f}]: {result['content']}")
|
|
|
|
formatted_output = "\n".join(formatted_lines)
|
|
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": formatted_output
|
|
}
|
|
]
|
|
}
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"API request failed: {str(e)}"
|
|
}
|
|
]
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": f"Search error: {str(e)}"
|
|
}
|
|
]
|
|
}
|
|
|
|
|
|
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Handle MCP request"""
|
|
try:
|
|
method = request.get("method")
|
|
params = request.get("params", {})
|
|
request_id = request.get("id")
|
|
|
|
if method == "initialize":
|
|
return create_initialize_response(request_id, "semantic-search")
|
|
|
|
elif method == "ping":
|
|
return create_ping_response(request_id)
|
|
|
|
elif method == "tools/list":
|
|
# 从 JSON 文件加载工具定义
|
|
tools = load_tools_from_json("semantic_search_tools.json")
|
|
return create_tools_list_response(request_id, tools)
|
|
|
|
elif method == "tools/call":
|
|
tool_name = params.get("name")
|
|
arguments = params.get("arguments", {})
|
|
|
|
if tool_name == "semantic_search":
|
|
queries = arguments.get("queries", [])
|
|
# 兼容旧的query参数
|
|
if not queries and "query" in arguments:
|
|
queries = arguments.get("query", "")
|
|
embeddings_file = arguments.get("embeddings_file", "")
|
|
top_k = arguments.get("top_k", 20)
|
|
|
|
result = semantic_search(queries, embeddings_file, top_k)
|
|
|
|
return {
|
|
"jsonrpc": "2.0",
|
|
"id": request_id,
|
|
"result": result
|
|
}
|
|
|
|
else:
|
|
return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}")
|
|
|
|
else:
|
|
return create_error_response(request_id, -32601, f"Unknown method: {method}")
|
|
|
|
except Exception as e:
|
|
return create_error_response(request.get("id"), -32603, f"Internal error: {str(e)}")
|
|
|
|
|
|
async def main():
|
|
"""Main entry point."""
|
|
await handle_mcp_streaming(handle_request)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|