#!/usr/bin/env python3 """ 语义检索服务 支持高并发的语义搜索功能 """ import asyncio import time import logging from typing import List, Dict, Any, Optional, Tuple import numpy as np from .manager import get_cache_manager, get_model_manager, CacheEntry logger = logging.getLogger(__name__) class SemanticSearchService: """语义检索服务""" def __init__(self): self.cache_manager = get_cache_manager() self.model_manager = get_model_manager() async def semantic_search( self, embedding_file: str, query: str, top_k: int = 20, min_score: float = 0.0 ) -> Dict[str, Any]: """ 执行语义搜索 Args: embedding_file: embedding.pkl 文件路径 query: 查询关键词 top_k: 返回结果数量 min_score: 最小相似度阈值 Returns: 搜索结果 """ start_time = time.time() try: # 验证输入参数 if not embedding_file: return { "success": False, "error": "embedding_file 参数不能为空" } if not query or not query.strip(): return { "success": False, "error": "query 参数不能为空" } query = query.strip() # 加载 embedding 数据 cache_entry = await self.cache_manager.load_embedding_data(embedding_file) if cache_entry is None: return { "success": False, "error": f"无法加载 embedding 文件: {embedding_file}" } # 编码查询 query_embeddings = await self.model_manager.encode_texts([query], batch_size=1) if len(query_embeddings) == 0: return { "success": False, "error": "查询编码失败" } query_embedding = query_embeddings[0] # 计算相似度 similarities = self._compute_similarities(query_embedding, cache_entry.embeddings) # 获取 top_k 结果 top_results = self._get_top_results( similarities, cache_entry.chunks, top_k, min_score ) processing_time = time.time() - start_time return { "success": True, "query": query, "embedding_file": embedding_file, "top_k": top_k, "processing_time": processing_time, "chunking_strategy": cache_entry.chunking_strategy, "total_chunks": len(cache_entry.chunks), "results": top_results, "cache_stats": { "access_count": cache_entry.access_count, "load_time": cache_entry.load_time } } except Exception as e: logger.error(f"语义搜索失败: {embedding_file}, {query}, {e}") return { "success": False, "error": f"搜索失败: {str(e)}" } def _compute_similarities(self, query_embedding: np.ndarray, embeddings: np.ndarray) -> np.ndarray: """计算相似度分数""" try: # 确保数据类型和形状正确 if embeddings.ndim == 1: embeddings = embeddings.reshape(1, -1) # 使用余弦相似度 query_norm = np.linalg.norm(query_embedding) embeddings_norm = np.linalg.norm(embeddings, axis=1) # 避免除零错误 if query_norm == 0: query_norm = 1e-8 embeddings_norm = np.where(embeddings_norm == 0, 1e-8, embeddings_norm) # 计算余弦相似度 similarities = np.dot(embeddings, query_embedding) / (embeddings_norm * query_norm) # 确保结果在合理范围内 similarities = np.clip(similarities, -1.0, 1.0) return similarities except Exception as e: logger.error(f"相似度计算失败: {e}") return np.array([]) def _get_top_results( self, similarities: np.ndarray, chunks: List[str], top_k: int, min_score: float ) -> List[Dict[str, Any]]: """获取 top_k 结果""" try: if len(similarities) == 0: return [] # 获取排序后的索引 top_indices = np.argsort(-similarities)[:top_k] results = [] for rank, idx in enumerate(top_indices): score = similarities[idx] # 过滤低分结果 if score < min_score: continue chunk = chunks[idx] # 创建结果条目 result = { "rank": rank + 1, "score": float(score), "content": chunk, "content_preview": self._get_preview(chunk) } results.append(result) return results except Exception as e: logger.error(f"获取 top_k 结果失败: {e}") return [] def _get_preview(self, content: str, max_length: int = 200) -> str: """获取内容预览""" if not content: return "" # 清理内容 preview = content.replace('\n', ' ').strip() # 截断 if len(preview) > max_length: preview = preview[:max_length] + "..." return preview async def batch_search( self, requests: List[Dict[str, Any]] ) -> List[Dict[str, Any]]: """ 批量搜索 Args: requests: 搜索请求列表,每个请求包含 embedding_file, query, top_k Returns: 搜索结果列表 """ if not requests: return [] # 并发执行搜索 tasks = [] for req in requests: embedding_file = req.get("embedding_file", "") query = req.get("query", "") top_k = req.get("top_k", 20) min_score = req.get("min_score", 0.0) task = self.semantic_search(embedding_file, query, top_k, min_score) tasks.append(task) # 等待所有任务完成 results = await asyncio.gather(*tasks, return_exceptions=True) # 处理异常 processed_results = [] for i, result in enumerate(results): if isinstance(result, Exception): processed_results.append({ "success": False, "error": f"搜索异常: {str(result)}", "request_index": i }) else: result["request_index"] = i processed_results.append(result) return processed_results def get_service_stats(self) -> Dict[str, Any]: """获取服务统计信息""" return { "cache_manager_stats": self.cache_manager.get_cache_stats(), "model_manager_info": self.model_manager.get_model_info() } # 全局实例 _search_service = None def get_search_service() -> SemanticSearchService: """获取搜索服务实例""" global _search_service if _search_service is None: _search_service = SemanticSearchService() return _search_service