diff --git a/embedding/__init__.py b/embedding/__init__.py index f465e5f..fe52395 100644 --- a/embedding/__init__.py +++ b/embedding/__init__.py @@ -5,13 +5,12 @@ Embedding Package from .manager import get_cache_manager, get_model_manager from .search_service import get_search_service -from .embedding import embed_document, semantic_search, split_document_by_pages +from .embedding import embed_document, split_document_by_pages __all__ = [ 'get_cache_manager', 'get_model_manager', 'get_search_service', 'embed_document', - 'semantic_search', 'split_document_by_pages' ] \ No newline at end of file diff --git a/embedding/embedding.py b/embedding/embedding.py index 2472ec4..956048a 100644 --- a/embedding/embedding.py +++ b/embedding/embedding.py @@ -226,81 +226,6 @@ def embed_document(input_file='document.txt', output_file='embedding.pkl', print(f"处理文档时出错:{e}") return None -def semantic_search(user_query, embeddings_file='embedding.pkl', top_k=20): - """ - 输入用户查询,进行语义匹配,返回top_k个最相关的内容块 - - Args: - user_query (str): 用户查询 - embeddings_file (str): 嵌入向量文件路径 - top_k (int): 返回的结果数量 - - Returns: - list: 包含(内容块, 相似度分数)的列表 - """ - try: - with open(embeddings_file, 'rb') as f: - embedding_data = pickle.load(f) - - # 兼容新旧数据结构 - if 'chunks' in embedding_data: - # 新的数据结构(使用chunks) - chunks = embedding_data['chunks'] - chunk_embeddings = embedding_data['embeddings'] - chunking_strategy = embedding_data.get('chunking_strategy', 'unknown') - content_type = "内容块" - else: - # 旧的数据结构(使用sentences) - chunks = embedding_data['sentences'] - chunk_embeddings = embedding_data['embeddings'] - chunking_strategy = 'line' - content_type = "句子" - - # 使用API接口进行编码 - print("使用API接口进行查询编码...") - query_embeddings = encode_texts_via_api([user_query], batch_size=1) - query_embedding = query_embeddings[0] if len(query_embeddings) > 0 else np.array([]) - - # 计算相似度 - if len(chunk_embeddings.shape) > 1: - cos_scores = np.dot(chunk_embeddings, query_embedding) / ( - np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_embedding) + 1e-8 - ) - else: - cos_scores = [0.0] # 兼容性处理 - - # 处理不同格式下的 cos_scores - if isinstance(cos_scores, np.ndarray): - cos_scores_np = cos_scores - else: - # PyTorch tensor - if hasattr(cos_scores, 'is_cuda') and cos_scores.is_cuda: - cos_scores_np = cos_scores.cpu().numpy() - else: - cos_scores_np = cos_scores.numpy() - - top_results = np.argsort(-cos_scores_np)[:top_k] - - results = [] - print(f"\n与查询最相关的 {top_k} 个{content_type} (分块策略: {chunking_strategy}):") - for i, idx in enumerate(top_results): - chunk = chunks[idx] - score = cos_scores_np[idx] - results.append((chunk, score)) - # 显示内容预览(如果内容太长) - preview = chunk[:100] + "..." if len(chunk) > 100 else chunk - preview = preview.replace('\n', ' ') # 替换换行符以便显示 - print(f"{i+1}. [{score:.4f}] {preview}") - - return results - - except FileNotFoundError: - print(f"错误:找不到嵌入文件 {embeddings_file}") - print("请先运行 embed_document() 函数生成嵌入文件") - return [] - except Exception as e: - print(f"搜索时出错:{e}") - return [] def paragraph_chunking(text, max_chunk_size=1000, overlap=100, min_chunk_size=200, separator='\n\n'): @@ -767,9 +692,7 @@ def demo_usage(): print(" overlap=200,") print(" min_chunk_size=300)") - print("\n4. 进行语义搜索:") - print("semantic_search('查询内容', 'paragraph_embedding.pkl', top_k=5)") - + # 如果直接运行此文件,执行测试 if __name__ == "__main__": diff --git a/fastapi_app.py b/fastapi_app.py index de5470c..4620ddf 100644 --- a/fastapi_app.py +++ b/fastapi_app.py @@ -132,35 +132,6 @@ def get_versioned_filename(upload_dir: str, name_without_ext: str, file_extensio # 语义检索请求模型 -class SemanticSearchRequest(BaseModel): - embedding_file: str = Field(..., description="embedding.pkl 文件路径") - query: str = Field(..., description="搜索关键词") - top_k: int = Field(default=20, description="返回结果数量", ge=1, le=100) - min_score: float = Field(default=0.0, description="最小相似度阈值", ge=0.0, le=1.0) - - -class BatchSearchRequest(BaseModel): - requests: List[SemanticSearchRequest] = Field(..., description="搜索请求列表") - - -# 语义检索响应模型 -class SearchResult(BaseModel): - rank: int = Field(..., description="排名") - score: float = Field(..., description="相似度分数") - content: str = Field(..., description="匹配的内容") - content_preview: str = Field(..., description="内容预览") - - -class SemanticSearchResponse(BaseModel): - success: bool = Field(..., description="是否成功") - query: str = Field(..., description="查询关键词") - embedding_file: str = Field(..., description="embedding 文件路径") - processing_time: float = Field(..., description="处理时间(秒)") - total_chunks: int = Field(..., description="总文档块数") - chunking_strategy: str = Field(..., description="分块策略") - results: List[SearchResult] = Field(..., description="搜索结果") - cache_stats: Optional[Dict[str, Any]] = Field(None, description="缓存统计") - error: Optional[str] = Field(None, description="错误信息") # 编码请求和响应模型 @@ -1589,148 +1560,6 @@ async def reset_files_processing(dataset_id: str): raise HTTPException(status_code=500, detail=f"重置文件处理状态失败: {str(e)}") -# ============ 语义检索 API 端点 ============ - -@app.post("/api/v1/semantic-search", response_model=SemanticSearchResponse) -async def semantic_search(request: SemanticSearchRequest): - """ - 语义搜索 API - - Args: - request: 包含 embedding_file 和 query 的搜索请求 - - Returns: - 语义搜索结果 - """ - try: - search_service = get_search_service() - result = await search_service.semantic_search( - embedding_file=request.embedding_file, - query=request.query, - top_k=request.top_k, - min_score=request.min_score - ) - - if result["success"]: - return SemanticSearchResponse( - success=True, - query=result["query"], - embedding_file=result["embedding_file"], - processing_time=result["processing_time"], - total_chunks=result["total_chunks"], - chunking_strategy=result["chunking_strategy"], - results=[ - SearchResult( - rank=r["rank"], - score=r["score"], - content=r["content"], - content_preview=r["content_preview"] - ) - for r in result["results"] - ], - cache_stats=result.get("cache_stats") - ) - else: - return SemanticSearchResponse( - success=False, - query=request.query, - embedding_file=request.embedding_file, - processing_time=0.0, - total_chunks=0, - chunking_strategy="", - results=[], - error=result.get("error", "未知错误") - ) - - except Exception as e: - logger.error(f"语义搜索 API 错误: {e}") - raise HTTPException(status_code=500, detail=f"语义搜索失败: {str(e)}") - - -@app.post("/api/v1/semantic-search/batch") -async def batch_semantic_search(request: BatchSearchRequest): - """ - 批量语义搜索 API - - Args: - request: 包含多个搜索请求的批量请求 - - Returns: - 批量搜索结果 - """ - try: - search_service = get_search_service() - - # 转换请求格式 - search_requests = [ - { - "embedding_file": req.embedding_file, - "query": req.query, - "top_k": req.top_k, - "min_score": req.min_score - } - for req in request.requests - ] - - results = await search_service.batch_search(search_requests) - - return { - "success": True, - "total_requests": len(request.requests), - "results": results - } - - except Exception as e: - logger.error(f"批量语义搜索 API 错误: {e}") - raise HTTPException(status_code=500, detail=f"批量语义搜索失败: {str(e)}") - - -@app.get("/api/v1/semantic-search/stats") -async def get_semantic_search_stats(): - """ - 获取语义搜索服务统计信息 - - Returns: - 服务统计信息 - """ - try: - search_service = get_search_service() - stats = search_service.get_service_stats() - - return { - "success": True, - "timestamp": int(time.time()), - "stats": stats - } - - except Exception as e: - logger.error(f"获取语义搜索统计信息失败: {e}") - raise HTTPException(status_code=500, detail=f"获取统计信息失败: {str(e)}") - - -@app.post("/api/v1/semantic-search/clear-cache") -async def clear_semantic_search_cache(): - """ - 清空语义搜索缓存 - - Returns: - 清理结果 - """ - try: - from manager import get_cache_manager - cache_manager = get_cache_manager() - cache_manager.clear_cache() - - return { - "success": True, - "message": "缓存已清空" - } - - except Exception as e: - logger.error(f"清空语义搜索缓存失败: {e}") - raise HTTPException(status_code=500, detail=f"清空缓存失败: {str(e)}") - - @app.post("/api/v1/embedding/encode", response_model=EncodeResponse) async def encode_texts(request: EncodeRequest): """ diff --git a/mcp/semantic_search_server.py b/mcp/semantic_search_server.py index 5831ced..242fb70 100644 --- a/mcp/semantic_search_server.py +++ b/mcp/semantic_search_server.py @@ -13,10 +13,9 @@ import sys from typing import Any, Dict, List, Optional, Union import numpy as np -from sentence_transformers import SentenceTransformer, util from mcp_common import ( - get_allowed_directory, - load_tools_from_json, + get_allowed_directory, + load_tools_from_json, resolve_file_path, find_file_in_project, create_error_response, @@ -30,12 +29,47 @@ from mcp_common import ( import requests +def encode_query_via_api(query: str, fastapi_url: str = None) -> np.ndarray: + """通过API编码单个查询""" + if not fastapi_url: + fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001') + + api_endpoint = f"{fastapi_url}/api/v1/embedding/encode" + + try: + # 调用编码接口 + request_data = { + "texts": [query], + "batch_size": 1 + } + + response = requests.post( + api_endpoint, + json=request_data, + timeout=30 + ) + + if response.status_code == 200: + result_data = response.json() + if result_data.get("success"): + embeddings_list = result_data.get("embeddings", []) + if embeddings_list: + return np.array(embeddings_list[0]) + + print(f"API编码失败: {response.status_code} - {response.text}") + return None + + except Exception as e: + print(f"API编码异常: {e}") + return None + + def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: int = 20) -> Dict[str, Any]: - """执行语义搜索,通过调用 FastAPI 接口""" + """执行语义搜索,直接读取本地embedding文件并计算相似度""" # 处理查询输入 if isinstance(queries, str): queries = [queries] - + # 验证查询列表 if not queries or not any(q.strip() for q in queries): return { @@ -46,50 +80,64 @@ def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: } ] } - + # 过滤空查询 queries = [q.strip() for q in queries if q.strip()] - - try: - # FastAPI 服务地址 - fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001') - api_endpoint = f"{fastapi_url}/api/v1/semantic-search" - - # 处理每个查询 - all_results = [] - resolved_embeddings_file = resolve_file_path(embeddings_file) - for query in queries: - # 调用 FastAPI 接口 - request_data = { - "embedding_file": resolved_embeddings_file, - "query": query, - "top_k": top_k, - "min_score": 0.0 - } - - response = requests.post( - api_endpoint, - json=request_data, - timeout=30 - ) - - if response.status_code == 200: - result_data = response.json() - if result_data.get("success"): - for res in result_data.get("results", []): - all_results.append({ - 'query': query, - 'rank': res["rank"], - 'content': res["content"], - 'similarity_score': res["score"], - 'file_path': embeddings_file - }) - else: - print(f"搜索失败: {result_data.get('error', '未知错误')}") + try: + # 解析embedding文件路径 + resolved_embeddings_file = resolve_file_path(embeddings_file) + + # 读取embedding文件 + with open(resolved_embeddings_file, 'rb') as f: + embedding_data = pickle.load(f) + + # 兼容新旧数据结构 + if 'chunks' in embedding_data: + # 新的数据结构(使用chunks) + chunks = embedding_data['chunks'] + chunk_embeddings = embedding_data['embeddings'] + chunking_strategy = embedding_data.get('chunking_strategy', 'unknown') + else: + # 旧的数据结构(使用sentences) + chunks = embedding_data['sentences'] + chunk_embeddings = embedding_data['embeddings'] + chunking_strategy = 'line' + + all_results = [] + + # 处理每个查询 + for query in queries: + # 使用API编码查询 + print(f"正在为查询编码: {query}") + query_embedding = encode_query_via_api(query) + + if query_embedding is None: + print(f"查询编码失败: {query}") + continue + + # 计算相似度 + if len(chunk_embeddings.shape) > 1: + cos_scores = np.dot(chunk_embeddings, query_embedding) / ( + np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_embedding) + 1e-8 + ) else: - print(f"API 调用失败: {response.status_code} - {response.text}") - + cos_scores = np.array([0.0] * len(chunks)) + + # 获取top_k结果 + top_indices = np.argsort(-cos_scores)[:top_k] + + for rank, idx in enumerate(top_indices): + score = cos_scores[idx] + if score > 0: # 只包含有一定相关性的结果 + all_results.append({ + 'query': query, + 'rank': rank + 1, + 'content': chunks[idx], + 'similarity_score': float(score), + 'file_path': embeddings_file + }) + if not all_results: return { "content": [ @@ -99,20 +147,20 @@ def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: } ] } - + # 按相似度分数排序所有结果 all_results.sort(key=lambda x: x['similarity_score'], reverse=True) - + # 格式化输出 formatted_lines = [] formatted_lines.append(f"Found {len(all_results)} results for {len(queries)} queries:") formatted_lines.append("") - + for i, result in enumerate(all_results): formatted_lines.append(f"#{i+1} [query: '{result['query']}'] [similarity:{result['similarity_score']:.4f}]: {result['content']}") - + formatted_output = "\n".join(formatted_lines) - + return { "content": [ { @@ -121,13 +169,13 @@ def semantic_search(queries: Union[str, List[str]], embeddings_file: str, top_k: } ] } - - except requests.exceptions.RequestException as e: + + except FileNotFoundError: return { "content": [ { "type": "text", - "text": f"API request failed: {str(e)}" + "text": f"Error: Embeddings file not found: {embeddings_file}" } ] } diff --git a/prompt/wowtalk.md b/prompt/wowtalk.md index ee1d093..3016baa 100644 --- a/prompt/wowtalk.md +++ b/prompt/wowtalk.md @@ -13,6 +13,18 @@ ## 核心功能识别 - **设备控制**:打开/关闭/调节 → Iot Control-dxcore_update_device_status + - 空调(dc_fan)设备参数说明: + running_control: 运行控制 (可选, 0=停止, 1=启动) + automatic_manual_operation: 自动/手动模式 (可选, 0=手动, 1=自动) + air_volume_control: 风量控制 (可选, 15=弱, 20=中, 30=强) + humi_setting: 湿度设定 (可选, 范围: 0-100) + temp_setting: 温度设定 (可选, 范围: 0.0-100.0) + wind_direction_setting: 风向设定 (可选, 范围: -90 to 90) + wind_direction_mode: 风向模式 (可选, 0=自动, 1=中央) + - 照明 (light)设备参数说明: + dimming_control: 调光控制 (可选, 0-100) + color_control_x: 色温控制 X 值 (可选, 与 color_control_y 同时使用) + color_control_y: 色温控制 Y 值 (可选, 与 color_control_x 同时使用) - **状态查询**:状态/温度/湿度 → Iot Control-dxcore_get_device_status - **位置服务**:位置/在哪/查找 → Iot Control-eb_get_sensor_location - **设备查找**:房间/设备查找 → Iot Control-find_devices_by_room @@ -84,11 +96,25 @@ ## 设备控制场景 -**用户**:"打开附近的风扇" +**用户**:"打开附近的风扇和灯光" - find_employee_by_name(name="[当前用户]") → 获取用户位置和sensor_id - find_iot_device(device_type="dc_fan", target_sensor_id="[当前用户的sensor_id]") → 查找附近设备 -- dxcore_update_device_status(running_control=1, sensor_id="[找到的设备的sensor_id]") → 开启设备 -**响应**:"已为您开启301室的风扇" +- find_iot_device(device_type="light", target_sensor_id="[当前用户的sensor_id]") → 查找附近设备 +- **确认步骤**:"即将为您开启301室的风扇(强风模式)和灯光(100%亮度),是否确认?" +- **用户同意后**:dxcore_update_device_status(running_control=1, sensor_id="[风扇设备的sensor_id]") → 开启风扇设备 +- dxcore_update_device_status(dimming_control=100, sensor_id="[灯光设备的sensor_id]") → 开启灯光设备 +**响应**:"已为您开启301室的风扇和灯光" +**用户拒绝时**:"好的,已取消设备控制操作" + +**用户**:"关闭附近的风扇和灯光" +- find_employee_by_name(name="[当前用户]") → 获取用户位置和sensor_id +- find_iot_device(device_type="dc_fan", target_sensor_id="[当前用户的sensor_id]") → 查找附近设备 +- find_iot_device(device_type="light", target_sensor_id="[当前用户的sensor_id]") → 查找附近设备 +- **确认步骤**:"即将为您关闭301室的风扇和灯光,是否确认?" +- **用户同意后**:dxcore_update_device_status(running_control=0, sensor_id="[风扇设备的sensor_id]") → 关闭风扇设备 +- dxcore_update_device_status(dimming_control=0, sensor_id="[灯光设备的sensor_id]") → 关闭灯光设备 +**响应**:"已为您关闭301室的风扇和灯光" +**用户拒绝时**:"好的,已取消设备控制操作" **用户**:"5楼风扇电量异常,通知清水さん并报告具体位置" - find_iot_device(device_type="dc_fan") → 查找设备 @@ -116,16 +142,36 @@ **响应**:"[根据web_fetch内容回复]" +## 设备控制确认机制 + +### 确认标准格式 +对于所有设备控制请求,必须按以下格式确认: +"即将[操作内容][设备名称][具体参数],是否确认?" + +### 参数确认细则 +- **空调操作**:明确温度设定、运行模式、风速等 +- **照明控制**:明确亮度百分比、色温设定等 +- **风扇设备**:明确风速档位、运行状态等 +- **其他设备**:明确关键参数和预期效果 + +### 拒绝处理 +用户明确拒绝时: +- 取消所有设备控制操作 +- 回复:"好的,已取消设备控制操作" + # 响应规范 ## 回复原则 - **简洁明了**:每条回复控制在1-2句话 +- **确认优先**:设备控制前必须确认,不得自动执行 - **结果导向**:基于工具执行结果直接反馈 - **专业语气**:保持企业服务水准 - **即时响应**:工具调用完成后立即回复 ## 标准回复格式 +- **设备确认**:"即将开启301室空调至24度,是否确认?" - **设备操作**:"空调已调至24度,运行正常" +- **取消操作**:"好的,已取消设备控制操作" - **消息发送**:"消息已发送至田中さん" - **位置查询**:"清水さん在A栋3楼会议室" - **任务完成**:"已完成:设备开启、消息发送、位置确认"