234 lines
7.3 KiB
Python
234 lines
7.3 KiB
Python
import pickle
|
||
import re
|
||
import numpy as np
|
||
from sentence_transformers import SentenceTransformer, util
|
||
|
||
# 延迟加载模型
|
||
embedder = None
|
||
|
||
def get_model():
|
||
"""获取模型实例(延迟加载)"""
|
||
global embedder
|
||
if embedder is None:
|
||
print("正在加载模型...")
|
||
embedder = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', device='cpu')
|
||
print("模型加载完成")
|
||
return embedder
|
||
|
||
def clean_text(text):
|
||
"""
|
||
清理文本,去除特殊字符和无意义字符
|
||
|
||
Args:
|
||
text (str): 原始文本
|
||
|
||
Returns:
|
||
str: 清理后的文本
|
||
"""
|
||
# 去除HTML标签
|
||
text = re.sub(r'<[^>]+>', '', text)
|
||
|
||
# 去除多余的空白字符
|
||
text = re.sub(r'\s+', ' ', text)
|
||
|
||
# 去除控制字符和非打印字符,但保留Unicode文字字符
|
||
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', text)
|
||
|
||
# 去除首尾空白
|
||
text = text.strip()
|
||
|
||
return text
|
||
|
||
def is_meaningful_line(text):
|
||
"""
|
||
判断一行文本是否有意义
|
||
|
||
Args:
|
||
text (str): 文本行
|
||
|
||
Returns:
|
||
bool: 是否有意义
|
||
"""
|
||
if not text or len(text.strip()) < 5:
|
||
return False
|
||
|
||
# 过滤纯数字行
|
||
if text.strip().isdigit():
|
||
return False
|
||
|
||
# 过滤只有符号的行
|
||
if re.match(r'^[^\w\u4e00-\u9fa5]+$', text):
|
||
return False
|
||
|
||
# 过滤常见的无意义行
|
||
meaningless_patterns = [
|
||
r'^[-=_]{3,}$', # 分割线
|
||
r'^第\d+页$', # 页码
|
||
r'^\d+\.\s*$', # 只有编号
|
||
r'^[a-zA-Z]\.\s*$', # 只有一个字母编号
|
||
]
|
||
|
||
for pattern in meaningless_patterns:
|
||
if re.match(pattern, text.strip()):
|
||
return False
|
||
|
||
return True
|
||
|
||
def embed_document(input_file='document.txt', output_file='document_embeddings.pkl'):
|
||
"""
|
||
读取document.txt文件,按行进行embedding,保存为pickle文件
|
||
|
||
Args:
|
||
input_file (str): 输入文档文件路径
|
||
output_file (str): 输出pickle文件路径
|
||
"""
|
||
try:
|
||
with open(input_file, 'r', encoding='utf-8') as f:
|
||
lines = f.readlines()
|
||
|
||
cleaned_sentences = []
|
||
original_count = len(lines)
|
||
|
||
for line in lines:
|
||
# 清理文本
|
||
cleaned_text = clean_text(line)
|
||
|
||
# 检查是否有意义
|
||
if is_meaningful_line(cleaned_text):
|
||
cleaned_sentences.append(cleaned_text)
|
||
|
||
print(f"原始行数: {original_count}")
|
||
print(f"清理后有效句子数: {len(cleaned_sentences)}")
|
||
print(f"过滤比例: {((original_count - len(cleaned_sentences)) / original_count * 100):.1f}%")
|
||
|
||
if not cleaned_sentences:
|
||
print("警告:没有找到有意义的句子!")
|
||
return None
|
||
|
||
print(f"正在处理 {len(cleaned_sentences)} 个有效句子...")
|
||
|
||
model = get_model()
|
||
sentence_embeddings = model.encode(cleaned_sentences, convert_to_tensor=True)
|
||
|
||
embedding_data = {
|
||
'sentences': cleaned_sentences,
|
||
'embeddings': sentence_embeddings
|
||
}
|
||
|
||
with open(output_file, 'wb') as f:
|
||
pickle.dump(embedding_data, f)
|
||
|
||
print(f"已保存嵌入向量到 {output_file}")
|
||
return embedding_data
|
||
|
||
except FileNotFoundError:
|
||
print(f"错误:找不到文件 {input_file}")
|
||
return None
|
||
except Exception as e:
|
||
print(f"处理文档时出错:{e}")
|
||
return None
|
||
|
||
def semantic_search(user_query, embeddings_file='document_embeddings.pkl', top_k=20):
|
||
"""
|
||
输入用户查询,进行语义匹配,返回top_k个最相关的句子
|
||
|
||
Args:
|
||
user_query (str): 用户查询
|
||
embeddings_file (str): 嵌入向量文件路径
|
||
top_k (int): 返回的结果数量
|
||
|
||
Returns:
|
||
list: 包含(句子, 相似度分数)的列表
|
||
"""
|
||
try:
|
||
with open(embeddings_file, 'rb') as f:
|
||
embedding_data = pickle.load(f)
|
||
|
||
sentences = embedding_data['sentences']
|
||
sentence_embeddings = embedding_data['embeddings']
|
||
|
||
model = get_model()
|
||
query_embedding = model.encode(user_query, convert_to_tensor=True)
|
||
|
||
cos_scores = util.cos_sim(query_embedding, sentence_embeddings)[0]
|
||
|
||
top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k]
|
||
|
||
results = []
|
||
print(f"\n与查询最相关的 {top_k} 个句子:")
|
||
for i, idx in enumerate(top_results):
|
||
sentence = sentences[idx]
|
||
score = cos_scores[idx].item()
|
||
results.append((sentence, score))
|
||
print(f"{i+1}. [{score:.4f}] {sentence}")
|
||
|
||
return results
|
||
|
||
except FileNotFoundError:
|
||
print(f"错误:找不到嵌入文件 {embeddings_file}")
|
||
print("请先运行 embed_document() 函数生成嵌入文件")
|
||
return []
|
||
except Exception as e:
|
||
print(f"搜索时出错:{e}")
|
||
return []
|
||
|
||
|
||
def split_document_by_pages(input_file='document.txt', output_file='serialization.txt'):
|
||
"""
|
||
按页分割document.txt文件,将每页内容整理成一行写入serialization.txt
|
||
|
||
Args:
|
||
input_file (str): 输入文档文件路径
|
||
output_file (str): 输出序列化文件路径
|
||
"""
|
||
try:
|
||
with open(input_file, 'r', encoding='utf-8') as f:
|
||
lines = f.readlines()
|
||
|
||
pages = []
|
||
current_page = []
|
||
|
||
for line in lines:
|
||
line = line.strip()
|
||
|
||
# 检查是否是页分隔符
|
||
if re.match(r'^#\s*Page\s+\d+', line, re.IGNORECASE):
|
||
# 如果当前页有内容,保存当前页
|
||
if current_page:
|
||
# 将当前页内容合并成一行
|
||
page_content = '\\n'.join(current_page).strip()
|
||
if page_content: # 只保存非空页面
|
||
pages.append(page_content)
|
||
current_page = []
|
||
continue
|
||
|
||
# 如果不是页分隔符且有内容,添加到当前页
|
||
if line:
|
||
current_page.append(line)
|
||
|
||
# 处理最后一页
|
||
if current_page:
|
||
page_content = ' '.join(current_page).strip()
|
||
if page_content:
|
||
pages.append(page_content)
|
||
|
||
print(f"总共分割出 {len(pages)} 页")
|
||
|
||
# 写入序列化文件
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
for i, page_content in enumerate(pages, 1):
|
||
f.write(f"{page_content}\n")
|
||
|
||
print(f"已将页面内容序列化到 {output_file}")
|
||
return pages
|
||
|
||
except FileNotFoundError:
|
||
print(f"错误:找不到文件 {input_file}")
|
||
return []
|
||
except Exception as e:
|
||
print(f"分割文档时出错:{e}")
|
||
return []
|
||
|
||
split_document_by_pages("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt")
|
||
# embed_document("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt") # 取消注释来运行
|