From ef7cb7560f57786aba25b57daa5c8ebafe3abef6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Thu, 20 Nov 2025 20:54:38 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4model=5Fclient?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- embedding/__init__.py | 7 +- embedding/manager.py | 208 +---------------------------- embedding/model_client.py | 181 ------------------------- embedding/model_server.py | 240 --------------------------------- embedding/search_service.py | 259 ------------------------------------ fastapi_app.py | 6 +- 6 files changed, 6 insertions(+), 895 deletions(-) delete mode 100644 embedding/model_client.py delete mode 100644 embedding/model_server.py delete mode 100644 embedding/search_service.py diff --git a/embedding/__init__.py b/embedding/__init__.py index fe52395..957bb26 100644 --- a/embedding/__init__.py +++ b/embedding/__init__.py @@ -3,14 +3,11 @@ Embedding Package 提供文本编码和语义搜索功能 """ -from .manager import get_cache_manager, get_model_manager -from .search_service import get_search_service +from .manager import get_model_manager from .embedding import embed_document, split_document_by_pages __all__ = [ - 'get_cache_manager', 'get_model_manager', - 'get_search_service', 'embed_document', 'split_document_by_pages' -] \ No newline at end of file +] diff --git a/embedding/manager.py b/embedding/manager.py index ed1c2cd..25c4d28 100644 --- a/embedding/manager.py +++ b/embedding/manager.py @@ -22,203 +22,6 @@ from sentence_transformers import SentenceTransformer logger = logging.getLogger(__name__) -@dataclass -class CacheEntry: - """缓存条目""" - embeddings: np.ndarray - chunks: List[str] - chunking_strategy: str - chunking_params: Dict[str, Any] - model_path: str - file_path: str - file_mtime: float # 文件修改时间 - access_count: int # 访问次数 - last_access_time: float # 最后访问时间 - load_time: float # 加载时间 - memory_size: int # 占用内存大小(字节) - - -class EmbeddingCacheManager: - """Embedding 数据缓存管理器""" - - def __init__(self, max_cache_size: int = 5, max_memory_mb: int = 1024): - self.max_cache_size = max_cache_size # 最大缓存条目数 - self.max_memory_bytes = max_memory_mb * 1024 * 1024 # 最大内存使用量 - self._cache: OrderedDict[str, CacheEntry] = OrderedDict() - self._lock = threading.RLock() - self._current_memory_usage = 0 - - logger.info(f"EmbeddingCacheManager 初始化: max_size={max_cache_size}, max_memory={max_memory_mb}MB") - - def _get_file_key(self, file_path: str) -> str: - """生成文件缓存键""" - # 使用绝对路径和文件修改时间生成唯一键 - try: - abs_path = os.path.abspath(file_path) - mtime = os.path.getmtime(abs_path) - key_data = f"{abs_path}:{mtime}" - return hashlib.md5(key_data.encode()).hexdigest() - except Exception as e: - logger.warning(f"生成文件键失败: {file_path}, {e}") - return hashlib.md5(file_path.encode()).hexdigest() - - def _estimate_memory_size(self, embeddings: np.ndarray, chunks: List[str]) -> int: - """估算数据内存占用""" - try: - embeddings_size = embeddings.nbytes - chunks_size = sum(len(chunk.encode('utf-8')) for chunk in chunks) - overhead = 1024 * 1024 # 1MB 开销 - return embeddings_size + chunks_size + overhead - except Exception: - return 100 * 1024 * 1024 # 默认100MB - - def _cleanup_cache(self): - """清理缓存以释放内存""" - with self._lock: - # 按访问时间和次数排序,清理最少使用的条目 - entries = list(self._cache.items()) - - # 计算需要清理的条目 - to_remove = [] - current_size = len(self._cache) - - # 如果缓存条目超限,清理最老的条目 - if current_size > self.max_cache_size: - to_remove.extend(entries[:current_size - self.max_cache_size]) - - # 如果内存超限,按LRU策略清理 - if self._current_memory_usage > self.max_memory_bytes: - # 按最后访问时间排序 - entries.sort(key=lambda x: x[1].last_access_time) - - accumulated_size = 0 - for key, entry in entries: - if accumulated_size >= self._current_memory_usage - self.max_memory_bytes: - break - to_remove.append((key, entry)) - accumulated_size += entry.memory_size - - # 执行清理 - for key, entry in to_remove: - if key in self._cache: - del self._cache[key] - self._current_memory_usage -= entry.memory_size - logger.info(f"清理缓存条目: {key} ({entry.memory_size / 1024 / 1024:.1f}MB)") - - async def load_embedding_data(self, file_path: str) -> Optional[CacheEntry]: - """加载 embedding 数据""" - cache_key = self._get_file_key(file_path) - - # 检查缓存 - with self._lock: - if cache_key in self._cache: - entry = self._cache[cache_key] - entry.access_count += 1 - entry.last_access_time = time.time() - # 移动到末尾(最近使用) - self._cache.move_to_end(cache_key) - logger.info(f"缓存命中: {file_path}") - return entry - - # 缓存未命中,异步加载数据 - try: - start_time = time.time() - - # 检查文件是否存在 - if not os.path.exists(file_path): - logger.error(f"文件不存在: {file_path}") - return None - - # 加载 embedding 数据 - with open(file_path, 'rb') as f: - embedding_data = pickle.load(f) - - # 兼容新旧数据结构 - if 'chunks' in embedding_data: - chunks = embedding_data['chunks'] - embeddings = embedding_data['embeddings'] - chunking_strategy = embedding_data.get('chunking_strategy', 'unknown') - chunking_params = embedding_data.get('chunking_params', {}) - model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') - else: - chunks = embedding_data['sentences'] - embeddings = embedding_data['embeddings'] - chunking_strategy = 'line' - chunking_params = {} - model_path = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2' - - # 确保 embeddings 是 numpy 数组 - if hasattr(embeddings, 'cpu'): - embeddings = embeddings.cpu().numpy() - elif hasattr(embeddings, 'numpy'): - embeddings = embeddings.numpy() - elif not isinstance(embeddings, np.ndarray): - embeddings = np.array(embeddings) - - # 创建缓存条目 - load_time = time.time() - start_time - file_mtime = os.path.getmtime(file_path) - memory_size = self._estimate_memory_size(embeddings, chunks) - - entry = CacheEntry( - embeddings=embeddings, - chunks=chunks, - chunking_strategy=chunking_strategy, - chunking_params=chunking_params, - model_path=model_path, - file_path=file_path, - file_mtime=file_mtime, - access_count=1, - last_access_time=time.time(), - load_time=load_time, - memory_size=memory_size - ) - - # 添加到缓存 - with self._lock: - self._cache[cache_key] = entry - self._current_memory_usage += memory_size - - # 清理缓存 - self._cleanup_cache() - - logger.info(f"加载完成: {file_path} ({memory_size / 1024 / 1024:.1f}MB, {load_time:.2f}s)") - return entry - - except Exception as e: - logger.error(f"加载 embedding 数据失败: {file_path}, {e}") - return None - - def get_cache_stats(self) -> Dict[str, Any]: - """获取缓存统计信息""" - with self._lock: - return { - "cache_size": len(self._cache), - "max_cache_size": self.max_cache_size, - "memory_usage_mb": self._current_memory_usage / 1024 / 1024, - "max_memory_mb": self.max_memory_bytes / 1024 / 1024, - "memory_usage_percent": (self._current_memory_usage / self.max_memory_bytes) * 100, - "entries": [ - { - "file_path": entry.file_path, - "access_count": entry.access_count, - "last_access_time": entry.last_access_time, - "memory_size_mb": entry.memory_size / 1024 / 1024 - } - for entry in self._cache.values() - ] - } - - def clear_cache(self): - """清空缓存""" - with self._lock: - cleared_count = len(self._cache) - cleared_memory = self._current_memory_usage - self._cache.clear() - self._current_memory_usage = 0 - logger.info(f"清空缓存: {cleared_count} 个条目, {cleared_memory / 1024 / 1024:.1f}MB") - - class GlobalModelManager: """全局模型管理器""" @@ -312,17 +115,8 @@ class GlobalModelManager: # 全局实例 -_cache_manager = None _model_manager = None -def get_cache_manager() -> EmbeddingCacheManager: - """获取缓存管理器实例""" - global _cache_manager - if _cache_manager is None: - max_cache_size = int(os.getenv("EMBEDDING_MAX_CACHE_SIZE", "5")) - max_memory_mb = int(os.getenv("EMBEDDING_MAX_MEMORY_MB", "1024")) - _cache_manager = EmbeddingCacheManager(max_cache_size, max_memory_mb) - return _cache_manager def get_model_manager() -> GlobalModelManager: """获取模型管理器实例""" @@ -330,4 +124,4 @@ def get_model_manager() -> GlobalModelManager: if _model_manager is None: model_name = os.getenv("SENTENCE_TRANSFORMER_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") _model_manager = GlobalModelManager(model_name) - return _model_manager \ No newline at end of file + return _model_manager diff --git a/embedding/model_client.py b/embedding/model_client.py deleted file mode 100644 index d168236..0000000 --- a/embedding/model_client.py +++ /dev/null @@ -1,181 +0,0 @@ -#!/usr/bin/env python3 -""" -模型服务客户端 -提供与 model_server.py 交互的客户端接口 -""" - -import os -import asyncio -import logging -from typing import List, Union, Dict, Any -import numpy as np -import aiohttp -import json -from sentence_transformers import util - -logger = logging.getLogger(__name__) - -class ModelClient: - """模型服务客户端""" - - def __init__(self, base_url: str = "http://127.0.0.1:8000"): - self.base_url = base_url.rstrip('/') - self.session = None - - async def __aenter__(self): - self.session = aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=300) # 5分钟超时 - ) - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - if self.session: - await self.session.close() - - async def health_check(self) -> Dict[str, Any]: - """检查服务健康状态""" - try: - async with self.session.get(f"{self.base_url}/health") as response: - return await response.json() - except Exception as e: - logger.error(f"健康检查失败: {e}") - return {"status": "unavailable", "error": str(e)} - - async def load_model(self) -> bool: - """预加载模型""" - try: - async with self.session.post(f"{self.base_url}/load_model") as response: - result = await response.json() - return response.status == 200 - except Exception as e: - logger.error(f"模型加载失败: {e}") - return False - - async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray: - """编码文本为向量""" - if not texts: - return np.array([]) - - try: - data = { - "texts": texts, - "batch_size": batch_size - } - - async with self.session.post( - f"{self.base_url}/encode", - json=data - ) as response: - if response.status != 200: - error_text = await response.text() - raise Exception(f"编码失败: {error_text}") - - result = await response.json() - return np.array(result["embeddings"]) - - except Exception as e: - logger.error(f"文本编码失败: {e}") - raise - - async def compute_similarity(self, text1: str, text2: str) -> float: - """计算两个文本的相似度""" - try: - data = { - "text1": text1, - "text2": text2 - } - - async with self.session.post( - f"{self.base_url}/similarity", - json=data - ) as response: - if response.status != 200: - error_text = await response.text() - raise Exception(f"相似度计算失败: {error_text}") - - result = await response.json() - return result["similarity_score"] - - except Exception as e: - logger.error(f"相似度计算失败: {e}") - raise - -# 同步版本的客户端包装器 -class SyncModelClient: - """同步模型客户端""" - - def __init__(self, base_url: str = "http://127.0.0.1:8000"): - self.base_url = base_url - self._async_client = ModelClient(base_url) - - def _run_async(self, coro): - """运行异步协程""" - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - return loop.run_until_complete(coro) - - def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray: - """同步版本的文本编码""" - async def _encode(): - async with self._async_client as client: - return await client.encode_texts(texts, batch_size) - - return self._run_async(_encode()) - - def compute_similarity(self, text1: str, text2: str) -> float: - """同步版本的相似度计算""" - async def _similarity(): - async with self._async_client as client: - return await client.compute_similarity(text1, text2) - - return self._run_async(_similarity()) - - def health_check(self) -> Dict[str, Any]: - """同步版本的健康检查""" - async def _health(): - async with self._async_client as client: - return await client.health_check() - - return self._run_async(_health()) - - def load_model(self) -> bool: - """同步版本的模型加载""" - async def _load(): - async with self._async_client as client: - return await client.load_model() - - return self._run_async(_load()) - -# 全局客户端实例(单例模式) -_client_instance = None - -def get_model_client(base_url: str = None) -> SyncModelClient: - """获取模型客户端实例(单例)""" - global _client_instance - - if _client_instance is None: - url = base_url or os.environ.get('MODEL_SERVER_URL', 'http://127.0.0.1:8000') - _client_instance = SyncModelClient(url) - - return _client_instance - -# 便捷函数 -def encode_texts(texts: List[str], batch_size: int = 32) -> np.ndarray: - """便捷的文本编码函数""" - client = get_model_client() - return client.encode_texts(texts, batch_size) - -def compute_similarity(text1: str, text2: str) -> float: - """便捷的相似度计算函数""" - client = get_model_client() - return client.compute_similarity(text1, text2) - -def is_service_available() -> bool: - """检查模型服务是否可用""" - client = get_model_client() - health = client.health_check() - return health.get("status") == "running" or health.get("status") == "ready" \ No newline at end of file diff --git a/embedding/model_server.py b/embedding/model_server.py deleted file mode 100644 index df298f8..0000000 --- a/embedding/model_server.py +++ /dev/null @@ -1,240 +0,0 @@ -#!/usr/bin/env python3 -""" -独立的模型服务器 -提供统一的 embedding 服务,减少内存占用 -""" - -import os -import asyncio -import logging -from typing import List, Dict, Any -import numpy as np -from sentence_transformers import SentenceTransformer -from fastapi import FastAPI, HTTPException -from fastapi.middleware.cors import CORSMiddleware -import uvicorn -from pydantic import BaseModel - -# 配置日志 -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -# 全局模型实例 -model = None -model_config = { - "model_name": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", - "local_model_path": "./models/paraphrase-multilingual-MiniLM-L12-v2", - "device": "cpu" -} - -# Pydantic 模型 -class TextRequest(BaseModel): - texts: List[str] - batch_size: int = 32 - -class SimilarityRequest(BaseModel): - text1: str - text2: str - -class HealthResponse(BaseModel): - status: str - model_loaded: bool - model_path: str - device: str - -class EmbeddingResponse(BaseModel): - embeddings: List[List[float]] - shape: List[int] - -class SimilarityResponse(BaseModel): - similarity_score: float - -class ModelServer: - def __init__(self): - self.model = None - self.model_loaded = False - - async def load_model(self): - """延迟加载模型""" - if self.model_loaded: - return - - logger.info("正在加载模型...") - - # 检查本地模型是否存在 - if os.path.exists(model_config["local_model_path"]): - model_path = model_config["local_model_path"] - logger.info(f"使用本地模型: {model_path}") - else: - model_path = model_config["model_name"] - logger.info(f"使用 HuggingFace 模型: {model_path}") - - # 从环境变量获取设备配置 - device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', model_config["device"]) - if device not in ['cpu', 'cuda', 'mps']: - logger.warning(f"不支持的设备类型 '{device}',使用默认 CPU") - device = 'cpu' - - logger.info(f"使用设备: {device}") - - try: - self.model = SentenceTransformer(model_path, device=device) - self.model_loaded = True - model_config["device"] = device - model_config["current_model_path"] = model_path - logger.info("模型加载完成") - except Exception as e: - logger.error(f"模型加载失败: {e}") - raise HTTPException(status_code=500, detail=f"模型加载失败: {str(e)}") - - async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray: - """编码文本为向量""" - if not texts: - return np.array([]) - - await self.load_model() - - try: - embeddings = self.model.encode( - texts, - batch_size=batch_size, - convert_to_tensor=True, - show_progress_bar=False - ) - - # 转换为 CPU numpy 数组 - if embeddings.is_cuda: - embeddings = embeddings.cpu().numpy() - else: - embeddings = embeddings.numpy() - - return embeddings - except Exception as e: - logger.error(f"编码失败: {e}") - raise HTTPException(status_code=500, detail=f"编码失败: {str(e)}") - - async def compute_similarity(self, text1: str, text2: str) -> float: - """计算两个文本的相似度""" - if not text1 or not text2: - raise HTTPException(status_code=400, detail="文本不能为空") - - await self.load_model() - - try: - embeddings = self.model.encode([text1, text2], convert_to_tensor=True) - from sentence_transformers import util - similarity = util.cos_sim(embeddings[0:1], embeddings[1:2])[0][0] - return similarity.item() - except Exception as e: - logger.error(f"相似度计算失败: {e}") - raise HTTPException(status_code=500, detail=f"相似度计算失败: {str(e)}") - -# 创建服务器实例 -server = ModelServer() - -# 创建 FastAPI 应用 -app = FastAPI( - title="Embedding Model Server", - description="统一的句子嵌入和相似度计算服务", - version="1.0.0" -) - -# 添加 CORS 中间件 -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# 健康检查 -@app.get("/health", response_model=HealthResponse) -async def health_check(): - """健康检查""" - try: - if not server.model_loaded: - return HealthResponse( - status="ready", - model_loaded=False, - model_path="", - device=model_config["device"] - ) - - return HealthResponse( - status="running", - model_loaded=True, - model_path=model_config.get("current_model_path", ""), - device=model_config["device"] - ) - except Exception as e: - logger.error(f"健康检查失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -# 预加载模型 -@app.post("/load_model") -async def preload_model(): - """预加载模型""" - await server.load_model() - return {"message": "模型加载完成"} - -# 文本嵌入 -@app.post("/encode", response_model=EmbeddingResponse) -async def encode_texts(request: TextRequest): - """编码文本为向量""" - try: - embeddings = await server.encode_texts(request.texts, request.batch_size) - return EmbeddingResponse( - embeddings=embeddings.tolist(), - shape=list(embeddings.shape) - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - -# 相似度计算 -@app.post("/similarity", response_model=SimilarityResponse) -async def compute_similarity(request: SimilarityRequest): - """计算两个文本的相似度""" - try: - similarity = await server.compute_similarity(request.text1, request.text2) - return SimilarityResponse(similarity_score=similarity) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - -# 批量编码(用于文档处理) -@app.post("/batch_encode") -async def batch_encode(texts: List[str], batch_size: int = 64): - """批量编码大量文本""" - if not texts: - return {"embeddings": [], "shape": [0, 0]} - - try: - embeddings = await server.encode_texts(texts, batch_size) - return { - "embeddings": embeddings.tolist(), - "shape": list(embeddings.shape) - } - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="启动嵌入模型服务器") - parser.add_argument("--host", default="127.0.0.1", help="服务器地址") - parser.add_argument("--port", type=int, default=8000, help="服务器端口") - parser.add_argument("--workers", type=int, default=1, help="工作进程数") - parser.add_argument("--reload", action="store_true", help="开发模式自动重载") - - args = parser.parse_args() - - logger.info(f"启动模型服务器: http://{args.host}:{args.port}") - - uvicorn.run( - "model_server:app", - host=args.host, - port=args.port, - workers=args.workers, - reload=args.reload, - log_level="info" - ) \ No newline at end of file diff --git a/embedding/search_service.py b/embedding/search_service.py deleted file mode 100644 index a2951ee..0000000 --- a/embedding/search_service.py +++ /dev/null @@ -1,259 +0,0 @@ -#!/usr/bin/env python3 -""" -语义检索服务 -支持高并发的语义搜索功能 -""" - -import asyncio -import time -import logging -from typing import List, Dict, Any, Optional, Tuple -import numpy as np - -from .manager import get_cache_manager, get_model_manager, CacheEntry - -logger = logging.getLogger(__name__) - - -class SemanticSearchService: - """语义检索服务""" - - def __init__(self): - self.cache_manager = get_cache_manager() - self.model_manager = get_model_manager() - - async def semantic_search( - self, - embedding_file: str, - query: str, - top_k: int = 20, - min_score: float = 0.0 - ) -> Dict[str, Any]: - """ - 执行语义搜索 - - Args: - embedding_file: embedding.pkl 文件路径 - query: 查询关键词 - top_k: 返回结果数量 - min_score: 最小相似度阈值 - - Returns: - 搜索结果 - """ - start_time = time.time() - - try: - # 验证输入参数 - if not embedding_file: - return { - "success": False, - "error": "embedding_file 参数不能为空" - } - - if not query or not query.strip(): - return { - "success": False, - "error": "query 参数不能为空" - } - - query = query.strip() - - # 加载 embedding 数据 - cache_entry = await self.cache_manager.load_embedding_data(embedding_file) - if cache_entry is None: - return { - "success": False, - "error": f"无法加载 embedding 文件: {embedding_file}" - } - - # 编码查询 - query_embeddings = await self.model_manager.encode_texts([query], batch_size=1) - if len(query_embeddings) == 0: - return { - "success": False, - "error": "查询编码失败" - } - - query_embedding = query_embeddings[0] - - # 计算相似度 - similarities = self._compute_similarities(query_embedding, cache_entry.embeddings) - - # 获取 top_k 结果 - top_results = self._get_top_results( - similarities, - cache_entry.chunks, - top_k, - min_score - ) - - processing_time = time.time() - start_time - - return { - "success": True, - "query": query, - "embedding_file": embedding_file, - "top_k": top_k, - "processing_time": processing_time, - "chunking_strategy": cache_entry.chunking_strategy, - "total_chunks": len(cache_entry.chunks), - "results": top_results, - "cache_stats": { - "access_count": cache_entry.access_count, - "load_time": cache_entry.load_time - } - } - - except Exception as e: - logger.error(f"语义搜索失败: {embedding_file}, {query}, {e}") - return { - "success": False, - "error": f"搜索失败: {str(e)}" - } - - def _compute_similarities(self, query_embedding: np.ndarray, embeddings: np.ndarray) -> np.ndarray: - """计算相似度分数""" - try: - # 确保数据类型和形状正确 - if embeddings.ndim == 1: - embeddings = embeddings.reshape(1, -1) - - # 使用余弦相似度 - query_norm = np.linalg.norm(query_embedding) - embeddings_norm = np.linalg.norm(embeddings, axis=1) - - # 避免除零错误 - if query_norm == 0: - query_norm = 1e-8 - embeddings_norm = np.where(embeddings_norm == 0, 1e-8, embeddings_norm) - - # 计算余弦相似度 - similarities = np.dot(embeddings, query_embedding) / (embeddings_norm * query_norm) - - # 确保结果在合理范围内 - similarities = np.clip(similarities, -1.0, 1.0) - - return similarities - - except Exception as e: - logger.error(f"相似度计算失败: {e}") - return np.array([]) - - def _get_top_results( - self, - similarities: np.ndarray, - chunks: List[str], - top_k: int, - min_score: float - ) -> List[Dict[str, Any]]: - """获取 top_k 结果""" - try: - if len(similarities) == 0: - return [] - - # 获取排序后的索引 - top_indices = np.argsort(-similarities)[:top_k] - - results = [] - for rank, idx in enumerate(top_indices): - score = similarities[idx] - - # 过滤低分结果 - if score < min_score: - continue - - chunk = chunks[idx] - - # 创建结果条目 - result = { - "rank": rank + 1, - "score": float(score), - "content": chunk, - "content_preview": self._get_preview(chunk) - } - - results.append(result) - - return results - - except Exception as e: - logger.error(f"获取 top_k 结果失败: {e}") - return [] - - def _get_preview(self, content: str, max_length: int = 200) -> str: - """获取内容预览""" - if not content: - return "" - - # 清理内容 - preview = content.replace('\n', ' ').strip() - - # 截断 - if len(preview) > max_length: - preview = preview[:max_length] + "..." - - return preview - - async def batch_search( - self, - requests: List[Dict[str, Any]] - ) -> List[Dict[str, Any]]: - """ - 批量搜索 - - Args: - requests: 搜索请求列表,每个请求包含 embedding_file, query, top_k - - Returns: - 搜索结果列表 - """ - if not requests: - return [] - - # 并发执行搜索 - tasks = [] - for req in requests: - embedding_file = req.get("embedding_file", "") - query = req.get("query", "") - top_k = req.get("top_k", 20) - min_score = req.get("min_score", 0.0) - - task = self.semantic_search(embedding_file, query, top_k, min_score) - tasks.append(task) - - # 等待所有任务完成 - results = await asyncio.gather(*tasks, return_exceptions=True) - - # 处理异常 - processed_results = [] - for i, result in enumerate(results): - if isinstance(result, Exception): - processed_results.append({ - "success": False, - "error": f"搜索异常: {str(result)}", - "request_index": i - }) - else: - result["request_index"] = i - processed_results.append(result) - - return processed_results - - def get_service_stats(self) -> Dict[str, Any]: - """获取服务统计信息""" - return { - "cache_manager_stats": self.cache_manager.get_cache_stats(), - "model_manager_info": self.model_manager.get_model_info() - } - - -# 全局实例 -_search_service = None - -def get_search_service() -> SemanticSearchService: - """获取搜索服务实例""" - global _search_service - if _search_service is None: - _search_service = SemanticSearchService() - return _search_service \ No newline at end of file diff --git a/fastapi_app.py b/fastapi_app.py index 4620ddf..d095b80 100644 --- a/fastapi_app.py +++ b/fastapi_app.py @@ -24,7 +24,7 @@ from qwen_agent.llm.schema import ASSISTANT, FUNCTION from pydantic import BaseModel, Field # 导入语义检索服务 -from embedding import get_search_service +from embedding import get_model_manager # Import utility modules from utils import ( @@ -1572,7 +1572,7 @@ async def encode_texts(request: EncodeRequest): 编码结果 """ try: - search_service = get_search_service() + model_manager = get_model_manager() if not request.texts: return EncodeResponse( @@ -1587,7 +1587,7 @@ async def encode_texts(request: EncodeRequest): start_time = time.time() # 使用模型管理器编码文本 - embeddings = await search_service.model_manager.encode_texts( + embeddings = await model_manager.encode_texts( request.texts, batch_size=request.batch_size )