优化向量检索
This commit is contained in:
parent
bacf9fce60
commit
37784ebefe
@ -5,13 +5,12 @@ Embedding Package
|
||||
|
||||
from .manager import get_cache_manager, get_model_manager
|
||||
from .search_service import get_search_service
|
||||
from .embedding import embed_document, semantic_search, split_document_by_pages
|
||||
from .embedding import embed_document, split_document_by_pages
|
||||
|
||||
__all__ = [
|
||||
'get_cache_manager',
|
||||
'get_model_manager',
|
||||
'get_search_service',
|
||||
'embed_document',
|
||||
'semantic_search',
|
||||
'split_document_by_pages'
|
||||
]
|
||||
@ -226,81 +226,6 @@ def embed_document(input_file='document.txt', output_file='embedding.pkl',
|
||||
print(f"处理文档时出错:{e}")
|
||||
return None
|
||||
|
||||
def semantic_search(user_query, embeddings_file='embedding.pkl', top_k=20):
|
||||
"""
|
||||
输入用户查询,进行语义匹配,返回top_k个最相关的内容块
|
||||
|
||||
Args:
|
||||
user_query (str): 用户查询
|
||||
embeddings_file (str): 嵌入向量文件路径
|
||||
top_k (int): 返回的结果数量
|
||||
|
||||
Returns:
|
||||
list: 包含(内容块, 相似度分数)的列表
|
||||
"""
|
||||
try:
|
||||
with open(embeddings_file, 'rb') as f:
|
||||
embedding_data = pickle.load(f)
|
||||
|
||||
# 兼容新旧数据结构
|
||||
if 'chunks' in embedding_data:
|
||||
# 新的数据结构(使用chunks)
|
||||
chunks = embedding_data['chunks']
|
||||
chunk_embeddings = embedding_data['embeddings']
|
||||
chunking_strategy = embedding_data.get('chunking_strategy', 'unknown')
|
||||
content_type = "内容块"
|
||||
else:
|
||||
# 旧的数据结构(使用sentences)
|
||||
chunks = embedding_data['sentences']
|
||||
chunk_embeddings = embedding_data['embeddings']
|
||||
chunking_strategy = 'line'
|
||||
content_type = "句子"
|
||||
|
||||
# 使用API接口进行编码
|
||||
print("使用API接口进行查询编码...")
|
||||
query_embeddings = encode_texts_via_api([user_query], batch_size=1)
|
||||
query_embedding = query_embeddings[0] if len(query_embeddings) > 0 else np.array([])
|
||||
|
||||
# 计算相似度
|
||||
if len(chunk_embeddings.shape) > 1:
|
||||
cos_scores = np.dot(chunk_embeddings, query_embedding) / (
|
||||
np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_embedding) + 1e-8
|
||||
)
|
||||
else:
|
||||
cos_scores = [0.0] # 兼容性处理
|
||||
|
||||
# 处理不同格式下的 cos_scores
|
||||
if isinstance(cos_scores, np.ndarray):
|
||||
cos_scores_np = cos_scores
|
||||
else:
|
||||
# PyTorch tensor
|
||||
if hasattr(cos_scores, 'is_cuda') and cos_scores.is_cuda:
|
||||
cos_scores_np = cos_scores.cpu().numpy()
|
||||
else:
|
||||
cos_scores_np = cos_scores.numpy()
|
||||
|
||||
top_results = np.argsort(-cos_scores_np)[:top_k]
|
||||
|
||||
results = []
|
||||
print(f"\n与查询最相关的 {top_k} 个{content_type} (分块策略: {chunking_strategy}):")
|
||||
for i, idx in enumerate(top_results):
|
||||
chunk = chunks[idx]
|
||||
score = cos_scores_np[idx]
|
||||
results.append((chunk, score))
|
||||
# 显示内容预览(如果内容太长)
|
||||
preview = chunk[:100] + "..." if len(chunk) > 100 else chunk
|
||||
preview = preview.replace('\n', ' ') # 替换换行符以便显示
|
||||
print(f"{i+1}. [{score:.4f}] {preview}")
|
||||
|
||||
return results
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"错误:找不到嵌入文件 {embeddings_file}")
|
||||
print("请先运行 embed_document() 函数生成嵌入文件")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"搜索时出错:{e}")
|
||||
return []
|
||||
|
||||
|
||||
def paragraph_chunking(text, max_chunk_size=1000, overlap=100, min_chunk_size=200, separator='\n\n'):
|
||||
@ -767,8 +692,6 @@ def demo_usage():
|
||||
print(" overlap=200,")
|
||||
print(" min_chunk_size=300)")
|
||||
|
||||
print("\n4. 进行语义搜索:")
|
||||
print("semantic_search('查询内容', 'paragraph_embedding.pkl', top_k=5)")
|
||||
|
||||
|
||||
# 如果直接运行此文件,执行测试
|
||||
|
||||
171
fastapi_app.py
171
fastapi_app.py
@ -132,35 +132,6 @@ def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extensio
|
||||
|
||||
|
||||
# 语义检索请求模型
|
||||
class SemanticSearchRequest(BaseModel):
|
||||
embedding_file: str = Field(..., description="embedding.pkl 文件路径")
|
||||
query: str = Field(..., description="搜索关键词")
|
||||
top_k: int = Field(default=20, description="返回结果数量", ge=1, le=100)
|
||||
min_score: float = Field(default=0.0, description="最小相似度阈值", ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class BatchSearchRequest(BaseModel):
|
||||
requests: List[SemanticSearchRequest] = Field(..., description="搜索请求列表")
|
||||
|
||||
|
||||
# 语义检索响应模型
|
||||
class SearchResult(BaseModel):
|
||||
rank: int = Field(..., description="排名")
|
||||
score: float = Field(..., description="相似度分数")
|
||||
content: str = Field(..., description="匹配的内容")
|
||||
content_preview: str = Field(..., description="内容预览")
|
||||
|
||||
|
||||
class SemanticSearchResponse(BaseModel):
|
||||
success: bool = Field(..., description="是否成功")
|
||||
query: str = Field(..., description="查询关键词")
|
||||
embedding_file: str = Field(..., description="embedding 文件路径")
|
||||
processing_time: float = Field(..., description="处理时间(秒)")
|
||||
total_chunks: int = Field(..., description="总文档块数")
|
||||
chunking_strategy: str = Field(..., description="分块策略")
|
||||
results: List[SearchResult] = Field(..., description="搜索结果")
|
||||
cache_stats: Optional[Dict[str, Any]] = Field(None, description="缓存统计")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
|
||||
|
||||
# 编码请求和响应模型
|
||||
@ -1589,148 +1560,6 @@ async def reset_files_processing(dataset_id: str):
|
||||
raise HTTPException(status_code=500, detail=f"重置文件处理状态失败: {str(e)}")
|
||||
|
||||
|
||||
# ============ 语义检索 API 端点 ============
|
||||
|
||||
@app.post("/api/v1/semantic-search", response_model=SemanticSearchResponse)
|
||||
async def semantic_search(request: SemanticSearchRequest):
|
||||
"""
|
||||
语义搜索 API
|
||||
|
||||
Args:
|
||||
request: 包含 embedding_file 和 query 的搜索请求
|
||||
|
||||
Returns:
|
||||
语义搜索结果
|
||||
"""
|
||||
try:
|
||||
search_service = get_search_service()
|
||||
result = await search_service.semantic_search(
|
||||
embedding_file=request.embedding_file,
|
||||
query=request.query,
|
||||
top_k=request.top_k,
|
||||
min_score=request.min_score
|
||||
)
|
||||
|
||||
if result["success"]:
|
||||
return SemanticSearchResponse(
|
||||
success=True,
|
||||
query=result["query"],
|
||||
embedding_file=result["embedding_file"],
|
||||
processing_time=result["processing_time"],
|
||||
total_chunks=result["total_chunks"],
|
||||
chunking_strategy=result["chunking_strategy"],
|
||||
results=[
|
||||
SearchResult(
|
||||
rank=r["rank"],
|
||||
score=r["score"],
|
||||
content=r["content"],
|
||||
content_preview=r["content_preview"]
|
||||
)
|
||||
for r in result["results"]
|
||||
],
|
||||
cache_stats=result.get("cache_stats")
|
||||
)
|
||||
else:
|
||||
return SemanticSearchResponse(
|
||||
success=False,
|
||||
query=request.query,
|
||||
embedding_file=request.embedding_file,
|
||||
processing_time=0.0,
|
||||
total_chunks=0,
|
||||
chunking_strategy="",
|
||||
results=[],
|
||||
error=result.get("error", "未知错误")
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"语义搜索 API 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"语义搜索失败: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/api/v1/semantic-search/batch")
|
||||
async def batch_semantic_search(request: BatchSearchRequest):
|
||||
"""
|
||||
批量语义搜索 API
|
||||
|
||||
Args:
|
||||
request: 包含多个搜索请求的批量请求
|
||||
|
||||
Returns:
|
||||
批量搜索结果
|
||||
"""
|
||||
try:
|
||||
search_service = get_search_service()
|
||||
|
||||
# 转换请求格式
|
||||
search_requests = [
|
||||
{
|
||||
"embedding_file": req.embedding_file,
|
||||
"query": req.query,
|
||||
"top_k": req.top_k,
|
||||
"min_score": req.min_score
|
||||
}
|
||||
for req in request.requests
|
||||
]
|
||||
|
||||
results = await search_service.batch_search(search_requests)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"total_requests": len(request.requests),
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量语义搜索 API 错误: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量语义搜索失败: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/api/v1/semantic-search/stats")
|
||||
async def get_semantic_search_stats():
|
||||
"""
|
||||
获取语义搜索服务统计信息
|
||||
|
||||
Returns:
|
||||
服务统计信息
|
||||
"""
|
||||
try:
|
||||
search_service = get_search_service()
|
||||
stats = search_service.get_service_stats()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"timestamp": int(time.time()),
|
||||
"stats": stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取语义搜索统计信息失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/api/v1/semantic-search/clear-cache")
|
||||
async def clear_semantic_search_cache():
|
||||
"""
|
||||
清空语义搜索缓存
|
||||
|
||||
Returns:
|
||||
清理结果
|
||||
"""
|
||||
try:
|
||||
from manager import get_cache_manager
|
||||
cache_manager = get_cache_manager()
|
||||
cache_manager.clear_cache()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "缓存已清空"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清空语义搜索缓存失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"清空缓存失败: {str(e)}")
|
||||
|
||||
|
||||
@app.post("/api/v1/embedding/encode", response_model=EncodeResponse)
|
||||
async def encode_texts(request: EncodeRequest):
|
||||
"""
|
||||
|
||||
@ -13,7 +13,6 @@ import sys
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
from mcp_common import (
|
||||
get_allowed_directory,
|
||||
load_tools_from_json,
|
||||
@ -30,8 +29,43 @@ from mcp_common import (
|
||||
import requests
|
||||
|
||||
|
||||
def encode_query_via_api(query: str, fastapi_url: str = None) -> np.ndarray:
|
||||
"""通过API编码单个查询"""
|
||||
if not fastapi_url:
|
||||
fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
|
||||
|
||||
api_endpoint = f"{fastapi_url}/api/v1/embedding/encode"
|
||||
|
||||
try:
|
||||
# 调用编码接口
|
||||
request_data = {
|
||||
"texts": [query],
|
||||
"batch_size": 1
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
api_endpoint,
|
||||
json=request_data,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result_data = response.json()
|
||||
if result_data.get("success"):
|
||||
embeddings_list = result_data.get("embeddings", [])
|
||||
if embeddings_list:
|
||||
return np.array(embeddings_list[0])
|
||||
|
||||
print(f"API编码失败: {response.status_code} - {response.text}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"API编码异常: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: int = 20) -> Dict[str, Any]:
|
||||
"""执行语义搜索,通过调用 FastAPI 接口"""
|
||||
"""执行语义搜索,直接读取本地embedding文件并计算相似度"""
|
||||
# 处理查询输入
|
||||
if isinstance(queries, str):
|
||||
queries = [queries]
|
||||
@ -51,44 +85,58 @@ def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k:
|
||||
queries = [q.strip() for q in queries if q.strip()]
|
||||
|
||||
try:
|
||||
# FastAPI 服务地址
|
||||
fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
|
||||
api_endpoint = f"{fastapi_url}/api/v1/semantic-search"
|
||||
# 解析embedding文件路径
|
||||
resolved_embeddings_file = resolve_file_path(embeddings_file)
|
||||
|
||||
# 读取embedding文件
|
||||
with open(resolved_embeddings_file, 'rb') as f:
|
||||
embedding_data = pickle.load(f)
|
||||
|
||||
# 兼容新旧数据结构
|
||||
if 'chunks' in embedding_data:
|
||||
# 新的数据结构(使用chunks)
|
||||
chunks = embedding_data['chunks']
|
||||
chunk_embeddings = embedding_data['embeddings']
|
||||
chunking_strategy = embedding_data.get('chunking_strategy', 'unknown')
|
||||
else:
|
||||
# 旧的数据结构(使用sentences)
|
||||
chunks = embedding_data['sentences']
|
||||
chunk_embeddings = embedding_data['embeddings']
|
||||
chunking_strategy = 'line'
|
||||
|
||||
all_results = []
|
||||
|
||||
# 处理每个查询
|
||||
all_results = []
|
||||
resolved_embeddings_file = resolve_file_path(embeddings_file)
|
||||
for query in queries:
|
||||
# 调用 FastAPI 接口
|
||||
request_data = {
|
||||
"embedding_file": resolved_embeddings_file,
|
||||
"query": query,
|
||||
"top_k": top_k,
|
||||
"min_score": 0.0
|
||||
}
|
||||
# 使用API编码查询
|
||||
print(f"正在为查询编码: {query}")
|
||||
query_embedding = encode_query_via_api(query)
|
||||
|
||||
response = requests.post(
|
||||
api_endpoint,
|
||||
json=request_data,
|
||||
timeout=30
|
||||
if query_embedding is None:
|
||||
print(f"查询编码失败: {query}")
|
||||
continue
|
||||
|
||||
# 计算相似度
|
||||
if len(chunk_embeddings.shape) > 1:
|
||||
cos_scores = np.dot(chunk_embeddings, query_embedding) / (
|
||||
np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_embedding) + 1e-8
|
||||
)
|
||||
else:
|
||||
cos_scores = np.array([0.0] * len(chunks))
|
||||
|
||||
if response.status_code == 200:
|
||||
result_data = response.json()
|
||||
# 获取top_k结果
|
||||
top_indices = np.argsort(-cos_scores)[:top_k]
|
||||
|
||||
if result_data.get("success"):
|
||||
for res in result_data.get("results", []):
|
||||
for rank, idx in enumerate(top_indices):
|
||||
score = cos_scores[idx]
|
||||
if score > 0: # 只包含有一定相关性的结果
|
||||
all_results.append({
|
||||
'query': query,
|
||||
'rank': res["rank"],
|
||||
'content': res["content"],
|
||||
'similarity_score': res["score"],
|
||||
'rank': rank + 1,
|
||||
'content': chunks[idx],
|
||||
'similarity_score': float(score),
|
||||
'file_path': embeddings_file
|
||||
})
|
||||
else:
|
||||
print(f"搜索失败: {result_data.get('error', '未知错误')}")
|
||||
else:
|
||||
print(f"API 调用失败: {response.status_code} - {response.text}")
|
||||
|
||||
if not all_results:
|
||||
return {
|
||||
@ -122,12 +170,12 @@ def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k:
|
||||
]
|
||||
}
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"API request failed: {str(e)}"
|
||||
"text": f"Error: Embeddings file not found: {embeddings_file}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -13,6 +13,18 @@
|
||||
|
||||
## 核心功能识别
|
||||
- **设备控制**:打开/关闭/调节 → Iot Control-dxcore_update_device_status
|
||||
- 空调(dc_fan)设备参数说明:
|
||||
running_control: 运行控制 (可选, 0=停止, 1=启动)
|
||||
automatic_manual_operation: 自动/手动模式 (可选, 0=手动, 1=自动)
|
||||
air_volume_control: 风量控制 (可选, 15=弱, 20=中, 30=强)
|
||||
humi_setting: 湿度设定 (可选, 范围: 0-100)
|
||||
temp_setting: 温度设定 (可选, 范围: 0.0-100.0)
|
||||
wind_direction_setting: 风向设定 (可选, 范围: -90 to 90)
|
||||
wind_direction_mode: 风向模式 (可选, 0=自动, 1=中央)
|
||||
- 照明 (light)设备参数说明:
|
||||
dimming_control: 调光控制 (可选, 0-100)
|
||||
color_control_x: 色温控制 X 值 (可选, 与 color_control_y 同时使用)
|
||||
color_control_y: 色温控制 Y 值 (可选, 与 color_control_x 同时使用)
|
||||
- **状态查询**:状态/温度/湿度 → Iot Control-dxcore_get_device_status
|
||||
- **位置服务**:位置/在哪/查找 → Iot Control-eb_get_sensor_location
|
||||
- **设备查找**:房间/设备查找 → Iot Control-find_devices_by_room
|
||||
@ -84,11 +96,25 @@
|
||||
|
||||
|
||||
## 设备控制场景
|
||||
**用户**:"打开附近的风扇"
|
||||
**用户**:"打开附近的风扇和灯光"
|
||||
- find_employee_by_name(name="[当前用户]") → 获取用户位置和sensor_id
|
||||
- find_iot_device(device_type="dc_fan", target_sensor_id="[当前用户的sensor_id]") → 查找附近设备
|
||||
- dxcore_update_device_status(running_control=1, sensor_id="[找到的设备的sensor_id]") → 开启设备
|
||||
**响应**:"已为您开启301室的风扇"
|
||||
- find_iot_device(device_type="light", target_sensor_id="[当前用户的sensor_id]") → 查找附近设备
|
||||
- **确认步骤**:"即将为您开启301室的风扇(强风模式)和灯光(100%亮度),是否确认?"
|
||||
- **用户同意后**:dxcore_update_device_status(running_control=1, sensor_id="[风扇设备的sensor_id]") → 开启风扇设备
|
||||
- dxcore_update_device_status(dimming_control=100, sensor_id="[灯光设备的sensor_id]") → 开启灯光设备
|
||||
**响应**:"已为您开启301室的风扇和灯光"
|
||||
**用户拒绝时**:"好的,已取消设备控制操作"
|
||||
|
||||
**用户**:"关闭附近的风扇和灯光"
|
||||
- find_employee_by_name(name="[当前用户]") → 获取用户位置和sensor_id
|
||||
- find_iot_device(device_type="dc_fan", target_sensor_id="[当前用户的sensor_id]") → 查找附近设备
|
||||
- find_iot_device(device_type="light", target_sensor_id="[当前用户的sensor_id]") → 查找附近设备
|
||||
- **确认步骤**:"即将为您关闭301室的风扇和灯光,是否确认?"
|
||||
- **用户同意后**:dxcore_update_device_status(running_control=0, sensor_id="[风扇设备的sensor_id]") → 关闭风扇设备
|
||||
- dxcore_update_device_status(dimming_control=0, sensor_id="[灯光设备的sensor_id]") → 关闭灯光设备
|
||||
**响应**:"已为您关闭301室的风扇和灯光"
|
||||
**用户拒绝时**:"好的,已取消设备控制操作"
|
||||
|
||||
**用户**:"5楼风扇电量异常,通知清水さん并报告具体位置"
|
||||
- find_iot_device(device_type="dc_fan") → 查找设备
|
||||
@ -116,16 +142,36 @@
|
||||
**响应**:"[根据web_fetch内容回复]"
|
||||
|
||||
|
||||
## 设备控制确认机制
|
||||
|
||||
### 确认标准格式
|
||||
对于所有设备控制请求,必须按以下格式确认:
|
||||
"即将[操作内容][设备名称][具体参数],是否确认?"
|
||||
|
||||
### 参数确认细则
|
||||
- **空调操作**:明确温度设定、运行模式、风速等
|
||||
- **照明控制**:明确亮度百分比、色温设定等
|
||||
- **风扇设备**:明确风速档位、运行状态等
|
||||
- **其他设备**:明确关键参数和预期效果
|
||||
|
||||
### 拒绝处理
|
||||
用户明确拒绝时:
|
||||
- 取消所有设备控制操作
|
||||
- 回复:"好的,已取消设备控制操作"
|
||||
|
||||
# 响应规范
|
||||
|
||||
## 回复原则
|
||||
- **简洁明了**:每条回复控制在1-2句话
|
||||
- **确认优先**:设备控制前必须确认,不得自动执行
|
||||
- **结果导向**:基于工具执行结果直接反馈
|
||||
- **专业语气**:保持企业服务水准
|
||||
- **即时响应**:工具调用完成后立即回复
|
||||
|
||||
## 标准回复格式
|
||||
- **设备确认**:"即将开启301室空调至24度,是否确认?"
|
||||
- **设备操作**:"空调已调至24度,运行正常"
|
||||
- **取消操作**:"好的,已取消设备控制操作"
|
||||
- **消息发送**:"消息已发送至田中さん"
|
||||
- **位置查询**:"清水さん在A栋3楼会议室"
|
||||
- **任务完成**:"已完成:设备开启、消息发送、位置确认"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user