catalog-agent/embedding/embedding.py
2025-10-17 22:04:10 +08:00

765 lines
26 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pickle
import re
import numpy as np
from sentence_transformers import SentenceTransformer, util
# 延迟加载模型
embedder = None
def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'):
"""获取模型实例(延迟加载)
Args:
model_name_or_path (str): 模型名称或本地路径
- 可以是 HuggingFace 模型名称
- 可以是本地模型路径
"""
global embedder
if embedder is None:
print("正在加载模型...")
print(f"模型路径: {model_name_or_path}")
# 检查是否是本地路径
import os
if os.path.exists(model_name_or_path):
print("使用本地模型")
embedder = SentenceTransformer(model_name_or_path, device='cpu')
else:
print("使用 HuggingFace 模型")
embedder = SentenceTransformer(model_name_or_path, device='cpu')
print("模型加载完成")
return embedder
def clean_text(text):
"""
清理文本,去除特殊字符和无意义字符
Args:
text (str): 原始文本
Returns:
str: 清理后的文本
"""
# 去除HTML标签
text = re.sub(r'<[^>]+>', '', text)
# 去除多余的空白字符
text = re.sub(r'\s+', ' ', text)
# 去除控制字符和非打印字符但保留Unicode文字字符
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', text)
# 去除首尾空白
text = text.strip()
return text
def is_meaningful_line(text):
"""
判断一行文本是否有意义
Args:
text (str): 文本行
Returns:
bool: 是否有意义
"""
if not text or len(text.strip()) < 5:
return False
# 过滤纯数字行
if text.strip().isdigit():
return False
# 过滤只有符号的行
if re.match(r'^[^\w\u4e00-\u9fa5]+$', text):
return False
# 过滤常见的无意义行
meaningless_patterns = [
r'^[-=_]{3,}$', # 分割线
r'^第\d+页$', # 页码
r'^\d+\.\s*$', # 只有编号
r'^[a-zA-Z]\.\s*$', # 只有一个字母编号
]
for pattern in meaningless_patterns:
if re.match(pattern, text.strip()):
return False
return True
def embed_document(input_file='document.txt', output_file='document_embeddings.pkl',
chunking_strategy='line', model_path=None, **chunking_params):
"""
读取文档文件使用指定分块策略进行embedding保存为pickle文件
Args:
input_file (str): 输入文档文件路径
output_file (str): 输出pickle文件路径
chunking_strategy (str): 分块策略,可选 'line', 'paragraph'
model_path (str): 模型路径可以是本地路径或HuggingFace模型名称
**chunking_params: 分块参数
- 对于 'line' 策略:无额外参数
- 对于 'paragraph' 策略:
- max_chunk_size: 最大chunk大小默认1000
- overlap: 重叠大小默认100
- min_chunk_size: 最小chunk大小默认200
- separator: 段落分隔符(默认'\n'
"""
try:
with open(input_file, 'r', encoding='utf-8') as f:
content = f.read()
chunks = []
if chunking_strategy == 'line':
# 原有的按行处理逻辑
lines = content.split('\n')
original_count = len(lines)
for line in lines:
# 清理文本
cleaned_text = clean_text(line)
# 检查是否有意义
if is_meaningful_line(cleaned_text):
chunks.append(cleaned_text)
print(f"使用按行分块策略")
print(f"原始行数: {original_count}")
print(f"清理后有效句子数: {len(chunks)}")
print(f"过滤比例: {((original_count - len(chunks)) / original_count * 100):.1f}%")
elif chunking_strategy == 'paragraph':
# 新的段落级分块策略
# 设置默认参数
params = {
'max_chunk_size': 1000,
'overlap': 100,
'min_chunk_size': 200,
'separator': '\n'
}
params.update(chunking_params)
# 先清理整个文档的空白字符
cleaned_content = clean_text(content)
# 使用段落分块
chunks = paragraph_chunking(cleaned_content, **params)
print(f"使用段落级分块策略")
print(f"文档总长度: {len(content)} 字符")
print(f"分块数量: {len(chunks)}")
if chunks:
print(f"平均chunk大小: {sum(len(chunk) for chunk in chunks) / len(chunks):.1f} 字符")
print(f"最大chunk大小: {max(len(chunk) for chunk in chunks)} 字符")
print(f"最小chunk大小: {min(len(chunk) for chunk in chunks)} 字符")
elif chunking_strategy == 'smart':
# 智能分块策略,自动检测文档格式
params = {
'max_chunk_size': 1000,
'overlap': 100,
'min_chunk_size': 200
}
params.update(chunking_params)
# 使用智能分块
chunks = smart_chunking(content, **params)
print(f"使用智能分块策略")
print(f"文档总长度: {len(content)} 字符")
print(f"分块数量: {len(chunks)}")
if chunks:
print(f"平均chunk大小: {sum(len(chunk) for chunk in chunks) / len(chunks):.1f} 字符")
print(f"最大chunk大小: {max(len(chunk) for chunk in chunks)} 字符")
print(f"最小chunk大小: {min(len(chunk) for chunk in chunks)} 字符")
else:
raise ValueError(f"不支持的分块策略: {chunking_strategy}")
if not chunks:
print("警告:没有找到有效的内容块!")
return None
print(f"正在处理 {len(chunks)} 个内容块...")
# 设置默认模型路径
if model_path is None:
model_path = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
model = get_model(model_path)
chunk_embeddings = model.encode(chunks, convert_to_tensor=True)
embedding_data = {
'chunks': chunks,
'embeddings': chunk_embeddings,
'chunking_strategy': chunking_strategy,
'chunking_params': chunking_params,
'model_path': model_path
}
with open(output_file, 'wb') as f:
pickle.dump(embedding_data, f)
print(f"已保存嵌入向量到 {output_file}")
return embedding_data
except FileNotFoundError:
print(f"错误:找不到文件 {input_file}")
return None
except Exception as e:
print(f"处理文档时出错:{e}")
return None
def semantic_search(user_query, embeddings_file='document_embeddings.pkl', top_k=20):
"""
输入用户查询进行语义匹配返回top_k个最相关的内容块
Args:
user_query (str): 用户查询
embeddings_file (str): 嵌入向量文件路径
top_k (int): 返回的结果数量
Returns:
list: 包含(内容块, 相似度分数)的列表
"""
try:
with open(embeddings_file, 'rb') as f:
embedding_data = pickle.load(f)
# 兼容新旧数据结构
if 'chunks' in embedding_data:
# 新的数据结构使用chunks
chunks = embedding_data['chunks']
chunk_embeddings = embedding_data['embeddings']
chunking_strategy = embedding_data.get('chunking_strategy', 'unknown')
content_type = "内容块"
else:
# 旧的数据结构使用sentences
chunks = embedding_data['sentences']
chunk_embeddings = embedding_data['embeddings']
chunking_strategy = 'line'
content_type = "句子"
# 从embedding_data中获取模型路径如果有的话
model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
model = get_model(model_path)
query_embedding = model.encode(user_query, convert_to_tensor=True)
cos_scores = util.cos_sim(query_embedding, chunk_embeddings)[0]
top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k]
results = []
print(f"\n与查询最相关的 {top_k}{content_type} (分块策略: {chunking_strategy}):")
for i, idx in enumerate(top_results):
chunk = chunks[idx]
score = cos_scores[idx].item()
results.append((chunk, score))
# 显示内容预览(如果内容太长)
preview = chunk[:100] + "..." if len(chunk) > 100 else chunk
preview = preview.replace('\n', ' ') # 替换换行符以便显示
print(f"{i+1}. [{score:.4f}] {preview}")
return results
except FileNotFoundError:
print(f"错误:找不到嵌入文件 {embeddings_file}")
print("请先运行 embed_document() 函数生成嵌入文件")
return []
except Exception as e:
print(f"搜索时出错:{e}")
return []
def paragraph_chunking(text, max_chunk_size=1000, overlap=100, min_chunk_size=200, separator='\n\n'):
"""
段落级智能分块函数 - 使用固定chunk大小分块不按页面分割
Args:
text (str): 输入文本
max_chunk_size (int): 最大chunk大小字符数
overlap (int): 重叠部分大小(字符数)
min_chunk_size (int): 最小chunk大小字符数
separator (str): 段落分隔符
Returns:
list: 分块后的文本列表
"""
if not text or not text.strip():
return []
# 直接使用固定长度分块策略,不考虑页面标记
return _fixed_length_chunking(text, max_chunk_size, overlap, min_chunk_size)
def _split_long_content(content, max_size, min_size, separator):
"""
拆分过长的内容
Args:
content (str): 要拆分的内容
max_size (int): 最大大小
min_size (int): 最小大小
separator (str): 分隔符
Returns:
list: 拆分后的块列表
"""
if len(content) <= max_size:
return [content]
# 尝试按段落拆分
paragraphs = content.split(separator)
if len(paragraphs) > 1:
chunks = []
current_chunk = ""
for para in paragraphs:
if not current_chunk:
current_chunk = para
elif len(current_chunk + separator + para) <= max_size:
current_chunk += separator + para
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = para
if current_chunk:
chunks.append(current_chunk)
return chunks
# 如果不能按段落拆分,按句子拆分
sentences = _split_into_sentences(content)
chunks = []
current_chunk = ""
for sentence in sentences:
if not current_chunk:
current_chunk = sentence
elif len(current_chunk + " " + sentence) <= max_size:
current_chunk += " " + sentence
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = sentence
if current_chunk:
chunks.append(current_chunk)
return chunks
def _split_into_sentences(text):
"""
将文本拆分为句子
Args:
text (str): 输入文本
Returns:
list: 句子列表
"""
# 简单的句子分割(可以根据需要改进)
import re
# 按句号、问号、感叹号分割,但保留数字中的点
sentence_endings = re.compile(r'(?<=[.!?])\s+(?=[\dA-Z\u4e00-\u9fa5])')
sentences = sentence_endings.split(text.strip())
return [s.strip() for s in sentences if s.strip()]
def _create_overlap_chunk(previous_chunk, new_paragraph, overlap_size):
"""
创建带有重叠内容的新chunk
Args:
previous_chunk (str): 前一个chunk
new_paragraph (str): 新段落
overlap_size (int): 重叠大小
Returns:
str: 带重叠的新chunk
"""
if overlap_size <= 0:
return new_paragraph
# 从前一个chunk的末尾获取重叠内容
overlap_text = previous_chunk[-overlap_size:] if len(previous_chunk) > overlap_size else previous_chunk
# 尝试在句子边界处分割重叠内容
sentences = _split_into_sentences(overlap_text)
if len(sentences) > 1:
# 去掉可能不完整的第一个句子
overlap_text = " ".join(sentences[1:])
elif len(overlap_text) > overlap_size * 0.5:
# 如果只有一个句子且长度合适,保留它
pass
else:
# 重叠内容太少,不使用重叠
return new_paragraph
return overlap_text + "\n\n" + new_paragraph
def _add_overlap_to_chunk(previous_chunk, current_chunk, overlap_size):
"""
为当前chunk添加与前一个chunk的重叠
Args:
previous_chunk (str): 前一个chunk
current_chunk (str): 当前chunk
overlap_size (int): 重叠大小
Returns:
str: 带重叠的chunk
"""
if overlap_size <= 0:
return current_chunk
# 从前一个chunk的末尾获取重叠内容
overlap_text = previous_chunk[-overlap_size:] if len(previous_chunk) > overlap_size else previous_chunk
# 尝试在句子边界处分割
sentences = _split_into_sentences(overlap_text)
if len(sentences) > 1:
overlap_text = " ".join(sentences[1:])
return overlap_text + "\n\n" + current_chunk
def smart_chunking(text, max_chunk_size=1000, overlap=100, min_chunk_size=200):
"""
智能分块函数,自动检测文档格式并选择最佳分块策略
Args:
text (str): 输入文本
max_chunk_size (int): 最大chunk大小字符数
overlap (int): 重叠部分大小(字符数)
min_chunk_size (int): 最小chunk大小字符数
Returns:
list: 分块后的文本列表
"""
if not text or not text.strip():
return []
# 检测文档类型(支持 # Page 和 # File 格式)
has_page_markers = '# Page' in text or '# File' in text
has_paragraph_breaks = '\n\n' in text
has_line_breaks = '\n' in text
# 选择合适的分隔符和策略
if has_page_markers:
# 使用页面分隔符
return _page_based_chunking(text, max_chunk_size, overlap, min_chunk_size)
elif has_paragraph_breaks:
# 使用段落分隔符
return paragraph_chunking(text, max_chunk_size, overlap, min_chunk_size, '\n\n')
elif has_line_breaks:
# 使用行分隔符
return _line_based_chunking(text, max_chunk_size, overlap, min_chunk_size)
else:
# 按固定长度分块
return _fixed_length_chunking(text, max_chunk_size, overlap, min_chunk_size)
def _page_based_chunking(text, max_chunk_size, overlap, min_chunk_size):
"""基于页面的分块策略"""
import re
# 使用正则表达式分割页面(支持 # Page 和 # File 格式)
page_pattern = r'#\s*(Page\s+\d+|File\s+[^\n]+)'
pages = re.split(page_pattern, text)
# 清理和过滤页面内容
cleaned_pages = []
for page in pages:
page = page.strip()
if page and len(page) > min_chunk_size * 0.3: # 过滤太小的页面
cleaned_pages.append(page)
if not cleaned_pages:
return []
# 如果页面内容过大,需要进一步分割
chunks = []
for page in cleaned_pages:
if len(page) <= max_chunk_size:
chunks.append(page)
else:
# 页面过大,需要分割
sub_chunks = _split_long_content(page, max_chunk_size, min_chunk_size, '\n')
chunks.extend(sub_chunks)
# 添加重叠
if overlap > 0 and len(chunks) > 1:
chunks = _add_overlaps_to_chunks(chunks, overlap)
return chunks
def _line_based_chunking(text, max_chunk_size, overlap, min_chunk_size):
"""基于行的分块策略"""
lines = text.split('\n')
chunks = []
current_chunk = ""
for line in lines:
line = line.strip()
if not line:
continue
if not current_chunk:
current_chunk = line
elif len(current_chunk + '\n' + line) <= max_chunk_size:
current_chunk += '\n' + line
else:
if len(current_chunk) >= min_chunk_size:
chunks.append(current_chunk)
current_chunk = _create_overlap_for_line(current_chunk, line, overlap)
else:
# 当前行太长,需要分割
split_chunks = _split_long_content(current_chunk + '\n' + line, max_chunk_size, min_chunk_size, '\n')
if chunks and split_chunks:
split_chunks[0] = _add_overlap_to_chunk(chunks[-1], split_chunks[0], overlap)
chunks.extend(split_chunks[:-1])
current_chunk = split_chunks[-1] if split_chunks else ""
if current_chunk and len(current_chunk) >= min_chunk_size:
chunks.append(current_chunk)
elif current_chunk and chunks:
chunks[-1] += '\n' + current_chunk
return chunks
def _fixed_length_chunking(text, max_chunk_size, overlap, min_chunk_size):
"""固定长度分块策略"""
chunks = []
start = 0
while start < len(text):
end = start + max_chunk_size
if end >= len(text):
chunks.append(text[start:])
break
# 尝试在句号、问号或感叹号处分割
split_pos = end
for i in range(end, max(start, end - 100), -1):
if text[i] in '.!?。!?':
split_pos = i + 1
break
chunk = text[start:split_pos]
if len(chunk) >= min_chunk_size:
chunks.append(chunk)
start = split_pos - overlap if overlap > 0 else split_pos
else:
start += max_chunk_size // 2
return chunks
def _create_overlap_for_line(previous_chunk, new_line, overlap_size):
"""为行分块创建重叠"""
if overlap_size <= 0:
return new_line
# 从前一个chunk的末尾获取重叠内容
overlap_text = previous_chunk[-overlap_size:] if len(previous_chunk) > overlap_size else previous_chunk
# 尝试在合适的边界分割
last_newline = overlap_text.rfind('\n')
if last_newline > 0:
overlap_text = overlap_text[last_newline + 1:]
return overlap_text + '\n' + new_line
def _add_overlaps_to_chunks(chunks, overlap_size):
"""为chunks添加重叠"""
if overlap_size <= 0 or len(chunks) <= 1:
return chunks
result = [chunks[0]]
for i in range(1, len(chunks)):
previous_chunk = chunks[i-1]
current_chunk = chunks[i]
# 添加重叠
overlap_text = previous_chunk[-overlap_size:] if len(previous_chunk) > overlap_size else previous_chunk
# 尝试在合适的边界分割
last_newline = overlap_text.rfind('\n')
if last_newline > 0:
overlap_text = overlap_text[last_newline + 1:]
elif '.' in overlap_text:
# 尝试在句号处分割
last_period = overlap_text.rfind('.')
if last_period > 0:
overlap_text = overlap_text[last_period + 1:].strip()
if overlap_text:
combined_chunk = overlap_text + '\n\n' + current_chunk
result.append(combined_chunk)
else:
result.append(current_chunk)
return result
def split_document_by_pages(input_file='document.txt', output_file='pagination.txt'):
"""
按页或文件分割document.txt文件将每页内容整理成一行写入pagination.txt
Args:
input_file (str): 输入文档文件路径
output_file (str): 输出序列化文件路径
"""
try:
with open(input_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
pages = []
current_page = []
for line in lines:
line = line.strip()
# 检查是否是页分隔符(支持 # Page 和 # File 格式)
if re.match(r'^#\s*(Page|File)', line, re.IGNORECASE):
# 如果当前页有内容,保存当前页
if current_page:
# 将当前页内容合并成一行
page_content = ' '.join(current_page).strip()
if page_content: # 只保存非空页面
pages.append(page_content)
current_page = []
continue
# 如果不是页分隔符且有内容,添加到当前页
if line:
current_page.append(line)
# 处理最后一页
if current_page:
page_content = ' '.join(current_page).strip()
if page_content:
pages.append(page_content)
print(f"总共分割出 {len(pages)}")
# 写入序列化文件
with open(output_file, 'w', encoding='utf-8') as f:
for i, page_content in enumerate(pages, 1):
f.write(f"{page_content}\n")
print(f"已将页面内容序列化到 {output_file}")
return pages
except FileNotFoundError:
print(f"错误:找不到文件 {input_file}")
return []
except Exception as e:
print(f"分割文档时出错:{e}")
return []
def test_chunking_strategies():
"""
测试不同的分块策略,比较效果
"""
# 测试文本
test_text = """
第一段:这是一个测试段落。包含了多个句子。这是为了测试分块功能。
第二段:这是另一个段落。它也包含了多个句子,用来验证分块策略的效果。我们需要确保分块的质量。
第三段这是第三个段落内容比较长包含了更多的信息。这个段落可能会触发分块逻辑因为它可能会超过最大chunk大小的限制。我们需要确保在这种情况下分块算法能够正确地处理并且在句子边界进行分割。
第四段:这是第四个段落。它相对较短。
第五段:这是最后一个段落。它用来测试分块策略的完整性和准确性。
"""
print("=" * 60)
print("分块策略测试")
print("=" * 60)
# 测试1: 段落级分块小chunk
print("\n1. 段落级分块 - 小chunk (max_size=200):")
chunks_small = paragraph_chunking(test_text, max_chunk_size=200, overlap=50)
for i, chunk in enumerate(chunks_small):
print(f"Chunk {i+1} (长度: {len(chunk)}): {chunk[:50]}...")
# 测试2: 段落级分块大chunk
print("\n2. 段落级分块 - 大chunk (max_size=500):")
chunks_large = paragraph_chunking(test_text, max_chunk_size=500, overlap=100)
for i, chunk in enumerate(chunks_large):
print(f"Chunk {i+1} (长度: {len(chunk)}): {chunk[:50]}...")
# 测试3: 段落级分块(无重叠)
print("\n3. 段落级分块 - 无重叠:")
chunks_no_overlap = paragraph_chunking(test_text, max_chunk_size=300, overlap=0)
for i, chunk in enumerate(chunks_no_overlap):
print(f"Chunk {i+1} (长度: {len(chunk)}): {chunk[:50]}...")
print(f"\n测试总结:")
print(f"- 小chunk策略: {len(chunks_small)} 个chunks")
print(f"- 大chunk策略: {len(chunks_large)} 个chunks")
print(f"- 无重叠策略: {len(chunks_no_overlap)} 个chunks")
def demo_usage():
"""
演示如何使用新的分块功能
"""
print("=" * 60)
print("使用示例")
print("=" * 60)
print("\n1. 使用传统的按行分块:")
print("embed_document('document.txt', 'line_embeddings.pkl', chunking_strategy='line')")
print("\n2. 使用段落级分块(默认参数):")
print("embed_document('document.txt', 'paragraph_embeddings.pkl', chunking_strategy='paragraph')")
print("\n3. 使用自定义参数的段落级分块:")
print("embed_document('document.txt', 'custom_embeddings.pkl',")
print(" chunking_strategy='paragraph',")
print(" max_chunk_size=1500,")
print(" overlap=200,")
print(" min_chunk_size=300)")
print("\n4. 进行语义搜索:")
print("semantic_search('查询内容', 'paragraph_embeddings.pkl', top_k=5)")
# 如果直接运行此文件,执行测试
if __name__ == "__main__":
#test_chunking_strategies()
#demo_usage()
# 使用新的段落级分块示例:
# 可以指定本地模型路径,避免从 HuggingFace 下载
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
embed_document("./projects/test/dataset/all_hp_product_spec_book2506/document.txt",
"./projects/test/dataset/all_hp_product_spec_book2506/smart_embeddings.pkl",
chunking_strategy='smart', # 使用智能分块策略
model_path=local_model_path, # 使用本地模型
max_chunk_size=800, # 较小的chunk大小
overlap=100)
# 其他示例调用(注释掉的):
# split_document_by_pages("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt")
# embed_document("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt") # 取消注释来运行