优化向量检索

This commit is contained in:
朱潮 2025-11-20 19:36:19 +08:00
parent bacf9fce60
commit 37784ebefe
5 changed files with 152 additions and 307 deletions

View File

@ -5,13 +5,12 @@ Embedding Package
from .manager import get_cache_manager, get_model_manager
from .search_service import get_search_service
from .embedding import embed_document, semantic_search, split_document_by_pages
from .embedding import embed_document, split_document_by_pages
__all__ = [
'get_cache_manager',
'get_model_manager',
'get_search_service',
'embed_document',
'semantic_search',
'split_document_by_pages'
]

View File

@ -226,81 +226,6 @@ def embed_document(input_file='document.txt', output_file='embedding.pkl',
print(f"处理文档时出错:{e}")
return None
def semantic_search(user_query, embeddings_file='embedding.pkl', top_k=20):
"""
输入用户查询进行语义匹配返回top_k个最相关的内容块
Args:
user_query (str): 用户查询
embeddings_file (str): 嵌入向量文件路径
top_k (int): 返回的结果数量
Returns:
list: 包含(内容块, 相似度分数)的列表
"""
try:
with open(embeddings_file, 'rb') as f:
embedding_data = pickle.load(f)
# 兼容新旧数据结构
if 'chunks' in embedding_data:
# 新的数据结构使用chunks
chunks = embedding_data['chunks']
chunk_embeddings = embedding_data['embeddings']
chunking_strategy = embedding_data.get('chunking_strategy', 'unknown')
content_type = "内容块"
else:
# 旧的数据结构使用sentences
chunks = embedding_data['sentences']
chunk_embeddings = embedding_data['embeddings']
chunking_strategy = 'line'
content_type = "句子"
# 使用API接口进行编码
print("使用API接口进行查询编码...")
query_embeddings = encode_texts_via_api([user_query], batch_size=1)
query_embedding = query_embeddings[0] if len(query_embeddings) > 0 else np.array([])
# 计算相似度
if len(chunk_embeddings.shape) > 1:
cos_scores = np.dot(chunk_embeddings, query_embedding) / (
np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_embedding) + 1e-8
)
else:
cos_scores = [0.0] # 兼容性处理
# 处理不同格式下的 cos_scores
if isinstance(cos_scores, np.ndarray):
cos_scores_np = cos_scores
else:
# PyTorch tensor
if hasattr(cos_scores, 'is_cuda') and cos_scores.is_cuda:
cos_scores_np = cos_scores.cpu().numpy()
else:
cos_scores_np = cos_scores.numpy()
top_results = np.argsort(-cos_scores_np)[:top_k]
results = []
print(f"\n与查询最相关的 {top_k}{content_type} (分块策略: {chunking_strategy}):")
for i, idx in enumerate(top_results):
chunk = chunks[idx]
score = cos_scores_np[idx]
results.append((chunk, score))
# 显示内容预览(如果内容太长)
preview = chunk[:100] + "..." if len(chunk) > 100 else chunk
preview = preview.replace('\n', ' ') # 替换换行符以便显示
print(f"{i+1}. [{score:.4f}] {preview}")
return results
except FileNotFoundError:
print(f"错误:找不到嵌入文件 {embeddings_file}")
print("请先运行 embed_document() 函数生成嵌入文件")
return []
except Exception as e:
print(f"搜索时出错:{e}")
return []
def paragraph_chunking(text, max_chunk_size=1000, overlap=100, min_chunk_size=200, separator='\n\n'):
@ -767,8 +692,6 @@ def demo_usage():
print(" overlap=200,")
print(" min_chunk_size=300)")
print("\n4. 进行语义搜索:")
print("semantic_search('查询内容', 'paragraph_embedding.pkl', top_k=5)")
# 如果直接运行此文件,执行测试

View File

@ -132,35 +132,6 @@ def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extensio
# 语义检索请求模型
class SemanticSearchRequest(BaseModel):
embedding_file: str = Field(..., description="embedding.pkl 文件路径")
query: str = Field(..., description="搜索关键词")
top_k: int = Field(default=20, description="返回结果数量", ge=1, le=100)
min_score: float = Field(default=0.0, description="最小相似度阈值", ge=0.0, le=1.0)
class BatchSearchRequest(BaseModel):
requests: List[SemanticSearchRequest] = Field(..., description="搜索请求列表")
# 语义检索响应模型
class SearchResult(BaseModel):
rank: int = Field(..., description="排名")
score: float = Field(..., description="相似度分数")
content: str = Field(..., description="匹配的内容")
content_preview: str = Field(..., description="内容预览")
class SemanticSearchResponse(BaseModel):
success: bool = Field(..., description="是否成功")
query: str = Field(..., description="查询关键词")
embedding_file: str = Field(..., description="embedding 文件路径")
processing_time: float = Field(..., description="处理时间(秒)")
total_chunks: int = Field(..., description="总文档块数")
chunking_strategy: str = Field(..., description="分块策略")
results: List[SearchResult] = Field(..., description="搜索结果")
cache_stats: Optional[Dict[str, Any]] = Field(None, description="缓存统计")
error: Optional[str] = Field(None, description="错误信息")
# 编码请求和响应模型
@ -1589,148 +1560,6 @@ async def reset_files_processing(dataset_id: str):
raise HTTPException(status_code=500, detail=f"重置文件处理状态失败: {str(e)}")
# ============ 语义检索 API 端点 ============
@app.post("/api/v1/semantic-search", response_model=SemanticSearchResponse)
async def semantic_search(request: SemanticSearchRequest):
"""
语义搜索 API
Args:
request: 包含 embedding_file query 的搜索请求
Returns:
语义搜索结果
"""
try:
search_service = get_search_service()
result = await search_service.semantic_search(
embedding_file=request.embedding_file,
query=request.query,
top_k=request.top_k,
min_score=request.min_score
)
if result["success"]:
return SemanticSearchResponse(
success=True,
query=result["query"],
embedding_file=result["embedding_file"],
processing_time=result["processing_time"],
total_chunks=result["total_chunks"],
chunking_strategy=result["chunking_strategy"],
results=[
SearchResult(
rank=r["rank"],
score=r["score"],
content=r["content"],
content_preview=r["content_preview"]
)
for r in result["results"]
],
cache_stats=result.get("cache_stats")
)
else:
return SemanticSearchResponse(
success=False,
query=request.query,
embedding_file=request.embedding_file,
processing_time=0.0,
total_chunks=0,
chunking_strategy="",
results=[],
error=result.get("error", "未知错误")
)
except Exception as e:
logger.error(f"语义搜索 API 错误: {e}")
raise HTTPException(status_code=500, detail=f"语义搜索失败: {str(e)}")
@app.post("/api/v1/semantic-search/batch")
async def batch_semantic_search(request: BatchSearchRequest):
"""
批量语义搜索 API
Args:
request: 包含多个搜索请求的批量请求
Returns:
批量搜索结果
"""
try:
search_service = get_search_service()
# 转换请求格式
search_requests = [
{
"embedding_file": req.embedding_file,
"query": req.query,
"top_k": req.top_k,
"min_score": req.min_score
}
for req in request.requests
]
results = await search_service.batch_search(search_requests)
return {
"success": True,
"total_requests": len(request.requests),
"results": results
}
except Exception as e:
logger.error(f"批量语义搜索 API 错误: {e}")
raise HTTPException(status_code=500, detail=f"批量语义搜索失败: {str(e)}")
@app.get("/api/v1/semantic-search/stats")
async def get_semantic_search_stats():
"""
获取语义搜索服务统计信息
Returns:
服务统计信息
"""
try:
search_service = get_search_service()
stats = search_service.get_service_stats()
return {
"success": True,
"timestamp": int(time.time()),
"stats": stats
}
except Exception as e:
logger.error(f"获取语义搜索统计信息失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}")
@app.post("/api/v1/semantic-search/clear-cache")
async def clear_semantic_search_cache():
"""
清空语义搜索缓存
Returns:
清理结果
"""
try:
from manager import get_cache_manager
cache_manager = get_cache_manager()
cache_manager.clear_cache()
return {
"success": True,
"message": "缓存已清空"
}
except Exception as e:
logger.error(f"清空语义搜索缓存失败: {e}")
raise HTTPException(status_code=500, detail=f"清空缓存失败: {str(e)}")
@app.post("/api/v1/embedding/encode", response_model=EncodeResponse)
async def encode_texts(request: EncodeRequest):
"""

View File

@ -13,7 +13,6 @@ import sys
from typing import Any, Dict, List, Optional, Union
import numpy as np
from sentence_transformers import SentenceTransformer, util
from mcp_common import (
get_allowed_directory,
load_tools_from_json,
@ -30,8 +29,43 @@ from mcp_common import (
import requests
def encode_query_via_api(query: str, fastapi_url: str = None) -> np.ndarray:
"""通过API编码单个查询"""
if not fastapi_url:
fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
api_endpoint = f"{fastapi_url}/api/v1/embedding/encode"
try:
# 调用编码接口
request_data = {
"texts": [query],
"batch_size": 1
}
response = requests.post(
api_endpoint,
json=request_data,
timeout=30
)
if response.status_code == 200:
result_data = response.json()
if result_data.get("success"):
embeddings_list = result_data.get("embeddings", [])
if embeddings_list:
return np.array(embeddings_list[0])
print(f"API编码失败: {response.status_code} - {response.text}")
return None
except Exception as e:
print(f"API编码异常: {e}")
return None
def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: int = 20) -> Dict[str, Any]:
"""执行语义搜索,通过调用 FastAPI 接口"""
"""执行语义搜索,直接读取本地embedding文件并计算相似度"""
# 处理查询输入
if isinstance(queries, str):
queries = [queries]
@ -51,44 +85,58 @@ def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k:
queries = [q.strip() for q in queries if q.strip()]
try:
# FastAPI 服务地址
fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
api_endpoint = f"{fastapi_url}/api/v1/semantic-search"
# 解析embedding文件路径
resolved_embeddings_file = resolve_file_path(embeddings_file)
# 读取embedding文件
with open(resolved_embeddings_file, 'rb') as f:
embedding_data = pickle.load(f)
# 兼容新旧数据结构
if 'chunks' in embedding_data:
# 新的数据结构使用chunks
chunks = embedding_data['chunks']
chunk_embeddings = embedding_data['embeddings']
chunking_strategy = embedding_data.get('chunking_strategy', 'unknown')
else:
# 旧的数据结构使用sentences
chunks = embedding_data['sentences']
chunk_embeddings = embedding_data['embeddings']
chunking_strategy = 'line'
all_results = []
# 处理每个查询
all_results = []
resolved_embeddings_file = resolve_file_path(embeddings_file)
for query in queries:
# 调用 FastAPI 接口
request_data = {
"embedding_file": resolved_embeddings_file,
"query": query,
"top_k": top_k,
"min_score": 0.0
}
# 使用API编码查询
print(f"正在为查询编码: {query}")
query_embedding = encode_query_via_api(query)
response = requests.post(
api_endpoint,
json=request_data,
timeout=30
)
if query_embedding is None:
print(f"查询编码失败: {query}")
continue
if response.status_code == 200:
result_data = response.json()
if result_data.get("success"):
for res in result_data.get("results", []):
all_results.append({
'query': query,
'rank': res["rank"],
'content': res["content"],
'similarity_score': res["score"],
'file_path': embeddings_file
})
else:
print(f"搜索失败: {result_data.get('error', '未知错误')}")
# 计算相似度
if len(chunk_embeddings.shape) > 1:
cos_scores = np.dot(chunk_embeddings, query_embedding) / (
np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_embedding) + 1e-8
)
else:
print(f"API 调用失败: {response.status_code} - {response.text}")
cos_scores = np.array([0.0] * len(chunks))
# 获取top_k结果
top_indices = np.argsort(-cos_scores)[:top_k]
for rank, idx in enumerate(top_indices):
score = cos_scores[idx]
if score > 0: # 只包含有一定相关性的结果
all_results.append({
'query': query,
'rank': rank + 1,
'content': chunks[idx],
'similarity_score': float(score),
'file_path': embeddings_file
})
if not all_results:
return {
@ -122,12 +170,12 @@ def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k:
]
}
except requests.exceptions.RequestException as e:
except FileNotFoundError:
return {
"content": [
{
"type": "text",
"text": f"API request failed: {str(e)}"
"text": f"Error: Embeddings file not found: {embeddings_file}"
}
]
}

View File

@ -13,6 +13,18 @@
## 核心功能识别
- **设备控制**:打开/关闭/调节 → Iot Control-dxcore_update_device_status
- 空调(dc_fan)设备参数说明:
running_control: 运行控制 (可选, 0=停止, 1=启动)
automatic_manual_operation: 自动/手动模式 (可选, 0=手动, 1=自动)
air_volume_control: 风量控制 (可选, 15=弱, 20=中, 30=强)
humi_setting: 湿度设定 (可选, 范围: 0-100)
temp_setting: 温度设定 (可选, 范围: 0.0-100.0)
wind_direction_setting: 风向设定 (可选, 范围: -90 to 90)
wind_direction_mode: 风向模式 (可选, 0=自动, 1=中央)
- 照明 (light)设备参数说明:
dimming_control: 调光控制 (可选, 0-100)
color_control_x: 色温控制 X 值 (可选, 与 color_control_y 同时使用)
color_control_y: 色温控制 Y 值 (可选, 与 color_control_x 同时使用)
- **状态查询**:状态/温度/湿度 → Iot Control-dxcore_get_device_status
- **位置服务**:位置/在哪/查找 → Iot Control-eb_get_sensor_location
- **设备查找**:房间/设备查找 → Iot Control-find_devices_by_room
@ -84,11 +96,25 @@
## 设备控制场景
**用户**"打开附近的风扇"
**用户**"打开附近的风扇和灯光"
- find_employee_by_name(name="[当前用户]") → 获取用户位置和sensor_id
- find_iot_device(device_type="dc_fan", target_sensor_id="[当前用户的sensor_id]") → 查找附近设备
- dxcore_update_device_status(running_control=1, sensor_id="[找到的设备的sensor_id]") → 开启设备
**响应**"已为您开启301室的风扇"
- find_iot_device(device_type="light", target_sensor_id="[当前用户的sensor_id]") → 查找附近设备
- **确认步骤**"即将为您开启301室的风扇(强风模式)和灯光(100%亮度),是否确认?"
- **用户同意后**dxcore_update_device_status(running_control=1, sensor_id="[风扇设备的sensor_id]") → 开启风扇设备
- dxcore_update_device_status(dimming_control=100, sensor_id="[灯光设备的sensor_id]") → 开启灯光设备
**响应**"已为您开启301室的风扇和灯光"
**用户拒绝时**"好的,已取消设备控制操作"
**用户**"关闭附近的风扇和灯光"
- find_employee_by_name(name="[当前用户]") → 获取用户位置和sensor_id
- find_iot_device(device_type="dc_fan", target_sensor_id="[当前用户的sensor_id]") → 查找附近设备
- find_iot_device(device_type="light", target_sensor_id="[当前用户的sensor_id]") → 查找附近设备
- **确认步骤**"即将为您关闭301室的风扇和灯光是否确认"
- **用户同意后**dxcore_update_device_status(running_control=0, sensor_id="[风扇设备的sensor_id]") → 关闭风扇设备
- dxcore_update_device_status(dimming_control=0, sensor_id="[灯光设备的sensor_id]") → 关闭灯光设备
**响应**"已为您关闭301室的风扇和灯光"
**用户拒绝时**"好的,已取消设备控制操作"
**用户**"5楼风扇电量异常通知清水さん并报告具体位置"
- find_iot_device(device_type="dc_fan") → 查找设备
@ -116,16 +142,36 @@
**响应**"[根据web_fetch内容回复]"
## 设备控制确认机制
### 确认标准格式
对于所有设备控制请求,必须按以下格式确认:
"即将[操作内容][设备名称][具体参数],是否确认?"
### 参数确认细则
- **空调操作**:明确温度设定、运行模式、风速等
- **照明控制**:明确亮度百分比、色温设定等
- **风扇设备**:明确风速档位、运行状态等
- **其他设备**:明确关键参数和预期效果
### 拒绝处理
用户明确拒绝时:
- 取消所有设备控制操作
- 回复:"好的,已取消设备控制操作"
# 响应规范
## 回复原则
- **简洁明了**每条回复控制在1-2句话
- **确认优先**:设备控制前必须确认,不得自动执行
- **结果导向**:基于工具执行结果直接反馈
- **专业语气**:保持企业服务水准
- **即时响应**:工具调用完成后立即回复
## 标准回复格式
- **设备确认**"即将开启301室空调至24度是否确认"
- **设备操作**"空调已调至24度运行正常"
- **取消操作**"好的,已取消设备控制操作"
- **消息发送**"消息已发送至田中さん"
- **位置查询**"清水さん在A栋3楼会议室"
- **任务完成**"已完成:设备开启、消息发送、位置确认"