qwen_agent/mcp/semantic_search_server.py
朱潮 425f3c5bb4 chore: replace Chinese comments and log messages with English
Convert all Chinese comments, docstrings, logger/print output,
HTTPException detail messages, and API response messages to English
across the entire codebase. Functional zh/ja localized strings
(e.g. prompt templates, timezone display names, date formats) are
preserved as-is.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-30 19:45:35 +08:00

249 lines
7.7 KiB
Python

#!/usr/bin/env python3
"""
Semantic search MCP server.
Performs semantic similarity search based on embedding vectors.
References the implementation approach in multi_keyword_search_server.py.
"""
import asyncio
import json
import os
import pickle
import sys
from typing import Any, Dict, List, Optional, Union
import numpy as np
from mcp_common import (
get_allowed_directory,
load_tools_from_json,
resolve_file_path,
find_file_in_project,
create_error_response,
create_success_response,
create_initialize_response,
create_ping_response,
create_tools_list_response,
handle_mcp_streaming
)
from utils.settings import FASTAPI_URL
import requests
def encode_query_via_api(query: str, fastapi_url: str = None) -> np.ndarray:
"""Encode a single query through the API."""
if not fastapi_url:
fastapi_url = FASTAPI_URL
api_endpoint = f"{fastapi_url}/api/v1/embedding/encode"
try:
# Call the encoding endpoint
request_data = {
"texts": [query],
"batch_size": 1
}
response = requests.post(
api_endpoint,
json=request_data,
timeout=30
)
if response.status_code == 200:
result_data = response.json()
if result_data.get("success"):
embeddings_list = result_data.get("embeddings", [])
if embeddings_list:
return np.array(embeddings_list[0])
print(f"API encoding failed: {response.status_code} - {response.text}")
return None
except Exception as e:
print(f"API encoding error: {e}")
return None
def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: int = 20) -> Dict[str, Any]:
"""Run semantic search by reading the local embeddings file and computing similarity."""
# Process query input
if isinstance(queries, str):
queries = [queries]
# Validate query list
if not queries or not any(q.strip() for q in queries):
return {
"content": [
{
"type": "text",
"text": "Error: Queries cannot be empty"
}
]
}
# Filter empty queries
queries = [q.strip() for q in queries if q.strip()]
try:
# Resolve embeddings file path
resolved_embeddings_file = resolve_file_path(embeddings_file)
# Read embeddings file
with open(resolved_embeddings_file, 'rb') as f:
embedding_data = pickle.load(f)
# Support both new and old data structures
if 'chunks' in embedding_data:
# New data structure (using chunks)
chunks = embedding_data['chunks']
chunk_embeddings = embedding_data['embeddings']
chunking_strategy = embedding_data.get('chunking_strategy', 'unknown')
else:
# Old data structure (using sentences)
chunks = embedding_data['sentences']
chunk_embeddings = embedding_data['embeddings']
chunking_strategy = 'line'
all_results = []
# Process each query
for query in queries:
# Encode query with API
print(f"Encoding query: {query}")
query_embedding = encode_query_via_api(query)
if query_embedding is None:
print(f"Query encoding failed: {query}")
continue
# Compute similarity
if len(chunk_embeddings.shape) > 1:
cos_scores = np.dot(chunk_embeddings, query_embedding) / (
np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_embedding) + 1e-8
)
else:
cos_scores = np.array([0.0] * len(chunks))
# Get top_k results
top_indices = np.argsort(-cos_scores)[:top_k]
for rank, idx in enumerate(top_indices):
score = cos_scores[idx]
if score > 0: # Only include results with some relevance
all_results.append({
'query': query,
'rank': rank + 1,
'content': chunks[idx],
'similarity_score': float(score),
'file_path': embeddings_file
})
if not all_results:
return {
"content": [
{
"type": "text",
"text": "No matching results found"
}
]
}
# Sort all results by similarity score
all_results.sort(key=lambda x: x['similarity_score'], reverse=True)
# Format output
formatted_lines = []
formatted_lines.append(f"Found {len(all_results)} results for {len(queries)} queries:")
formatted_lines.append("")
for i, result in enumerate(all_results):
formatted_lines.append(f"#{i+1} [query: '{result['query']}'] [similarity:{result['similarity_score']:.4f}]: {result['content']}")
formatted_output = "\n".join(formatted_lines)
return {
"content": [
{
"type": "text",
"text": formatted_output
}
]
}
except FileNotFoundError:
return {
"content": [
{
"type": "text",
"text": f"Error: Embeddings file not found: {embeddings_file}"
}
]
}
except Exception as e:
return {
"content": [
{
"type": "text",
"text": f"Search error: {str(e)}"
}
]
}
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""Handle MCP request"""
try:
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
if method == "initialize":
return create_initialize_response(request_id, "semantic-search")
elif method == "ping":
return create_ping_response(request_id)
elif method == "tools/list":
# Load tool definitions from JSON file
tools = load_tools_from_json("semantic_search_tools.json")
return create_tools_list_response(request_id, tools)
elif method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
if tool_name == "semantic_search":
queries = arguments.get("queries", [])
# Support the legacy query parameter
if not queries and "query" in arguments:
queries = arguments.get("query", "")
embeddings_file = arguments.get("embeddings_file", "")
top_k = arguments.get("top_k", 20)
result = semantic_search(queries, embeddings_file, top_k)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
else:
return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}")
else:
return create_error_response(request_id, -32601, f"Unknown method: {method}")
except Exception as e:
return create_error_response(request.get("id"), -32603, f"Internal error: {str(e)}")
async def main():
"""Main entry point."""
await handle_mcp_streaming(handle_request)
if __name__ == "__main__":
asyncio.run(main())