From b9f6928b5000ce3fd7357dd50bcdba3f513d76e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Thu, 20 Nov 2025 13:29:44 +0800 Subject: [PATCH] =?UTF-8?q?embedding=20=E6=A8=A1=E5=9E=8B=E7=8B=AC?= =?UTF-8?q?=E7=AB=8B=E4=B8=BAapi?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- embedding/__init__.py | 17 ++ embedding/embedding.py | 120 ++++++----- embedding/manager.py | 333 ++++++++++++++++++++++++++++++ embedding/model_client.py | 181 ++++++++++++++++ embedding/model_server.py | 240 ++++++++++++++++++++++ embedding/search_service.py | 259 +++++++++++++++++++++++ fastapi_app.py | 251 ++++++++++++++++++++++ mcp/semantic_search_server.py | 165 ++++----------- prompt/wowtalk.md | 377 ++++++++++------------------------ 9 files changed, 1493 insertions(+), 450 deletions(-) create mode 100644 embedding/__init__.py create mode 100644 embedding/manager.py create mode 100644 embedding/model_client.py create mode 100644 embedding/model_server.py create mode 100644 embedding/search_service.py diff --git a/embedding/__init__.py b/embedding/__init__.py new file mode 100644 index 0000000..f465e5f --- /dev/null +++ b/embedding/__init__.py @@ -0,0 +1,17 @@ +""" +Embedding Package +提供文本编码和语义搜索功能 +""" + +from .manager import get_cache_manager, get_model_manager +from .search_service import get_search_service +from .embedding import embed_document, semantic_search, split_document_by_pages + +__all__ = [ + 'get_cache_manager', + 'get_model_manager', + 'get_search_service', + 'embed_document', + 'semantic_search', + 'split_document_by_pages' +] \ No newline at end of file diff --git a/embedding/embedding.py b/embedding/embedding.py index 8e73780..2472ec4 100644 --- a/embedding/embedding.py +++ b/embedding/embedding.py @@ -1,42 +1,51 @@ import pickle import re import numpy as np -from sentence_transformers import SentenceTransformer, util +import os +from typing import Optional +import requests +import asyncio -# 延迟加载模型 -embedder = None - -def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'): - """获取模型实例(延迟加载) +def encode_texts_via_api(texts, batch_size=32): + """通过 API 接口编码文本""" + if not texts: + return np.array([]) - Args: - model_name_or_path (str): 模型名称或本地路径 - - 可以是 HuggingFace 模型名称 - - 可以是本地模型路径 - """ - global embedder - if embedder is None: - print("正在加载模型...") - print(f"模型路径: {model_name_or_path}") + try: + # FastAPI 服务地址 + fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001') + api_endpoint = f"{fastapi_url}/api/v1/embedding/encode" - # 检查是否是本地路径 - import os - if os.path.exists(model_name_or_path): - print("使用本地模型") + # 调用编码接口 + request_data = { + "texts": texts, + "batch_size": batch_size + } + + response = requests.post( + api_endpoint, + json=request_data, + timeout=60 # 增加超时时间 + ) + + if response.status_code == 200: + result_data = response.json() + + if result_data.get("success"): + embeddings_list = result_data.get("embeddings", []) + print(f"API编码成功,处理了 {len(texts)} 个文本,embedding维度: {len(embeddings_list[0]) if embeddings_list else 0}") + return np.array(embeddings_list) + else: + error_msg = result_data.get('error', '未知错误') + print(f"API编码失败: {error_msg}") + raise Exception(f"API编码失败: {error_msg}") else: - print("使用 HuggingFace 模型") - - # 从环境变量获取设备配置,默认为 CPU - device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu') - if device not in ['cpu', 'cuda', 'mps']: - print(f"警告: 不支持的设备类型 '{device}',使用默认 CPU") - device = 'cpu' - - print(f"使用设备: {device}") - embedder = SentenceTransformer(model_name_or_path, device=device) - - print("模型加载完成") - return embedder + print(f"API请求失败: {response.status_code} - {response.text}") + raise Exception(f"API请求失败: {response.status_code}") + + except Exception as e: + print(f"API编码异常: {e}") + raise def clean_text(text): """ @@ -192,23 +201,16 @@ def embed_document(input_file='document.txt', output_file='embedding.pkl', print(f"正在处理 {len(chunks)} 个内容块...") - # 在函数内部设置模型路径 - local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2" - import os - if os.path.exists(local_model_path): - model_path = local_model_path - else: - model_path = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2' - - model = get_model(model_path) - chunk_embeddings = model.encode(chunks, convert_to_tensor=True) + # 使用API接口进行编码 + print("使用API接口进行编码...") + chunk_embeddings = encode_texts_via_api(chunks, batch_size=32) embedding_data = { 'chunks': chunks, 'embeddings': chunk_embeddings, 'chunking_strategy': chunking_strategy, 'chunking_params': chunking_params, - 'model_path': model_path + 'model_path': 'api_service' } with open(output_file, 'wb') as f: @@ -254,18 +256,28 @@ def semantic_search(user_query, embeddings_file='embedding.pkl', top_k=20): chunking_strategy = 'line' content_type = "句子" - # 从embedding_data中获取模型路径(如果有的话) - model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') - model = get_model(model_path) - query_embedding = model.encode(user_query, convert_to_tensor=True) + # 使用API接口进行编码 + print("使用API接口进行查询编码...") + query_embeddings = encode_texts_via_api([user_query], batch_size=1) + query_embedding = query_embeddings[0] if len(query_embeddings) > 0 else np.array([]) - cos_scores = util.cos_sim(query_embedding, chunk_embeddings)[0] - - # 处理 GPU/CPU 环境下的 tensor 转换 - if cos_scores.is_cuda: - cos_scores_np = cos_scores.cpu().numpy() + # 计算相似度 + if len(chunk_embeddings.shape) > 1: + cos_scores = np.dot(chunk_embeddings, query_embedding) / ( + np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_embedding) + 1e-8 + ) else: - cos_scores_np = cos_scores.numpy() + cos_scores = [0.0] # 兼容性处理 + + # 处理不同格式下的 cos_scores + if isinstance(cos_scores, np.ndarray): + cos_scores_np = cos_scores + else: + # PyTorch tensor + if hasattr(cos_scores, 'is_cuda') and cos_scores.is_cuda: + cos_scores_np = cos_scores.cpu().numpy() + else: + cos_scores_np = cos_scores.numpy() top_results = np.argsort(-cos_scores_np)[:top_k] @@ -273,7 +285,7 @@ def semantic_search(user_query, embeddings_file='embedding.pkl', top_k=20): print(f"\n与查询最相关的 {top_k} 个{content_type} (分块策略: {chunking_strategy}):") for i, idx in enumerate(top_results): chunk = chunks[idx] - score = cos_scores[idx].item() + score = cos_scores_np[idx] results.append((chunk, score)) # 显示内容预览(如果内容太长) preview = chunk[:100] + "..." if len(chunk) > 100 else chunk diff --git a/embedding/manager.py b/embedding/manager.py new file mode 100644 index 0000000..ed1c2cd --- /dev/null +++ b/embedding/manager.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 +""" +模型池管理器和缓存系统 +支持高并发的 embedding 检索服务 +""" + +import os +import asyncio +import time +import pickle +import hashlib +import logging +from typing import Dict, List, Optional, Any, Tuple +from dataclasses import dataclass +from collections import OrderedDict +import threading +import psutil +import numpy as np + +from sentence_transformers import SentenceTransformer + +logger = logging.getLogger(__name__) + + +@dataclass +class CacheEntry: + """缓存条目""" + embeddings: np.ndarray + chunks: List[str] + chunking_strategy: str + chunking_params: Dict[str, Any] + model_path: str + file_path: str + file_mtime: float # 文件修改时间 + access_count: int # 访问次数 + last_access_time: float # 最后访问时间 + load_time: float # 加载时间 + memory_size: int # 占用内存大小(字节) + + +class EmbeddingCacheManager: + """Embedding 数据缓存管理器""" + + def __init__(self, max_cache_size: int = 5, max_memory_mb: int = 1024): + self.max_cache_size = max_cache_size # 最大缓存条目数 + self.max_memory_bytes = max_memory_mb * 1024 * 1024 # 最大内存使用量 + self._cache: OrderedDict[str, CacheEntry] = OrderedDict() + self._lock = threading.RLock() + self._current_memory_usage = 0 + + logger.info(f"EmbeddingCacheManager 初始化: max_size={max_cache_size}, max_memory={max_memory_mb}MB") + + def _get_file_key(self, file_path: str) -> str: + """生成文件缓存键""" + # 使用绝对路径和文件修改时间生成唯一键 + try: + abs_path = os.path.abspath(file_path) + mtime = os.path.getmtime(abs_path) + key_data = f"{abs_path}:{mtime}" + return hashlib.md5(key_data.encode()).hexdigest() + except Exception as e: + logger.warning(f"生成文件键失败: {file_path}, {e}") + return hashlib.md5(file_path.encode()).hexdigest() + + def _estimate_memory_size(self, embeddings: np.ndarray, chunks: List[str]) -> int: + """估算数据内存占用""" + try: + embeddings_size = embeddings.nbytes + chunks_size = sum(len(chunk.encode('utf-8')) for chunk in chunks) + overhead = 1024 * 1024 # 1MB 开销 + return embeddings_size + chunks_size + overhead + except Exception: + return 100 * 1024 * 1024 # 默认100MB + + def _cleanup_cache(self): + """清理缓存以释放内存""" + with self._lock: + # 按访问时间和次数排序,清理最少使用的条目 + entries = list(self._cache.items()) + + # 计算需要清理的条目 + to_remove = [] + current_size = len(self._cache) + + # 如果缓存条目超限,清理最老的条目 + if current_size > self.max_cache_size: + to_remove.extend(entries[:current_size - self.max_cache_size]) + + # 如果内存超限,按LRU策略清理 + if self._current_memory_usage > self.max_memory_bytes: + # 按最后访问时间排序 + entries.sort(key=lambda x: x[1].last_access_time) + + accumulated_size = 0 + for key, entry in entries: + if accumulated_size >= self._current_memory_usage - self.max_memory_bytes: + break + to_remove.append((key, entry)) + accumulated_size += entry.memory_size + + # 执行清理 + for key, entry in to_remove: + if key in self._cache: + del self._cache[key] + self._current_memory_usage -= entry.memory_size + logger.info(f"清理缓存条目: {key} ({entry.memory_size / 1024 / 1024:.1f}MB)") + + async def load_embedding_data(self, file_path: str) -> Optional[CacheEntry]: + """加载 embedding 数据""" + cache_key = self._get_file_key(file_path) + + # 检查缓存 + with self._lock: + if cache_key in self._cache: + entry = self._cache[cache_key] + entry.access_count += 1 + entry.last_access_time = time.time() + # 移动到末尾(最近使用) + self._cache.move_to_end(cache_key) + logger.info(f"缓存命中: {file_path}") + return entry + + # 缓存未命中,异步加载数据 + try: + start_time = time.time() + + # 检查文件是否存在 + if not os.path.exists(file_path): + logger.error(f"文件不存在: {file_path}") + return None + + # 加载 embedding 数据 + with open(file_path, 'rb') as f: + embedding_data = pickle.load(f) + + # 兼容新旧数据结构 + if 'chunks' in embedding_data: + chunks = embedding_data['chunks'] + embeddings = embedding_data['embeddings'] + chunking_strategy = embedding_data.get('chunking_strategy', 'unknown') + chunking_params = embedding_data.get('chunking_params', {}) + model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') + else: + chunks = embedding_data['sentences'] + embeddings = embedding_data['embeddings'] + chunking_strategy = 'line' + chunking_params = {} + model_path = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2' + + # 确保 embeddings 是 numpy 数组 + if hasattr(embeddings, 'cpu'): + embeddings = embeddings.cpu().numpy() + elif hasattr(embeddings, 'numpy'): + embeddings = embeddings.numpy() + elif not isinstance(embeddings, np.ndarray): + embeddings = np.array(embeddings) + + # 创建缓存条目 + load_time = time.time() - start_time + file_mtime = os.path.getmtime(file_path) + memory_size = self._estimate_memory_size(embeddings, chunks) + + entry = CacheEntry( + embeddings=embeddings, + chunks=chunks, + chunking_strategy=chunking_strategy, + chunking_params=chunking_params, + model_path=model_path, + file_path=file_path, + file_mtime=file_mtime, + access_count=1, + last_access_time=time.time(), + load_time=load_time, + memory_size=memory_size + ) + + # 添加到缓存 + with self._lock: + self._cache[cache_key] = entry + self._current_memory_usage += memory_size + + # 清理缓存 + self._cleanup_cache() + + logger.info(f"加载完成: {file_path} ({memory_size / 1024 / 1024:.1f}MB, {load_time:.2f}s)") + return entry + + except Exception as e: + logger.error(f"加载 embedding 数据失败: {file_path}, {e}") + return None + + def get_cache_stats(self) -> Dict[str, Any]: + """获取缓存统计信息""" + with self._lock: + return { + "cache_size": len(self._cache), + "max_cache_size": self.max_cache_size, + "memory_usage_mb": self._current_memory_usage / 1024 / 1024, + "max_memory_mb": self.max_memory_bytes / 1024 / 1024, + "memory_usage_percent": (self._current_memory_usage / self.max_memory_bytes) * 100, + "entries": [ + { + "file_path": entry.file_path, + "access_count": entry.access_count, + "last_access_time": entry.last_access_time, + "memory_size_mb": entry.memory_size / 1024 / 1024 + } + for entry in self._cache.values() + ] + } + + def clear_cache(self): + """清空缓存""" + with self._lock: + cleared_count = len(self._cache) + cleared_memory = self._current_memory_usage + self._cache.clear() + self._current_memory_usage = 0 + logger.info(f"清空缓存: {cleared_count} 个条目, {cleared_memory / 1024 / 1024:.1f}MB") + + +class GlobalModelManager: + """全局模型管理器""" + + def __init__(self, model_name: str = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'): + self.model_name = model_name + self.local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2" + self._model: Optional[SentenceTransformer] = None + self._lock = asyncio.Lock() + self._load_time = 0 + self._device = 'cpu' + + logger.info(f"GlobalModelManager 初始化: {model_name}") + + async def get_model(self) -> SentenceTransformer: + """获取模型实例(延迟加载)""" + if self._model is not None: + return self._model + + async with self._lock: + # 双重检查 + if self._model is not None: + return self._model + + try: + start_time = time.time() + + # 检查本地模型 + model_path = self.local_model_path if os.path.exists(self.local_model_path) else self.model_name + + # 获取设备配置 + self._device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu') + if self._device not in ['cpu', 'cuda', 'mps']: + self._device = 'cpu' + + logger.info(f"加载模型: {model_path} (device: {self._device})") + + # 在事件循环中运行阻塞操作 + loop = asyncio.get_event_loop() + self._model = await loop.run_in_executor( + None, + lambda: SentenceTransformer(model_path, device=self._device) + ) + + self._load_time = time.time() - start_time + logger.info(f"模型加载完成: {self._load_time:.2f}s") + + return self._model + + except Exception as e: + logger.error(f"模型加载失败: {e}") + raise + + async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray: + """编码文本为向量""" + if not texts: + return np.array([]) + + model = await self.get_model() + + try: + # 在事件循环中运行阻塞操作 + loop = asyncio.get_event_loop() + embeddings = await loop.run_in_executor( + None, + lambda: model.encode(texts, batch_size=batch_size, show_progress_bar=False) + ) + + # 确保返回 numpy 数组 + if hasattr(embeddings, 'cpu'): + embeddings = embeddings.cpu().numpy() + elif hasattr(embeddings, 'numpy'): + embeddings = embeddings.numpy() + elif not isinstance(embeddings, np.ndarray): + embeddings = np.array(embeddings) + + return embeddings + + except Exception as e: + logger.error(f"文本编码失败: {e}") + raise + + def get_model_info(self) -> Dict[str, Any]: + """获取模型信息""" + return { + "model_name": self.model_name, + "local_model_path": self.local_model_path, + "device": self._device, + "is_loaded": self._model is not None, + "load_time": self._load_time + } + + +# 全局实例 +_cache_manager = None +_model_manager = None + +def get_cache_manager() -> EmbeddingCacheManager: + """获取缓存管理器实例""" + global _cache_manager + if _cache_manager is None: + max_cache_size = int(os.getenv("EMBEDDING_MAX_CACHE_SIZE", "5")) + max_memory_mb = int(os.getenv("EMBEDDING_MAX_MEMORY_MB", "1024")) + _cache_manager = EmbeddingCacheManager(max_cache_size, max_memory_mb) + return _cache_manager + +def get_model_manager() -> GlobalModelManager: + """获取模型管理器实例""" + global _model_manager + if _model_manager is None: + model_name = os.getenv("SENTENCE_TRANSFORMER_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") + _model_manager = GlobalModelManager(model_name) + return _model_manager \ No newline at end of file diff --git a/embedding/model_client.py b/embedding/model_client.py new file mode 100644 index 0000000..d168236 --- /dev/null +++ b/embedding/model_client.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +""" +模型服务客户端 +提供与 model_server.py 交互的客户端接口 +""" + +import os +import asyncio +import logging +from typing import List, Union, Dict, Any +import numpy as np +import aiohttp +import json +from sentence_transformers import util + +logger = logging.getLogger(__name__) + +class ModelClient: + """模型服务客户端""" + + def __init__(self, base_url: str = "http://127.0.0.1:8000"): + self.base_url = base_url.rstrip('/') + self.session = None + + async def __aenter__(self): + self.session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=300) # 5分钟超时 + ) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + if self.session: + await self.session.close() + + async def health_check(self) -> Dict[str, Any]: + """检查服务健康状态""" + try: + async with self.session.get(f"{self.base_url}/health") as response: + return await response.json() + except Exception as e: + logger.error(f"健康检查失败: {e}") + return {"status": "unavailable", "error": str(e)} + + async def load_model(self) -> bool: + """预加载模型""" + try: + async with self.session.post(f"{self.base_url}/load_model") as response: + result = await response.json() + return response.status == 200 + except Exception as e: + logger.error(f"模型加载失败: {e}") + return False + + async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray: + """编码文本为向量""" + if not texts: + return np.array([]) + + try: + data = { + "texts": texts, + "batch_size": batch_size + } + + async with self.session.post( + f"{self.base_url}/encode", + json=data + ) as response: + if response.status != 200: + error_text = await response.text() + raise Exception(f"编码失败: {error_text}") + + result = await response.json() + return np.array(result["embeddings"]) + + except Exception as e: + logger.error(f"文本编码失败: {e}") + raise + + async def compute_similarity(self, text1: str, text2: str) -> float: + """计算两个文本的相似度""" + try: + data = { + "text1": text1, + "text2": text2 + } + + async with self.session.post( + f"{self.base_url}/similarity", + json=data + ) as response: + if response.status != 200: + error_text = await response.text() + raise Exception(f"相似度计算失败: {error_text}") + + result = await response.json() + return result["similarity_score"] + + except Exception as e: + logger.error(f"相似度计算失败: {e}") + raise + +# 同步版本的客户端包装器 +class SyncModelClient: + """同步模型客户端""" + + def __init__(self, base_url: str = "http://127.0.0.1:8000"): + self.base_url = base_url + self._async_client = ModelClient(base_url) + + def _run_async(self, coro): + """运行异步协程""" + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + return loop.run_until_complete(coro) + + def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray: + """同步版本的文本编码""" + async def _encode(): + async with self._async_client as client: + return await client.encode_texts(texts, batch_size) + + return self._run_async(_encode()) + + def compute_similarity(self, text1: str, text2: str) -> float: + """同步版本的相似度计算""" + async def _similarity(): + async with self._async_client as client: + return await client.compute_similarity(text1, text2) + + return self._run_async(_similarity()) + + def health_check(self) -> Dict[str, Any]: + """同步版本的健康检查""" + async def _health(): + async with self._async_client as client: + return await client.health_check() + + return self._run_async(_health()) + + def load_model(self) -> bool: + """同步版本的模型加载""" + async def _load(): + async with self._async_client as client: + return await client.load_model() + + return self._run_async(_load()) + +# 全局客户端实例(单例模式) +_client_instance = None + +def get_model_client(base_url: str = None) -> SyncModelClient: + """获取模型客户端实例(单例)""" + global _client_instance + + if _client_instance is None: + url = base_url or os.environ.get('MODEL_SERVER_URL', 'http://127.0.0.1:8000') + _client_instance = SyncModelClient(url) + + return _client_instance + +# 便捷函数 +def encode_texts(texts: List[str], batch_size: int = 32) -> np.ndarray: + """便捷的文本编码函数""" + client = get_model_client() + return client.encode_texts(texts, batch_size) + +def compute_similarity(text1: str, text2: str) -> float: + """便捷的相似度计算函数""" + client = get_model_client() + return client.compute_similarity(text1, text2) + +def is_service_available() -> bool: + """检查模型服务是否可用""" + client = get_model_client() + health = client.health_check() + return health.get("status") == "running" or health.get("status") == "ready" \ No newline at end of file diff --git a/embedding/model_server.py b/embedding/model_server.py new file mode 100644 index 0000000..df298f8 --- /dev/null +++ b/embedding/model_server.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +""" +独立的模型服务器 +提供统一的 embedding 服务,减少内存占用 +""" + +import os +import asyncio +import logging +from typing import List, Dict, Any +import numpy as np +from sentence_transformers import SentenceTransformer +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +import uvicorn +from pydantic import BaseModel + +# 配置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# 全局模型实例 +model = None +model_config = { + "model_name": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", + "local_model_path": "./models/paraphrase-multilingual-MiniLM-L12-v2", + "device": "cpu" +} + +# Pydantic 模型 +class TextRequest(BaseModel): + texts: List[str] + batch_size: int = 32 + +class SimilarityRequest(BaseModel): + text1: str + text2: str + +class HealthResponse(BaseModel): + status: str + model_loaded: bool + model_path: str + device: str + +class EmbeddingResponse(BaseModel): + embeddings: List[List[float]] + shape: List[int] + +class SimilarityResponse(BaseModel): + similarity_score: float + +class ModelServer: + def __init__(self): + self.model = None + self.model_loaded = False + + async def load_model(self): + """延迟加载模型""" + if self.model_loaded: + return + + logger.info("正在加载模型...") + + # 检查本地模型是否存在 + if os.path.exists(model_config["local_model_path"]): + model_path = model_config["local_model_path"] + logger.info(f"使用本地模型: {model_path}") + else: + model_path = model_config["model_name"] + logger.info(f"使用 HuggingFace 模型: {model_path}") + + # 从环境变量获取设备配置 + device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', model_config["device"]) + if device not in ['cpu', 'cuda', 'mps']: + logger.warning(f"不支持的设备类型 '{device}',使用默认 CPU") + device = 'cpu' + + logger.info(f"使用设备: {device}") + + try: + self.model = SentenceTransformer(model_path, device=device) + self.model_loaded = True + model_config["device"] = device + model_config["current_model_path"] = model_path + logger.info("模型加载完成") + except Exception as e: + logger.error(f"模型加载失败: {e}") + raise HTTPException(status_code=500, detail=f"模型加载失败: {str(e)}") + + async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray: + """编码文本为向量""" + if not texts: + return np.array([]) + + await self.load_model() + + try: + embeddings = self.model.encode( + texts, + batch_size=batch_size, + convert_to_tensor=True, + show_progress_bar=False + ) + + # 转换为 CPU numpy 数组 + if embeddings.is_cuda: + embeddings = embeddings.cpu().numpy() + else: + embeddings = embeddings.numpy() + + return embeddings + except Exception as e: + logger.error(f"编码失败: {e}") + raise HTTPException(status_code=500, detail=f"编码失败: {str(e)}") + + async def compute_similarity(self, text1: str, text2: str) -> float: + """计算两个文本的相似度""" + if not text1 or not text2: + raise HTTPException(status_code=400, detail="文本不能为空") + + await self.load_model() + + try: + embeddings = self.model.encode([text1, text2], convert_to_tensor=True) + from sentence_transformers import util + similarity = util.cos_sim(embeddings[0:1], embeddings[1:2])[0][0] + return similarity.item() + except Exception as e: + logger.error(f"相似度计算失败: {e}") + raise HTTPException(status_code=500, detail=f"相似度计算失败: {str(e)}") + +# 创建服务器实例 +server = ModelServer() + +# 创建 FastAPI 应用 +app = FastAPI( + title="Embedding Model Server", + description="统一的句子嵌入和相似度计算服务", + version="1.0.0" +) + +# 添加 CORS 中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 健康检查 +@app.get("/health", response_model=HealthResponse) +async def health_check(): + """健康检查""" + try: + if not server.model_loaded: + return HealthResponse( + status="ready", + model_loaded=False, + model_path="", + device=model_config["device"] + ) + + return HealthResponse( + status="running", + model_loaded=True, + model_path=model_config.get("current_model_path", ""), + device=model_config["device"] + ) + except Exception as e: + logger.error(f"健康检查失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +# 预加载模型 +@app.post("/load_model") +async def preload_model(): + """预加载模型""" + await server.load_model() + return {"message": "模型加载完成"} + +# 文本嵌入 +@app.post("/encode", response_model=EmbeddingResponse) +async def encode_texts(request: TextRequest): + """编码文本为向量""" + try: + embeddings = await server.encode_texts(request.texts, request.batch_size) + return EmbeddingResponse( + embeddings=embeddings.tolist(), + shape=list(embeddings.shape) + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +# 相似度计算 +@app.post("/similarity", response_model=SimilarityResponse) +async def compute_similarity(request: SimilarityRequest): + """计算两个文本的相似度""" + try: + similarity = await server.compute_similarity(request.text1, request.text2) + return SimilarityResponse(similarity_score=similarity) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +# 批量编码(用于文档处理) +@app.post("/batch_encode") +async def batch_encode(texts: List[str], batch_size: int = 64): + """批量编码大量文本""" + if not texts: + return {"embeddings": [], "shape": [0, 0]} + + try: + embeddings = await server.encode_texts(texts, batch_size) + return { + "embeddings": embeddings.tolist(), + "shape": list(embeddings.shape) + } + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="启动嵌入模型服务器") + parser.add_argument("--host", default="127.0.0.1", help="服务器地址") + parser.add_argument("--port", type=int, default=8000, help="服务器端口") + parser.add_argument("--workers", type=int, default=1, help="工作进程数") + parser.add_argument("--reload", action="store_true", help="开发模式自动重载") + + args = parser.parse_args() + + logger.info(f"启动模型服务器: http://{args.host}:{args.port}") + + uvicorn.run( + "model_server:app", + host=args.host, + port=args.port, + workers=args.workers, + reload=args.reload, + log_level="info" + ) \ No newline at end of file diff --git a/embedding/search_service.py b/embedding/search_service.py new file mode 100644 index 0000000..a2951ee --- /dev/null +++ b/embedding/search_service.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python3 +""" +语义检索服务 +支持高并发的语义搜索功能 +""" + +import asyncio +import time +import logging +from typing import List, Dict, Any, Optional, Tuple +import numpy as np + +from .manager import get_cache_manager, get_model_manager, CacheEntry + +logger = logging.getLogger(__name__) + + +class SemanticSearchService: + """语义检索服务""" + + def __init__(self): + self.cache_manager = get_cache_manager() + self.model_manager = get_model_manager() + + async def semantic_search( + self, + embedding_file: str, + query: str, + top_k: int = 20, + min_score: float = 0.0 + ) -> Dict[str, Any]: + """ + 执行语义搜索 + + Args: + embedding_file: embedding.pkl 文件路径 + query: 查询关键词 + top_k: 返回结果数量 + min_score: 最小相似度阈值 + + Returns: + 搜索结果 + """ + start_time = time.time() + + try: + # 验证输入参数 + if not embedding_file: + return { + "success": False, + "error": "embedding_file 参数不能为空" + } + + if not query or not query.strip(): + return { + "success": False, + "error": "query 参数不能为空" + } + + query = query.strip() + + # 加载 embedding 数据 + cache_entry = await self.cache_manager.load_embedding_data(embedding_file) + if cache_entry is None: + return { + "success": False, + "error": f"无法加载 embedding 文件: {embedding_file}" + } + + # 编码查询 + query_embeddings = await self.model_manager.encode_texts([query], batch_size=1) + if len(query_embeddings) == 0: + return { + "success": False, + "error": "查询编码失败" + } + + query_embedding = query_embeddings[0] + + # 计算相似度 + similarities = self._compute_similarities(query_embedding, cache_entry.embeddings) + + # 获取 top_k 结果 + top_results = self._get_top_results( + similarities, + cache_entry.chunks, + top_k, + min_score + ) + + processing_time = time.time() - start_time + + return { + "success": True, + "query": query, + "embedding_file": embedding_file, + "top_k": top_k, + "processing_time": processing_time, + "chunking_strategy": cache_entry.chunking_strategy, + "total_chunks": len(cache_entry.chunks), + "results": top_results, + "cache_stats": { + "access_count": cache_entry.access_count, + "load_time": cache_entry.load_time + } + } + + except Exception as e: + logger.error(f"语义搜索失败: {embedding_file}, {query}, {e}") + return { + "success": False, + "error": f"搜索失败: {str(e)}" + } + + def _compute_similarities(self, query_embedding: np.ndarray, embeddings: np.ndarray) -> np.ndarray: + """计算相似度分数""" + try: + # 确保数据类型和形状正确 + if embeddings.ndim == 1: + embeddings = embeddings.reshape(1, -1) + + # 使用余弦相似度 + query_norm = np.linalg.norm(query_embedding) + embeddings_norm = np.linalg.norm(embeddings, axis=1) + + # 避免除零错误 + if query_norm == 0: + query_norm = 1e-8 + embeddings_norm = np.where(embeddings_norm == 0, 1e-8, embeddings_norm) + + # 计算余弦相似度 + similarities = np.dot(embeddings, query_embedding) / (embeddings_norm * query_norm) + + # 确保结果在合理范围内 + similarities = np.clip(similarities, -1.0, 1.0) + + return similarities + + except Exception as e: + logger.error(f"相似度计算失败: {e}") + return np.array([]) + + def _get_top_results( + self, + similarities: np.ndarray, + chunks: List[str], + top_k: int, + min_score: float + ) -> List[Dict[str, Any]]: + """获取 top_k 结果""" + try: + if len(similarities) == 0: + return [] + + # 获取排序后的索引 + top_indices = np.argsort(-similarities)[:top_k] + + results = [] + for rank, idx in enumerate(top_indices): + score = similarities[idx] + + # 过滤低分结果 + if score < min_score: + continue + + chunk = chunks[idx] + + # 创建结果条目 + result = { + "rank": rank + 1, + "score": float(score), + "content": chunk, + "content_preview": self._get_preview(chunk) + } + + results.append(result) + + return results + + except Exception as e: + logger.error(f"获取 top_k 结果失败: {e}") + return [] + + def _get_preview(self, content: str, max_length: int = 200) -> str: + """获取内容预览""" + if not content: + return "" + + # 清理内容 + preview = content.replace('\n', ' ').strip() + + # 截断 + if len(preview) > max_length: + preview = preview[:max_length] + "..." + + return preview + + async def batch_search( + self, + requests: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """ + 批量搜索 + + Args: + requests: 搜索请求列表,每个请求包含 embedding_file, query, top_k + + Returns: + 搜索结果列表 + """ + if not requests: + return [] + + # 并发执行搜索 + tasks = [] + for req in requests: + embedding_file = req.get("embedding_file", "") + query = req.get("query", "") + top_k = req.get("top_k", 20) + min_score = req.get("min_score", 0.0) + + task = self.semantic_search(embedding_file, query, top_k, min_score) + tasks.append(task) + + # 等待所有任务完成 + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 处理异常 + processed_results = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + processed_results.append({ + "success": False, + "error": f"搜索异常: {str(result)}", + "request_index": i + }) + else: + result["request_index"] = i + processed_results.append(result) + + return processed_results + + def get_service_stats(self) -> Dict[str, Any]: + """获取服务统计信息""" + return { + "cache_manager_stats": self.cache_manager.get_cache_stats(), + "model_manager_info": self.model_manager.get_model_info() + } + + +# 全局实例 +_search_service = None + +def get_search_service() -> SemanticSearchService: + """获取搜索服务实例""" + global _search_service + if _search_service is None: + _search_service = SemanticSearchService() + return _search_service \ No newline at end of file diff --git a/fastapi_app.py b/fastapi_app.py index 642fea8..de5470c 100644 --- a/fastapi_app.py +++ b/fastapi_app.py @@ -23,6 +23,8 @@ from file_manager_api import router as file_manager_router from qwen_agent.llm.schema import ASSISTANT, FUNCTION from pydantic import BaseModel, Field +# 导入语义检索服务 +from embedding import get_search_service # Import utility modules from utils import ( @@ -126,6 +128,56 @@ def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extensio return versioned_filename, next_version +# Models are now imported from utils module + + +# 语义检索请求模型 +class SemanticSearchRequest(BaseModel): + embedding_file: str = Field(..., description="embedding.pkl 文件路径") + query: str = Field(..., description="搜索关键词") + top_k: int = Field(default=20, description="返回结果数量", ge=1, le=100) + min_score: float = Field(default=0.0, description="最小相似度阈值", ge=0.0, le=1.0) + + +class BatchSearchRequest(BaseModel): + requests: List[SemanticSearchRequest] = Field(..., description="搜索请求列表") + + +# 语义检索响应模型 +class SearchResult(BaseModel): + rank: int = Field(..., description="排名") + score: float = Field(..., description="相似度分数") + content: str = Field(..., description="匹配的内容") + content_preview: str = Field(..., description="内容预览") + + +class SemanticSearchResponse(BaseModel): + success: bool = Field(..., description="是否成功") + query: str = Field(..., description="查询关键词") + embedding_file: str = Field(..., description="embedding 文件路径") + processing_time: float = Field(..., description="处理时间(秒)") + total_chunks: int = Field(..., description="总文档块数") + chunking_strategy: str = Field(..., description="分块策略") + results: List[SearchResult] = Field(..., description="搜索结果") + cache_stats: Optional[Dict[str, Any]] = Field(None, description="缓存统计") + error: Optional[str] = Field(None, description="错误信息") + + +# 编码请求和响应模型 +class EncodeRequest(BaseModel): + texts: List[str] = Field(..., description="要编码的文本列表") + batch_size: int = Field(default=32, description="批次大小", ge=1, le=128) + + +class EncodeResponse(BaseModel): + success: bool = Field(..., description="是否成功") + embeddings: List[List[float]] = Field(..., description="编码结果") + shape: List[int] = Field(..., description="embeddings 形状") + processing_time: float = Field(..., description="处理时间(秒)") + total_texts: int = Field(..., description="总文本数量") + error: Optional[str] = Field(None, description="错误信息") + + # Custom version for qwen-agent messages - keep this function as it's specific to this app def get_content_from_messages(messages: List[dict], tool_response: bool = True) -> str: """Extract content from qwen-agent messages with special formatting""" @@ -1536,6 +1588,205 @@ async def reset_files_processing(dataset_id: str): except Exception as e: raise HTTPException(status_code=500, detail=f"重置文件处理状态失败: {str(e)}") + +# ============ 语义检索 API 端点 ============ + +@app.post("/api/v1/semantic-search", response_model=SemanticSearchResponse) +async def semantic_search(request: SemanticSearchRequest): + """ + 语义搜索 API + + Args: + request: 包含 embedding_file 和 query 的搜索请求 + + Returns: + 语义搜索结果 + """ + try: + search_service = get_search_service() + result = await search_service.semantic_search( + embedding_file=request.embedding_file, + query=request.query, + top_k=request.top_k, + min_score=request.min_score + ) + + if result["success"]: + return SemanticSearchResponse( + success=True, + query=result["query"], + embedding_file=result["embedding_file"], + processing_time=result["processing_time"], + total_chunks=result["total_chunks"], + chunking_strategy=result["chunking_strategy"], + results=[ + SearchResult( + rank=r["rank"], + score=r["score"], + content=r["content"], + content_preview=r["content_preview"] + ) + for r in result["results"] + ], + cache_stats=result.get("cache_stats") + ) + else: + return SemanticSearchResponse( + success=False, + query=request.query, + embedding_file=request.embedding_file, + processing_time=0.0, + total_chunks=0, + chunking_strategy="", + results=[], + error=result.get("error", "未知错误") + ) + + except Exception as e: + logger.error(f"语义搜索 API 错误: {e}") + raise HTTPException(status_code=500, detail=f"语义搜索失败: {str(e)}") + + +@app.post("/api/v1/semantic-search/batch") +async def batch_semantic_search(request: BatchSearchRequest): + """ + 批量语义搜索 API + + Args: + request: 包含多个搜索请求的批量请求 + + Returns: + 批量搜索结果 + """ + try: + search_service = get_search_service() + + # 转换请求格式 + search_requests = [ + { + "embedding_file": req.embedding_file, + "query": req.query, + "top_k": req.top_k, + "min_score": req.min_score + } + for req in request.requests + ] + + results = await search_service.batch_search(search_requests) + + return { + "success": True, + "total_requests": len(request.requests), + "results": results + } + + except Exception as e: + logger.error(f"批量语义搜索 API 错误: {e}") + raise HTTPException(status_code=500, detail=f"批量语义搜索失败: {str(e)}") + + +@app.get("/api/v1/semantic-search/stats") +async def get_semantic_search_stats(): + """ + 获取语义搜索服务统计信息 + + Returns: + 服务统计信息 + """ + try: + search_service = get_search_service() + stats = search_service.get_service_stats() + + return { + "success": True, + "timestamp": int(time.time()), + "stats": stats + } + + except Exception as e: + logger.error(f"获取语义搜索统计信息失败: {e}") + raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}") + + +@app.post("/api/v1/semantic-search/clear-cache") +async def clear_semantic_search_cache(): + """ + 清空语义搜索缓存 + + Returns: + 清理结果 + """ + try: + from manager import get_cache_manager + cache_manager = get_cache_manager() + cache_manager.clear_cache() + + return { + "success": True, + "message": "缓存已清空" + } + + except Exception as e: + logger.error(f"清空语义搜索缓存失败: {e}") + raise HTTPException(status_code=500, detail=f"清空缓存失败: {str(e)}") + + +@app.post("/api/v1/embedding/encode", response_model=EncodeResponse) +async def encode_texts(request: EncodeRequest): + """ + 文本编码 API + + Args: + request: 包含 texts 和 batch_size 的编码请求 + + Returns: + 编码结果 + """ + try: + search_service = get_search_service() + + if not request.texts: + return EncodeResponse( + success=False, + embeddings=[], + shape=[0, 0], + processing_time=0.0, + total_texts=0, + error="texts 不能为空" + ) + + start_time = time.time() + + # 使用模型管理器编码文本 + embeddings = await search_service.model_manager.encode_texts( + request.texts, + batch_size=request.batch_size + ) + + processing_time = time.time() - start_time + + # 转换为列表格式 + embeddings_list = embeddings.tolist() + + return EncodeResponse( + success=True, + embeddings=embeddings_list, + shape=list(embeddings.shape), + processing_time=processing_time, + total_texts=len(request.texts) + ) + + except Exception as e: + logger.error(f"文本编码 API 错误: {e}") + return EncodeResponse( + success=False, + embeddings=[], + shape=[0, 0], + processing_time=0.0, + total_texts=len(request.texts) if request else 0, + error=f"编码失败: {str(e)}" + ) + # 注册文件管理API路由 app.include_router(file_manager_router) diff --git a/mcp/semantic_search_server.py b/mcp/semantic_search_server.py index 4dbb65e..5831ced 100644 --- a/mcp/semantic_search_server.py +++ b/mcp/semantic_search_server.py @@ -27,41 +27,11 @@ from mcp_common import ( handle_mcp_streaming ) -# 延迟加载模型 -embedder = None - -def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'): - """获取模型实例(延迟加载) - - Args: - model_name_or_path (str): 模型名称或本地路径 - - 可以是 HuggingFace 模型名称 - - 可以是本地模型路径 - """ - global embedder - if embedder is None: - # 优先使用本地模型路径 - local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2" - - # 从环境变量获取设备配置,默认为 CPU - device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu') - if device not in ['cpu', 'cuda', 'mps']: - print(f"警告: 不支持的设备类型 '{device}',使用默认 CPU") - device = 'cpu' - - # 检查本地模型是否存在 - if os.path.exists(local_model_path): - print(f"使用本地模型: {local_model_path}") - embedder = SentenceTransformer(local_model_path, device=device) - else: - print(f"本地模型不存在,使用HuggingFace模型: {model_name_or_path}") - embedder = SentenceTransformer(model_name_or_path, device=device) - - return embedder +import requests def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: int = 20) -> Dict[str, Any]: - """执行语义搜索,支持多个查询""" + """执行语义搜索,通过调用 FastAPI 接口""" # 处理查询输入 if isinstance(queries, str): queries = [queries] @@ -80,52 +50,45 @@ def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: # 过滤空查询 queries = [q.strip() for q in queries if q.strip()] - # 验证embeddings文件路径 try: - # 解析文件路径,支持 folder/document.txt 和 document.txt 格式 - resolved_embeddings_file = resolve_file_path(embeddings_file) + # FastAPI 服务地址 + fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001') + api_endpoint = f"{fastapi_url}/api/v1/semantic-search" - # 加载嵌入数据 - with open(resolved_embeddings_file, 'rb') as f: - embedding_data = pickle.load(f) - - # 兼容新旧数据结构 - if 'chunks' in embedding_data: - # 新的数据结构(使用chunks) - sentences = embedding_data['chunks'] - sentence_embeddings = embedding_data['embeddings'] - # 从embedding_data中获取模型路径(如果有的话) - model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') - model = get_model(model_path) - else: - # 旧的数据结构(使用sentences) - sentences = embedding_data['sentences'] - sentence_embeddings = embedding_data['embeddings'] - model = get_model() - - # 编码所有查询 - query_embeddings = model.encode(queries, convert_to_tensor=True) - - # 计算所有查询的相似度 + # 处理每个查询 all_results = [] - for i, query in enumerate(queries): - query_embedding = query_embeddings[i:i+1] # 保持2D形状 - cos_scores = util.cos_sim(query_embedding, sentence_embeddings)[0] + resolved_embeddings_file = resolve_file_path(embeddings_file) + for query in queries: + # 调用 FastAPI 接口 + request_data = { + "embedding_file": resolved_embeddings_file, + "query": query, + "top_k": top_k, + "min_score": 0.0 + } - # 获取top_k结果 - top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k] + response = requests.post( + api_endpoint, + json=request_data, + timeout=30 + ) - # 格式化结果 - for j, idx in enumerate(top_results): - sentence = sentences[idx] - score = cos_scores[idx].item() - all_results.append({ - 'query': query, - 'rank': j + 1, - 'content': sentence, - 'similarity_score': score, - 'file_path': embeddings_file - }) + if response.status_code == 200: + result_data = response.json() + + if result_data.get("success"): + for res in result_data.get("results", []): + all_results.append({ + 'query': query, + 'rank': res["rank"], + 'content': res["content"], + 'similarity_score': res["score"], + 'file_path': embeddings_file + }) + else: + print(f"搜索失败: {result_data.get('error', '未知错误')}") + else: + print(f"API 调用失败: {response.status_code} - {response.text}") if not all_results: return { @@ -159,12 +122,12 @@ def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: ] } - except FileNotFoundError: + except requests.exceptions.RequestException as e: return { "content": [ { "type": "text", - "text": f"Error: embeddings file {embeddings_file} not found" + "text": f"API request failed: {str(e)}" } ] } @@ -179,49 +142,6 @@ def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: } - - -def get_model_info() -> Dict[str, Any]: - """获取当前模型信息""" - try: - # 检查本地模型路径 - local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2" - - if os.path.exists(local_model_path): - return { - "content": [ - { - "type": "text", - "text": f"✅ 使用本地模型: {local_model_path}\n" - f"模型状态: 已加载\n" - f"设备: CPU\n" - f"说明: 避免从HuggingFace下载,提高响应速度" - } - ] - } - else: - return { - "content": [ - { - "type": "text", - "text": f"⚠️ 本地模型不存在: {local_model_path}\n" - f"将使用HuggingFace模型: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2\n" - f"建议: 下载模型到本地以提高响应速度\n" - f"设备: CPU" - } - ] - } - except Exception as e: - return { - "content": [ - { - "type": "text", - "text": f"❌ 获取模型信息失败: {str(e)}" - } - ] - } - - async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]: """Handle MCP request""" try: @@ -260,15 +180,6 @@ async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]: "result": result } - elif tool_name == "get_model_info": - result = get_model_info() - - return { - "jsonrpc": "2.0", - "id": request_id, - "result": result - } - else: return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}") diff --git a/prompt/wowtalk.md b/prompt/wowtalk.md index 287ee1c..ee1d093 100644 --- a/prompt/wowtalk.md +++ b/prompt/wowtalk.md @@ -1,300 +1,139 @@ -# 系统角色定义 +# 清水建築智能AI管理コンシェルジュ -## 核心身份 -您是企业级智能办公助手,具备完整的IoT设备管理、实时通信、环境监测和资产追踪能力。 +## 系统角色 +あなたは清水建設株式会社のイノベーション拠点「温故創新の森 NOVARE(ノヴァーレ)」のスマートビル管理AIコンシェルジュです,具备完整的IoT设备管理、实时通信、环境监测和资产追踪能力。 ## 执行准则 -- **工具优先原则**:所有可执行操作必须通过工具实现 -- **即时响应机制**:识别操作意图后立即触发相应工具调用 -- **最小化延迟**:禁止使用过渡性语言,直接执行并返回结果 +- **知识库优先**:所有问题优先查询知识库,无结果时再使用其他工具 +- **工具驱动**:所有操作通过工具接口实现 +- **即时响应**:识别意图后立即触发相应工具调用 +- **结果导向**:直接返回执行结果,减少过渡性语言 -# 工具调用协议 +# 工具接口映射 -## 强制触发规则 -| 操作类型 | 识别关键词 | 目标工具 | 执行优先级 | -|---------|-----------|----------|-----------| -| 设备控制 | 打开/关闭/启动/停止/调节/设置 | dxcore_update_device_status | P0 | -| 消息通知 | @用户名/通知/告知/提醒 | wowtalk_send_message_to_member | P0 | -| 状态查询 | 状态/温度/湿度/运行情况 | dxcore_get_device_status | P1 | -| 位置查询 | 位置/在哪/查找/坐标 | eb_get_sensor_location | P1 | -| 环境查询 | 天气/气温/降水/风速 | weather_get_by_location | P1 | -| 网络搜索 | 搜索/查找/查询/百度/谷歌 | web_search | P1 | -| 人员检索 | 找人/员工/同事/联系方式 | find_employee_by_name | P2 | -| 设备检索 | 找设备/传感器/终端 | find_iot_device | P2 | +## 核心功能识别 +- **设备控制**:打开/关闭/调节 → Iot Control-dxcore_update_device_status +- **状态查询**:状态/温度/湿度 → Iot Control-dxcore_get_device_status +- **位置服务**:位置/在哪/查找 → Iot Control-eb_get_sensor_location +- **设备查找**:房间/设备查找 → Iot Control-find_devices_by_room +- **人员检索**:找人/员工/同事 → Iot Control-find_employee_by_name +- **设备检索**:找设备/传感器 → Iot Control-find_iot_device +- **消息通知**:通知/告知/提醒 → Wowtalk tool-wowtalk_send_message_to_member +- **环境信息**:天气/气温/风速 → Weather Information-weather_get_by_location +- **知识库检索**: 知识查询/其他查询优先检索知识库 → rag_retrieve-rag_retrieve +- **网络搜索**:搜索/查询/百度 → WebSearch-web_search -## 立即执行机制 -- **零延迟策略**:识别操作意图后立即执行工具调用,禁止缓冲性语言 -- **并行执行**:多操作请求同时触发相应工具,最大化执行效率 -- **原子操作**:每个工具调用作为独立事务执行,确保结果可靠性 +## 执行原则 +- **即时执行**:识别意图后立即调用工具 +- **并行处理**:支持多个工具同时执行 +- **精准返回**:基于工具执行结果直接响应 -# 功能模块架构 +# 核心功能模块 -## 通信管理模块 -### 消息路由系统 -```mermaid -graph LR - A[用户输入] --> B{识别@标记} - B -->|检测到| C[解析用户ID和消息内容] - C --> D[调用wowtalk_send_message_to_member] - D --> E[返回发送状态] -``` +## 消息通知 +- **触发条件**:通知/告知/提醒等关键词 +- **执行方式**:调用wowtalk_send_message_to_member发送消息 +- **状态返回**:消息发送成功/失败状态 -### 执行规范 -- **模式识别**:`@用户名(id:ID)` → 立即触发消息路由 -- **并行处理**:多收信人场景下并发调用发送工具 -- **状态确认**:每次调用后返回明确的发送结果 +## 设备控制 +- **控制范围**:空调、照明、风扇等IoT设备 +- **操作类型**:开关控制、参数调节(温度16-30°C、湿度30-70%、风速0-100%) +- **状态查询**:实时获取设备运行状态 -## 设备管理模块 -### 设备控制接口 -- **状态变更操作**:dxcore_update_device_status -- **状态查询操作**:dxcore_get_device_status -- **参数设置范围**:温度(16-30°C)、湿度(30-70%)、风速(0-100%) +## 定位服务 +- **人员定位**:通过姓名查找员工位置 +- **设备定位**:查询IoT设备所在房间/区域 +- **精度标准**:室内3米、室外10米 -### 控制指令映射 -| 用户语言 | 系统指令 | 参数格式 | -|---------|----------|----------| -| "打开空调" | update_device_status | {sensor_id: "001", running_control: 1} | -| "调到24度" | update_device_status | {sensor_id: "001",temp_setting: 24} | -| "查看温度" | get_device_status | {sensor_id: "001"} | +## 环境信息 +- **天气查询**:实时天气、温度、风速等数据 +- **环境监测**:室内温度、湿度等环境参数 +- **智能建议**:基于环境数据提供优化建议 -## 位置服务模块 -### 定位查询协议 -- **触发条件**:包含位置关键词的查询 -- **响应格式**:楼层 + 房间/区域(过滤坐标、ID等技术信息) -- **精度要求**:室内3米精度,室外10米精度 +## 检索引擎 +- **人员搜索**:支持姓名、部门等多维度查找 +- **设备搜索**:按类型、位置、状态条件筛选 +- **网络搜索**:实时获取互联网信息 -## 环境监测模块 -### 天气服务集成 -- **自动调用**:识别天气相关词汇后立即执行 -- **数据源**:weather_get_by_location -- **增值服务**:自动生成出行建议和注意事项 +## 知识库集成 +- **优先查询**:用户的其他问题请优先调用rag_retrieve查询知识库 +- **补充搜索**:知识库无结果时使用网络搜索web_search +- **结果整合**:综合多源信息提供完整答案 -## 资产检索模块 -### 搜索引擎优化 -- **人员检索**:支持姓名、工号、部门多维度搜索 -- **设备检索**:支持设备类型、位置、状态多条件过滤 -- **结果排序**:按相关度和距离优先级排序 +# 智能执行流程 -## 网络搜索模块 -### Web搜索集成 -- **自动调用**:识别搜索相关词汇后立即执行 -- **数据源**:web_search工具,支持实时网络信息检索 +## 处理流程 +1. **意图识别**:分析用户输入,提取操作类型和参数 +2. **工具选择**:根据意图匹配相应工具接口 +3. **并行执行**:同时调用多个相关工具 +4. **结果聚合**:整合执行结果,统一返回 -# 智能执行引擎 +# 应用场景 -## 多阶段处理流水线 -```mermaid -sequenceDiagram - participant U as 用户输入 - participant IR as 意图识别引擎 - participant TM as 工具映射器 - participant TE as 工具执行器 - participant SR as 结果聚合器 +## 消息通知场景 +**用户**:"通知清水さん检查2楼空调" +- find_employee_by_name(name="清水") +- wowtalk_send_message_to_member(to_account="[清水的sensor_id]", message_content="请检查2楼空调") +**响应**:"已通知至清水さん检查2楼空调" - U->>IR: 请求解析 - IR->>TM: 操作意图分类 - TM->>TE: 并行工具调用 - TE->>SR: 执行结果返回 - SR->>U: 统一响应输出 -``` +**用户**:"搜索最新的节能技术方案,并发送给田中さん" +- web_search(query="最新节能技术方案", max_results=5) +- find_employee_by_name(name="田中") +- wowtalk_send_message_to_member(to_account="[田中的sensor_id]", message_content="[搜索结果摘要]") +**响应**:"最新节能技术方案,已发送给田中さん" -## 处理阶段详解 -### 阶段1:语义解析与意图识别 -- **自然语言理解**:提取操作动词、目标对象、参数值 -- **上下文关联**:结合历史对话理解隐含意图 -- **多意图检测**:识别复合请求中的多个操作需求 +## 设备控制场景 +**用户**:"打开附近的风扇" +- find_employee_by_name(name="[当前用户]") → 获取用户位置和sensor_id +- find_iot_device(device_type="dc_fan", target_sensor_id="[当前用户的sensor_id]") → 查找附近设备 +- dxcore_update_device_status(running_control=1, sensor_id="[找到的设备的sensor_id]") → 开启设备 +**响应**:"已为您开启301室的风扇" -### 阶段2:智能工具编排 -- **工具选择算法**:基于操作类型匹配最优工具 -- **参数提取与验证**:自动提取并验证工具调用参数 -- **执行计划生成**:确定串行/并行执行策略 +**用户**:"5楼风扇电量异常,通知清水さん并报告具体位置" +- find_iot_device(device_type="dc_fan") → 查找设备 +- dxcore_get_device_status(sensor_id="[风扇的sensor_id]") → 获取电量百分比、故障代码 +- find_employee_by_name(name="清水") → 人员信息查询,获取wowtalkid和位置信息 +- wowtalk_send_message_to_member(to_account="[清水太郎wowtalk_id]", message_content="5楼风扇电量异常,请及时处理") → 发送通知 +**响应**:"已通知清水さん,风扇位于5楼东侧,电量15%" -### 阶段3:并行执行与结果聚合 -- **异步执行**:P0级任务立即执行,P1-P2级任务并行处理 -- **状态监控**:实时跟踪每个工具的执行状态 -- **异常隔离**:单个工具失败不影响其他工具执行 -- **结果聚合**:合并多工具执行结果,生成统一响应 +## 问答场景 +**用户**:"无人机多少钱一台" +- 先进行关键词扩展: 无人机的价格,无人机产品介绍 +- rag_retrieve(query="无人机的价格,无人机产品介绍") → 先查询知识库 → 内部知识库检索到精确信息 +**响应**:"无人机价格为xxx" -## 复合请求处理示例 -**输入**:"通知@张工(id:001)检查2楼空调,同时查询室外温度" +**用户**:"打印机如何使用" +- 先进行关键词扩展: 打印机使用教程,打印机使用说明 +- rag_retrieve(query="打印机使用教程,打印机使用说明") → 先查询知识库,但是不完整 +- web_fetch(query="打印机使用教程,打印机使用说明") → 再检索网页 +**响应**:"[综合rag_retrieve和web_fetch的内容回复]" -**执行流程**: -``` -执行: -├── wowtalk_send_message_to_member(to_account="001", message_content="请检查2楼空调") -├── find_employee_by_name(name="张工") -├── find_iot_device(device_type="dc_fan",target_sensor_id="xxxx") -└── weather_get_by_location(location="当前位置") -``` +**用户**:"感冒了吃什么药" +- 先进行关键词扩展: 感冒药推荐、如何治疗感冒 +- rag_retrieve(query="感冒药推荐、如何治疗感冒") → 先查询知识库,但是没有检索到相关信息 +- web_fetch(query="感冒药推荐、如何治疗感冒") → 再检索网页 +**响应**:"[根据web_fetch内容回复]" -**输入**:"搜索最新的节能技术方案,并发送给@李经理(id:002)" -**执行流程**: -``` -执行: -├── web_search(query="最新节能技术方案", max_results=5) -└── wowtalk_send_message_to_member(to_account="002", message_content="[搜索结果摘要]") -``` +# 响应规范 -# 应用场景与执行范例 +## 回复原则 +- **简洁明了**:每条回复控制在1-2句话 +- **结果导向**:基于工具执行结果直接反馈 +- **专业语气**:保持企业服务水准 +- **即时响应**:工具调用完成后立即回复 -## 场景1:上下文感知设备控制 -**用户请求**:"我是清水太郎,请打开附近的风扇" +## 标准回复格式 +- **设备操作**:"空调已调至24度,运行正常" +- **消息发送**:"消息已发送至田中さん" +- **位置查询**:"清水さん在A栋3楼会议室" +- **任务完成**:"已完成:设备开启、消息发送、位置确认" -**执行序列**: -```python -# 步骤1:人员定位 -find_employee_by_name(name="清水太郎") -# → 返回:清水太郎的sensor_id +## 执行保障 +- **工具优先**:所有操作通过工具实现 +- **状态同步**:确保执行结果与实际状态一致 -# 步骤2:人员附近的设备检索 -find_iot_device( - device_type="dc_fan", - target_sensor_id="{清水太郎的sensor_id}" # -) -# → 返回:device_list - -# 步骤3:设备控制 -dxcore_update_device_status( - sensor_id="{风扇的sensor_id}", - running_control=1 -) -``` - -**响应模板**:"已为您开启301室的风扇,当前运行正常。" - -## 场景2:智能消息路由 -**用户请求**:"通知清水太郎会议室温度过高需要调节" - -**执行逻辑**: -```python -# 步骤1:人员信息查询 -find_employee_by_name(name="清水太郎") -# 返回:获取wowtalkid和位置信息 - -# 步骤2:人员通知 -wowtalk_send_message_to_member( - to_account="{清水太郎wowtalk_id}", - message_content="会议室温度过高需要调节" -) -``` - -**响应模板**:"消息已发送至清水太郎,将会尽快处理温度问题。" - -## 场景3:多维度协同处理 -**用户请求**:"5楼风扇电量异常,通知清水太郎并报告具体位置" - -**并行执行策略**: -```python -# 步骤1:查找设备列表 -find_iot_device( - device_type="dc_fan" -) -# 返回:获取5楼风扇的sensor_id - -# 步骤2:故障诊断 -dxcore_get_device_status(sensor_id="{风扇的sensor_id}") -# → 获取电量百分比、故障代码 - -# 步骤4:人员信息查询,获取wowtalkid和位置信息 -find_employee_by_name(name="清水太郎") - -# 步骤5:人员通知 -wowtalk_send_message_to_member( - to_account="{清水太郎wowtalk_id}", - message_content="5楼风扇电量异常,请及时处理" -) -# → 返回精确定位信息 -``` - -**响应模板**:"已通知清水太郎,风扇位于5楼东侧走廊,当前电量15%。" - -# 系统集成与技术规范 - -## 核心工具接口 -| 工具类型 | 接口函数 | 功能描述 | 调用优先级 | -|---------|----------|----------|-----------| -| 消息路由 | wowtalk_send_message_to_member | 实时消息推送 | P0 | -| 设备控制 | dxcore_update_device_status | 设备状态变更 | P0 | -| 设备查询 | dxcore_get_device_status | 设备状态读取 | P1 | -| 位置服务 | eb_get_sensor_location | 空间定位查询 | P1 | -| 环境监测 | weather_get_by_location | 天气数据获取 | P1 | -| 网络搜索 | web_search | 互联网信息查询 | P1 | -| 人员检索 | find_employee_by_name | 员工信息查询 | P2 | -| 设备检索 | find_iot_device | IoT设备搜索 | P2 | - -# 异常处理与容错机制 - -## 分层错误处理策略 - -### Level 1:参数验证异常 -**场景**:工具调用参数缺失或格式错误 -**处理流程**: -1. 检测参数缺失类型 -2. 生成精准的参数请求问题 -3. 获取用户补充信息后立即重试 -4. 最多重试3次,超出则转人工处理 - -### Level 2:设备连接异常 -**场景**:目标设备离线或无响应 -**处理流程**: -1. 检查设备网络连接状态 -2. 尝试重新连接设备 -3. 失败时提供替代设备建议 -4. 记录故障信息并通知维护团队 - -### Level 3:权限控制异常 -**场景**:当前用户权限不足执行操作 -**处理流程**: -1. 验证用户操作权限 -2. 权限不足时自动识别有权限的用户 -3. 提供权限申请指引 -4. 可选:自动转接权限审批流程 - -### Level 4:系统服务异常 -**场景**:工具服务不可用或超时 -**处理流程**: -1. 检测服务健康状态 -2. 启用备用服务或降级功能 -3. 记录异常日志用于系统分析 -4. 向用户明确说明当前限制 - -## 智能降级策略 -- **功能降级**:核心功能不可用时启用基础版本 -- **响应降级**:复杂处理超时时转为快速响应模式 -- **交互降级**:自动化流程失败时转为人工辅助模式 - -# 交互设计规范 - -## 响应原则 -- **简洁性**:每条响应控制在1-3句话内 -- **准确性**:所有信息必须基于工具执行结果 -- **及时性**:工具调用完成后立即反馈 -- **专业性**:保持企业级服务的专业语气 - -## 标准响应模板 -```yaml -设备操作成功: - template: "{设备名称}已{操作结果}。{附加状态信息}" - example: "空调已调至24度,当前运行正常。" - -消息发送成功: - template: "消息已发送至{接收人},{预期处理时间}" - example: "消息已发送至张工,预计5分钟内响应。" - -查询结果反馈: - template: "{查询对象}位于{位置信息}。{后续建议}" - example: "李经理在A栋3楼会议室。需要我帮您联系吗?" - -复合任务完成: - template: "所有任务已完成:{任务列表摘要}" - example: "任务完成:设备已开启、消息已发送、位置已确认。" -``` - -## 系统配置参数 +## 系统信息 - **bot_id**: {bot_id} - -# 执行保证机制 -1. **工具调用优先**:可执行操作必须通过工具实现 -2. **状态一致性**:所有操作结果与实际设备状态同步 +- **当前用户**: {user_identifier}