catalog-agent/embedding/embedding.py
2025-10-16 21:06:02 +08:00

234 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pickle
import re
import numpy as np
from sentence_transformers import SentenceTransformer, util
# 延迟加载模型
embedder = None
def get_model():
"""获取模型实例(延迟加载)"""
global embedder
if embedder is None:
print("正在加载模型...")
embedder = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', device='cpu')
print("模型加载完成")
return embedder
def clean_text(text):
"""
清理文本,去除特殊字符和无意义字符
Args:
text (str): 原始文本
Returns:
str: 清理后的文本
"""
# 去除HTML标签
text = re.sub(r'<[^>]+>', '', text)
# 去除多余的空白字符
text = re.sub(r'\s+', ' ', text)
# 去除控制字符和非打印字符但保留Unicode文字字符
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', text)
# 去除首尾空白
text = text.strip()
return text
def is_meaningful_line(text):
"""
判断一行文本是否有意义
Args:
text (str): 文本行
Returns:
bool: 是否有意义
"""
if not text or len(text.strip()) < 5:
return False
# 过滤纯数字行
if text.strip().isdigit():
return False
# 过滤只有符号的行
if re.match(r'^[^\w\u4e00-\u9fa5]+$', text):
return False
# 过滤常见的无意义行
meaningless_patterns = [
r'^[-=_]{3,}$', # 分割线
r'^第\d+页$', # 页码
r'^\d+\.\s*$', # 只有编号
r'^[a-zA-Z]\.\s*$', # 只有一个字母编号
]
for pattern in meaningless_patterns:
if re.match(pattern, text.strip()):
return False
return True
def embed_document(input_file='document.txt', output_file='document_embeddings.pkl'):
"""
读取document.txt文件按行进行embedding保存为pickle文件
Args:
input_file (str): 输入文档文件路径
output_file (str): 输出pickle文件路径
"""
try:
with open(input_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
cleaned_sentences = []
original_count = len(lines)
for line in lines:
# 清理文本
cleaned_text = clean_text(line)
# 检查是否有意义
if is_meaningful_line(cleaned_text):
cleaned_sentences.append(cleaned_text)
print(f"原始行数: {original_count}")
print(f"清理后有效句子数: {len(cleaned_sentences)}")
print(f"过滤比例: {((original_count - len(cleaned_sentences)) / original_count * 100):.1f}%")
if not cleaned_sentences:
print("警告:没有找到有意义的句子!")
return None
print(f"正在处理 {len(cleaned_sentences)} 个有效句子...")
model = get_model()
sentence_embeddings = model.encode(cleaned_sentences, convert_to_tensor=True)
embedding_data = {
'sentences': cleaned_sentences,
'embeddings': sentence_embeddings
}
with open(output_file, 'wb') as f:
pickle.dump(embedding_data, f)
print(f"已保存嵌入向量到 {output_file}")
return embedding_data
except FileNotFoundError:
print(f"错误:找不到文件 {input_file}")
return None
except Exception as e:
print(f"处理文档时出错:{e}")
return None
def semantic_search(user_query, embeddings_file='document_embeddings.pkl', top_k=20):
"""
输入用户查询进行语义匹配返回top_k个最相关的句子
Args:
user_query (str): 用户查询
embeddings_file (str): 嵌入向量文件路径
top_k (int): 返回的结果数量
Returns:
list: 包含(句子, 相似度分数)的列表
"""
try:
with open(embeddings_file, 'rb') as f:
embedding_data = pickle.load(f)
sentences = embedding_data['sentences']
sentence_embeddings = embedding_data['embeddings']
model = get_model()
query_embedding = model.encode(user_query, convert_to_tensor=True)
cos_scores = util.cos_sim(query_embedding, sentence_embeddings)[0]
top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k]
results = []
print(f"\n与查询最相关的 {top_k} 个句子:")
for i, idx in enumerate(top_results):
sentence = sentences[idx]
score = cos_scores[idx].item()
results.append((sentence, score))
print(f"{i+1}. [{score:.4f}] {sentence}")
return results
except FileNotFoundError:
print(f"错误:找不到嵌入文件 {embeddings_file}")
print("请先运行 embed_document() 函数生成嵌入文件")
return []
except Exception as e:
print(f"搜索时出错:{e}")
return []
def split_document_by_pages(input_file='document.txt', output_file='serialization.txt'):
"""
按页分割document.txt文件将每页内容整理成一行写入serialization.txt
Args:
input_file (str): 输入文档文件路径
output_file (str): 输出序列化文件路径
"""
try:
with open(input_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
pages = []
current_page = []
for line in lines:
line = line.strip()
# 检查是否是页分隔符
if re.match(r'^#\s*Page\s+\d+', line, re.IGNORECASE):
# 如果当前页有内容,保存当前页
if current_page:
# 将当前页内容合并成一行
page_content = '\\n'.join(current_page).strip()
if page_content: # 只保存非空页面
pages.append(page_content)
current_page = []
continue
# 如果不是页分隔符且有内容,添加到当前页
if line:
current_page.append(line)
# 处理最后一页
if current_page:
page_content = ' '.join(current_page).strip()
if page_content:
pages.append(page_content)
print(f"总共分割出 {len(pages)}")
# 写入序列化文件
with open(output_file, 'w', encoding='utf-8') as f:
for i, page_content in enumerate(pages, 1):
f.write(f"{page_content}\n")
print(f"已将页面内容序列化到 {output_file}")
return pages
except FileNotFoundError:
print(f"错误:找不到文件 {input_file}")
return []
except Exception as e:
print(f"分割文档时出错:{e}")
return []
split_document_by_pages("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt")
# embed_document("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt") # 取消注释来运行