765 lines
26 KiB
Python
765 lines
26 KiB
Python
import pickle
|
||
import re
|
||
import numpy as np
|
||
from sentence_transformers import SentenceTransformer, util
|
||
|
||
# 延迟加载模型
|
||
embedder = None
|
||
|
||
def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'):
|
||
"""获取模型实例(延迟加载)
|
||
|
||
Args:
|
||
model_name_or_path (str): 模型名称或本地路径
|
||
- 可以是 HuggingFace 模型名称
|
||
- 可以是本地模型路径
|
||
"""
|
||
global embedder
|
||
if embedder is None:
|
||
print("正在加载模型...")
|
||
print(f"模型路径: {model_name_or_path}")
|
||
|
||
# 检查是否是本地路径
|
||
import os
|
||
if os.path.exists(model_name_or_path):
|
||
print("使用本地模型")
|
||
embedder = SentenceTransformer(model_name_or_path, device='cpu')
|
||
else:
|
||
print("使用 HuggingFace 模型")
|
||
embedder = SentenceTransformer(model_name_or_path, device='cpu')
|
||
|
||
print("模型加载完成")
|
||
return embedder
|
||
|
||
def clean_text(text):
|
||
"""
|
||
清理文本,去除特殊字符和无意义字符
|
||
|
||
Args:
|
||
text (str): 原始文本
|
||
|
||
Returns:
|
||
str: 清理后的文本
|
||
"""
|
||
# 去除HTML标签
|
||
text = re.sub(r'<[^>]+>', '', text)
|
||
|
||
# 去除多余的空白字符
|
||
text = re.sub(r'\s+', ' ', text)
|
||
|
||
# 去除控制字符和非打印字符,但保留Unicode文字字符
|
||
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', text)
|
||
|
||
# 去除首尾空白
|
||
text = text.strip()
|
||
|
||
return text
|
||
|
||
def is_meaningful_line(text):
|
||
"""
|
||
判断一行文本是否有意义
|
||
|
||
Args:
|
||
text (str): 文本行
|
||
|
||
Returns:
|
||
bool: 是否有意义
|
||
"""
|
||
if not text or len(text.strip()) < 5:
|
||
return False
|
||
|
||
# 过滤纯数字行
|
||
if text.strip().isdigit():
|
||
return False
|
||
|
||
# 过滤只有符号的行
|
||
if re.match(r'^[^\w\u4e00-\u9fa5]+$', text):
|
||
return False
|
||
|
||
# 过滤常见的无意义行
|
||
meaningless_patterns = [
|
||
r'^[-=_]{3,}$', # 分割线
|
||
r'^第\d+页$', # 页码
|
||
r'^\d+\.\s*$', # 只有编号
|
||
r'^[a-zA-Z]\.\s*$', # 只有一个字母编号
|
||
]
|
||
|
||
for pattern in meaningless_patterns:
|
||
if re.match(pattern, text.strip()):
|
||
return False
|
||
|
||
return True
|
||
|
||
def embed_document(input_file='document.txt', output_file='document_embeddings.pkl',
|
||
chunking_strategy='line', model_path=None, **chunking_params):
|
||
"""
|
||
读取文档文件,使用指定分块策略进行embedding,保存为pickle文件
|
||
|
||
Args:
|
||
input_file (str): 输入文档文件路径
|
||
output_file (str): 输出pickle文件路径
|
||
chunking_strategy (str): 分块策略,可选 'line', 'paragraph'
|
||
model_path (str): 模型路径,可以是本地路径或HuggingFace模型名称
|
||
**chunking_params: 分块参数
|
||
- 对于 'line' 策略:无额外参数
|
||
- 对于 'paragraph' 策略:
|
||
- max_chunk_size: 最大chunk大小(默认1000)
|
||
- overlap: 重叠大小(默认100)
|
||
- min_chunk_size: 最小chunk大小(默认200)
|
||
- separator: 段落分隔符(默认'\n')
|
||
"""
|
||
try:
|
||
with open(input_file, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
chunks = []
|
||
|
||
if chunking_strategy == 'line':
|
||
# 原有的按行处理逻辑
|
||
lines = content.split('\n')
|
||
original_count = len(lines)
|
||
|
||
for line in lines:
|
||
# 清理文本
|
||
cleaned_text = clean_text(line)
|
||
|
||
# 检查是否有意义
|
||
if is_meaningful_line(cleaned_text):
|
||
chunks.append(cleaned_text)
|
||
|
||
print(f"使用按行分块策略")
|
||
print(f"原始行数: {original_count}")
|
||
print(f"清理后有效句子数: {len(chunks)}")
|
||
print(f"过滤比例: {((original_count - len(chunks)) / original_count * 100):.1f}%")
|
||
|
||
elif chunking_strategy == 'paragraph':
|
||
# 新的段落级分块策略
|
||
# 设置默认参数
|
||
params = {
|
||
'max_chunk_size': 1000,
|
||
'overlap': 100,
|
||
'min_chunk_size': 200,
|
||
'separator': '\n'
|
||
}
|
||
params.update(chunking_params)
|
||
|
||
# 先清理整个文档的空白字符
|
||
cleaned_content = clean_text(content)
|
||
|
||
# 使用段落分块
|
||
chunks = paragraph_chunking(cleaned_content, **params)
|
||
|
||
print(f"使用段落级分块策略")
|
||
print(f"文档总长度: {len(content)} 字符")
|
||
print(f"分块数量: {len(chunks)}")
|
||
if chunks:
|
||
print(f"平均chunk大小: {sum(len(chunk) for chunk in chunks) / len(chunks):.1f} 字符")
|
||
print(f"最大chunk大小: {max(len(chunk) for chunk in chunks)} 字符")
|
||
print(f"最小chunk大小: {min(len(chunk) for chunk in chunks)} 字符")
|
||
|
||
elif chunking_strategy == 'smart':
|
||
# 智能分块策略,自动检测文档格式
|
||
params = {
|
||
'max_chunk_size': 1000,
|
||
'overlap': 100,
|
||
'min_chunk_size': 200
|
||
}
|
||
params.update(chunking_params)
|
||
|
||
# 使用智能分块
|
||
chunks = smart_chunking(content, **params)
|
||
|
||
print(f"使用智能分块策略")
|
||
print(f"文档总长度: {len(content)} 字符")
|
||
print(f"分块数量: {len(chunks)}")
|
||
if chunks:
|
||
print(f"平均chunk大小: {sum(len(chunk) for chunk in chunks) / len(chunks):.1f} 字符")
|
||
print(f"最大chunk大小: {max(len(chunk) for chunk in chunks)} 字符")
|
||
print(f"最小chunk大小: {min(len(chunk) for chunk in chunks)} 字符")
|
||
|
||
else:
|
||
raise ValueError(f"不支持的分块策略: {chunking_strategy}")
|
||
|
||
if not chunks:
|
||
print("警告:没有找到有效的内容块!")
|
||
return None
|
||
|
||
print(f"正在处理 {len(chunks)} 个内容块...")
|
||
|
||
# 设置默认模型路径
|
||
if model_path is None:
|
||
model_path = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
|
||
|
||
model = get_model(model_path)
|
||
chunk_embeddings = model.encode(chunks, convert_to_tensor=True)
|
||
|
||
embedding_data = {
|
||
'chunks': chunks,
|
||
'embeddings': chunk_embeddings,
|
||
'chunking_strategy': chunking_strategy,
|
||
'chunking_params': chunking_params,
|
||
'model_path': model_path
|
||
}
|
||
|
||
with open(output_file, 'wb') as f:
|
||
pickle.dump(embedding_data, f)
|
||
|
||
print(f"已保存嵌入向量到 {output_file}")
|
||
return embedding_data
|
||
|
||
except FileNotFoundError:
|
||
print(f"错误:找不到文件 {input_file}")
|
||
return None
|
||
except Exception as e:
|
||
print(f"处理文档时出错:{e}")
|
||
return None
|
||
|
||
def semantic_search(user_query, embeddings_file='document_embeddings.pkl', top_k=20):
|
||
"""
|
||
输入用户查询,进行语义匹配,返回top_k个最相关的内容块
|
||
|
||
Args:
|
||
user_query (str): 用户查询
|
||
embeddings_file (str): 嵌入向量文件路径
|
||
top_k (int): 返回的结果数量
|
||
|
||
Returns:
|
||
list: 包含(内容块, 相似度分数)的列表
|
||
"""
|
||
try:
|
||
with open(embeddings_file, 'rb') as f:
|
||
embedding_data = pickle.load(f)
|
||
|
||
# 兼容新旧数据结构
|
||
if 'chunks' in embedding_data:
|
||
# 新的数据结构(使用chunks)
|
||
chunks = embedding_data['chunks']
|
||
chunk_embeddings = embedding_data['embeddings']
|
||
chunking_strategy = embedding_data.get('chunking_strategy', 'unknown')
|
||
content_type = "内容块"
|
||
else:
|
||
# 旧的数据结构(使用sentences)
|
||
chunks = embedding_data['sentences']
|
||
chunk_embeddings = embedding_data['embeddings']
|
||
chunking_strategy = 'line'
|
||
content_type = "句子"
|
||
|
||
# 从embedding_data中获取模型路径(如果有的话)
|
||
model_path = embedding_data.get('model_path', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
|
||
model = get_model(model_path)
|
||
query_embedding = model.encode(user_query, convert_to_tensor=True)
|
||
|
||
cos_scores = util.cos_sim(query_embedding, chunk_embeddings)[0]
|
||
|
||
top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k]
|
||
|
||
results = []
|
||
print(f"\n与查询最相关的 {top_k} 个{content_type} (分块策略: {chunking_strategy}):")
|
||
for i, idx in enumerate(top_results):
|
||
chunk = chunks[idx]
|
||
score = cos_scores[idx].item()
|
||
results.append((chunk, score))
|
||
# 显示内容预览(如果内容太长)
|
||
preview = chunk[:100] + "..." if len(chunk) > 100 else chunk
|
||
preview = preview.replace('\n', ' ') # 替换换行符以便显示
|
||
print(f"{i+1}. [{score:.4f}] {preview}")
|
||
|
||
return results
|
||
|
||
except FileNotFoundError:
|
||
print(f"错误:找不到嵌入文件 {embeddings_file}")
|
||
print("请先运行 embed_document() 函数生成嵌入文件")
|
||
return []
|
||
except Exception as e:
|
||
print(f"搜索时出错:{e}")
|
||
return []
|
||
|
||
|
||
def paragraph_chunking(text, max_chunk_size=1000, overlap=100, min_chunk_size=200, separator='\n\n'):
|
||
"""
|
||
段落级智能分块函数 - 使用固定chunk大小分块,不按页面分割
|
||
|
||
Args:
|
||
text (str): 输入文本
|
||
max_chunk_size (int): 最大chunk大小(字符数)
|
||
overlap (int): 重叠部分大小(字符数)
|
||
min_chunk_size (int): 最小chunk大小(字符数)
|
||
separator (str): 段落分隔符
|
||
|
||
Returns:
|
||
list: 分块后的文本列表
|
||
"""
|
||
if not text or not text.strip():
|
||
return []
|
||
|
||
# 直接使用固定长度分块策略,不考虑页面标记
|
||
return _fixed_length_chunking(text, max_chunk_size, overlap, min_chunk_size)
|
||
|
||
|
||
def _split_long_content(content, max_size, min_size, separator):
|
||
"""
|
||
拆分过长的内容
|
||
|
||
Args:
|
||
content (str): 要拆分的内容
|
||
max_size (int): 最大大小
|
||
min_size (int): 最小大小
|
||
separator (str): 分隔符
|
||
|
||
Returns:
|
||
list: 拆分后的块列表
|
||
"""
|
||
if len(content) <= max_size:
|
||
return [content]
|
||
|
||
# 尝试按段落拆分
|
||
paragraphs = content.split(separator)
|
||
if len(paragraphs) > 1:
|
||
chunks = []
|
||
current_chunk = ""
|
||
|
||
for para in paragraphs:
|
||
if not current_chunk:
|
||
current_chunk = para
|
||
elif len(current_chunk + separator + para) <= max_size:
|
||
current_chunk += separator + para
|
||
else:
|
||
if current_chunk:
|
||
chunks.append(current_chunk)
|
||
current_chunk = para
|
||
|
||
if current_chunk:
|
||
chunks.append(current_chunk)
|
||
|
||
return chunks
|
||
|
||
# 如果不能按段落拆分,按句子拆分
|
||
sentences = _split_into_sentences(content)
|
||
chunks = []
|
||
current_chunk = ""
|
||
|
||
for sentence in sentences:
|
||
if not current_chunk:
|
||
current_chunk = sentence
|
||
elif len(current_chunk + " " + sentence) <= max_size:
|
||
current_chunk += " " + sentence
|
||
else:
|
||
if current_chunk:
|
||
chunks.append(current_chunk)
|
||
current_chunk = sentence
|
||
|
||
if current_chunk:
|
||
chunks.append(current_chunk)
|
||
|
||
return chunks
|
||
|
||
|
||
def _split_into_sentences(text):
|
||
"""
|
||
将文本拆分为句子
|
||
|
||
Args:
|
||
text (str): 输入文本
|
||
|
||
Returns:
|
||
list: 句子列表
|
||
"""
|
||
# 简单的句子分割(可以根据需要改进)
|
||
import re
|
||
|
||
# 按句号、问号、感叹号分割,但保留数字中的点
|
||
sentence_endings = re.compile(r'(?<=[.!?])\s+(?=[\dA-Z\u4e00-\u9fa5])')
|
||
sentences = sentence_endings.split(text.strip())
|
||
|
||
return [s.strip() for s in sentences if s.strip()]
|
||
|
||
|
||
def _create_overlap_chunk(previous_chunk, new_paragraph, overlap_size):
|
||
"""
|
||
创建带有重叠内容的新chunk
|
||
|
||
Args:
|
||
previous_chunk (str): 前一个chunk
|
||
new_paragraph (str): 新段落
|
||
overlap_size (int): 重叠大小
|
||
|
||
Returns:
|
||
str: 带重叠的新chunk
|
||
"""
|
||
if overlap_size <= 0:
|
||
return new_paragraph
|
||
|
||
# 从前一个chunk的末尾获取重叠内容
|
||
overlap_text = previous_chunk[-overlap_size:] if len(previous_chunk) > overlap_size else previous_chunk
|
||
|
||
# 尝试在句子边界处分割重叠内容
|
||
sentences = _split_into_sentences(overlap_text)
|
||
if len(sentences) > 1:
|
||
# 去掉可能不完整的第一个句子
|
||
overlap_text = " ".join(sentences[1:])
|
||
elif len(overlap_text) > overlap_size * 0.5:
|
||
# 如果只有一个句子且长度合适,保留它
|
||
pass
|
||
else:
|
||
# 重叠内容太少,不使用重叠
|
||
return new_paragraph
|
||
|
||
return overlap_text + "\n\n" + new_paragraph
|
||
|
||
|
||
def _add_overlap_to_chunk(previous_chunk, current_chunk, overlap_size):
|
||
"""
|
||
为当前chunk添加与前一个chunk的重叠
|
||
|
||
Args:
|
||
previous_chunk (str): 前一个chunk
|
||
current_chunk (str): 当前chunk
|
||
overlap_size (int): 重叠大小
|
||
|
||
Returns:
|
||
str: 带重叠的chunk
|
||
"""
|
||
if overlap_size <= 0:
|
||
return current_chunk
|
||
|
||
# 从前一个chunk的末尾获取重叠内容
|
||
overlap_text = previous_chunk[-overlap_size:] if len(previous_chunk) > overlap_size else previous_chunk
|
||
|
||
# 尝试在句子边界处分割
|
||
sentences = _split_into_sentences(overlap_text)
|
||
if len(sentences) > 1:
|
||
overlap_text = " ".join(sentences[1:])
|
||
|
||
return overlap_text + "\n\n" + current_chunk
|
||
|
||
|
||
def smart_chunking(text, max_chunk_size=1000, overlap=100, min_chunk_size=200):
|
||
"""
|
||
智能分块函数,自动检测文档格式并选择最佳分块策略
|
||
|
||
Args:
|
||
text (str): 输入文本
|
||
max_chunk_size (int): 最大chunk大小(字符数)
|
||
overlap (int): 重叠部分大小(字符数)
|
||
min_chunk_size (int): 最小chunk大小(字符数)
|
||
|
||
Returns:
|
||
list: 分块后的文本列表
|
||
"""
|
||
if not text or not text.strip():
|
||
return []
|
||
|
||
# 检测文档类型(支持 # Page 和 # File 格式)
|
||
has_page_markers = '# Page' in text or '# File' in text
|
||
has_paragraph_breaks = '\n\n' in text
|
||
has_line_breaks = '\n' in text
|
||
|
||
# 选择合适的分隔符和策略
|
||
if has_page_markers:
|
||
# 使用页面分隔符
|
||
return _page_based_chunking(text, max_chunk_size, overlap, min_chunk_size)
|
||
elif has_paragraph_breaks:
|
||
# 使用段落分隔符
|
||
return paragraph_chunking(text, max_chunk_size, overlap, min_chunk_size, '\n\n')
|
||
elif has_line_breaks:
|
||
# 使用行分隔符
|
||
return _line_based_chunking(text, max_chunk_size, overlap, min_chunk_size)
|
||
else:
|
||
# 按固定长度分块
|
||
return _fixed_length_chunking(text, max_chunk_size, overlap, min_chunk_size)
|
||
|
||
|
||
def _page_based_chunking(text, max_chunk_size, overlap, min_chunk_size):
|
||
"""基于页面的分块策略"""
|
||
import re
|
||
|
||
# 使用正则表达式分割页面(支持 # Page 和 # File 格式)
|
||
page_pattern = r'#\s*(Page\s+\d+|File\s+[^\n]+)'
|
||
pages = re.split(page_pattern, text)
|
||
|
||
# 清理和过滤页面内容
|
||
cleaned_pages = []
|
||
for page in pages:
|
||
page = page.strip()
|
||
if page and len(page) > min_chunk_size * 0.3: # 过滤太小的页面
|
||
cleaned_pages.append(page)
|
||
|
||
if not cleaned_pages:
|
||
return []
|
||
|
||
# 如果页面内容过大,需要进一步分割
|
||
chunks = []
|
||
for page in cleaned_pages:
|
||
if len(page) <= max_chunk_size:
|
||
chunks.append(page)
|
||
else:
|
||
# 页面过大,需要分割
|
||
sub_chunks = _split_long_content(page, max_chunk_size, min_chunk_size, '\n')
|
||
chunks.extend(sub_chunks)
|
||
|
||
# 添加重叠
|
||
if overlap > 0 and len(chunks) > 1:
|
||
chunks = _add_overlaps_to_chunks(chunks, overlap)
|
||
|
||
return chunks
|
||
|
||
|
||
def _line_based_chunking(text, max_chunk_size, overlap, min_chunk_size):
|
||
"""基于行的分块策略"""
|
||
lines = text.split('\n')
|
||
chunks = []
|
||
current_chunk = ""
|
||
|
||
for line in lines:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
|
||
if not current_chunk:
|
||
current_chunk = line
|
||
elif len(current_chunk + '\n' + line) <= max_chunk_size:
|
||
current_chunk += '\n' + line
|
||
else:
|
||
if len(current_chunk) >= min_chunk_size:
|
||
chunks.append(current_chunk)
|
||
current_chunk = _create_overlap_for_line(current_chunk, line, overlap)
|
||
else:
|
||
# 当前行太长,需要分割
|
||
split_chunks = _split_long_content(current_chunk + '\n' + line, max_chunk_size, min_chunk_size, '\n')
|
||
if chunks and split_chunks:
|
||
split_chunks[0] = _add_overlap_to_chunk(chunks[-1], split_chunks[0], overlap)
|
||
chunks.extend(split_chunks[:-1])
|
||
current_chunk = split_chunks[-1] if split_chunks else ""
|
||
|
||
if current_chunk and len(current_chunk) >= min_chunk_size:
|
||
chunks.append(current_chunk)
|
||
elif current_chunk and chunks:
|
||
chunks[-1] += '\n' + current_chunk
|
||
|
||
return chunks
|
||
|
||
|
||
def _fixed_length_chunking(text, max_chunk_size, overlap, min_chunk_size):
|
||
"""固定长度分块策略"""
|
||
chunks = []
|
||
start = 0
|
||
|
||
while start < len(text):
|
||
end = start + max_chunk_size
|
||
|
||
if end >= len(text):
|
||
chunks.append(text[start:])
|
||
break
|
||
|
||
# 尝试在句号、问号或感叹号处分割
|
||
split_pos = end
|
||
for i in range(end, max(start, end - 100), -1):
|
||
if text[i] in '.!?。!?':
|
||
split_pos = i + 1
|
||
break
|
||
|
||
chunk = text[start:split_pos]
|
||
if len(chunk) >= min_chunk_size:
|
||
chunks.append(chunk)
|
||
start = split_pos - overlap if overlap > 0 else split_pos
|
||
else:
|
||
start += max_chunk_size // 2
|
||
|
||
return chunks
|
||
|
||
|
||
def _create_overlap_for_line(previous_chunk, new_line, overlap_size):
|
||
"""为行分块创建重叠"""
|
||
if overlap_size <= 0:
|
||
return new_line
|
||
|
||
# 从前一个chunk的末尾获取重叠内容
|
||
overlap_text = previous_chunk[-overlap_size:] if len(previous_chunk) > overlap_size else previous_chunk
|
||
|
||
# 尝试在合适的边界分割
|
||
last_newline = overlap_text.rfind('\n')
|
||
if last_newline > 0:
|
||
overlap_text = overlap_text[last_newline + 1:]
|
||
|
||
return overlap_text + '\n' + new_line
|
||
|
||
|
||
def _add_overlaps_to_chunks(chunks, overlap_size):
|
||
"""为chunks添加重叠"""
|
||
if overlap_size <= 0 or len(chunks) <= 1:
|
||
return chunks
|
||
|
||
result = [chunks[0]]
|
||
|
||
for i in range(1, len(chunks)):
|
||
previous_chunk = chunks[i-1]
|
||
current_chunk = chunks[i]
|
||
|
||
# 添加重叠
|
||
overlap_text = previous_chunk[-overlap_size:] if len(previous_chunk) > overlap_size else previous_chunk
|
||
|
||
# 尝试在合适的边界分割
|
||
last_newline = overlap_text.rfind('\n')
|
||
if last_newline > 0:
|
||
overlap_text = overlap_text[last_newline + 1:]
|
||
elif '.' in overlap_text:
|
||
# 尝试在句号处分割
|
||
last_period = overlap_text.rfind('.')
|
||
if last_period > 0:
|
||
overlap_text = overlap_text[last_period + 1:].strip()
|
||
|
||
if overlap_text:
|
||
combined_chunk = overlap_text + '\n\n' + current_chunk
|
||
result.append(combined_chunk)
|
||
else:
|
||
result.append(current_chunk)
|
||
|
||
return result
|
||
|
||
|
||
def split_document_by_pages(input_file='document.txt', output_file='pagination.txt'):
|
||
"""
|
||
按页或文件分割document.txt文件,将每页内容整理成一行写入pagination.txt
|
||
|
||
Args:
|
||
input_file (str): 输入文档文件路径
|
||
output_file (str): 输出序列化文件路径
|
||
"""
|
||
try:
|
||
with open(input_file, 'r', encoding='utf-8') as f:
|
||
lines = f.readlines()
|
||
|
||
pages = []
|
||
current_page = []
|
||
|
||
for line in lines:
|
||
line = line.strip()
|
||
|
||
# 检查是否是页分隔符(支持 # Page 和 # File 格式)
|
||
if re.match(r'^#\s*(Page|File)', line, re.IGNORECASE):
|
||
# 如果当前页有内容,保存当前页
|
||
if current_page:
|
||
# 将当前页内容合并成一行
|
||
page_content = ' '.join(current_page).strip()
|
||
if page_content: # 只保存非空页面
|
||
pages.append(page_content)
|
||
current_page = []
|
||
continue
|
||
|
||
# 如果不是页分隔符且有内容,添加到当前页
|
||
if line:
|
||
current_page.append(line)
|
||
|
||
# 处理最后一页
|
||
if current_page:
|
||
page_content = ' '.join(current_page).strip()
|
||
if page_content:
|
||
pages.append(page_content)
|
||
|
||
print(f"总共分割出 {len(pages)} 页")
|
||
|
||
# 写入序列化文件
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
for i, page_content in enumerate(pages, 1):
|
||
f.write(f"{page_content}\n")
|
||
|
||
print(f"已将页面内容序列化到 {output_file}")
|
||
return pages
|
||
|
||
except FileNotFoundError:
|
||
print(f"错误:找不到文件 {input_file}")
|
||
return []
|
||
except Exception as e:
|
||
print(f"分割文档时出错:{e}")
|
||
return []
|
||
|
||
def test_chunking_strategies():
|
||
"""
|
||
测试不同的分块策略,比较效果
|
||
"""
|
||
# 测试文本
|
||
test_text = """
|
||
第一段:这是一个测试段落。包含了多个句子。这是为了测试分块功能。
|
||
|
||
第二段:这是另一个段落。它也包含了多个句子,用来验证分块策略的效果。我们需要确保分块的质量。
|
||
|
||
第三段:这是第三个段落,内容比较长,包含了更多的信息。这个段落可能会触发分块逻辑,因为它可能会超过最大chunk大小的限制。我们需要确保在这种情况下,分块算法能够正确地处理,并且在句子边界进行分割。
|
||
|
||
第四段:这是第四个段落。它相对较短。
|
||
|
||
第五段:这是最后一个段落。它用来测试分块策略的完整性和准确性。
|
||
"""
|
||
|
||
print("=" * 60)
|
||
print("分块策略测试")
|
||
print("=" * 60)
|
||
|
||
# 测试1: 段落级分块(小chunk)
|
||
print("\n1. 段落级分块 - 小chunk (max_size=200):")
|
||
chunks_small = paragraph_chunking(test_text, max_chunk_size=200, overlap=50)
|
||
for i, chunk in enumerate(chunks_small):
|
||
print(f"Chunk {i+1} (长度: {len(chunk)}): {chunk[:50]}...")
|
||
|
||
# 测试2: 段落级分块(大chunk)
|
||
print("\n2. 段落级分块 - 大chunk (max_size=500):")
|
||
chunks_large = paragraph_chunking(test_text, max_chunk_size=500, overlap=100)
|
||
for i, chunk in enumerate(chunks_large):
|
||
print(f"Chunk {i+1} (长度: {len(chunk)}): {chunk[:50]}...")
|
||
|
||
# 测试3: 段落级分块(无重叠)
|
||
print("\n3. 段落级分块 - 无重叠:")
|
||
chunks_no_overlap = paragraph_chunking(test_text, max_chunk_size=300, overlap=0)
|
||
for i, chunk in enumerate(chunks_no_overlap):
|
||
print(f"Chunk {i+1} (长度: {len(chunk)}): {chunk[:50]}...")
|
||
|
||
print(f"\n测试总结:")
|
||
print(f"- 小chunk策略: {len(chunks_small)} 个chunks")
|
||
print(f"- 大chunk策略: {len(chunks_large)} 个chunks")
|
||
print(f"- 无重叠策略: {len(chunks_no_overlap)} 个chunks")
|
||
|
||
|
||
def demo_usage():
|
||
"""
|
||
演示如何使用新的分块功能
|
||
"""
|
||
print("=" * 60)
|
||
print("使用示例")
|
||
print("=" * 60)
|
||
|
||
print("\n1. 使用传统的按行分块:")
|
||
print("embed_document('document.txt', 'line_embeddings.pkl', chunking_strategy='line')")
|
||
|
||
print("\n2. 使用段落级分块(默认参数):")
|
||
print("embed_document('document.txt', 'paragraph_embeddings.pkl', chunking_strategy='paragraph')")
|
||
|
||
print("\n3. 使用自定义参数的段落级分块:")
|
||
print("embed_document('document.txt', 'custom_embeddings.pkl',")
|
||
print(" chunking_strategy='paragraph',")
|
||
print(" max_chunk_size=1500,")
|
||
print(" overlap=200,")
|
||
print(" min_chunk_size=300)")
|
||
|
||
print("\n4. 进行语义搜索:")
|
||
print("semantic_search('查询内容', 'paragraph_embeddings.pkl', top_k=5)")
|
||
|
||
|
||
# 如果直接运行此文件,执行测试
|
||
if __name__ == "__main__":
|
||
#test_chunking_strategies()
|
||
#demo_usage()
|
||
|
||
# 使用新的段落级分块示例:
|
||
# 可以指定本地模型路径,避免从 HuggingFace 下载
|
||
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
|
||
|
||
embed_document("./projects/test/dataset/all_hp_product_spec_book2506/document.txt",
|
||
"./projects/test/dataset/all_hp_product_spec_book2506/smart_embeddings.pkl",
|
||
chunking_strategy='smart', # 使用智能分块策略
|
||
model_path=local_model_path, # 使用本地模型
|
||
max_chunk_size=800, # 较小的chunk大小
|
||
overlap=100)
|
||
|
||
# 其他示例调用(注释掉的):
|
||
# split_document_by_pages("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt")
|
||
# embed_document("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt") # 取消注释来运行
|