add semantic search

This commit is contained in:
朱潮 2025-10-16 21:06:02 +08:00
parent 4e4b094709
commit d0e3e62291
10 changed files with 33519 additions and 10 deletions

1
.python-version Normal file
View File

@ -0,0 +1 @@
3.12.0

28874
embedding/document.txt Normal file

File diff suppressed because one or more lines are too long

233
embedding/embedding.py Normal file
View File

@ -0,0 +1,233 @@
import pickle
import re
import numpy as np
from sentence_transformers import SentenceTransformer, util
# 延迟加载模型
embedder = None
def get_model():
"""获取模型实例(延迟加载)"""
global embedder
if embedder is None:
print("正在加载模型...")
embedder = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', device='cpu')
print("模型加载完成")
return embedder
def clean_text(text):
"""
清理文本去除特殊字符和无意义字符
Args:
text (str): 原始文本
Returns:
str: 清理后的文本
"""
# 去除HTML标签
text = re.sub(r'<[^>]+>', '', text)
# 去除多余的空白字符
text = re.sub(r'\s+', ' ', text)
# 去除控制字符和非打印字符但保留Unicode文字字符
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', text)
# 去除首尾空白
text = text.strip()
return text
def is_meaningful_line(text):
"""
判断一行文本是否有意义
Args:
text (str): 文本行
Returns:
bool: 是否有意义
"""
if not text or len(text.strip()) < 5:
return False
# 过滤纯数字行
if text.strip().isdigit():
return False
# 过滤只有符号的行
if re.match(r'^[^\w\u4e00-\u9fa5]+$', text):
return False
# 过滤常见的无意义行
meaningless_patterns = [
r'^[-=_]{3,}$', # 分割线
r'^第\d+页$', # 页码
r'^\d+\.\s*$', # 只有编号
r'^[a-zA-Z]\.\s*$', # 只有一个字母编号
]
for pattern in meaningless_patterns:
if re.match(pattern, text.strip()):
return False
return True
def embed_document(input_file='document.txt', output_file='document_embeddings.pkl'):
"""
读取document.txt文件按行进行embedding保存为pickle文件
Args:
input_file (str): 输入文档文件路径
output_file (str): 输出pickle文件路径
"""
try:
with open(input_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
cleaned_sentences = []
original_count = len(lines)
for line in lines:
# 清理文本
cleaned_text = clean_text(line)
# 检查是否有意义
if is_meaningful_line(cleaned_text):
cleaned_sentences.append(cleaned_text)
print(f"原始行数: {original_count}")
print(f"清理后有效句子数: {len(cleaned_sentences)}")
print(f"过滤比例: {((original_count - len(cleaned_sentences)) / original_count * 100):.1f}%")
if not cleaned_sentences:
print("警告:没有找到有意义的句子!")
return None
print(f"正在处理 {len(cleaned_sentences)} 个有效句子...")
model = get_model()
sentence_embeddings = model.encode(cleaned_sentences, convert_to_tensor=True)
embedding_data = {
'sentences': cleaned_sentences,
'embeddings': sentence_embeddings
}
with open(output_file, 'wb') as f:
pickle.dump(embedding_data, f)
print(f"已保存嵌入向量到 {output_file}")
return embedding_data
except FileNotFoundError:
print(f"错误:找不到文件 {input_file}")
return None
except Exception as e:
print(f"处理文档时出错:{e}")
return None
def semantic_search(user_query, embeddings_file='document_embeddings.pkl', top_k=20):
"""
输入用户查询进行语义匹配返回top_k个最相关的句子
Args:
user_query (str): 用户查询
embeddings_file (str): 嵌入向量文件路径
top_k (int): 返回的结果数量
Returns:
list: 包含(句子, 相似度分数)的列表
"""
try:
with open(embeddings_file, 'rb') as f:
embedding_data = pickle.load(f)
sentences = embedding_data['sentences']
sentence_embeddings = embedding_data['embeddings']
model = get_model()
query_embedding = model.encode(user_query, convert_to_tensor=True)
cos_scores = util.cos_sim(query_embedding, sentence_embeddings)[0]
top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k]
results = []
print(f"\n与查询最相关的 {top_k} 个句子:")
for i, idx in enumerate(top_results):
sentence = sentences[idx]
score = cos_scores[idx].item()
results.append((sentence, score))
print(f"{i+1}. [{score:.4f}] {sentence}")
return results
except FileNotFoundError:
print(f"错误:找不到嵌入文件 {embeddings_file}")
print("请先运行 embed_document() 函数生成嵌入文件")
return []
except Exception as e:
print(f"搜索时出错:{e}")
return []
def split_document_by_pages(input_file='document.txt', output_file='serialization.txt'):
"""
按页分割document.txt文件将每页内容整理成一行写入serialization.txt
Args:
input_file (str): 输入文档文件路径
output_file (str): 输出序列化文件路径
"""
try:
with open(input_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
pages = []
current_page = []
for line in lines:
line = line.strip()
# 检查是否是页分隔符
if re.match(r'^#\s*Page\s+\d+', line, re.IGNORECASE):
# 如果当前页有内容,保存当前页
if current_page:
# 将当前页内容合并成一行
page_content = '\\n'.join(current_page).strip()
if page_content: # 只保存非空页面
pages.append(page_content)
current_page = []
continue
# 如果不是页分隔符且有内容,添加到当前页
if line:
current_page.append(line)
# 处理最后一页
if current_page:
page_content = ' '.join(current_page).strip()
if page_content:
pages.append(page_content)
print(f"总共分割出 {len(pages)}")
# 写入序列化文件
with open(output_file, 'w', encoding='utf-8') as f:
for i, page_content in enumerate(pages, 1):
f.write(f"{page_content}\n")
print(f"已将页面内容序列化到 {output_file}")
return pages
except FileNotFoundError:
print(f"错误:找不到文件 {input_file}")
return []
except Exception as e:
print(f"分割文档时出错:{e}")
return []
split_document_by_pages("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt")
# embed_document("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt") # 取消注释来运行

View File

@ -43,6 +43,17 @@ from file_loaded_agent_manager import get_global_agent_manager, init_global_agen
from gbase_agent import update_agent_llm
from zip_project_handler import zip_handler
def get_zip_url_from_unique_id(unique_id: str) -> Optional[str]:
"""从unique_map.json中读取zip_url"""
try:
with open('unique_map.json', 'r', encoding='utf-8') as f:
unique_map = json.load(f)
return unique_map.get(unique_id)
except Exception as e:
print(f"Error reading unique_map.json: {e}")
return None
# 全局助手管理器配置
max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "20"))
@ -74,6 +85,7 @@ class ChatRequest(BaseModel):
model: str = "qwen3-next"
model_server: str = ""
zip_url: Optional[str] = None
unique_id: Optional[str] = None
stream: Optional[bool] = False
class Config:
@ -180,15 +192,22 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
else:
api_key = authorization
# 从最外层获取zip_url参数
# 从最外层获取zip_url和unique_id参数
zip_url = request.zip_url
unique_id = request.unique_id
# 如果提供了unique_id从unique_map.json中读取zip_url
if unique_id:
zip_url = get_zip_url_from_unique_id(unique_id)
if not zip_url:
raise HTTPException(status_code=400, detail=f"No zip_url found for unique_id: {unique_id}")
if not zip_url:
raise HTTPException(status_code=400, detail="zip_url is required")
# 使用ZIP URL获取项目数据
print(f"从ZIP URL加载项目: {zip_url}")
project_dir = zip_handler.get_project_from_zip(zip_url)
project_dir = zip_handler.get_project_from_zip(zip_url, unique_id if unique_id else None)
if not project_dir:
raise HTTPException(status_code=400, detail=f"Failed to load project from ZIP URL: {zip_url}")
@ -199,7 +218,7 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
print(f"警告: 项目目录 {project_dir} 中未找到任何 document.txt 文件")
# 收集额外参数作为 generate_cfg
exclude_fields = {'messages', 'model', 'model_server', 'zip_url', 'stream'}
exclude_fields = {'messages', 'model', 'model_server', 'zip_url', 'unique_id', 'stream'}
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
# 从全局管理器获取或创建文件预加载的助手实例
@ -213,7 +232,19 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
generate_cfg=generate_cfg
)
# 构建包含项目信息的消息上下文
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
messages = []
for msg in request.messages:
if msg.role == "assistant":
# 对assistant消息进行[ANSWER]分割处理,只保留最后一段
content_parts = msg.content.split("[ANSWER]")
if content_parts:
# 取最后一段非空文本
last_part = content_parts[-1].strip()
messages.append({"role": msg.role, "content": last_part})
else:
messages.append({"role": msg.role, "content": msg.content})
else:
messages.append({"role": msg.role, "content": msg.content})
# 根据stream参数决定返回流式还是非流式响应
if request.stream:

View File

@ -0,0 +1,367 @@
#!/usr/bin/env python3
"""
语义搜索MCP服务器
基于embedding向量进行语义相似度搜索
参考multi_keyword_search_server.py的实现方式
"""
import asyncio
import json
import os
import pickle
import sys
from typing import Any, Dict, List, Optional
import numpy as np
from sentence_transformers import SentenceTransformer, util
# 延迟加载模型
embedder = None
def get_model():
"""获取模型实例(延迟加载)"""
global embedder
if embedder is None:
embedder = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', device='cpu')
return embedder
def validate_file_path(file_path: str, allowed_dir: str) -> str:
"""验证文件路径是否在允许的目录内"""
# 转换为绝对路径
if not os.path.isabs(file_path):
file_path = os.path.abspath(file_path)
allowed_dir = os.path.abspath(allowed_dir)
# 检查路径是否在允许的目录内
if not file_path.startswith(allowed_dir):
raise ValueError(f"访问被拒绝: 路径 {file_path} 不在允许的目录 {allowed_dir}")
# 检查路径遍历攻击
if ".." in file_path:
raise ValueError(f"访问被拒绝: 检测到路径遍历攻击尝试")
return file_path
def get_allowed_directory():
"""获取允许访问的目录"""
# 从环境变量读取项目数据目录
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
return os.path.abspath(project_dir)
def semantic_search(query: str, embeddings_file: str, top_k: int = 20) -> Dict[str, Any]:
"""执行语义搜索"""
if not query.strip():
return {
"content": [
{
"type": "text",
"text": "错误:查询不能为空"
}
]
}
# 处理项目目录限制
project_data_dir = get_allowed_directory()
# 验证embeddings文件路径
try:
# 解析相对路径
if not os.path.isabs(embeddings_file):
# 尝试在项目目录中查找文件
full_path = os.path.join(project_data_dir, embeddings_file.lstrip('./'))
if not os.path.exists(full_path):
# 如果直接路径不存在,尝试递归查找
found = find_file_in_project(embeddings_file, project_data_dir)
if found:
embeddings_file = found
else:
return {
"content": [
{
"type": "text",
"text": f"错误:在项目目录 {project_data_dir} 中未找到embeddings文件 {embeddings_file}"
}
]
}
else:
embeddings_file = full_path
else:
if not embeddings_file.startswith(project_data_dir):
return {
"content": [
{
"type": "text",
"text": f"错误embeddings文件路径必须在项目目录 {project_data_dir}"
}
]
}
if not os.path.exists(embeddings_file):
return {
"content": [
{
"type": "text",
"text": f"错误embeddings文件 {embeddings_file} 不存在"
}
]
}
except Exception as e:
return {
"content": [
{
"type": "text",
"text": f"错误embeddings文件路径验证失败 - {str(e)}"
}
]
}
try:
# 加载嵌入数据
with open(embeddings_file, 'rb') as f:
embedding_data = pickle.load(f)
sentences = embedding_data['sentences']
sentence_embeddings = embedding_data['embeddings']
# 加载模型
model = get_model()
# 编码查询
query_embedding = model.encode(query, convert_to_tensor=True)
# 计算相似度
cos_scores = util.cos_sim(query_embedding, sentence_embeddings)[0]
# 获取top_k结果
top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k]
# 格式化结果
results = []
for i, idx in enumerate(top_results):
sentence = sentences[idx]
score = cos_scores[idx].item()
results.append({
'rank': i + 1,
'content': sentence,
'similarity_score': score,
'file_path': embeddings_file
})
if not results:
return {
"content": [
{
"type": "text",
"text": "未找到匹配的结果"
}
]
}
# 格式化输出
formatted_output = "\n".join([
f"#{result['rank']} [similarity:{result['similarity_score']:.4f}]: {result['content']}"
for result in results
])
return {
"content": [
{
"type": "text",
"text": formatted_output
}
]
}
except FileNotFoundError:
return {
"content": [
{
"type": "text",
"text": f"错误找不到embeddings文件 {embeddings_file}"
}
]
}
except Exception as e:
return {
"content": [
{
"type": "text",
"text": f"搜索时出错:{str(e)}"
}
]
}
def find_file_in_project(filename: str, project_dir: str) -> Optional[str]:
"""在项目目录中递归查找文件"""
for root, dirs, files in os.walk(project_dir):
if filename in files:
return os.path.join(root, filename)
return None
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
"""Handle MCP request"""
try:
method = request.get("method")
params = request.get("params", {})
request_id = request.get("id")
if method == "initialize":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"serverInfo": {
"name": "semantic-search",
"version": "1.0.0"
}
}
}
elif method == "ping":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"pong": True
}
}
elif method == "tools/list":
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"tools": [
{
"name": "semantic_search",
"description": "语义搜索工具,基于向量嵌入进行相似度搜索。格式:#[排名] [相似度分数]: [匹配内容]",
"inputSchema": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "搜索查询文本"
},
"embeddings_file": {
"type": "string",
"description": "embeddings pickle文件路径"
},
"top_k": {
"type": "integer",
"description": "返回结果的最大数量默认20",
"default": 20
}
},
"required": ["query", "embeddings_file"]
}
}
]
}
}
elif method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
if tool_name == "semantic_search":
query = arguments.get("query", "")
embeddings_file = arguments.get("embeddings_file", "")
top_k = arguments.get("top_k", 20)
result = semantic_search(query, embeddings_file, top_k)
return {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32601,
"message": f"Unknown tool: {tool_name}"
}
}
else:
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": -32601,
"message": f"Unknown method: {method}"
}
}
except Exception as e:
return {
"jsonrpc": "2.0",
"id": request.get("id"),
"error": {
"code": -32603,
"message": f"Internal error: {str(e)}"
}
}
async def main():
"""Main entry point."""
try:
while True:
# Read from stdin
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
if not line:
break
line = line.strip()
if not line:
continue
try:
request = json.loads(line)
response = await handle_request(request)
# Write to stdout
sys.stdout.write(json.dumps(response) + "\n")
sys.stdout.flush()
except json.JSONDecodeError:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32700,
"message": "Parse error"
}
}
sys.stdout.write(json.dumps(error_response) + "\n")
sys.stdout.flush()
except Exception as e:
error_response = {
"jsonrpc": "2.0",
"error": {
"code": -32603,
"message": f"Internal error: {str(e)}"
}
}
sys.stdout.write(json.dumps(error_response) + "\n")
sys.stdout.flush()
except KeyboardInterrupt:
pass
if __name__ == "__main__":
asyncio.run(main())

3952
poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

31
pyproject.toml Normal file
View File

@ -0,0 +1,31 @@
[project]
name = "catalog-agent"
version = "0.1.0"
description = ""
authors = [
{name = "朱潮",email = "zhuchaowe@users.noreply.github.com"}
]
readme = "README.md"
requires-python = "3.12.0"
dependencies = [
"fastapi==0.116.1",
"uvicorn==0.35.0",
"requests==2.32.5",
"qwen-agent[rag,mcp]==0.0.29",
"pydantic==2.10.5",
"python-dateutil==2.8.2",
"torch==2.2.0",
"transformers",
"sentence-transformers",
"numpy<2",
]
[tool.poetry]
package-mode = false
[build-system]
requires = ["poetry-core>=2.0.0,<3.0.0"]
build-backend = "poetry.core.masonry.api"

View File

@ -11,3 +11,9 @@ qwen-agent[rag,mcp]==0.0.29
# 数据处理
pydantic==2.10.5
python-dateutil==2.8.2
# embedding
torch
transformers
sentence-transformers

3
unique_map.json Normal file
View File

@ -0,0 +1,3 @@
{
"b743ccc3-13be-43ea-8ec9-4ce9c86103b3": "public/all_hp_product_spec_book2506.zip"
}

View File

@ -79,12 +79,13 @@ class ZipProjectHandler:
print(f"解压ZIP文件失败: {e}")
return False
def get_project_from_zip(self, zip_url: str) -> Optional[str]:
def get_project_from_zip(self, zip_url: str, unique_id: Optional[str] = None) -> Optional[str]:
"""
从ZIP URL或本地路径获取项目数据
Args:
zip_url: ZIP文件的URL或本地相对路径
unique_id: 可选的唯一标识符用作文件夹名称
Returns:
Optional[str]: 成功时返回项目目录路径失败时返回None
@ -93,16 +94,26 @@ class ZipProjectHandler:
print(f"无效的URL或路径: {zip_url}")
return None
# 检查缓存
url_hash = self._get_url_hash(zip_url)
cached_project_dir = self.projects_dir / url_hash
# 使用unique_id作为目录名如果没有则使用url_hash
if unique_id:
project_dir_name = unique_id
# 当使用unique_id时不检查缓存直接重新解压以确保项目结构正确
cached_project_dir = self.projects_dir / project_dir_name
else:
project_dir_name = self._get_url_hash(zip_url)
cached_project_dir = self.projects_dir / project_dir_name
if cached_project_dir.exists():
if cached_project_dir.exists() and not unique_id:
print(f"使用缓存的项目目录: {cached_project_dir}")
return str(cached_project_dir)
# 下载或复制ZIP文件
zip_filename = f"{url_hash}.zip"
url_hash = self._get_url_hash(zip_url)
# 当使用unique_id时使用unique_id作为ZIP文件名前缀以避免冲突
if unique_id:
zip_filename = f"{unique_id}_{url_hash}.zip"
else:
zip_filename = f"{url_hash}.zip"
zip_path = self.cache_dir / zip_filename
if not zip_path.exists():