qwen_agent/mcp/semantic_search_server.py
2025-12-15 21:58:54 +08:00

249 lines
7.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
语义搜索MCP服务器
基于embedding向量进行语义相似度搜索
参考multi_keyword_search_server.py的实现方式
"""
import asyncio
import json
import os
import pickle
import sys
from typing import Any, Dict, List, Optional, Union
import numpy as np
from mcp_common import (
get_allowed_directory,
load_tools_from_json,
resolve_file_path,
find_file_in_project,
create_error_response,
create_success_response,
create_initialize_response,
create_ping_response,
create_tools_list_response,
handle_mcp_streaming
)
from utils.settings import FASTAPI_URL
import requests
def encode_query_via_api(query: str, fastapi_url: str = None) -> np.ndarray:
"""通过API编码单个查询"""
if not fastapi_url:
fastapi_url = FASTAPI_URL
api_endpoint = f"{fastapi_url}/api/v1/embedding/encode"
try:
# 调用编码接口
request_data = {
"texts": [query],
"batch_size": 1
}
response = requests.post(
api_endpoint,
json=request_data,
timeout=30
)
if response.status_code == 200:
result_data = response.json()
if result_data.get("success"):
embeddings_list = result_data.get("embeddings", [])
if embeddings_list:
return np.array(embeddings_list[0])
print(f"API编码失败: {response.status_code} - {response.text}")
return None
except Exception as e:
print(f"API编码异常: {e}")
return None
def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: int = 20) -> Dict[str, Any]:
"""执行语义搜索直接读取本地embedding文件并计算相似度"""
# 处理查询输入
if isinstance(queries, str):
queries = [queries]
# 验证查询列表
if not queries or not any(q.strip() for q in queries):
return {
"content": [
{
"type": "text",
"text": "Error: Queries cannot be empty"
}
]
}
# 过滤空查询
queries = [q.strip() for q in queries if q.strip()]
try:
# 解析embedding文件路径
resolved_embeddings_file = resolve_file_path(embeddings_file)
# 读取embedding文件
with open(resolved_embeddings_file, 'rb') as f:
embedding_data = pickle.load(f)
# 兼容新旧数据结构
if 'chunks' in embedding_data:
# 新的数据结构使用chunks
chunks = embedding_data['chunks']
chunk_embeddings = embedding_data['embeddings']
chunking_strategy = embedding_data.get('chunking_strategy', 'unknown')
else:
# 旧的数据结构使用sentences
chunks = embedding_data['sentences']
chunk_embeddings = embedding_data['embeddings']
chunking_strategy = 'line'
all_results = []
# 处理每个查询
for query in queries:
# 使用API编码查询
print(f"正在为查询编码: {query}")
query_embedding = encode_query_via_api(query)
if query_embedding is None:
print(f"查询编码失败: {query}")
continue
# 计算相似度
if len(chunk_embeddings.shape) > 1:
cos_scores = np.dot(chunk_embeddings, query_embedding) / (
np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_embedding) + 1e-8
)
else:
cos_scores = np.array([0.0] * len(chunks))
# 获取top_k结果
top_indices = np.argsort(-cos_scores)[:top_k]
for rank, idx in enumerate(top_indices):
score = cos_scores[idx]
if score > 0: # 只包含有一定相关性的结果
all_results.append({
'query': query,
'rank': rank + 1,
'content': chunks[idx],
'similarity_score': float(score),
'file_path': embeddings_file
})
if not all_results:
return {
"content": [
{
"type": "text",
"text": "No matching results found"
}
]
}
# 按相似度分数排序所有结果
all_results.sort(key=lambda x: x['similarity_score'], reverse=True)
# 格式化输出
formatted_lines = []
formatted_lines.append(f"Found {len(all_results)} results for {len(queries)} queries:")
formatted_lines.append("")
for i, result in enumerate(all_results):
formatted_lines.append(f"#{i+1} [query: '{result['query']}'] [similarity:{result['similarity_score']:.4f}]: {result['content']}")
formatted_output = "\n".join(formatted_lines)
return {
"content": [
{
"type": "text",
"text": formatted_output
}
]
}
except FileNotFoundError:
return {
"content": [
{
"type": "text",
"text": f"Error: Embeddings file not found: {embeddings_file}"
}
]
}
except Exception as e:
return {
"content": [
{
"type": "text",
"text": f"Search error: {str(e)}"
}
]
}
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""Handle MCP request"""
try:
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
if method == "initialize":
return create_initialize_response(request_id, "semantic-search")
elif method == "ping":
return create_ping_response(request_id)
elif method == "tools/list":
# 从 JSON 文件加载工具定义
tools = load_tools_from_json("semantic_search_tools.json")
return create_tools_list_response(request_id, tools)
elif method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
if tool_name == "semantic_search":
queries = arguments.get("queries", [])
# 兼容旧的query参数
if not queries and "query" in arguments:
queries = arguments.get("query", "")
embeddings_file = arguments.get("embeddings_file", "")
top_k = arguments.get("top_k", 20)
result = semantic_search(queries, embeddings_file, top_k)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
else:
return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}")
else:
return create_error_response(request_id, -32601, f"Unknown method: {method}")
except Exception as e:
return create_error_response(request.get("id"), -32603, f"Internal error: {str(e)}")
async def main():
"""Main entry point."""
await handle_mcp_streaming(handle_request)
if __name__ == "__main__":
asyncio.run(main())