259 lines
7.9 KiB
Python
259 lines
7.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
语义检索服务
|
|
支持高并发的语义搜索功能
|
|
"""
|
|
|
|
import asyncio
|
|
import time
|
|
import logging
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
import numpy as np
|
|
|
|
from .manager import get_cache_manager, get_model_manager, CacheEntry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SemanticSearchService:
|
|
"""语义检索服务"""
|
|
|
|
def __init__(self):
|
|
self.cache_manager = get_cache_manager()
|
|
self.model_manager = get_model_manager()
|
|
|
|
async def semantic_search(
|
|
self,
|
|
embedding_file: str,
|
|
query: str,
|
|
top_k: int = 20,
|
|
min_score: float = 0.0
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
执行语义搜索
|
|
|
|
Args:
|
|
embedding_file: embedding.pkl 文件路径
|
|
query: 查询关键词
|
|
top_k: 返回结果数量
|
|
min_score: 最小相似度阈值
|
|
|
|
Returns:
|
|
搜索结果
|
|
"""
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# 验证输入参数
|
|
if not embedding_file:
|
|
return {
|
|
"success": False,
|
|
"error": "embedding_file 参数不能为空"
|
|
}
|
|
|
|
if not query or not query.strip():
|
|
return {
|
|
"success": False,
|
|
"error": "query 参数不能为空"
|
|
}
|
|
|
|
query = query.strip()
|
|
|
|
# 加载 embedding 数据
|
|
cache_entry = await self.cache_manager.load_embedding_data(embedding_file)
|
|
if cache_entry is None:
|
|
return {
|
|
"success": False,
|
|
"error": f"无法加载 embedding 文件: {embedding_file}"
|
|
}
|
|
|
|
# 编码查询
|
|
query_embeddings = await self.model_manager.encode_texts([query], batch_size=1)
|
|
if len(query_embeddings) == 0:
|
|
return {
|
|
"success": False,
|
|
"error": "查询编码失败"
|
|
}
|
|
|
|
query_embedding = query_embeddings[0]
|
|
|
|
# 计算相似度
|
|
similarities = self._compute_similarities(query_embedding, cache_entry.embeddings)
|
|
|
|
# 获取 top_k 结果
|
|
top_results = self._get_top_results(
|
|
similarities,
|
|
cache_entry.chunks,
|
|
top_k,
|
|
min_score
|
|
)
|
|
|
|
processing_time = time.time() - start_time
|
|
|
|
return {
|
|
"success": True,
|
|
"query": query,
|
|
"embedding_file": embedding_file,
|
|
"top_k": top_k,
|
|
"processing_time": processing_time,
|
|
"chunking_strategy": cache_entry.chunking_strategy,
|
|
"total_chunks": len(cache_entry.chunks),
|
|
"results": top_results,
|
|
"cache_stats": {
|
|
"access_count": cache_entry.access_count,
|
|
"load_time": cache_entry.load_time
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"语义搜索失败: {embedding_file}, {query}, {e}")
|
|
return {
|
|
"success": False,
|
|
"error": f"搜索失败: {str(e)}"
|
|
}
|
|
|
|
def _compute_similarities(self, query_embedding: np.ndarray, embeddings: np.ndarray) -> np.ndarray:
|
|
"""计算相似度分数"""
|
|
try:
|
|
# 确保数据类型和形状正确
|
|
if embeddings.ndim == 1:
|
|
embeddings = embeddings.reshape(1, -1)
|
|
|
|
# 使用余弦相似度
|
|
query_norm = np.linalg.norm(query_embedding)
|
|
embeddings_norm = np.linalg.norm(embeddings, axis=1)
|
|
|
|
# 避免除零错误
|
|
if query_norm == 0:
|
|
query_norm = 1e-8
|
|
embeddings_norm = np.where(embeddings_norm == 0, 1e-8, embeddings_norm)
|
|
|
|
# 计算余弦相似度
|
|
similarities = np.dot(embeddings, query_embedding) / (embeddings_norm * query_norm)
|
|
|
|
# 确保结果在合理范围内
|
|
similarities = np.clip(similarities, -1.0, 1.0)
|
|
|
|
return similarities
|
|
|
|
except Exception as e:
|
|
logger.error(f"相似度计算失败: {e}")
|
|
return np.array([])
|
|
|
|
def _get_top_results(
|
|
self,
|
|
similarities: np.ndarray,
|
|
chunks: List[str],
|
|
top_k: int,
|
|
min_score: float
|
|
) -> List[Dict[str, Any]]:
|
|
"""获取 top_k 结果"""
|
|
try:
|
|
if len(similarities) == 0:
|
|
return []
|
|
|
|
# 获取排序后的索引
|
|
top_indices = np.argsort(-similarities)[:top_k]
|
|
|
|
results = []
|
|
for rank, idx in enumerate(top_indices):
|
|
score = similarities[idx]
|
|
|
|
# 过滤低分结果
|
|
if score < min_score:
|
|
continue
|
|
|
|
chunk = chunks[idx]
|
|
|
|
# 创建结果条目
|
|
result = {
|
|
"rank": rank + 1,
|
|
"score": float(score),
|
|
"content": chunk,
|
|
"content_preview": self._get_preview(chunk)
|
|
}
|
|
|
|
results.append(result)
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取 top_k 结果失败: {e}")
|
|
return []
|
|
|
|
def _get_preview(self, content: str, max_length: int = 200) -> str:
|
|
"""获取内容预览"""
|
|
if not content:
|
|
return ""
|
|
|
|
# 清理内容
|
|
preview = content.replace('\n', ' ').strip()
|
|
|
|
# 截断
|
|
if len(preview) > max_length:
|
|
preview = preview[:max_length] + "..."
|
|
|
|
return preview
|
|
|
|
async def batch_search(
|
|
self,
|
|
requests: List[Dict[str, Any]]
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
批量搜索
|
|
|
|
Args:
|
|
requests: 搜索请求列表,每个请求包含 embedding_file, query, top_k
|
|
|
|
Returns:
|
|
搜索结果列表
|
|
"""
|
|
if not requests:
|
|
return []
|
|
|
|
# 并发执行搜索
|
|
tasks = []
|
|
for req in requests:
|
|
embedding_file = req.get("embedding_file", "")
|
|
query = req.get("query", "")
|
|
top_k = req.get("top_k", 20)
|
|
min_score = req.get("min_score", 0.0)
|
|
|
|
task = self.semantic_search(embedding_file, query, top_k, min_score)
|
|
tasks.append(task)
|
|
|
|
# 等待所有任务完成
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# 处理异常
|
|
processed_results = []
|
|
for i, result in enumerate(results):
|
|
if isinstance(result, Exception):
|
|
processed_results.append({
|
|
"success": False,
|
|
"error": f"搜索异常: {str(result)}",
|
|
"request_index": i
|
|
})
|
|
else:
|
|
result["request_index"] = i
|
|
processed_results.append(result)
|
|
|
|
return processed_results
|
|
|
|
def get_service_stats(self) -> Dict[str, Any]:
|
|
"""获取服务统计信息"""
|
|
return {
|
|
"cache_manager_stats": self.cache_manager.get_cache_stats(),
|
|
"model_manager_info": self.model_manager.get_model_info()
|
|
}
|
|
|
|
|
|
# 全局实例
|
|
_search_service = None
|
|
|
|
def get_search_service() -> SemanticSearchService:
|
|
"""获取搜索服务实例"""
|
|
global _search_service
|
|
if _search_service is None:
|
|
_search_service = SemanticSearchService()
|
|
return _search_service |