add semantic search
This commit is contained in:
parent
4e4b094709
commit
d0e3e62291
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
||||
3.12.0
|
||||
28874
embedding/document.txt
Normal file
28874
embedding/document.txt
Normal file
File diff suppressed because one or more lines are too long
233
embedding/embedding.py
Normal file
233
embedding/embedding.py
Normal file
@ -0,0 +1,233 @@
|
||||
import pickle
|
||||
import re
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
|
||||
# 延迟加载模型
|
||||
embedder = None
|
||||
|
||||
def get_model():
|
||||
"""获取模型实例(延迟加载)"""
|
||||
global embedder
|
||||
if embedder is None:
|
||||
print("正在加载模型...")
|
||||
embedder = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', device='cpu')
|
||||
print("模型加载完成")
|
||||
return embedder
|
||||
|
||||
def clean_text(text):
|
||||
"""
|
||||
清理文本,去除特殊字符和无意义字符
|
||||
|
||||
Args:
|
||||
text (str): 原始文本
|
||||
|
||||
Returns:
|
||||
str: 清理后的文本
|
||||
"""
|
||||
# 去除HTML标签
|
||||
text = re.sub(r'<[^>]+>', '', text)
|
||||
|
||||
# 去除多余的空白字符
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
|
||||
# 去除控制字符和非打印字符,但保留Unicode文字字符
|
||||
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', text)
|
||||
|
||||
# 去除首尾空白
|
||||
text = text.strip()
|
||||
|
||||
return text
|
||||
|
||||
def is_meaningful_line(text):
|
||||
"""
|
||||
判断一行文本是否有意义
|
||||
|
||||
Args:
|
||||
text (str): 文本行
|
||||
|
||||
Returns:
|
||||
bool: 是否有意义
|
||||
"""
|
||||
if not text or len(text.strip()) < 5:
|
||||
return False
|
||||
|
||||
# 过滤纯数字行
|
||||
if text.strip().isdigit():
|
||||
return False
|
||||
|
||||
# 过滤只有符号的行
|
||||
if re.match(r'^[^\w\u4e00-\u9fa5]+$', text):
|
||||
return False
|
||||
|
||||
# 过滤常见的无意义行
|
||||
meaningless_patterns = [
|
||||
r'^[-=_]{3,}$', # 分割线
|
||||
r'^第\d+页$', # 页码
|
||||
r'^\d+\.\s*$', # 只有编号
|
||||
r'^[a-zA-Z]\.\s*$', # 只有一个字母编号
|
||||
]
|
||||
|
||||
for pattern in meaningless_patterns:
|
||||
if re.match(pattern, text.strip()):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def embed_document(input_file='document.txt', output_file='document_embeddings.pkl'):
|
||||
"""
|
||||
读取document.txt文件,按行进行embedding,保存为pickle文件
|
||||
|
||||
Args:
|
||||
input_file (str): 输入文档文件路径
|
||||
output_file (str): 输出pickle文件路径
|
||||
"""
|
||||
try:
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
cleaned_sentences = []
|
||||
original_count = len(lines)
|
||||
|
||||
for line in lines:
|
||||
# 清理文本
|
||||
cleaned_text = clean_text(line)
|
||||
|
||||
# 检查是否有意义
|
||||
if is_meaningful_line(cleaned_text):
|
||||
cleaned_sentences.append(cleaned_text)
|
||||
|
||||
print(f"原始行数: {original_count}")
|
||||
print(f"清理后有效句子数: {len(cleaned_sentences)}")
|
||||
print(f"过滤比例: {((original_count - len(cleaned_sentences)) / original_count * 100):.1f}%")
|
||||
|
||||
if not cleaned_sentences:
|
||||
print("警告:没有找到有意义的句子!")
|
||||
return None
|
||||
|
||||
print(f"正在处理 {len(cleaned_sentences)} 个有效句子...")
|
||||
|
||||
model = get_model()
|
||||
sentence_embeddings = model.encode(cleaned_sentences, convert_to_tensor=True)
|
||||
|
||||
embedding_data = {
|
||||
'sentences': cleaned_sentences,
|
||||
'embeddings': sentence_embeddings
|
||||
}
|
||||
|
||||
with open(output_file, 'wb') as f:
|
||||
pickle.dump(embedding_data, f)
|
||||
|
||||
print(f"已保存嵌入向量到 {output_file}")
|
||||
return embedding_data
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"错误:找不到文件 {input_file}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"处理文档时出错:{e}")
|
||||
return None
|
||||
|
||||
def semantic_search(user_query, embeddings_file='document_embeddings.pkl', top_k=20):
|
||||
"""
|
||||
输入用户查询,进行语义匹配,返回top_k个最相关的句子
|
||||
|
||||
Args:
|
||||
user_query (str): 用户查询
|
||||
embeddings_file (str): 嵌入向量文件路径
|
||||
top_k (int): 返回的结果数量
|
||||
|
||||
Returns:
|
||||
list: 包含(句子, 相似度分数)的列表
|
||||
"""
|
||||
try:
|
||||
with open(embeddings_file, 'rb') as f:
|
||||
embedding_data = pickle.load(f)
|
||||
|
||||
sentences = embedding_data['sentences']
|
||||
sentence_embeddings = embedding_data['embeddings']
|
||||
|
||||
model = get_model()
|
||||
query_embedding = model.encode(user_query, convert_to_tensor=True)
|
||||
|
||||
cos_scores = util.cos_sim(query_embedding, sentence_embeddings)[0]
|
||||
|
||||
top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k]
|
||||
|
||||
results = []
|
||||
print(f"\n与查询最相关的 {top_k} 个句子:")
|
||||
for i, idx in enumerate(top_results):
|
||||
sentence = sentences[idx]
|
||||
score = cos_scores[idx].item()
|
||||
results.append((sentence, score))
|
||||
print(f"{i+1}. [{score:.4f}] {sentence}")
|
||||
|
||||
return results
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"错误:找不到嵌入文件 {embeddings_file}")
|
||||
print("请先运行 embed_document() 函数生成嵌入文件")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"搜索时出错:{e}")
|
||||
return []
|
||||
|
||||
|
||||
def split_document_by_pages(input_file='document.txt', output_file='serialization.txt'):
|
||||
"""
|
||||
按页分割document.txt文件,将每页内容整理成一行写入serialization.txt
|
||||
|
||||
Args:
|
||||
input_file (str): 输入文档文件路径
|
||||
output_file (str): 输出序列化文件路径
|
||||
"""
|
||||
try:
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
pages = []
|
||||
current_page = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
|
||||
# 检查是否是页分隔符
|
||||
if re.match(r'^#\s*Page\s+\d+', line, re.IGNORECASE):
|
||||
# 如果当前页有内容,保存当前页
|
||||
if current_page:
|
||||
# 将当前页内容合并成一行
|
||||
page_content = '\\n'.join(current_page).strip()
|
||||
if page_content: # 只保存非空页面
|
||||
pages.append(page_content)
|
||||
current_page = []
|
||||
continue
|
||||
|
||||
# 如果不是页分隔符且有内容,添加到当前页
|
||||
if line:
|
||||
current_page.append(line)
|
||||
|
||||
# 处理最后一页
|
||||
if current_page:
|
||||
page_content = ' '.join(current_page).strip()
|
||||
if page_content:
|
||||
pages.append(page_content)
|
||||
|
||||
print(f"总共分割出 {len(pages)} 页")
|
||||
|
||||
# 写入序列化文件
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for i, page_content in enumerate(pages, 1):
|
||||
f.write(f"{page_content}\n")
|
||||
|
||||
print(f"已将页面内容序列化到 {output_file}")
|
||||
return pages
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"错误:找不到文件 {input_file}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"分割文档时出错:{e}")
|
||||
return []
|
||||
|
||||
split_document_by_pages("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt")
|
||||
# embed_document("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt") # 取消注释来运行
|
||||
@ -43,6 +43,17 @@ from file_loaded_agent_manager import get_global_agent_manager, init_global_agen
|
||||
from gbase_agent import update_agent_llm
|
||||
from zip_project_handler import zip_handler
|
||||
|
||||
|
||||
def get_zip_url_from_unique_id(unique_id: str) -> Optional[str]:
|
||||
"""从unique_map.json中读取zip_url"""
|
||||
try:
|
||||
with open('unique_map.json', 'r', encoding='utf-8') as f:
|
||||
unique_map = json.load(f)
|
||||
return unique_map.get(unique_id)
|
||||
except Exception as e:
|
||||
print(f"Error reading unique_map.json: {e}")
|
||||
return None
|
||||
|
||||
# 全局助手管理器配置
|
||||
max_cached_agents = int(os.getenv("MAX_CACHED_AGENTS", "20"))
|
||||
|
||||
@ -74,6 +85,7 @@ class ChatRequest(BaseModel):
|
||||
model: str = "qwen3-next"
|
||||
model_server: str = ""
|
||||
zip_url: Optional[str] = None
|
||||
unique_id: Optional[str] = None
|
||||
stream: Optional[bool] = False
|
||||
|
||||
class Config:
|
||||
@ -180,15 +192,22 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
|
||||
else:
|
||||
api_key = authorization
|
||||
|
||||
# 从最外层获取zip_url参数
|
||||
# 从最外层获取zip_url和unique_id参数
|
||||
zip_url = request.zip_url
|
||||
unique_id = request.unique_id
|
||||
|
||||
# 如果提供了unique_id,从unique_map.json中读取zip_url
|
||||
if unique_id:
|
||||
zip_url = get_zip_url_from_unique_id(unique_id)
|
||||
if not zip_url:
|
||||
raise HTTPException(status_code=400, detail=f"No zip_url found for unique_id: {unique_id}")
|
||||
|
||||
if not zip_url:
|
||||
raise HTTPException(status_code=400, detail="zip_url is required")
|
||||
|
||||
# 使用ZIP URL获取项目数据
|
||||
print(f"从ZIP URL加载项目: {zip_url}")
|
||||
project_dir = zip_handler.get_project_from_zip(zip_url)
|
||||
project_dir = zip_handler.get_project_from_zip(zip_url, unique_id if unique_id else None)
|
||||
if not project_dir:
|
||||
raise HTTPException(status_code=400, detail=f"Failed to load project from ZIP URL: {zip_url}")
|
||||
|
||||
@ -199,7 +218,7 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
|
||||
print(f"警告: 项目目录 {project_dir} 中未找到任何 document.txt 文件")
|
||||
|
||||
# 收集额外参数作为 generate_cfg
|
||||
exclude_fields = {'messages', 'model', 'model_server', 'zip_url', 'stream'}
|
||||
exclude_fields = {'messages', 'model', 'model_server', 'zip_url', 'unique_id', 'stream'}
|
||||
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
|
||||
|
||||
# 从全局管理器获取或创建文件预加载的助手实例
|
||||
@ -213,7 +232,19 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
|
||||
generate_cfg=generate_cfg
|
||||
)
|
||||
# 构建包含项目信息的消息上下文
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
messages = []
|
||||
for msg in request.messages:
|
||||
if msg.role == "assistant":
|
||||
# 对assistant消息进行[ANSWER]分割处理,只保留最后一段
|
||||
content_parts = msg.content.split("[ANSWER]")
|
||||
if content_parts:
|
||||
# 取最后一段非空文本
|
||||
last_part = content_parts[-1].strip()
|
||||
messages.append({"role": msg.role, "content": last_part})
|
||||
else:
|
||||
messages.append({"role": msg.role, "content": msg.content})
|
||||
else:
|
||||
messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
# 根据stream参数决定返回流式还是非流式响应
|
||||
if request.stream:
|
||||
|
||||
367
mcp/semantic_search_server.py
Normal file
367
mcp/semantic_search_server.py
Normal file
@ -0,0 +1,367 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
语义搜索MCP服务器
|
||||
基于embedding向量进行语义相似度搜索
|
||||
参考multi_keyword_search_server.py的实现方式
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
|
||||
# 延迟加载模型
|
||||
embedder = None
|
||||
|
||||
def get_model():
|
||||
"""获取模型实例(延迟加载)"""
|
||||
global embedder
|
||||
if embedder is None:
|
||||
embedder = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', device='cpu')
|
||||
return embedder
|
||||
|
||||
|
||||
def validate_file_path(file_path: str, allowed_dir: str) -> str:
|
||||
"""验证文件路径是否在允许的目录内"""
|
||||
# 转换为绝对路径
|
||||
if not os.path.isabs(file_path):
|
||||
file_path = os.path.abspath(file_path)
|
||||
|
||||
allowed_dir = os.path.abspath(allowed_dir)
|
||||
|
||||
# 检查路径是否在允许的目录内
|
||||
if not file_path.startswith(allowed_dir):
|
||||
raise ValueError(f"访问被拒绝: 路径 {file_path} 不在允许的目录 {allowed_dir} 内")
|
||||
|
||||
# 检查路径遍历攻击
|
||||
if ".." in file_path:
|
||||
raise ValueError(f"访问被拒绝: 检测到路径遍历攻击尝试")
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def get_allowed_directory():
|
||||
"""获取允许访问的目录"""
|
||||
# 从环境变量读取项目数据目录
|
||||
project_dir = os.getenv("PROJECT_DATA_DIR", "./projects")
|
||||
return os.path.abspath(project_dir)
|
||||
|
||||
|
||||
def semantic_search(query: str, embeddings_file: str, top_k: int = 20) -> Dict[str, Any]:
|
||||
"""执行语义搜索"""
|
||||
if not query.strip():
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "错误:查询不能为空"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 处理项目目录限制
|
||||
project_data_dir = get_allowed_directory()
|
||||
|
||||
# 验证embeddings文件路径
|
||||
try:
|
||||
# 解析相对路径
|
||||
if not os.path.isabs(embeddings_file):
|
||||
# 尝试在项目目录中查找文件
|
||||
full_path = os.path.join(project_data_dir, embeddings_file.lstrip('./'))
|
||||
if not os.path.exists(full_path):
|
||||
# 如果直接路径不存在,尝试递归查找
|
||||
found = find_file_in_project(embeddings_file, project_data_dir)
|
||||
if found:
|
||||
embeddings_file = found
|
||||
else:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"错误:在项目目录 {project_data_dir} 中未找到embeddings文件 {embeddings_file}"
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
embeddings_file = full_path
|
||||
else:
|
||||
if not embeddings_file.startswith(project_data_dir):
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"错误:embeddings文件路径必须在项目目录 {project_data_dir} 内"
|
||||
}
|
||||
]
|
||||
}
|
||||
if not os.path.exists(embeddings_file):
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"错误:embeddings文件 {embeddings_file} 不存在"
|
||||
}
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"错误:embeddings文件路径验证失败 - {str(e)}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
try:
|
||||
# 加载嵌入数据
|
||||
with open(embeddings_file, 'rb') as f:
|
||||
embedding_data = pickle.load(f)
|
||||
|
||||
sentences = embedding_data['sentences']
|
||||
sentence_embeddings = embedding_data['embeddings']
|
||||
|
||||
# 加载模型
|
||||
model = get_model()
|
||||
|
||||
# 编码查询
|
||||
query_embedding = model.encode(query, convert_to_tensor=True)
|
||||
|
||||
# 计算相似度
|
||||
cos_scores = util.cos_sim(query_embedding, sentence_embeddings)[0]
|
||||
|
||||
# 获取top_k结果
|
||||
top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k]
|
||||
|
||||
# 格式化结果
|
||||
results = []
|
||||
for i, idx in enumerate(top_results):
|
||||
sentence = sentences[idx]
|
||||
score = cos_scores[idx].item()
|
||||
results.append({
|
||||
'rank': i + 1,
|
||||
'content': sentence,
|
||||
'similarity_score': score,
|
||||
'file_path': embeddings_file
|
||||
})
|
||||
|
||||
if not results:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "未找到匹配的结果"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 格式化输出
|
||||
formatted_output = "\n".join([
|
||||
f"#{result['rank']} [similarity:{result['similarity_score']:.4f}]: {result['content']}"
|
||||
for result in results
|
||||
])
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": formatted_output
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
except FileNotFoundError:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"错误:找不到embeddings文件 {embeddings_file}"
|
||||
}
|
||||
]
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"搜索时出错:{str(e)}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def find_file_in_project(filename: str, project_dir: str) -> Optional[str]:
|
||||
"""在项目目录中递归查找文件"""
|
||||
for root, dirs, files in os.walk(project_dir):
|
||||
if filename in files:
|
||||
return os.path.join(root, filename)
|
||||
return None
|
||||
|
||||
|
||||
async def handle_request(request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Handle MCP request"""
|
||||
try:
|
||||
method = request.get("method")
|
||||
params = request.get("params", {})
|
||||
request_id = request.get("id")
|
||||
|
||||
if method == "initialize":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {
|
||||
"tools": {}
|
||||
},
|
||||
"serverInfo": {
|
||||
"name": "semantic-search",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
elif method == "ping":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"pong": True
|
||||
}
|
||||
}
|
||||
|
||||
elif method == "tools/list":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
"name": "semantic_search",
|
||||
"description": "语义搜索工具,基于向量嵌入进行相似度搜索。格式:#[排名] [相似度分数]: [匹配内容]",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "搜索查询文本"
|
||||
},
|
||||
"embeddings_file": {
|
||||
"type": "string",
|
||||
"description": "embeddings pickle文件路径"
|
||||
},
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
"description": "返回结果的最大数量,默认20",
|
||||
"default": 20
|
||||
}
|
||||
},
|
||||
"required": ["query", "embeddings_file"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
elif method == "tools/call":
|
||||
tool_name = params.get("name")
|
||||
arguments = params.get("arguments", {})
|
||||
|
||||
if tool_name == "semantic_search":
|
||||
query = arguments.get("query", "")
|
||||
embeddings_file = arguments.get("embeddings_file", "")
|
||||
top_k = arguments.get("top_k", 20)
|
||||
|
||||
result = semantic_search(query, embeddings_file, top_k)
|
||||
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": result
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": f"Unknown tool: {tool_name}"
|
||||
}
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": f"Unknown method: {method}"
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request.get("id"),
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": f"Internal error: {str(e)}"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point."""
|
||||
try:
|
||||
while True:
|
||||
# Read from stdin
|
||||
line = await asyncio.get_event_loop().run_in_executor(None, sys.stdin.readline)
|
||||
if not line:
|
||||
break
|
||||
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
request = json.loads(line)
|
||||
response = await handle_request(request)
|
||||
|
||||
# Write to stdout
|
||||
sys.stdout.write(json.dumps(response) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
except json.JSONDecodeError:
|
||||
error_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"error": {
|
||||
"code": -32700,
|
||||
"message": "Parse error"
|
||||
}
|
||||
}
|
||||
sys.stdout.write(json.dumps(error_response) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
except Exception as e:
|
||||
error_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": f"Internal error: {str(e)}"
|
||||
}
|
||||
}
|
||||
sys.stdout.write(json.dumps(error_response) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
3952
poetry.lock
generated
Normal file
3952
poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
31
pyproject.toml
Normal file
31
pyproject.toml
Normal file
@ -0,0 +1,31 @@
|
||||
[project]
|
||||
name = "catalog-agent"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = [
|
||||
{name = "朱潮",email = "zhuchaowe@users.noreply.github.com"}
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = "3.12.0"
|
||||
dependencies = [
|
||||
"fastapi==0.116.1",
|
||||
"uvicorn==0.35.0",
|
||||
"requests==2.32.5",
|
||||
"qwen-agent[rag,mcp]==0.0.29",
|
||||
"pydantic==2.10.5",
|
||||
"python-dateutil==2.8.2",
|
||||
"torch==2.2.0",
|
||||
"transformers",
|
||||
"sentence-transformers",
|
||||
"numpy<2",
|
||||
]
|
||||
|
||||
|
||||
[tool.poetry]
|
||||
package-mode = false
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
|
||||
@ -11,3 +11,9 @@ qwen-agent[rag,mcp]==0.0.29
|
||||
# 数据处理
|
||||
pydantic==2.10.5
|
||||
python-dateutil==2.8.2
|
||||
|
||||
|
||||
# embedding
|
||||
torch
|
||||
transformers
|
||||
sentence-transformers
|
||||
|
||||
3
unique_map.json
Normal file
3
unique_map.json
Normal file
@ -0,0 +1,3 @@
|
||||
{
|
||||
"b743ccc3-13be-43ea-8ec9-4ce9c86103b3": "public/all_hp_product_spec_book2506.zip"
|
||||
}
|
||||
@ -79,12 +79,13 @@ class ZipProjectHandler:
|
||||
print(f"解压ZIP文件失败: {e}")
|
||||
return False
|
||||
|
||||
def get_project_from_zip(self, zip_url: str) -> Optional[str]:
|
||||
def get_project_from_zip(self, zip_url: str, unique_id: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
从ZIP URL或本地路径获取项目数据
|
||||
|
||||
Args:
|
||||
zip_url: ZIP文件的URL或本地相对路径
|
||||
unique_id: 可选的唯一标识符,用作文件夹名称
|
||||
|
||||
Returns:
|
||||
Optional[str]: 成功时返回项目目录路径,失败时返回None
|
||||
@ -93,16 +94,26 @@ class ZipProjectHandler:
|
||||
print(f"无效的URL或路径: {zip_url}")
|
||||
return None
|
||||
|
||||
# 检查缓存
|
||||
url_hash = self._get_url_hash(zip_url)
|
||||
cached_project_dir = self.projects_dir / url_hash
|
||||
# 使用unique_id作为目录名,如果没有则使用url_hash
|
||||
if unique_id:
|
||||
project_dir_name = unique_id
|
||||
# 当使用unique_id时,不检查缓存,直接重新解压以确保项目结构正确
|
||||
cached_project_dir = self.projects_dir / project_dir_name
|
||||
else:
|
||||
project_dir_name = self._get_url_hash(zip_url)
|
||||
cached_project_dir = self.projects_dir / project_dir_name
|
||||
|
||||
if cached_project_dir.exists():
|
||||
if cached_project_dir.exists() and not unique_id:
|
||||
print(f"使用缓存的项目目录: {cached_project_dir}")
|
||||
return str(cached_project_dir)
|
||||
|
||||
# 下载或复制ZIP文件
|
||||
zip_filename = f"{url_hash}.zip"
|
||||
url_hash = self._get_url_hash(zip_url)
|
||||
# 当使用unique_id时,使用unique_id作为ZIP文件名前缀以避免冲突
|
||||
if unique_id:
|
||||
zip_filename = f"{unique_id}_{url_hash}.zip"
|
||||
else:
|
||||
zip_filename = f"{url_hash}.zip"
|
||||
zip_path = self.cache_dir / zip_filename
|
||||
|
||||
if not zip_path.exists():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user