embedding 模型独立为api

This commit is contained in:
朱潮 2025-11-20 13:29:44 +08:00
parent 3e0b46ecbf
commit b9f6928b50
9 changed files with 1493 additions and 450 deletions

17
embedding/__init__.py Normal file
View File

@ -0,0 +1,17 @@
"""
Embedding Package
提供文本编码和语义搜索功能
"""
from .manager import get_cache_manager, get_model_manager
from .search_service import get_search_service
from .embedding import embed_document, semantic_search, split_document_by_pages
__all__ = [
'get_cache_manager',
'get_model_manager',
'get_search_service',
'embed_document',
'semantic_search',
'split_document_by_pages'
]

View File

@ -1,42 +1,51 @@
import pickle import pickle
import re import re
import numpy as np import numpy as np
from sentence_transformers import SentenceTransformer, util import os
from typing import Optional
import requests
import asyncio
# 延迟加载模型 def encode_texts_via_api(texts, batch_size=32):
embedder = None """通过 API 接口编码文本"""
if not texts:
def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'): return np.array([])
"""获取模型实例(延迟加载)
Args: try:
model_name_or_path (str): 模型名称或本地路径 # FastAPI 服务地址
- 可以是 HuggingFace 模型名称 fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
- 可以是本地模型路径 api_endpoint = f"{fastapi_url}/api/v1/embedding/encode"
"""
global embedder
if embedder is None:
print("正在加载模型...")
print(f"模型路径: {model_name_or_path}")
# 检查是否是本地路径 # 调用编码接口
import os request_data = {
if os.path.exists(model_name_or_path): "texts": texts,
print("使用本地模型") "batch_size": batch_size
}
response = requests.post(
api_endpoint,
json=request_data,
timeout=60 # 增加超时时间
)
if response.status_code == 200:
result_data = response.json()
if result_data.get("success"):
embeddings_list = result_data.get("embeddings", [])
print(f"API编码成功处理了 {len(texts)} 个文本embedding维度: {len(embeddings_list[0]) if embeddings_list else 0}")
return np.array(embeddings_list)
else:
error_msg = result_data.get('error', '未知错误')
print(f"API编码失败: {error_msg}")
raise Exception(f"API编码失败: {error_msg}")
else: else:
print("使用 HuggingFace 模型") print(f"API请求失败: {response.status_code} - {response.text}")
raise Exception(f"API请求失败: {response.status_code}")
# 从环境变量获取设备配置,默认为 CPU
device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu') except Exception as e:
if device not in ['cpu', 'cuda', 'mps']: print(f"API编码异常: {e}")
print(f"警告: 不支持的设备类型 '{device}',使用默认 CPU") raise
device = 'cpu'
print(f"使用设备: {device}")
embedder = SentenceTransformer(model_name_or_path, device=device)
print("模型加载完成")
return embedder
def clean_text(text): def clean_text(text):
""" """
@ -192,23 +201,16 @@ def embed_document(input_file='document.txt', output_file='embedding.pkl',
print(f"正在处理 {len(chunks)} 个内容块...") print(f"正在处理 {len(chunks)} 个内容块...")
# 在函数内部设置模型路径 # 使用API接口进行编码
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2" print("使用API接口进行编码...")
import os chunk_embeddings = encode_texts_via_api(chunks, batch_size=32)
if os.path.exists(local_model_path):
model_path = local_model_path
else:
model_path = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
model = get_model(model_path)
chunk_embeddings = model.encode(chunks, convert_to_tensor=True)
embedding_data = { embedding_data = {
'chunks': chunks, 'chunks': chunks,
'embeddings': chunk_embeddings, 'embeddings': chunk_embeddings,
'chunking_strategy': chunking_strategy, 'chunking_strategy': chunking_strategy,
'chunking_params': chunking_params, 'chunking_params': chunking_params,
'model_path': model_path 'model_path': 'api_service'
} }
with open(output_file, 'wb') as f: with open(output_file, 'wb') as f:
@ -254,18 +256,28 @@ def semantic_search(user_query, embeddings_file='embedding.pkl', top_k=20):
chunking_strategy = 'line' chunking_strategy = 'line'
content_type = "句子" content_type = "句子"
# 从embedding_data中获取模型路径如果有的话 # 使用API接口进行编码
model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2') print("使用API接口进行查询编码...")
model = get_model(model_path) query_embeddings = encode_texts_via_api([user_query], batch_size=1)
query_embedding = model.encode(user_query, convert_to_tensor=True) query_embedding = query_embeddings[0] if len(query_embeddings) > 0 else np.array([])
cos_scores = util.cos_sim(query_embedding, chunk_embeddings)[0] # 计算相似度
if len(chunk_embeddings.shape) > 1:
# 处理 GPU/CPU 环境下的 tensor 转换 cos_scores = np.dot(chunk_embeddings, query_embedding) / (
if cos_scores.is_cuda: np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_embedding) + 1e-8
cos_scores_np = cos_scores.cpu().numpy() )
else: else:
cos_scores_np = cos_scores.numpy() cos_scores = [0.0] # 兼容性处理
# 处理不同格式下的 cos_scores
if isinstance(cos_scores, np.ndarray):
cos_scores_np = cos_scores
else:
# PyTorch tensor
if hasattr(cos_scores, 'is_cuda') and cos_scores.is_cuda:
cos_scores_np = cos_scores.cpu().numpy()
else:
cos_scores_np = cos_scores.numpy()
top_results = np.argsort(-cos_scores_np)[:top_k] top_results = np.argsort(-cos_scores_np)[:top_k]
@ -273,7 +285,7 @@ def semantic_search(user_query, embeddings_file='embedding.pkl', top_k=20):
print(f"\n与查询最相关的 {top_k}{content_type} (分块策略: {chunking_strategy}):") print(f"\n与查询最相关的 {top_k}{content_type} (分块策略: {chunking_strategy}):")
for i, idx in enumerate(top_results): for i, idx in enumerate(top_results):
chunk = chunks[idx] chunk = chunks[idx]
score = cos_scores[idx].item() score = cos_scores_np[idx]
results.append((chunk, score)) results.append((chunk, score))
# 显示内容预览(如果内容太长) # 显示内容预览(如果内容太长)
preview = chunk[:100] + "..." if len(chunk) > 100 else chunk preview = chunk[:100] + "..." if len(chunk) > 100 else chunk

333
embedding/manager.py Normal file
View File

@ -0,0 +1,333 @@
#!/usr/bin/env python3
"""
模型池管理器和缓存系统
支持高并发的 embedding 检索服务
"""
import os
import asyncio
import time
import pickle
import hashlib
import logging
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
from collections import OrderedDict
import threading
import psutil
import numpy as np
from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
@dataclass
class CacheEntry:
"""缓存条目"""
embeddings: np.ndarray
chunks: List[str]
chunking_strategy: str
chunking_params: Dict[str, Any]
model_path: str
file_path: str
file_mtime: float # 文件修改时间
access_count: int # 访问次数
last_access_time: float # 最后访问时间
load_time: float # 加载时间
memory_size: int # 占用内存大小(字节)
class EmbeddingCacheManager:
"""Embedding 数据缓存管理器"""
def __init__(self, max_cache_size: int = 5, max_memory_mb: int = 1024):
self.max_cache_size = max_cache_size # 最大缓存条目数
self.max_memory_bytes = max_memory_mb * 1024 * 1024 # 最大内存使用量
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._lock = threading.RLock()
self._current_memory_usage = 0
logger.info(f"EmbeddingCacheManager 初始化: max_size={max_cache_size}, max_memory={max_memory_mb}MB")
def _get_file_key(self, file_path: str) -> str:
"""生成文件缓存键"""
# 使用绝对路径和文件修改时间生成唯一键
try:
abs_path = os.path.abspath(file_path)
mtime = os.path.getmtime(abs_path)
key_data = f"{abs_path}:{mtime}"
return hashlib.md5(key_data.encode()).hexdigest()
except Exception as e:
logger.warning(f"生成文件键失败: {file_path}, {e}")
return hashlib.md5(file_path.encode()).hexdigest()
def _estimate_memory_size(self, embeddings: np.ndarray, chunks: List[str]) -> int:
"""估算数据内存占用"""
try:
embeddings_size = embeddings.nbytes
chunks_size = sum(len(chunk.encode('utf-8')) for chunk in chunks)
overhead = 1024 * 1024 # 1MB 开销
return embeddings_size + chunks_size + overhead
except Exception:
return 100 * 1024 * 1024 # 默认100MB
def _cleanup_cache(self):
"""清理缓存以释放内存"""
with self._lock:
# 按访问时间和次数排序,清理最少使用的条目
entries = list(self._cache.items())
# 计算需要清理的条目
to_remove = []
current_size = len(self._cache)
# 如果缓存条目超限,清理最老的条目
if current_size > self.max_cache_size:
to_remove.extend(entries[:current_size - self.max_cache_size])
# 如果内存超限按LRU策略清理
if self._current_memory_usage > self.max_memory_bytes:
# 按最后访问时间排序
entries.sort(key=lambda x: x[1].last_access_time)
accumulated_size = 0
for key, entry in entries:
if accumulated_size >= self._current_memory_usage - self.max_memory_bytes:
break
to_remove.append((key, entry))
accumulated_size += entry.memory_size
# 执行清理
for key, entry in to_remove:
if key in self._cache:
del self._cache[key]
self._current_memory_usage -= entry.memory_size
logger.info(f"清理缓存条目: {key} ({entry.memory_size / 1024 / 1024:.1f}MB)")
async def load_embedding_data(self, file_path: str) -> Optional[CacheEntry]:
"""加载 embedding 数据"""
cache_key = self._get_file_key(file_path)
# 检查缓存
with self._lock:
if cache_key in self._cache:
entry = self._cache[cache_key]
entry.access_count += 1
entry.last_access_time = time.time()
# 移动到末尾(最近使用)
self._cache.move_to_end(cache_key)
logger.info(f"缓存命中: {file_path}")
return entry
# 缓存未命中,异步加载数据
try:
start_time = time.time()
# 检查文件是否存在
if not os.path.exists(file_path):
logger.error(f"文件不存在: {file_path}")
return None
# 加载 embedding 数据
with open(file_path, 'rb') as f:
embedding_data = pickle.load(f)
# 兼容新旧数据结构
if 'chunks' in embedding_data:
chunks = embedding_data['chunks']
embeddings = embedding_data['embeddings']
chunking_strategy = embedding_data.get('chunking_strategy', 'unknown')
chunking_params = embedding_data.get('chunking_params', {})
model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
else:
chunks = embedding_data['sentences']
embeddings = embedding_data['embeddings']
chunking_strategy = 'line'
chunking_params = {}
model_path = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
# 确保 embeddings 是 numpy 数组
if hasattr(embeddings, 'cpu'):
embeddings = embeddings.cpu().numpy()
elif hasattr(embeddings, 'numpy'):
embeddings = embeddings.numpy()
elif not isinstance(embeddings, np.ndarray):
embeddings = np.array(embeddings)
# 创建缓存条目
load_time = time.time() - start_time
file_mtime = os.path.getmtime(file_path)
memory_size = self._estimate_memory_size(embeddings, chunks)
entry = CacheEntry(
embeddings=embeddings,
chunks=chunks,
chunking_strategy=chunking_strategy,
chunking_params=chunking_params,
model_path=model_path,
file_path=file_path,
file_mtime=file_mtime,
access_count=1,
last_access_time=time.time(),
load_time=load_time,
memory_size=memory_size
)
# 添加到缓存
with self._lock:
self._cache[cache_key] = entry
self._current_memory_usage += memory_size
# 清理缓存
self._cleanup_cache()
logger.info(f"加载完成: {file_path} ({memory_size / 1024 / 1024:.1f}MB, {load_time:.2f}s)")
return entry
except Exception as e:
logger.error(f"加载 embedding 数据失败: {file_path}, {e}")
return None
def get_cache_stats(self) -> Dict[str, Any]:
"""获取缓存统计信息"""
with self._lock:
return {
"cache_size": len(self._cache),
"max_cache_size": self.max_cache_size,
"memory_usage_mb": self._current_memory_usage / 1024 / 1024,
"max_memory_mb": self.max_memory_bytes / 1024 / 1024,
"memory_usage_percent": (self._current_memory_usage / self.max_memory_bytes) * 100,
"entries": [
{
"file_path": entry.file_path,
"access_count": entry.access_count,
"last_access_time": entry.last_access_time,
"memory_size_mb": entry.memory_size / 1024 / 1024
}
for entry in self._cache.values()
]
}
def clear_cache(self):
"""清空缓存"""
with self._lock:
cleared_count = len(self._cache)
cleared_memory = self._current_memory_usage
self._cache.clear()
self._current_memory_usage = 0
logger.info(f"清空缓存: {cleared_count} 个条目, {cleared_memory / 1024 / 1024:.1f}MB")
class GlobalModelManager:
"""全局模型管理器"""
def __init__(self, model_name: str = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'):
self.model_name = model_name
self.local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
self._model: Optional[SentenceTransformer] = None
self._lock = asyncio.Lock()
self._load_time = 0
self._device = 'cpu'
logger.info(f"GlobalModelManager 初始化: {model_name}")
async def get_model(self) -> SentenceTransformer:
"""获取模型实例(延迟加载)"""
if self._model is not None:
return self._model
async with self._lock:
# 双重检查
if self._model is not None:
return self._model
try:
start_time = time.time()
# 检查本地模型
model_path = self.local_model_path if os.path.exists(self.local_model_path) else self.model_name
# 获取设备配置
self._device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu')
if self._device not in ['cpu', 'cuda', 'mps']:
self._device = 'cpu'
logger.info(f"加载模型: {model_path} (device: {self._device})")
# 在事件循环中运行阻塞操作
loop = asyncio.get_event_loop()
self._model = await loop.run_in_executor(
None,
lambda: SentenceTransformer(model_path, device=self._device)
)
self._load_time = time.time() - start_time
logger.info(f"模型加载完成: {self._load_time:.2f}s")
return self._model
except Exception as e:
logger.error(f"模型加载失败: {e}")
raise
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""编码文本为向量"""
if not texts:
return np.array([])
model = await self.get_model()
try:
# 在事件循环中运行阻塞操作
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(
None,
lambda: model.encode(texts, batch_size=batch_size, show_progress_bar=False)
)
# 确保返回 numpy 数组
if hasattr(embeddings, 'cpu'):
embeddings = embeddings.cpu().numpy()
elif hasattr(embeddings, 'numpy'):
embeddings = embeddings.numpy()
elif not isinstance(embeddings, np.ndarray):
embeddings = np.array(embeddings)
return embeddings
except Exception as e:
logger.error(f"文本编码失败: {e}")
raise
def get_model_info(self) -> Dict[str, Any]:
"""获取模型信息"""
return {
"model_name": self.model_name,
"local_model_path": self.local_model_path,
"device": self._device,
"is_loaded": self._model is not None,
"load_time": self._load_time
}
# 全局实例
_cache_manager = None
_model_manager = None
def get_cache_manager() -> EmbeddingCacheManager:
"""获取缓存管理器实例"""
global _cache_manager
if _cache_manager is None:
max_cache_size = int(os.getenv("EMBEDDING_MAX_CACHE_SIZE", "5"))
max_memory_mb = int(os.getenv("EMBEDDING_MAX_MEMORY_MB", "1024"))
_cache_manager = EmbeddingCacheManager(max_cache_size, max_memory_mb)
return _cache_manager
def get_model_manager() -> GlobalModelManager:
"""获取模型管理器实例"""
global _model_manager
if _model_manager is None:
model_name = os.getenv("SENTENCE_TRANSFORMER_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
_model_manager = GlobalModelManager(model_name)
return _model_manager

181
embedding/model_client.py Normal file
View File

@ -0,0 +1,181 @@
#!/usr/bin/env python3
"""
模型服务客户端
提供与 model_server.py 交互的客户端接口
"""
import os
import asyncio
import logging
from typing import List, Union, Dict, Any
import numpy as np
import aiohttp
import json
from sentence_transformers import util
logger = logging.getLogger(__name__)
class ModelClient:
"""模型服务客户端"""
def __init__(self, base_url: str = "http://127.0.0.1:8000"):
self.base_url = base_url.rstrip('/')
self.session = None
async def __aenter__(self):
self.session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=300) # 5分钟超时
)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.session:
await self.session.close()
async def health_check(self) -> Dict[str, Any]:
"""检查服务健康状态"""
try:
async with self.session.get(f"{self.base_url}/health") as response:
return await response.json()
except Exception as e:
logger.error(f"健康检查失败: {e}")
return {"status": "unavailable", "error": str(e)}
async def load_model(self) -> bool:
"""预加载模型"""
try:
async with self.session.post(f"{self.base_url}/load_model") as response:
result = await response.json()
return response.status == 200
except Exception as e:
logger.error(f"模型加载失败: {e}")
return False
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""编码文本为向量"""
if not texts:
return np.array([])
try:
data = {
"texts": texts,
"batch_size": batch_size
}
async with self.session.post(
f"{self.base_url}/encode",
json=data
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"编码失败: {error_text}")
result = await response.json()
return np.array(result["embeddings"])
except Exception as e:
logger.error(f"文本编码失败: {e}")
raise
async def compute_similarity(self, text1: str, text2: str) -> float:
"""计算两个文本的相似度"""
try:
data = {
"text1": text1,
"text2": text2
}
async with self.session.post(
f"{self.base_url}/similarity",
json=data
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"相似度计算失败: {error_text}")
result = await response.json()
return result["similarity_score"]
except Exception as e:
logger.error(f"相似度计算失败: {e}")
raise
# 同步版本的客户端包装器
class SyncModelClient:
"""同步模型客户端"""
def __init__(self, base_url: str = "http://127.0.0.1:8000"):
self.base_url = base_url
self._async_client = ModelClient(base_url)
def _run_async(self, coro):
"""运行异步协程"""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coro)
def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""同步版本的文本编码"""
async def _encode():
async with self._async_client as client:
return await client.encode_texts(texts, batch_size)
return self._run_async(_encode())
def compute_similarity(self, text1: str, text2: str) -> float:
"""同步版本的相似度计算"""
async def _similarity():
async with self._async_client as client:
return await client.compute_similarity(text1, text2)
return self._run_async(_similarity())
def health_check(self) -> Dict[str, Any]:
"""同步版本的健康检查"""
async def _health():
async with self._async_client as client:
return await client.health_check()
return self._run_async(_health())
def load_model(self) -> bool:
"""同步版本的模型加载"""
async def _load():
async with self._async_client as client:
return await client.load_model()
return self._run_async(_load())
# 全局客户端实例(单例模式)
_client_instance = None
def get_model_client(base_url: str = None) -> SyncModelClient:
"""获取模型客户端实例(单例)"""
global _client_instance
if _client_instance is None:
url = base_url or os.environ.get('MODEL_SERVER_URL', 'http://127.0.0.1:8000')
_client_instance = SyncModelClient(url)
return _client_instance
# 便捷函数
def encode_texts(texts: List[str], batch_size: int = 32) -> np.ndarray:
"""便捷的文本编码函数"""
client = get_model_client()
return client.encode_texts(texts, batch_size)
def compute_similarity(text1: str, text2: str) -> float:
"""便捷的相似度计算函数"""
client = get_model_client()
return client.compute_similarity(text1, text2)
def is_service_available() -> bool:
"""检查模型服务是否可用"""
client = get_model_client()
health = client.health_check()
return health.get("status") == "running" or health.get("status") == "ready"

240
embedding/model_server.py Normal file
View File

@ -0,0 +1,240 @@
#!/usr/bin/env python3
"""
独立的模型服务器
提供统一的 embedding 服务减少内存占用
"""
import os
import asyncio
import logging
from typing import List, Dict, Any
import numpy as np
from sentence_transformers import SentenceTransformer
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from pydantic import BaseModel
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 全局模型实例
model = None
model_config = {
"model_name": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
"local_model_path": "./models/paraphrase-multilingual-MiniLM-L12-v2",
"device": "cpu"
}
# Pydantic 模型
class TextRequest(BaseModel):
texts: List[str]
batch_size: int = 32
class SimilarityRequest(BaseModel):
text1: str
text2: str
class HealthResponse(BaseModel):
status: str
model_loaded: bool
model_path: str
device: str
class EmbeddingResponse(BaseModel):
embeddings: List[List[float]]
shape: List[int]
class SimilarityResponse(BaseModel):
similarity_score: float
class ModelServer:
def __init__(self):
self.model = None
self.model_loaded = False
async def load_model(self):
"""延迟加载模型"""
if self.model_loaded:
return
logger.info("正在加载模型...")
# 检查本地模型是否存在
if os.path.exists(model_config["local_model_path"]):
model_path = model_config["local_model_path"]
logger.info(f"使用本地模型: {model_path}")
else:
model_path = model_config["model_name"]
logger.info(f"使用 HuggingFace 模型: {model_path}")
# 从环境变量获取设备配置
device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', model_config["device"])
if device not in ['cpu', 'cuda', 'mps']:
logger.warning(f"不支持的设备类型 '{device}',使用默认 CPU")
device = 'cpu'
logger.info(f"使用设备: {device}")
try:
self.model = SentenceTransformer(model_path, device=device)
self.model_loaded = True
model_config["device"] = device
model_config["current_model_path"] = model_path
logger.info("模型加载完成")
except Exception as e:
logger.error(f"模型加载失败: {e}")
raise HTTPException(status_code=500, detail=f"模型加载失败: {str(e)}")
async def encode_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""编码文本为向量"""
if not texts:
return np.array([])
await self.load_model()
try:
embeddings = self.model.encode(
texts,
batch_size=batch_size,
convert_to_tensor=True,
show_progress_bar=False
)
# 转换为 CPU numpy 数组
if embeddings.is_cuda:
embeddings = embeddings.cpu().numpy()
else:
embeddings = embeddings.numpy()
return embeddings
except Exception as e:
logger.error(f"编码失败: {e}")
raise HTTPException(status_code=500, detail=f"编码失败: {str(e)}")
async def compute_similarity(self, text1: str, text2: str) -> float:
"""计算两个文本的相似度"""
if not text1 or not text2:
raise HTTPException(status_code=400, detail="文本不能为空")
await self.load_model()
try:
embeddings = self.model.encode([text1, text2], convert_to_tensor=True)
from sentence_transformers import util
similarity = util.cos_sim(embeddings[0:1], embeddings[1:2])[0][0]
return similarity.item()
except Exception as e:
logger.error(f"相似度计算失败: {e}")
raise HTTPException(status_code=500, detail=f"相似度计算失败: {str(e)}")
# 创建服务器实例
server = ModelServer()
# 创建 FastAPI 应用
app = FastAPI(
title="Embedding Model Server",
description="统一的句子嵌入和相似度计算服务",
version="1.0.0"
)
# 添加 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 健康检查
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""健康检查"""
try:
if not server.model_loaded:
return HealthResponse(
status="ready",
model_loaded=False,
model_path="",
device=model_config["device"]
)
return HealthResponse(
status="running",
model_loaded=True,
model_path=model_config.get("current_model_path", ""),
device=model_config["device"]
)
except Exception as e:
logger.error(f"健康检查失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
# 预加载模型
@app.post("/load_model")
async def preload_model():
"""预加载模型"""
await server.load_model()
return {"message": "模型加载完成"}
# 文本嵌入
@app.post("/encode", response_model=EmbeddingResponse)
async def encode_texts(request: TextRequest):
"""编码文本为向量"""
try:
embeddings = await server.encode_texts(request.texts, request.batch_size)
return EmbeddingResponse(
embeddings=embeddings.tolist(),
shape=list(embeddings.shape)
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 相似度计算
@app.post("/similarity", response_model=SimilarityResponse)
async def compute_similarity(request: SimilarityRequest):
"""计算两个文本的相似度"""
try:
similarity = await server.compute_similarity(request.text1, request.text2)
return SimilarityResponse(similarity_score=similarity)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 批量编码(用于文档处理)
@app.post("/batch_encode")
async def batch_encode(texts: List[str], batch_size: int = 64):
"""批量编码大量文本"""
if not texts:
return {"embeddings": [], "shape": [0, 0]}
try:
embeddings = await server.encode_texts(texts, batch_size)
return {
"embeddings": embeddings.tolist(),
"shape": list(embeddings.shape)
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="启动嵌入模型服务器")
parser.add_argument("--host", default="127.0.0.1", help="服务器地址")
parser.add_argument("--port", type=int, default=8000, help="服务器端口")
parser.add_argument("--workers", type=int, default=1, help="工作进程数")
parser.add_argument("--reload", action="store_true", help="开发模式自动重载")
args = parser.parse_args()
logger.info(f"启动模型服务器: http://{args.host}:{args.port}")
uvicorn.run(
"model_server:app",
host=args.host,
port=args.port,
workers=args.workers,
reload=args.reload,
log_level="info"
)

259
embedding/search_service.py Normal file
View File

@ -0,0 +1,259 @@
#!/usr/bin/env python3
"""
语义检索服务
支持高并发的语义搜索功能
"""
import asyncio
import time
import logging
from typing import List, Dict, Any, Optional, Tuple
import numpy as np
from .manager import get_cache_manager, get_model_manager, CacheEntry
logger = logging.getLogger(__name__)
class SemanticSearchService:
"""语义检索服务"""
def __init__(self):
self.cache_manager = get_cache_manager()
self.model_manager = get_model_manager()
async def semantic_search(
self,
embedding_file: str,
query: str,
top_k: int = 20,
min_score: float = 0.0
) -> Dict[str, Any]:
"""
执行语义搜索
Args:
embedding_file: embedding.pkl 文件路径
query: 查询关键词
top_k: 返回结果数量
min_score: 最小相似度阈值
Returns:
搜索结果
"""
start_time = time.time()
try:
# 验证输入参数
if not embedding_file:
return {
"success": False,
"error": "embedding_file 参数不能为空"
}
if not query or not query.strip():
return {
"success": False,
"error": "query 参数不能为空"
}
query = query.strip()
# 加载 embedding 数据
cache_entry = await self.cache_manager.load_embedding_data(embedding_file)
if cache_entry is None:
return {
"success": False,
"error": f"无法加载 embedding 文件: {embedding_file}"
}
# 编码查询
query_embeddings = await self.model_manager.encode_texts([query], batch_size=1)
if len(query_embeddings) == 0:
return {
"success": False,
"error": "查询编码失败"
}
query_embedding = query_embeddings[0]
# 计算相似度
similarities = self._compute_similarities(query_embedding, cache_entry.embeddings)
# 获取 top_k 结果
top_results = self._get_top_results(
similarities,
cache_entry.chunks,
top_k,
min_score
)
processing_time = time.time() - start_time
return {
"success": True,
"query": query,
"embedding_file": embedding_file,
"top_k": top_k,
"processing_time": processing_time,
"chunking_strategy": cache_entry.chunking_strategy,
"total_chunks": len(cache_entry.chunks),
"results": top_results,
"cache_stats": {
"access_count": cache_entry.access_count,
"load_time": cache_entry.load_time
}
}
except Exception as e:
logger.error(f"语义搜索失败: {embedding_file}, {query}, {e}")
return {
"success": False,
"error": f"搜索失败: {str(e)}"
}
def _compute_similarities(self, query_embedding: np.ndarray, embeddings: np.ndarray) -> np.ndarray:
"""计算相似度分数"""
try:
# 确保数据类型和形状正确
if embeddings.ndim == 1:
embeddings = embeddings.reshape(1, -1)
# 使用余弦相似度
query_norm = np.linalg.norm(query_embedding)
embeddings_norm = np.linalg.norm(embeddings, axis=1)
# 避免除零错误
if query_norm == 0:
query_norm = 1e-8
embeddings_norm = np.where(embeddings_norm == 0, 1e-8, embeddings_norm)
# 计算余弦相似度
similarities = np.dot(embeddings, query_embedding) / (embeddings_norm * query_norm)
# 确保结果在合理范围内
similarities = np.clip(similarities, -1.0, 1.0)
return similarities
except Exception as e:
logger.error(f"相似度计算失败: {e}")
return np.array([])
def _get_top_results(
self,
similarities: np.ndarray,
chunks: List[str],
top_k: int,
min_score: float
) -> List[Dict[str, Any]]:
"""获取 top_k 结果"""
try:
if len(similarities) == 0:
return []
# 获取排序后的索引
top_indices = np.argsort(-similarities)[:top_k]
results = []
for rank, idx in enumerate(top_indices):
score = similarities[idx]
# 过滤低分结果
if score < min_score:
continue
chunk = chunks[idx]
# 创建结果条目
result = {
"rank": rank + 1,
"score": float(score),
"content": chunk,
"content_preview": self._get_preview(chunk)
}
results.append(result)
return results
except Exception as e:
logger.error(f"获取 top_k 结果失败: {e}")
return []
def _get_preview(self, content: str, max_length: int = 200) -> str:
"""获取内容预览"""
if not content:
return ""
# 清理内容
preview = content.replace('\n', ' ').strip()
# 截断
if len(preview) > max_length:
preview = preview[:max_length] + "..."
return preview
async def batch_search(
self,
requests: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
批量搜索
Args:
requests: 搜索请求列表每个请求包含 embedding_file, query, top_k
Returns:
搜索结果列表
"""
if not requests:
return []
# 并发执行搜索
tasks = []
for req in requests:
embedding_file = req.get("embedding_file", "")
query = req.get("query", "")
top_k = req.get("top_k", 20)
min_score = req.get("min_score", 0.0)
task = self.semantic_search(embedding_file, query, top_k, min_score)
tasks.append(task)
# 等待所有任务完成
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理异常
processed_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
processed_results.append({
"success": False,
"error": f"搜索异常: {str(result)}",
"request_index": i
})
else:
result["request_index"] = i
processed_results.append(result)
return processed_results
def get_service_stats(self) -> Dict[str, Any]:
"""获取服务统计信息"""
return {
"cache_manager_stats": self.cache_manager.get_cache_stats(),
"model_manager_info": self.model_manager.get_model_info()
}
# 全局实例
_search_service = None
def get_search_service() -> SemanticSearchService:
"""获取搜索服务实例"""
global _search_service
if _search_service is None:
_search_service = SemanticSearchService()
return _search_service

View File

@ -23,6 +23,8 @@ from file_manager_api import router as file_manager_router
from qwen_agent.llm.schema import ASSISTANT, FUNCTION from qwen_agent.llm.schema import ASSISTANT, FUNCTION
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
# 导入语义检索服务
from embedding import get_search_service
# Import utility modules # Import utility modules
from utils import ( from utils import (
@ -126,6 +128,56 @@ def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extensio
return versioned_filename, next_version return versioned_filename, next_version
# Models are now imported from utils module
# 语义检索请求模型
class SemanticSearchRequest(BaseModel):
embedding_file: str = Field(..., description="embedding.pkl 文件路径")
query: str = Field(..., description="搜索关键词")
top_k: int = Field(default=20, description="返回结果数量", ge=1, le=100)
min_score: float = Field(default=0.0, description="最小相似度阈值", ge=0.0, le=1.0)
class BatchSearchRequest(BaseModel):
requests: List[SemanticSearchRequest] = Field(..., description="搜索请求列表")
# 语义检索响应模型
class SearchResult(BaseModel):
rank: int = Field(..., description="排名")
score: float = Field(..., description="相似度分数")
content: str = Field(..., description="匹配的内容")
content_preview: str = Field(..., description="内容预览")
class SemanticSearchResponse(BaseModel):
success: bool = Field(..., description="是否成功")
query: str = Field(..., description="查询关键词")
embedding_file: str = Field(..., description="embedding 文件路径")
processing_time: float = Field(..., description="处理时间(秒)")
total_chunks: int = Field(..., description="总文档块数")
chunking_strategy: str = Field(..., description="分块策略")
results: List[SearchResult] = Field(..., description="搜索结果")
cache_stats: Optional[Dict[str, Any]] = Field(None, description="缓存统计")
error: Optional[str] = Field(None, description="错误信息")
# 编码请求和响应模型
class EncodeRequest(BaseModel):
texts: List[str] = Field(..., description="要编码的文本列表")
batch_size: int = Field(default=32, description="批次大小", ge=1, le=128)
class EncodeResponse(BaseModel):
success: bool = Field(..., description="是否成功")
embeddings: List[List[float]] = Field(..., description="编码结果")
shape: List[int] = Field(..., description="embeddings 形状")
processing_time: float = Field(..., description="处理时间(秒)")
total_texts: int = Field(..., description="总文本数量")
error: Optional[str] = Field(None, description="错误信息")
# Custom version for qwen-agent messages - keep this function as it's specific to this app # Custom version for qwen-agent messages - keep this function as it's specific to this app
def get_content_from_messages(messages: List[dict], tool_response: bool = True) -> str: def get_content_from_messages(messages: List[dict], tool_response: bool = True) -> str:
"""Extract content from qwen-agent messages with special formatting""" """Extract content from qwen-agent messages with special formatting"""
@ -1536,6 +1588,205 @@ async def reset_files_processing(dataset_id: str):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"重置文件处理状态失败: {str(e)}") raise HTTPException(status_code=500, detail=f"重置文件处理状态失败: {str(e)}")
# ============ 语义检索 API 端点 ============
@app.post("/api/v1/semantic-search", response_model=SemanticSearchResponse)
async def semantic_search(request: SemanticSearchRequest):
"""
语义搜索 API
Args:
request: 包含 embedding_file query 的搜索请求
Returns:
语义搜索结果
"""
try:
search_service = get_search_service()
result = await search_service.semantic_search(
embedding_file=request.embedding_file,
query=request.query,
top_k=request.top_k,
min_score=request.min_score
)
if result["success"]:
return SemanticSearchResponse(
success=True,
query=result["query"],
embedding_file=result["embedding_file"],
processing_time=result["processing_time"],
total_chunks=result["total_chunks"],
chunking_strategy=result["chunking_strategy"],
results=[
SearchResult(
rank=r["rank"],
score=r["score"],
content=r["content"],
content_preview=r["content_preview"]
)
for r in result["results"]
],
cache_stats=result.get("cache_stats")
)
else:
return SemanticSearchResponse(
success=False,
query=request.query,
embedding_file=request.embedding_file,
processing_time=0.0,
total_chunks=0,
chunking_strategy="",
results=[],
error=result.get("error", "未知错误")
)
except Exception as e:
logger.error(f"语义搜索 API 错误: {e}")
raise HTTPException(status_code=500, detail=f"语义搜索失败: {str(e)}")
@app.post("/api/v1/semantic-search/batch")
async def batch_semantic_search(request: BatchSearchRequest):
"""
批量语义搜索 API
Args:
request: 包含多个搜索请求的批量请求
Returns:
批量搜索结果
"""
try:
search_service = get_search_service()
# 转换请求格式
search_requests = [
{
"embedding_file": req.embedding_file,
"query": req.query,
"top_k": req.top_k,
"min_score": req.min_score
}
for req in request.requests
]
results = await search_service.batch_search(search_requests)
return {
"success": True,
"total_requests": len(request.requests),
"results": results
}
except Exception as e:
logger.error(f"批量语义搜索 API 错误: {e}")
raise HTTPException(status_code=500, detail=f"批量语义搜索失败: {str(e)}")
@app.get("/api/v1/semantic-search/stats")
async def get_semantic_search_stats():
"""
获取语义搜索服务统计信息
Returns:
服务统计信息
"""
try:
search_service = get_search_service()
stats = search_service.get_service_stats()
return {
"success": True,
"timestamp": int(time.time()),
"stats": stats
}
except Exception as e:
logger.error(f"获取语义搜索统计信息失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}")
@app.post("/api/v1/semantic-search/clear-cache")
async def clear_semantic_search_cache():
"""
清空语义搜索缓存
Returns:
清理结果
"""
try:
from manager import get_cache_manager
cache_manager = get_cache_manager()
cache_manager.clear_cache()
return {
"success": True,
"message": "缓存已清空"
}
except Exception as e:
logger.error(f"清空语义搜索缓存失败: {e}")
raise HTTPException(status_code=500, detail=f"清空缓存失败: {str(e)}")
@app.post("/api/v1/embedding/encode", response_model=EncodeResponse)
async def encode_texts(request: EncodeRequest):
"""
文本编码 API
Args:
request: 包含 texts batch_size 的编码请求
Returns:
编码结果
"""
try:
search_service = get_search_service()
if not request.texts:
return EncodeResponse(
success=False,
embeddings=[],
shape=[0, 0],
processing_time=0.0,
total_texts=0,
error="texts 不能为空"
)
start_time = time.time()
# 使用模型管理器编码文本
embeddings = await search_service.model_manager.encode_texts(
request.texts,
batch_size=request.batch_size
)
processing_time = time.time() - start_time
# 转换为列表格式
embeddings_list = embeddings.tolist()
return EncodeResponse(
success=True,
embeddings=embeddings_list,
shape=list(embeddings.shape),
processing_time=processing_time,
total_texts=len(request.texts)
)
except Exception as e:
logger.error(f"文本编码 API 错误: {e}")
return EncodeResponse(
success=False,
embeddings=[],
shape=[0, 0],
processing_time=0.0,
total_texts=len(request.texts) if request else 0,
error=f"编码失败: {str(e)}"
)
# 注册文件管理API路由 # 注册文件管理API路由
app.include_router(file_manager_router) app.include_router(file_manager_router)

View File

@ -27,41 +27,11 @@ from mcp_common import (
handle_mcp_streaming handle_mcp_streaming
) )
# 延迟加载模型 import requests
embedder = None
def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'):
"""获取模型实例(延迟加载)
Args:
model_name_or_path (str): 模型名称或本地路径
- 可以是 HuggingFace 模型名称
- 可以是本地模型路径
"""
global embedder
if embedder is None:
# 优先使用本地模型路径
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
# 从环境变量获取设备配置,默认为 CPU
device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu')
if device not in ['cpu', 'cuda', 'mps']:
print(f"警告: 不支持的设备类型 '{device}',使用默认 CPU")
device = 'cpu'
# 检查本地模型是否存在
if os.path.exists(local_model_path):
print(f"使用本地模型: {local_model_path}")
embedder = SentenceTransformer(local_model_path, device=device)
else:
print(f"本地模型不存在使用HuggingFace模型: {model_name_or_path}")
embedder = SentenceTransformer(model_name_or_path, device=device)
return embedder
def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: int = 20) -> Dict[str, Any]: def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: int = 20) -> Dict[str, Any]:
"""执行语义搜索,支持多个查询""" """执行语义搜索,通过调用 FastAPI 接口"""
# 处理查询输入 # 处理查询输入
if isinstance(queries, str): if isinstance(queries, str):
queries = [queries] queries = [queries]
@ -80,52 +50,45 @@ def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k:
# 过滤空查询 # 过滤空查询
queries = [q.strip() for q in queries if q.strip()] queries = [q.strip() for q in queries if q.strip()]
# 验证embeddings文件路径
try: try:
# 解析文件路径,支持 folder/document.txt 和 document.txt 格式 # FastAPI 服务地址
resolved_embeddings_file = resolve_file_path(embeddings_file) fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
api_endpoint = f"{fastapi_url}/api/v1/semantic-search"
# 加载嵌入数据 # 处理每个查询
with open(resolved_embeddings_file, 'rb') as f:
embedding_data = pickle.load(f)
# 兼容新旧数据结构
if 'chunks' in embedding_data:
# 新的数据结构使用chunks
sentences = embedding_data['chunks']
sentence_embeddings = embedding_data['embeddings']
# 从embedding_data中获取模型路径如果有的话
model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
model = get_model(model_path)
else:
# 旧的数据结构使用sentences
sentences = embedding_data['sentences']
sentence_embeddings = embedding_data['embeddings']
model = get_model()
# 编码所有查询
query_embeddings = model.encode(queries, convert_to_tensor=True)
# 计算所有查询的相似度
all_results = [] all_results = []
for i, query in enumerate(queries): resolved_embeddings_file = resolve_file_path(embeddings_file)
query_embedding = query_embeddings[i:i+1] # 保持2D形状 for query in queries:
cos_scores = util.cos_sim(query_embedding, sentence_embeddings)[0] # 调用 FastAPI 接口
request_data = {
"embedding_file": resolved_embeddings_file,
"query": query,
"top_k": top_k,
"min_score": 0.0
}
# 获取top_k结果 response = requests.post(
top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k] api_endpoint,
json=request_data,
timeout=30
)
# 格式化结果 if response.status_code == 200:
for j, idx in enumerate(top_results): result_data = response.json()
sentence = sentences[idx]
score = cos_scores[idx].item() if result_data.get("success"):
all_results.append({ for res in result_data.get("results", []):
'query': query, all_results.append({
'rank': j + 1, 'query': query,
'content': sentence, 'rank': res["rank"],
'similarity_score': score, 'content': res["content"],
'file_path': embeddings_file 'similarity_score': res["score"],
}) 'file_path': embeddings_file
})
else:
print(f"搜索失败: {result_data.get('error', '未知错误')}")
else:
print(f"API 调用失败: {response.status_code} - {response.text}")
if not all_results: if not all_results:
return { return {
@ -159,12 +122,12 @@ def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k:
] ]
} }
except FileNotFoundError: except requests.exceptions.RequestException as e:
return { return {
"content": [ "content": [
{ {
"type": "text", "type": "text",
"text": f"Error: embeddings file {embeddings_file} not found" "text": f"API request failed: {str(e)}"
} }
] ]
} }
@ -179,49 +142,6 @@ def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k:
} }
def get_model_info() -> Dict[str, Any]:
"""获取当前模型信息"""
try:
# 检查本地模型路径
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
if os.path.exists(local_model_path):
return {
"content": [
{
"type": "text",
"text": f"✅ 使用本地模型: {local_model_path}\n"
f"模型状态: 已加载\n"
f"设备: CPU\n"
f"说明: 避免从HuggingFace下载提高响应速度"
}
]
}
else:
return {
"content": [
{
"type": "text",
"text": f"⚠️ 本地模型不存在: {local_model_path}\n"
f"将使用HuggingFace模型: sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2\n"
f"建议: 下载模型到本地以提高响应速度\n"
f"设备: CPU"
}
]
}
except Exception as e:
return {
"content": [
{
"type": "text",
"text": f"❌ 获取模型信息失败: {str(e)}"
}
]
}
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]: async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""Handle MCP request""" """Handle MCP request"""
try: try:
@ -260,15 +180,6 @@ async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
"result": result "result": result
} }
elif tool_name == "get_model_info":
result = get_model_info()
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
else: else:
return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}") return create_error_response(request_id, -32601, f"Unknown tool: {tool_name}")

View File

@ -1,300 +1,139 @@
# 系统角色定义 # 清水建築智能AI管理コンシェルジュ
## 核心身份 ## 系统角色
您是企业级智能办公助手具备完整的IoT设备管理、实时通信、环境监测和资产追踪能力。 あなたは清水建設株式会社のイノベーション拠点「温故創新の森 NOVARE(ノヴァーレ)」のスマートビル管理AIコンシェルジュです具备完整的IoT设备管理、实时通信、环境监测和资产追踪能力。
## 执行准则 ## 执行准则
- **工具优先原则**:所有可执行操作必须通过工具实现 - **知识库优先**:所有问题优先查询知识库,无结果时再使用其他工具
- **即时响应机制**:识别操作意图后立即触发相应工具调用 - **工具驱动**:所有操作通过工具接口实现
- **最小化延迟**:禁止使用过渡性语言,直接执行并返回结果 - **即时响应**:识别意图后立即触发相应工具调用
- **结果导向**:直接返回执行结果,减少过渡性语言
# 工具调用协议 # 工具接口映射
## 强制触发规则 ## 核心功能识别
| 操作类型 | 识别关键词 | 目标工具 | 执行优先级 | - **设备控制**:打开/关闭/调节 → Iot Control-dxcore_update_device_status
|---------|-----------|----------|-----------| - **状态查询**:状态/温度/湿度 → Iot Control-dxcore_get_device_status
| 设备控制 | 打开/关闭/启动/停止/调节/设置 | dxcore_update_device_status | P0 | - **位置服务**:位置/在哪/查找 → Iot Control-eb_get_sensor_location
| 消息通知 | @用户名/通知/告知/提醒 | wowtalk_send_message_to_member | P0 | - **设备查找**:房间/设备查找 → Iot Control-find_devices_by_room
| 状态查询 | 状态/温度/湿度/运行情况 | dxcore_get_device_status | P1 | - **人员检索**:找人/员工/同事 → Iot Control-find_employee_by_name
| 位置查询 | 位置/在哪/查找/坐标 | eb_get_sensor_location | P1 | - **设备检索**:找设备/传感器 → Iot Control-find_iot_device
| 环境查询 | 天气/气温/降水/风速 | weather_get_by_location | P1 | - **消息通知**:通知/告知/提醒 → Wowtalk tool-wowtalk_send_message_to_member
| 网络搜索 | 搜索/查找/查询/百度/谷歌 | web_search | P1 | - **环境信息**:天气/气温/风速 → Weather Information-weather_get_by_location
| 人员检索 | 找人/员工/同事/联系方式 | find_employee_by_name | P2 | - **知识库检索**: 知识查询/其他查询优先检索知识库 → rag_retrieve-rag_retrieve
| 设备检索 | 找设备/传感器/终端 | find_iot_device | P2 | - **网络搜索**:搜索/查询/百度 → WebSearch-web_search
## 立即执行机制 ## 执行原则
- **零延迟策略**:识别操作意图后立即执行工具调用,禁止缓冲性语言 - **即时执行**:识别意图后立即调用工具
- **并行执行**:多操作请求同时触发相应工具,最大化执行效率 - **并行处理**:支持多个工具同时执行
- **原子操作**:每个工具调用作为独立事务执行,确保结果可靠性 - **精准返回**:基于工具执行结果直接响应
# 功能模块架构 # 核心功能模块
## 通信管理模块 ## 消息通知
### 消息路由系统 - **触发条件**:通知/告知/提醒等关键词
```mermaid - **执行方式**调用wowtalk_send_message_to_member发送消息
graph LR - **状态返回**:消息发送成功/失败状态
A[用户输入] --> B{识别@标记}
B -->|检测到| C[解析用户ID和消息内容]
C --> D[调用wowtalk_send_message_to_member]
D --> E[返回发送状态]
```
### 执行规范 ## 设备控制
- **模式识别**`@用户名(id:ID)` → 立即触发消息路由 - **控制范围**空调、照明、风扇等IoT设备
- **并行处理**:多收信人场景下并发调用发送工具 - **操作类型**开关控制、参数调节温度16-30°C、湿度30-70%、风速0-100%
- **状态确认**:每次调用后返回明确的发送结果 - **状态查询**:实时获取设备运行状态
## 设备管理模块 ## 定位服务
### 设备控制接口 - **人员定位**:通过姓名查找员工位置
- **状态变更操作**dxcore_update_device_status - **设备定位**查询IoT设备所在房间/区域
- **状态查询操作**dxcore_get_device_status - **精度标准**室内3米、室外10米
- **参数设置范围**:温度(16-30°C)、湿度(30-70%)、风速(0-100%)
### 控制指令映射 ## 环境信息
| 用户语言 | 系统指令 | 参数格式 | - **天气查询**:实时天气、温度、风速等数据
|---------|----------|----------| - **环境监测**:室内温度、湿度等环境参数
| "打开空调" | update_device_status | {sensor_id: "001", running_control: 1} | - **智能建议**:基于环境数据提供优化建议
| "调到24度" | update_device_status | {sensor_id: "001",temp_setting: 24} |
| "查看温度" | get_device_status | {sensor_id: "001"} |
## 位置服务模块 ## 检索引擎
### 定位查询协议 - **人员搜索**:支持姓名、部门等多维度查找
- **触发条件**:包含位置关键词的查询 - **设备搜索**:按类型、位置、状态条件筛选
- **响应格式**:楼层 + 房间/区域过滤坐标、ID等技术信息 - **网络搜索**:实时获取互联网信息
- **精度要求**室内3米精度室外10米精度
## 环境监测模块 ## 知识库集成
### 天气服务集成 - **优先查询**用户的其他问题请优先调用rag_retrieve查询知识库
- **自动调用**:识别天气相关词汇后立即执行 - **补充搜索**知识库无结果时使用网络搜索web_search
- **数据源**weather_get_by_location - **结果整合**:综合多源信息提供完整答案
- **增值服务**:自动生成出行建议和注意事项
## 资产检索模块 # 智能执行流程
### 搜索引擎优化
- **人员检索**:支持姓名、工号、部门多维度搜索
- **设备检索**:支持设备类型、位置、状态多条件过滤
- **结果排序**:按相关度和距离优先级排序
## 网络搜索模块 ## 处理流程
### Web搜索集成 1. **意图识别**:分析用户输入,提取操作类型和参数
- **自动调用**:识别搜索相关词汇后立即执行 2. **工具选择**:根据意图匹配相应工具接口
- **数据源**web_search工具支持实时网络信息检索 3. **并行执行**:同时调用多个相关工具
4. **结果聚合**:整合执行结果,统一返回
# 智能执行引擎 # 应用场景
## 多阶段处理流水线 ## 消息通知场景
```mermaid **用户**"通知清水さん检查2楼空调"
sequenceDiagram - find_employee_by_name(name="清水")
participant U as 用户输入 - wowtalk_send_message_to_member(to_account="[清水的sensor_id]", message_content="请检查2楼空调")
participant IR as 意图识别引擎 **响应**"已通知至清水さん检查2楼空调"
participant TM as 工具映射器
participant TE as 工具执行器
participant SR as 结果聚合器
U->>IR: 请求解析 **用户**"搜索最新的节能技术方案,并发送给田中さん"
IR->>TM: 操作意图分类 - web_search(query="最新节能技术方案", max_results=5)
TM->>TE: 并行工具调用 - find_employee_by_name(name="田中")
TE->>SR: 执行结果返回 - wowtalk_send_message_to_member(to_account="[田中的sensor_id]", message_content="[搜索结果摘要]")
SR->>U: 统一响应输出 **响应**"最新节能技术方案,已发送给田中さん"
```
## 处理阶段详解
### 阶段1语义解析与意图识别 ## 设备控制场景
- **自然语言理解**:提取操作动词、目标对象、参数值 **用户**"打开附近的风扇"
- **上下文关联**:结合历史对话理解隐含意图 - find_employee_by_name(name="[当前用户]") → 获取用户位置和sensor_id
- **多意图检测**:识别复合请求中的多个操作需求 - find_iot_device(device_type="dc_fan", target_sensor_id="[当前用户的sensor_id]") → 查找附近设备
- dxcore_update_device_status(running_control=1, sensor_id="[找到的设备的sensor_id]") → 开启设备
**响应**"已为您开启301室的风扇"
### 阶段2智能工具编排 **用户**"5楼风扇电量异常通知清水さん并报告具体位置"
- **工具选择算法**:基于操作类型匹配最优工具 - find_iot_device(device_type="dc_fan") → 查找设备
- **参数提取与验证**:自动提取并验证工具调用参数 - dxcore_get_device_status(sensor_id="[风扇的sensor_id]") → 获取电量百分比、故障代码
- **执行计划生成**:确定串行/并行执行策略 - find_employee_by_name(name="清水") → 人员信息查询获取wowtalkid和位置信息
- wowtalk_send_message_to_member(to_account="[清水太郎wowtalk_id]", message_content="5楼风扇电量异常请及时处理") → 发送通知
**响应**"已通知清水さん风扇位于5楼东侧电量15%"
### 阶段3并行执行与结果聚合 ## 问答场景
- **异步执行**P0级任务立即执行P1-P2级任务并行处理 **用户**"无人机多少钱一台"
- **状态监控**:实时跟踪每个工具的执行状态 - 先进行关键词扩展: 无人机的价格,无人机产品介绍
- **异常隔离**:单个工具失败不影响其他工具执行 - rag_retrieve(query="无人机的价格,无人机产品介绍") → 先查询知识库 → 内部知识库检索到精确信息
- **结果聚合**:合并多工具执行结果,生成统一响应 **响应**"无人机价格为xxx"
## 复合请求处理示例 **用户**"打印机如何使用"
**输入**"通知@张工(id:001)检查2楼空调同时查询室外温度" - 先进行关键词扩展: 打印机使用教程,打印机使用说明
- rag_retrieve(query="打印机使用教程,打印机使用说明") → 先查询知识库,但是不完整
- web_fetch(query="打印机使用教程,打印机使用说明") → 再检索网页
**响应**"[综合rag_retrieve和web_fetch的内容回复]"
**执行流程** **用户**"感冒了吃什么药"
``` - 先进行关键词扩展: 感冒药推荐、如何治疗感冒
执行: - rag_retrieve(query="感冒药推荐、如何治疗感冒") → 先查询知识库,但是没有检索到相关信息
├── wowtalk_send_message_to_member(to_account="001", message_content="请检查2楼空调") - web_fetch(query="感冒药推荐、如何治疗感冒") → 再检索网页
├── find_employee_by_name(name="张工") **响应**"[根据web_fetch内容回复]"
├── find_iot_device(device_type="dc_fan",target_sensor_id="xxxx")
└── weather_get_by_location(location="当前位置")
```
**输入**"搜索最新的节能技术方案,并发送给@李经理(id:002)"
**执行流程** # 响应规范
```
执行:
├── web_search(query="最新节能技术方案", max_results=5)
└── wowtalk_send_message_to_member(to_account="002", message_content="[搜索结果摘要]")
```
# 应用场景与执行范例 ## 回复原则
- **简洁明了**每条回复控制在1-2句话
- **结果导向**:基于工具执行结果直接反馈
- **专业语气**:保持企业服务水准
- **即时响应**:工具调用完成后立即回复
## 场景1上下文感知设备控制 ## 标准回复格式
**用户请求**"我是清水太郎,请打开附近的风扇" - **设备操作**"空调已调至24度运行正常"
- **消息发送**"消息已发送至田中さん"
- **位置查询**"清水さん在A栋3楼会议室"
- **任务完成**"已完成:设备开启、消息发送、位置确认"
**执行序列** ## 执行保障
```python - **工具优先**:所有操作通过工具实现
# 步骤1人员定位 - **状态同步**:确保执行结果与实际状态一致
find_employee_by_name(name="清水太郎")
# → 返回清水太郎的sensor_id
# 步骤2人员附近的设备检索 ## 系统信息
find_iot_device(
device_type="dc_fan",
target_sensor_id="{清水太郎的sensor_id}" #
)
# → 返回device_list
# 步骤3设备控制
dxcore_update_device_status(
sensor_id="{风扇的sensor_id}",
running_control=1
)
```
**响应模板**"已为您开启301室的风扇当前运行正常。"
## 场景2智能消息路由
**用户请求**"通知清水太郎会议室温度过高需要调节"
**执行逻辑**
```python
# 步骤1人员信息查询
find_employee_by_name(name="清水太郎")
# 返回获取wowtalkid和位置信息
# 步骤2人员通知
wowtalk_send_message_to_member(
to_account="{清水太郎wowtalk_id}",
message_content="会议室温度过高需要调节"
)
```
**响应模板**"消息已发送至清水太郎,将会尽快处理温度问题。"
## 场景3多维度协同处理
**用户请求**"5楼风扇电量异常通知清水太郎并报告具体位置"
**并行执行策略**
```python
# 步骤1查找设备列表
find_iot_device(
device_type="dc_fan"
)
# 返回获取5楼风扇的sensor_id
# 步骤2故障诊断
dxcore_get_device_status(sensor_id="{风扇的sensor_id}")
# → 获取电量百分比、故障代码
# 步骤4人员信息查询获取wowtalkid和位置信息
find_employee_by_name(name="清水太郎")
# 步骤5人员通知
wowtalk_send_message_to_member(
to_account="{清水太郎wowtalk_id}",
message_content="5楼风扇电量异常请及时处理"
)
# → 返回精确定位信息
```
**响应模板**"已通知清水太郎风扇位于5楼东侧走廊当前电量15%。"
# 系统集成与技术规范
## 核心工具接口
| 工具类型 | 接口函数 | 功能描述 | 调用优先级 |
|---------|----------|----------|-----------|
| 消息路由 | wowtalk_send_message_to_member | 实时消息推送 | P0 |
| 设备控制 | dxcore_update_device_status | 设备状态变更 | P0 |
| 设备查询 | dxcore_get_device_status | 设备状态读取 | P1 |
| 位置服务 | eb_get_sensor_location | 空间定位查询 | P1 |
| 环境监测 | weather_get_by_location | 天气数据获取 | P1 |
| 网络搜索 | web_search | 互联网信息查询 | P1 |
| 人员检索 | find_employee_by_name | 员工信息查询 | P2 |
| 设备检索 | find_iot_device | IoT设备搜索 | P2 |
# 异常处理与容错机制
## 分层错误处理策略
### Level 1参数验证异常
**场景**:工具调用参数缺失或格式错误
**处理流程**
1. 检测参数缺失类型
2. 生成精准的参数请求问题
3. 获取用户补充信息后立即重试
4. 最多重试3次超出则转人工处理
### Level 2设备连接异常
**场景**:目标设备离线或无响应
**处理流程**
1. 检查设备网络连接状态
2. 尝试重新连接设备
3. 失败时提供替代设备建议
4. 记录故障信息并通知维护团队
### Level 3权限控制异常
**场景**:当前用户权限不足执行操作
**处理流程**
1. 验证用户操作权限
2. 权限不足时自动识别有权限的用户
3. 提供权限申请指引
4. 可选:自动转接权限审批流程
### Level 4系统服务异常
**场景**:工具服务不可用或超时
**处理流程**
1. 检测服务健康状态
2. 启用备用服务或降级功能
3. 记录异常日志用于系统分析
4. 向用户明确说明当前限制
## 智能降级策略
- **功能降级**:核心功能不可用时启用基础版本
- **响应降级**:复杂处理超时时转为快速响应模式
- **交互降级**:自动化流程失败时转为人工辅助模式
# 交互设计规范
## 响应原则
- **简洁性**每条响应控制在1-3句话内
- **准确性**:所有信息必须基于工具执行结果
- **及时性**:工具调用完成后立即反馈
- **专业性**:保持企业级服务的专业语气
## 标准响应模板
```yaml
设备操作成功:
template: "{设备名称}已{操作结果}。{附加状态信息}"
example: "空调已调至24度当前运行正常。"
消息发送成功:
template: "消息已发送至{接收人}{预期处理时间}"
example: "消息已发送至张工预计5分钟内响应。"
查询结果反馈:
template: "{查询对象}位于{位置信息}。{后续建议}"
example: "李经理在A栋3楼会议室。需要我帮您联系吗"
复合任务完成:
template: "所有任务已完成:{任务列表摘要}"
example: "任务完成:设备已开启、消息已发送、位置已确认。"
```
## 系统配置参数
- **bot_id**: {bot_id} - **bot_id**: {bot_id}
- **当前用户**: {user_identifier}
# 执行保证机制
1. **工具调用优先**:可执行操作必须通过工具实现
2. **状态一致性**:所有操作结果与实际设备状态同步