961 lines
32 KiB
Python
961 lines
32 KiB
Python
import pickle
|
||
import re
|
||
import numpy as np
|
||
import os
|
||
from typing import Optional, List, Dict, Any
|
||
import requests
|
||
import asyncio
|
||
import hashlib
|
||
import json
|
||
import logging
|
||
|
||
# Configure logger
|
||
logger = logging.getLogger('app')
|
||
|
||
def encode_texts_via_api(texts, batch_size=32):
|
||
"""通过 API 接口编码文本"""
|
||
if not texts:
|
||
return np.array([])
|
||
|
||
try:
|
||
# FastAPI 服务地址
|
||
fastapi_url = os.getenv('FASTAPI_URL', 'http://127.0.0.1:8001')
|
||
api_endpoint = f"{fastapi_url}/api/v1/embedding/encode"
|
||
|
||
# 调用编码接口
|
||
request_data = {
|
||
"texts": texts,
|
||
"batch_size": batch_size
|
||
}
|
||
|
||
response = requests.post(
|
||
api_endpoint,
|
||
json=request_data,
|
||
timeout=60 # 增加超时时间
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
result_data = response.json()
|
||
|
||
if result_data.get("success"):
|
||
embeddings_list = result_data.get("embeddings", [])
|
||
logger.info(f"API编码成功,处理了 {len(texts)} 个文本,embedding维度: {len(embeddings_list[0]) if embeddings_list else 0}")
|
||
return np.array(embeddings_list)
|
||
else:
|
||
error_msg = result_data.get('error', '未知错误')
|
||
logger.error(f"API编码失败: {error_msg}")
|
||
raise Exception(f"API编码失败: {error_msg}")
|
||
else:
|
||
logger.error(f"API请求失败: {response.status_code} - {response.text}")
|
||
raise Exception(f"API请求失败: {response.status_code}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"API编码异常: {e}")
|
||
raise
|
||
|
||
def clean_text(text):
|
||
"""
|
||
清理文本,去除特殊字符和无意义字符
|
||
|
||
Args:
|
||
text (str): 原始文本
|
||
|
||
Returns:
|
||
str: 清理后的文本
|
||
"""
|
||
# 去除HTML标签
|
||
text = re.sub(r'<[^>]+>', '', text)
|
||
|
||
# 去除多余的空白字符
|
||
text = re.sub(r'\s+', ' ', text)
|
||
|
||
# 去除控制字符和非打印字符,但保留Unicode文字字符
|
||
text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F]', '', text)
|
||
|
||
# 去除首尾空白
|
||
text = text.strip()
|
||
|
||
return text
|
||
|
||
def is_meaningful_line(text):
|
||
"""
|
||
判断一行文本是否有意义
|
||
|
||
Args:
|
||
text (str): 文本行
|
||
|
||
Returns:
|
||
bool: 是否有意义
|
||
"""
|
||
if not text or len(text.strip()) < 5:
|
||
return False
|
||
|
||
# 过滤纯数字行
|
||
if text.strip().isdigit():
|
||
return False
|
||
|
||
# 过滤只有符号的行
|
||
if re.match(r'^[^\w\u4e00-\u9fa5]+$', text):
|
||
return False
|
||
|
||
# 过滤常见的无意义行
|
||
meaningless_patterns = [
|
||
r'^[-=_]{3,}$', # 分割线
|
||
r'^第\d+页$', # 页码
|
||
r'^\d+\.\s*$', # 只有编号
|
||
r'^[a-zA-Z]\.\s*$', # 只有一个字母编号
|
||
]
|
||
|
||
for pattern in meaningless_patterns:
|
||
if re.match(pattern, text.strip()):
|
||
return False
|
||
|
||
return True
|
||
|
||
def embed_document(input_file='document.txt', output_file='embedding.pkl',
|
||
chunking_strategy='line', **chunking_params):
|
||
"""
|
||
读取文档文件,使用指定分块策略进行embedding,保存为pickle文件
|
||
|
||
Args:
|
||
input_file (str): 输入文档文件路径
|
||
output_file (str): 输出pickle文件路径
|
||
chunking_strategy (str): 分块策略,可选 'line', 'paragraph'
|
||
**chunking_params: 分块参数
|
||
- 对于 'line' 策略:无额外参数
|
||
- 对于 'paragraph' 策略:
|
||
- max_chunk_size: 最大chunk大小(默认1000)
|
||
- overlap: 重叠大小(默认100)
|
||
- min_chunk_size: 最小chunk大小(默认200)
|
||
- separator: 段落分隔符(默认'\n')
|
||
"""
|
||
try:
|
||
with open(input_file, 'r', encoding='utf-8') as f:
|
||
content = f.read()
|
||
|
||
chunks = []
|
||
|
||
if chunking_strategy == 'line':
|
||
# 原有的按行处理逻辑
|
||
lines = content.split('\n')
|
||
original_count = len(lines)
|
||
|
||
for line in lines:
|
||
# 清理文本
|
||
cleaned_text = clean_text(line)
|
||
|
||
# 检查是否有意义
|
||
if is_meaningful_line(cleaned_text):
|
||
chunks.append(cleaned_text)
|
||
|
||
logger.info(f"使用按行分块策略")
|
||
logger.info(f"原始行数: {original_count}")
|
||
logger.info(f"清理后有效句子数: {len(chunks)}")
|
||
logger.info(f"过滤比例: {((original_count - len(chunks)) / original_count * 100):.1f}%")
|
||
|
||
elif chunking_strategy == 'paragraph':
|
||
# 新的段落级分块策略
|
||
# 设置默认参数
|
||
params = {
|
||
'max_chunk_size': 1000,
|
||
'overlap': 100,
|
||
'min_chunk_size': 200,
|
||
'separator': '\n'
|
||
}
|
||
params.update(chunking_params)
|
||
|
||
# 先清理整个文档的空白字符
|
||
cleaned_content = clean_text(content)
|
||
|
||
# 使用段落分块
|
||
chunks = paragraph_chunking(cleaned_content, **params)
|
||
|
||
logger.info(f"使用段落级分块策略")
|
||
logger.info(f"文档总长度: {len(content)} 字符")
|
||
logger.info(f"分块数量: {len(chunks)}")
|
||
if chunks:
|
||
logger.debug(f"平均chunk大小: {sum(len(chunk) for chunk in chunks) / len(chunks):.1f} 字符")
|
||
logger.debug(f"最大chunk大小: {max(len(chunk) for chunk in chunks)} 字符")
|
||
logger.debug(f"最小chunk大小: {min(len(chunk) for chunk in chunks)} 字符")
|
||
|
||
elif chunking_strategy == 'smart':
|
||
# 智能分块策略,自动检测文档格式
|
||
params = {
|
||
'max_chunk_size': 1000,
|
||
'overlap': 100,
|
||
'min_chunk_size': 200
|
||
}
|
||
params.update(chunking_params)
|
||
|
||
# 使用智能分块
|
||
chunks = smart_chunking(content, **params)
|
||
|
||
logger.info(f"使用智能分块策略")
|
||
logger.info(f"文档总长度: {len(content)} 字符")
|
||
logger.info(f"分块数量: {len(chunks)}")
|
||
if chunks:
|
||
logger.debug(f"平均chunk大小: {sum(len(chunk) for chunk in chunks) / len(chunks):.1f} 字符")
|
||
logger.debug(f"最大chunk大小: {max(len(chunk) for chunk in chunks)} 字符")
|
||
logger.debug(f"最小chunk大小: {min(len(chunk) for chunk in chunks)} 字符")
|
||
|
||
else:
|
||
raise ValueError(f"不支持的分块策略: {chunking_strategy}")
|
||
|
||
if not chunks:
|
||
logger.warning("警告:没有找到有效的内容块!")
|
||
return None
|
||
|
||
logger.info(f"正在处理 {len(chunks)} 个内容块...")
|
||
|
||
# 使用API接口进行编码
|
||
logger.info("使用API接口进行编码...")
|
||
chunk_embeddings = encode_texts_via_api(chunks, batch_size=32)
|
||
|
||
embedding_data = {
|
||
'chunks': chunks,
|
||
'embeddings': chunk_embeddings,
|
||
'chunking_strategy': chunking_strategy,
|
||
'chunking_params': chunking_params,
|
||
'model_path': 'api_service'
|
||
}
|
||
|
||
with open(output_file, 'wb') as f:
|
||
pickle.dump(embedding_data, f)
|
||
|
||
logger.info(f"已保存嵌入向量到 {output_file}")
|
||
return embedding_data
|
||
|
||
except FileNotFoundError:
|
||
logger.error(f"错误:找不到文件 {input_file}")
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"处理文档时出错:{e}")
|
||
return None
|
||
|
||
|
||
|
||
def paragraph_chunking(text, max_chunk_size=1000, overlap=100, min_chunk_size=200, separator='\n\n'):
|
||
"""
|
||
段落级智能分块函数 - 使用固定chunk大小分块,不按页面分割
|
||
|
||
Args:
|
||
text (str): 输入文本
|
||
max_chunk_size (int): 最大chunk大小(字符数)
|
||
overlap (int): 重叠部分大小(字符数)
|
||
min_chunk_size (int): 最小chunk大小(字符数)
|
||
separator (str): 段落分隔符
|
||
|
||
Returns:
|
||
list: 分块后的文本列表
|
||
"""
|
||
if not text or not text.strip():
|
||
return []
|
||
|
||
# 直接使用固定长度分块策略,不考虑页面标记
|
||
return _fixed_length_chunking(text, max_chunk_size, overlap, min_chunk_size)
|
||
|
||
|
||
def _split_long_content(content, max_size, min_size, separator):
|
||
"""
|
||
拆分过长的内容
|
||
|
||
Args:
|
||
content (str): 要拆分的内容
|
||
max_size (int): 最大大小
|
||
min_size (int): 最小大小
|
||
separator (str): 分隔符
|
||
|
||
Returns:
|
||
list: 拆分后的块列表
|
||
"""
|
||
if len(content) <= max_size:
|
||
return [content]
|
||
|
||
# 尝试按段落拆分
|
||
paragraphs = content.split(separator)
|
||
if len(paragraphs) > 1:
|
||
chunks = []
|
||
current_chunk = ""
|
||
|
||
for para in paragraphs:
|
||
if not current_chunk:
|
||
current_chunk = para
|
||
elif len(current_chunk + separator + para) <= max_size:
|
||
current_chunk += separator + para
|
||
else:
|
||
if current_chunk:
|
||
chunks.append(current_chunk)
|
||
current_chunk = para
|
||
|
||
if current_chunk:
|
||
chunks.append(current_chunk)
|
||
|
||
return chunks
|
||
|
||
# 如果不能按段落拆分,按句子拆分
|
||
sentences = _split_into_sentences(content)
|
||
chunks = []
|
||
current_chunk = ""
|
||
|
||
for sentence in sentences:
|
||
if not current_chunk:
|
||
current_chunk = sentence
|
||
elif len(current_chunk + " " + sentence) <= max_size:
|
||
current_chunk += " " + sentence
|
||
else:
|
||
if current_chunk:
|
||
chunks.append(current_chunk)
|
||
current_chunk = sentence
|
||
|
||
if current_chunk:
|
||
chunks.append(current_chunk)
|
||
|
||
return chunks
|
||
|
||
|
||
def _split_into_sentences(text):
|
||
"""
|
||
将文本拆分为句子
|
||
|
||
Args:
|
||
text (str): 输入文本
|
||
|
||
Returns:
|
||
list: 句子列表
|
||
"""
|
||
# 简单的句子分割(可以根据需要改进)
|
||
import re
|
||
|
||
# 按句号、问号、感叹号分割,但保留数字中的点
|
||
sentence_endings = re.compile(r'(?<=[.!?])\s+(?=[\dA-Z\u4e00-\u9fa5])')
|
||
sentences = sentence_endings.split(text.strip())
|
||
|
||
return [s.strip() for s in sentences if s.strip()]
|
||
|
||
|
||
def _create_overlap_chunk(previous_chunk, new_paragraph, overlap_size):
|
||
"""
|
||
创建带有重叠内容的新chunk
|
||
|
||
Args:
|
||
previous_chunk (str): 前一个chunk
|
||
new_paragraph (str): 新段落
|
||
overlap_size (int): 重叠大小
|
||
|
||
Returns:
|
||
str: 带重叠的新chunk
|
||
"""
|
||
if overlap_size <= 0:
|
||
return new_paragraph
|
||
|
||
# 从前一个chunk的末尾获取重叠内容
|
||
overlap_text = previous_chunk[-overlap_size:] if len(previous_chunk) > overlap_size else previous_chunk
|
||
|
||
# 尝试在句子边界处分割重叠内容
|
||
sentences = _split_into_sentences(overlap_text)
|
||
if len(sentences) > 1:
|
||
# 去掉可能不完整的第一个句子
|
||
overlap_text = " ".join(sentences[1:])
|
||
elif len(overlap_text) > overlap_size * 0.5:
|
||
# 如果只有一个句子且长度合适,保留它
|
||
pass
|
||
else:
|
||
# 重叠内容太少,不使用重叠
|
||
return new_paragraph
|
||
|
||
return overlap_text + "\n\n" + new_paragraph
|
||
|
||
|
||
def _add_overlap_to_chunk(previous_chunk, current_chunk, overlap_size):
|
||
"""
|
||
为当前chunk添加与前一个chunk的重叠
|
||
|
||
Args:
|
||
previous_chunk (str): 前一个chunk
|
||
current_chunk (str): 当前chunk
|
||
overlap_size (int): 重叠大小
|
||
|
||
Returns:
|
||
str: 带重叠的chunk
|
||
"""
|
||
if overlap_size <= 0:
|
||
return current_chunk
|
||
|
||
# 从前一个chunk的末尾获取重叠内容
|
||
overlap_text = previous_chunk[-overlap_size:] if len(previous_chunk) > overlap_size else previous_chunk
|
||
|
||
# 尝试在句子边界处分割
|
||
sentences = _split_into_sentences(overlap_text)
|
||
if len(sentences) > 1:
|
||
overlap_text = " ".join(sentences[1:])
|
||
|
||
return overlap_text + "\n\n" + current_chunk
|
||
|
||
|
||
def smart_chunking(text, max_chunk_size=1000, overlap=100, min_chunk_size=200):
|
||
"""
|
||
智能分块函数,自动检测文档格式并选择最佳分块策略
|
||
|
||
Args:
|
||
text (str): 输入文本
|
||
max_chunk_size (int): 最大chunk大小(字符数)
|
||
overlap (int): 重叠部分大小(字符数)
|
||
min_chunk_size (int): 最小chunk大小(字符数)
|
||
|
||
Returns:
|
||
list: 分块后的文本列表
|
||
"""
|
||
if not text or not text.strip():
|
||
return []
|
||
|
||
# 检测文档类型(支持 # Page 和 # File 格式)
|
||
has_page_markers = '# Page' in text or '# File' in text
|
||
has_paragraph_breaks = '\n\n' in text
|
||
has_line_breaks = '\n' in text
|
||
|
||
# 选择合适的分隔符和策略
|
||
if has_page_markers:
|
||
# 使用页面分隔符
|
||
return _page_based_chunking(text, max_chunk_size, overlap, min_chunk_size)
|
||
elif has_paragraph_breaks:
|
||
# 使用段落分隔符
|
||
return paragraph_chunking(text, max_chunk_size, overlap, min_chunk_size, '\n\n')
|
||
elif has_line_breaks:
|
||
# 使用行分隔符
|
||
return _line_based_chunking(text, max_chunk_size, overlap, min_chunk_size)
|
||
else:
|
||
# 按固定长度分块
|
||
return _fixed_length_chunking(text, max_chunk_size, overlap, min_chunk_size)
|
||
|
||
|
||
def _page_based_chunking(text, max_chunk_size, overlap, min_chunk_size):
|
||
"""基于页面的分块策略"""
|
||
import re
|
||
|
||
# 使用正则表达式分割页面(支持 # Page 和 # File 格式)
|
||
page_pattern = r'#\s*(Page\s+\d+|File\s+[^\n]+)'
|
||
pages = re.split(page_pattern, text)
|
||
|
||
# 清理和过滤页面内容
|
||
cleaned_pages = []
|
||
for page in pages:
|
||
page = page.strip()
|
||
if page and len(page) > min_chunk_size * 0.3: # 过滤太小的页面
|
||
cleaned_pages.append(page)
|
||
|
||
if not cleaned_pages:
|
||
return []
|
||
|
||
# 如果页面内容过大,需要进一步分割
|
||
chunks = []
|
||
for page in cleaned_pages:
|
||
if len(page) <= max_chunk_size:
|
||
chunks.append(page)
|
||
else:
|
||
# 页面过大,需要分割
|
||
sub_chunks = _split_long_content(page, max_chunk_size, min_chunk_size, '\n')
|
||
chunks.extend(sub_chunks)
|
||
|
||
# 添加重叠
|
||
if overlap > 0 and len(chunks) > 1:
|
||
chunks = _add_overlaps_to_chunks(chunks, overlap)
|
||
|
||
return chunks
|
||
|
||
|
||
def _line_based_chunking(text, max_chunk_size, overlap, min_chunk_size):
|
||
"""基于行的分块策略"""
|
||
lines = text.split('\n')
|
||
chunks = []
|
||
current_chunk = ""
|
||
|
||
for line in lines:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
|
||
if not current_chunk:
|
||
current_chunk = line
|
||
elif len(current_chunk + '\n' + line) <= max_chunk_size:
|
||
current_chunk += '\n' + line
|
||
else:
|
||
if len(current_chunk) >= min_chunk_size:
|
||
chunks.append(current_chunk)
|
||
current_chunk = _create_overlap_for_line(current_chunk, line, overlap)
|
||
else:
|
||
# 当前行太长,需要分割
|
||
split_chunks = _split_long_content(current_chunk + '\n' + line, max_chunk_size, min_chunk_size, '\n')
|
||
if chunks and split_chunks:
|
||
split_chunks[0] = _add_overlap_to_chunk(chunks[-1], split_chunks[0], overlap)
|
||
chunks.extend(split_chunks[:-1])
|
||
current_chunk = split_chunks[-1] if split_chunks else ""
|
||
|
||
if current_chunk and len(current_chunk) >= min_chunk_size:
|
||
chunks.append(current_chunk)
|
||
elif current_chunk and chunks:
|
||
chunks[-1] += '\n' + current_chunk
|
||
|
||
return chunks
|
||
|
||
|
||
def _fixed_length_chunking(text, max_chunk_size, overlap, min_chunk_size):
|
||
"""固定长度分块策略"""
|
||
chunks = []
|
||
start = 0
|
||
|
||
while start < len(text):
|
||
end = start + max_chunk_size
|
||
|
||
if end >= len(text):
|
||
chunks.append(text[start:])
|
||
break
|
||
|
||
# 尝试在句号、问号或感叹号处分割
|
||
split_pos = end
|
||
for i in range(end, max(start, end - 100), -1):
|
||
if text[i] in '.!?。!?':
|
||
split_pos = i + 1
|
||
break
|
||
|
||
chunk = text[start:split_pos]
|
||
if len(chunk) >= min_chunk_size:
|
||
chunks.append(chunk)
|
||
start = split_pos - overlap if overlap > 0 else split_pos
|
||
else:
|
||
start += max_chunk_size // 2
|
||
|
||
return chunks
|
||
|
||
|
||
def _create_overlap_for_line(previous_chunk, new_line, overlap_size):
|
||
"""为行分块创建重叠"""
|
||
if overlap_size <= 0:
|
||
return new_line
|
||
|
||
# 从前一个chunk的末尾获取重叠内容
|
||
overlap_text = previous_chunk[-overlap_size:] if len(previous_chunk) > overlap_size else previous_chunk
|
||
|
||
# 尝试在合适的边界分割
|
||
last_newline = overlap_text.rfind('\n')
|
||
if last_newline > 0:
|
||
overlap_text = overlap_text[last_newline + 1:]
|
||
|
||
return overlap_text + '\n' + new_line
|
||
|
||
|
||
def _add_overlaps_to_chunks(chunks, overlap_size):
|
||
"""为chunks添加重叠"""
|
||
if overlap_size <= 0 or len(chunks) <= 1:
|
||
return chunks
|
||
|
||
result = [chunks[0]]
|
||
|
||
for i in range(1, len(chunks)):
|
||
previous_chunk = chunks[i-1]
|
||
current_chunk = chunks[i]
|
||
|
||
# 添加重叠
|
||
overlap_text = previous_chunk[-overlap_size:] if len(previous_chunk) > overlap_size else previous_chunk
|
||
|
||
# 尝试在合适的边界分割
|
||
last_newline = overlap_text.rfind('\n')
|
||
if last_newline > 0:
|
||
overlap_text = overlap_text[last_newline + 1:]
|
||
elif '.' in overlap_text:
|
||
# 尝试在句号处分割
|
||
last_period = overlap_text.rfind('.')
|
||
if last_period > 0:
|
||
overlap_text = overlap_text[last_period + 1:].strip()
|
||
|
||
if overlap_text:
|
||
combined_chunk = overlap_text + '\n\n' + current_chunk
|
||
result.append(combined_chunk)
|
||
else:
|
||
result.append(current_chunk)
|
||
|
||
return result
|
||
|
||
|
||
def split_document_by_pages(input_file='document.txt', output_file='pagination.txt'):
|
||
"""
|
||
按页或文件分割document.txt文件,将每页内容整理成一行写入pagination.txt
|
||
|
||
Args:
|
||
input_file (str): 输入文档文件路径
|
||
output_file (str): 输出序列化文件路径
|
||
"""
|
||
try:
|
||
with open(input_file, 'r', encoding='utf-8') as f:
|
||
lines = f.readlines()
|
||
|
||
pages = []
|
||
current_page = []
|
||
|
||
for line in lines:
|
||
line = line.strip()
|
||
|
||
# 检查是否是页分隔符(支持 # Page 和 # File 格式)
|
||
if re.match(r'^#\s*(Page|File)', line, re.IGNORECASE):
|
||
# 如果当前页有内容,保存当前页
|
||
if current_page:
|
||
# 将当前页内容合并成一行
|
||
page_content = ' '.join(current_page).strip()
|
||
if page_content: # 只保存非空页面
|
||
pages.append(page_content)
|
||
current_page = []
|
||
continue
|
||
|
||
# 如果不是页分隔符且有内容,添加到当前页
|
||
if line:
|
||
current_page.append(line)
|
||
|
||
# 处理最后一页
|
||
if current_page:
|
||
page_content = ' '.join(current_page).strip()
|
||
if page_content:
|
||
pages.append(page_content)
|
||
|
||
logger.info(f"总共分割出 {len(pages)} 页")
|
||
|
||
# 写入序列化文件
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
for i, page_content in enumerate(pages, 1):
|
||
f.write(f"{page_content}\n")
|
||
|
||
logger.info(f"已将页面内容序列化到 {output_file}")
|
||
return pages
|
||
|
||
except FileNotFoundError:
|
||
logger.error(f"错误:找不到文件 {input_file}")
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"分割文档时出错:{e}")
|
||
return []
|
||
|
||
def test_chunking_strategies():
|
||
"""
|
||
测试不同的分块策略,比较效果
|
||
"""
|
||
# 测试文本
|
||
test_text = """
|
||
第一段:这是一个测试段落。包含了多个句子。这是为了测试分块功能。
|
||
|
||
第二段:这是另一个段落。它也包含了多个句子,用来验证分块策略的效果。我们需要确保分块的质量。
|
||
|
||
第三段:这是第三个段落,内容比较长,包含了更多的信息。这个段落可能会触发分块逻辑,因为它可能会超过最大chunk大小的限制。我们需要确保在这种情况下,分块算法能够正确地处理,并且在句子边界进行分割。
|
||
|
||
第四段:这是第四个段落。它相对较短。
|
||
|
||
第五段:这是最后一个段落。它用来测试分块策略的完整性和准确性。
|
||
"""
|
||
|
||
logger.debug("=" * 60)
|
||
logger.debug("分块策略测试")
|
||
logger.debug("=" * 60)
|
||
|
||
# 测试1: 段落级分块(小chunk)
|
||
logger.debug("\n1. 段落级分块 - 小chunk (max_size=200):")
|
||
chunks_small = paragraph_chunking(test_text, max_chunk_size=200, overlap=50)
|
||
for i, chunk in enumerate(chunks_small):
|
||
logger.debug(f"Chunk {i+1} (长度: {len(chunk)}): {chunk[:50]}...")
|
||
|
||
# 测试2: 段落级分块(大chunk)
|
||
logger.debug("\n2. 段落级分块 - 大chunk (max_size=500):")
|
||
chunks_large = paragraph_chunking(test_text, max_chunk_size=500, overlap=100)
|
||
for i, chunk in enumerate(chunks_large):
|
||
logger.debug(f"Chunk {i+1} (长度: {len(chunk)}): {chunk[:50]}...")
|
||
|
||
# 测试3: 段落级分块(无重叠)
|
||
logger.debug("\n3. 段落级分块 - 无重叠:")
|
||
chunks_no_overlap = paragraph_chunking(test_text, max_chunk_size=300, overlap=0)
|
||
for i, chunk in enumerate(chunks_no_overlap):
|
||
logger.debug(f"Chunk {i+1} (长度: {len(chunk)}): {chunk[:50]}...")
|
||
|
||
logger.debug(f"\n测试总结:")
|
||
logger.debug(f"- 小chunk策略: {len(chunks_small)} 个chunks")
|
||
logger.debug(f"- 大chunk策略: {len(chunks_large)} 个chunks")
|
||
logger.debug(f"- 无重叠策略: {len(chunks_no_overlap)} 个chunks")
|
||
|
||
|
||
def demo_usage():
|
||
"""
|
||
演示如何使用新的分块功能
|
||
"""
|
||
logger.debug("=" * 60)
|
||
logger.debug("使用示例")
|
||
logger.debug("=" * 60)
|
||
|
||
logger.debug("\n1. 使用传统的按行分块:")
|
||
logger.debug("embed_document('document.txt', 'line_embedding.pkl', chunking_strategy='line')")
|
||
|
||
logger.debug("\n2. 使用段落级分块(默认参数):")
|
||
logger.debug("embed_document('document.txt', 'paragraph_embedding.pkl', chunking_strategy='paragraph')")
|
||
|
||
logger.debug("\n3. 使用自定义参数的段落级分块:")
|
||
logger.debug("embed_document('document.txt', 'custom_embedding.pkl',")
|
||
logger.debug(" chunking_strategy='paragraph',")
|
||
logger.debug(" max_chunk_size=1500,")
|
||
logger.debug(" overlap=200,")
|
||
logger.debug(" min_chunk_size=300)")
|
||
|
||
|
||
|
||
# 如果直接运行此文件,执行测试
|
||
if __name__ == "__main__":
|
||
#test_chunking_strategies()
|
||
#demo_usage()
|
||
|
||
# 使用新的智能分块示例:
|
||
embed_document("./projects/test/dataset/all_hp_product_spec_book2506/document.txt",
|
||
"./projects/test/dataset/all_hp_product_spec_book2506/smart_embedding.pkl",
|
||
chunking_strategy='smart', # 使用智能分块策略
|
||
max_chunk_size=800, # 较小的chunk大小
|
||
overlap=100)
|
||
|
||
def cache_terms_embeddings(bot_id: str, terms_list: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||
"""
|
||
处理terms列表,生成embedding并缓存
|
||
|
||
Args:
|
||
bot_id: 机器人ID,用于缓存key
|
||
terms_list: terms列表,每个term包含name, description, synonyms等字段
|
||
|
||
Returns:
|
||
Dict: 包含embedding数据的字典
|
||
"""
|
||
if not terms_list:
|
||
return {}
|
||
|
||
cache_key = f"{bot_id}_terms"
|
||
cache_file = f"projects/cache/{cache_key}.pkl"
|
||
|
||
# 确保cache目录存在
|
||
os.makedirs("projects/cache", exist_ok=True)
|
||
|
||
# 检查缓存是否存在且有效
|
||
if os.path.exists(cache_file):
|
||
try:
|
||
with open(cache_file, 'rb') as f:
|
||
cached_data = pickle.load(f)
|
||
|
||
# 验证缓存数据是否匹配当前的terms
|
||
current_hash = _generate_terms_hash(terms_list)
|
||
if cached_data.get('hash') == current_hash:
|
||
logger.info(f"Using cached terms embeddings for {cache_key}")
|
||
return cached_data
|
||
except Exception as e:
|
||
logger.error(f"Error loading cache: {e}")
|
||
|
||
# 准备要编码的文本
|
||
term_texts = []
|
||
term_info = []
|
||
|
||
for term in terms_list:
|
||
# 构建term的完整文本用于embedding
|
||
term_text_parts = []
|
||
|
||
if 'name' in term and term['name']:
|
||
term_text_parts.append(f"Name: {term['name']}")
|
||
|
||
if 'description' in term and term['description']:
|
||
term_text_parts.append(f"Description: {term['description']}")
|
||
|
||
# 处理同义词
|
||
synonyms = []
|
||
if 'synonyms' in term and term['synonyms']:
|
||
if isinstance(term['synonyms'], list):
|
||
synonyms = term['synonyms']
|
||
elif isinstance(term['synonyms'], str):
|
||
synonyms = [s.strip() for s in term['synonyms'].split(',') if s.strip()]
|
||
|
||
if synonyms:
|
||
term_text_parts.append(f"Synonyms: {', '.join(synonyms)}")
|
||
|
||
term_text = " | ".join(term_text_parts)
|
||
term_texts.append(term_text)
|
||
|
||
# 保存原始信息
|
||
term_info.append({
|
||
'name': term.get('name', ''),
|
||
'description': term.get('description', ''),
|
||
'synonyms': synonyms
|
||
})
|
||
|
||
# 生成embeddings
|
||
try:
|
||
embeddings = encode_texts_via_api(term_texts, batch_size=16)
|
||
|
||
# 准备缓存数据
|
||
cache_data = {
|
||
'hash': _generate_terms_hash(terms_list),
|
||
'term_info': term_info,
|
||
'embeddings': embeddings,
|
||
'texts': term_texts
|
||
}
|
||
|
||
# 保存到缓存
|
||
with open(cache_file, 'wb') as f:
|
||
pickle.dump(cache_data, f)
|
||
|
||
logger.info(f"Cached {len(term_texts)} terms embeddings to {cache_file}")
|
||
return cache_data
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error generating terms embeddings: {e}")
|
||
return {}
|
||
|
||
|
||
def search_similar_terms(query_text: str, cached_terms_data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||
"""
|
||
在缓存的terms中搜索与查询文本相似的terms
|
||
|
||
Args:
|
||
query_text: 查询文本
|
||
cached_terms_data: 缓存的terms数据
|
||
|
||
Returns:
|
||
List[Dict]: 匹配的terms列表,按相似度降序排列
|
||
"""
|
||
if not cached_terms_data or not query_text or 'embeddings' not in cached_terms_data:
|
||
return []
|
||
|
||
try:
|
||
# 生成查询文本的embedding
|
||
query_embedding = encode_texts_via_api([query_text], batch_size=1)
|
||
if len(query_embedding) == 0:
|
||
return []
|
||
|
||
query_vector = query_embedding[0]
|
||
term_embeddings = cached_terms_data['embeddings']
|
||
term_info = cached_terms_data['term_info']
|
||
|
||
# 添加调试信息
|
||
logger.debug(f"DEBUG: Query text: '{query_text}'")
|
||
logger.debug(f"DEBUG: Query vector shape: {query_vector.shape}, norm: {np.linalg.norm(query_vector)}")
|
||
|
||
# 计算cos相似度
|
||
similarities = _cosine_similarity(query_vector, term_embeddings)
|
||
|
||
logger.debug(f"DEBUG: Similarities: {similarities}")
|
||
logger.debug(f"DEBUG: Max similarity: {np.max(similarities):.3f}, Mean similarity: {np.mean(similarities):.3f}")
|
||
|
||
# 获取所有terms的相似度
|
||
matches = []
|
||
for i, similarity in enumerate(similarities):
|
||
match = {
|
||
'term_info': term_info[i],
|
||
'similarity': float(similarity),
|
||
'index': i
|
||
}
|
||
matches.append(match)
|
||
|
||
# 按相似度降序排列
|
||
matches.sort(key=lambda x: x['similarity'], reverse=True)
|
||
|
||
# 只返回top5结果
|
||
return matches[:5]
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in similarity search: {e}")
|
||
return []
|
||
|
||
|
||
def format_terms_analysis(similar_terms: List[Dict[str, Any]]) -> str:
|
||
"""
|
||
格式化相似terms为指定格式的字符串
|
||
|
||
Args:
|
||
similar_terms: 相似terms列表
|
||
|
||
Returns:
|
||
str: 格式化后的terms分析
|
||
"""
|
||
if not similar_terms:
|
||
return ""
|
||
|
||
formatted_terms = []
|
||
|
||
for i, match in enumerate(similar_terms, 1):
|
||
term_info = match['term_info']
|
||
similarity = match['similarity']
|
||
|
||
name = term_info.get('name', '')
|
||
description = term_info.get('description', '')
|
||
synonyms = term_info.get('synonyms', [])
|
||
|
||
# 格式化同义词
|
||
synonyms_str = ', '.join(synonyms) if synonyms else 'N/A'
|
||
|
||
formatted_term = f"{i}) Name: {name}, Description: {description}, Synonyms: {synonyms_str} (Similarity: {similarity:.3f})"
|
||
formatted_terms.append(formatted_term)
|
||
|
||
return "\n".join(formatted_terms)
|
||
|
||
|
||
def _generate_terms_hash(terms_list: List[Dict[str, Any]]) -> str:
|
||
"""生成terms列表的哈希值用于缓存验证"""
|
||
# 将terms列表转换为标准化的字符串
|
||
terms_str = json.dumps(terms_list, sort_keys=True, ensure_ascii=False)
|
||
return hashlib.md5(terms_str.encode('utf-8')).hexdigest()
|
||
|
||
|
||
def _cosine_similarity(query_vector: np.ndarray, term_embeddings: np.ndarray) -> np.ndarray:
|
||
"""
|
||
计算查询向量与所有term embeddings的cos相似度
|
||
参考semantic_search_server.py的实现,假设向量已经归一化
|
||
|
||
Args:
|
||
query_vector: 查询向量 (shape: [embedding_dim])
|
||
term_embeddings: term embeddings矩阵 (shape: [n_terms, embedding_dim])
|
||
|
||
Returns:
|
||
np.ndarray: 相似度数组 (shape: [n_terms])
|
||
"""
|
||
# 使用与semantic_search_server.py相同的算法
|
||
if len(term_embeddings.shape) > 1:
|
||
cos_scores = np.dot(term_embeddings, query_vector) / (
|
||
np.linalg.norm(term_embeddings, axis=1) * np.linalg.norm(query_vector) + 1e-8
|
||
)
|
||
else:
|
||
cos_scores = np.array([0.0] * len(term_embeddings))
|
||
|
||
return cos_scores
|
||
|
||
|
||
def process_terms_with_embedding(terms_list: List[Dict[str, Any]], bot_id: str, query_text: str) -> str:
|
||
"""
|
||
完整的terms处理流程:缓存、搜索相似度、格式化输出
|
||
|
||
Args:
|
||
terms_list: terms列表
|
||
bot_id: 机器人ID
|
||
query_text: 用户查询文本
|
||
|
||
Returns:
|
||
str: 格式化后的terms分析结果
|
||
"""
|
||
if not terms_list or not query_text:
|
||
return ""
|
||
|
||
# 1. 缓存terms的embeddings
|
||
cached_data = cache_terms_embeddings(bot_id, terms_list)
|
||
|
||
if not cached_data:
|
||
return ""
|
||
|
||
# 2. 搜索相似的terms (取top5)
|
||
similar_terms = search_similar_terms(query_text, cached_data)
|
||
|
||
# 3. 格式化输出
|
||
if similar_terms:
|
||
return format_terms_analysis(similar_terms)
|
||
else:
|
||
# 当没有找到相似terms时,可以返回空字符串或者提示信息
|
||
# 这里返回空字符串,让调用方决定如何处理
|
||
return ""
|
||
|
||
|
||
# 其他示例调用(注释掉的):
|
||
# split_document_by_pages("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt")
|
||
# embed_document("/Users/moshui/Documents/felo/qwen-agent/projects/test/dataset/all_hp_product_spec_book2506/document.txt") # 取消注释来运行
|