qwen_agent/embedding/search_service.py
2025-11-20 13:29:44 +08:00

259 lines
7.9 KiB
Python

#!/usr/bin/env python3
"""
语义检索服务
支持高并发的语义搜索功能
"""
import asyncio
import time
import logging
from typing import List, Dict, Any, Optional, Tuple
import numpy as np
from .manager import get_cache_manager, get_model_manager, CacheEntry
logger = logging.getLogger(__name__)
class SemanticSearchService:
"""语义检索服务"""
def __init__(self):
self.cache_manager = get_cache_manager()
self.model_manager = get_model_manager()
async def semantic_search(
self,
embedding_file: str,
query: str,
top_k: int = 20,
min_score: float = 0.0
) -> Dict[str, Any]:
"""
执行语义搜索
Args:
embedding_file: embedding.pkl 文件路径
query: 查询关键词
top_k: 返回结果数量
min_score: 最小相似度阈值
Returns:
搜索结果
"""
start_time = time.time()
try:
# 验证输入参数
if not embedding_file:
return {
"success": False,
"error": "embedding_file 参数不能为空"
}
if not query or not query.strip():
return {
"success": False,
"error": "query 参数不能为空"
}
query = query.strip()
# 加载 embedding 数据
cache_entry = await self.cache_manager.load_embedding_data(embedding_file)
if cache_entry is None:
return {
"success": False,
"error": f"无法加载 embedding 文件: {embedding_file}"
}
# 编码查询
query_embeddings = await self.model_manager.encode_texts([query], batch_size=1)
if len(query_embeddings) == 0:
return {
"success": False,
"error": "查询编码失败"
}
query_embedding = query_embeddings[0]
# 计算相似度
similarities = self._compute_similarities(query_embedding, cache_entry.embeddings)
# 获取 top_k 结果
top_results = self._get_top_results(
similarities,
cache_entry.chunks,
top_k,
min_score
)
processing_time = time.time() - start_time
return {
"success": True,
"query": query,
"embedding_file": embedding_file,
"top_k": top_k,
"processing_time": processing_time,
"chunking_strategy": cache_entry.chunking_strategy,
"total_chunks": len(cache_entry.chunks),
"results": top_results,
"cache_stats": {
"access_count": cache_entry.access_count,
"load_time": cache_entry.load_time
}
}
except Exception as e:
logger.error(f"语义搜索失败: {embedding_file}, {query}, {e}")
return {
"success": False,
"error": f"搜索失败: {str(e)}"
}
def _compute_similarities(self, query_embedding: np.ndarray, embeddings: np.ndarray) -> np.ndarray:
"""计算相似度分数"""
try:
# 确保数据类型和形状正确
if embeddings.ndim == 1:
embeddings = embeddings.reshape(1, -1)
# 使用余弦相似度
query_norm = np.linalg.norm(query_embedding)
embeddings_norm = np.linalg.norm(embeddings, axis=1)
# 避免除零错误
if query_norm == 0:
query_norm = 1e-8
embeddings_norm = np.where(embeddings_norm == 0, 1e-8, embeddings_norm)
# 计算余弦相似度
similarities = np.dot(embeddings, query_embedding) / (embeddings_norm * query_norm)
# 确保结果在合理范围内
similarities = np.clip(similarities, -1.0, 1.0)
return similarities
except Exception as e:
logger.error(f"相似度计算失败: {e}")
return np.array([])
def _get_top_results(
self,
similarities: np.ndarray,
chunks: List[str],
top_k: int,
min_score: float
) -> List[Dict[str, Any]]:
"""获取 top_k 结果"""
try:
if len(similarities) == 0:
return []
# 获取排序后的索引
top_indices = np.argsort(-similarities)[:top_k]
results = []
for rank, idx in enumerate(top_indices):
score = similarities[idx]
# 过滤低分结果
if score < min_score:
continue
chunk = chunks[idx]
# 创建结果条目
result = {
"rank": rank + 1,
"score": float(score),
"content": chunk,
"content_preview": self._get_preview(chunk)
}
results.append(result)
return results
except Exception as e:
logger.error(f"获取 top_k 结果失败: {e}")
return []
def _get_preview(self, content: str, max_length: int = 200) -> str:
"""获取内容预览"""
if not content:
return ""
# 清理内容
preview = content.replace('\n', ' ').strip()
# 截断
if len(preview) > max_length:
preview = preview[:max_length] + "..."
return preview
async def batch_search(
self,
requests: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
批量搜索
Args:
requests: 搜索请求列表,每个请求包含 embedding_file, query, top_k
Returns:
搜索结果列表
"""
if not requests:
return []
# 并发执行搜索
tasks = []
for req in requests:
embedding_file = req.get("embedding_file", "")
query = req.get("query", "")
top_k = req.get("top_k", 20)
min_score = req.get("min_score", 0.0)
task = self.semantic_search(embedding_file, query, top_k, min_score)
tasks.append(task)
# 等待所有任务完成
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理异常
processed_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
processed_results.append({
"success": False,
"error": f"搜索异常: {str(result)}",
"request_index": i
})
else:
result["request_index"] = i
processed_results.append(result)
return processed_results
def get_service_stats(self) -> Dict[str, Any]:
"""获取服务统计信息"""
return {
"cache_manager_stats": self.cache_manager.get_cache_stats(),
"model_manager_info": self.model_manager.get_model_info()
}
# 全局实例
_search_service = None
def get_search_service() -> SemanticSearchService:
"""获取搜索服务实例"""
global _search_service
if _search_service is None:
_search_service = SemanticSearchService()
return _search_service