add SENTENCE_TRANSFORMER_DEVICE

This commit is contained in:
朱潮 2025-10-20 19:56:50 +08:00
parent bd4435c1ec
commit 607d20492c
3 changed files with 26 additions and 24 deletions

View File

@ -23,10 +23,17 @@ def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-
import os
if os.path.exists(model_name_or_path):
print("使用本地模型")
embedder = SentenceTransformer(model_name_or_path, device='cpu')
else:
print("使用 HuggingFace 模型")
embedder = SentenceTransformer(model_name_or_path, device='cpu')
# 从环境变量获取设备配置,默认为 CPU
device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu')
if device not in ['cpu', 'cuda', 'mps']:
print(f"警告: 不支持的设备类型 '{device}',使用默认 CPU")
device = 'cpu'
print(f"使用设备: {device}")
embedder = SentenceTransformer(model_name_or_path, device=device)
print("模型加载完成")
return embedder
@ -91,7 +98,7 @@ def is_meaningful_line(text):
return True
def embed_document(input_file='document.txt', output_file='document_embeddings.pkl',
chunking_strategy='line', model_path=None, **chunking_params):
chunking_strategy='line', **chunking_params):
"""
读取文档文件使用指定分块策略进行embedding保存为pickle文件
@ -99,7 +106,6 @@ def embed_document(input_file='document.txt', output_file='document_embeddings.p
input_file (str): 输入文档文件路径
output_file (str): 输出pickle文件路径
chunking_strategy (str): 分块策略可选 'line', 'paragraph'
model_path (str): 模型路径可以是本地路径或HuggingFace模型名称
**chunking_params: 分块参数
- 对于 'line' 策略无额外参数
- 对于 'paragraph' 策略
@ -186,8 +192,12 @@ def embed_document(input_file='document.txt', output_file='document_embeddings.p
print(f"正在处理 {len(chunks)} 个内容块...")
# 设置默认模型路径
if model_path is None:
# 在函数内部设置模型路径
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
import os
if os.path.exists(local_model_path):
model_path = local_model_path
else:
model_path = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
model = get_model(model_path)
@ -251,7 +261,13 @@ def semantic_search(user_query, embeddings_file='document_embeddings.pkl', top_k
cos_scores = util.cos_sim(query_embedding, chunk_embeddings)[0]
top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k]
# 处理 GPU/CPU 环境下的 tensor 转换
if cos_scores.is_cuda:
cos_scores_np = cos_scores.cpu().numpy()
else:
cos_scores_np = cos_scores.numpy()
top_results = np.argsort(-cos_scores_np)[:top_k]
results = []
print(f"\n与查询最相关的 {top_k}{content_type} (分块策略: {chunking_strategy}):")
@ -748,14 +764,10 @@ if __name__ == "__main__":
#test_chunking_strategies()
#demo_usage()
# 使用新的段落级分块示例:
# 可以指定本地模型路径,避免从 HuggingFace 下载
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
# 使用新的智能分块示例:
embed_document("./projects/test/dataset/all_hp_product_spec_book2506/document.txt",
"./projects/test/dataset/all_hp_product_spec_book2506/smart_embeddings.pkl",
chunking_strategy='smart', # 使用智能分块策略
model_path=local_model_path, # 使用本地模型
max_chunk_size=800, # 较小的chunk大小
overlap=100)

View File

@ -238,16 +238,12 @@ async def download_dataset_files(unique_id: str, files: Dict[str, List[str]]) ->
# Generate embeddings
print(f" Generating embeddings for {key}")
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
if not os.path.exists(local_model_path):
local_model_path = None # Fallback to HuggingFace model
# Use paragraph chunking strategy with default settings
embedding_data = embed_document(
str(document_file),
str(embeddings_file),
chunking_strategy='paragraph',
model_path=local_model_path
chunking_strategy='paragraph'
)
if embedding_data:

View File

@ -100,17 +100,11 @@ def organize_single_project_files(unique_id: str, skip_processed=True):
# Generate embeddings
print(f" Generating embeddings for {document_file.name}")
try:
# Set local model path for faster processing
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
if not os.path.exists(local_model_path):
local_model_path = None # Fallback to HuggingFace model
# Use paragraph chunking strategy with default settings
embedding_data = embed_document(
str(document_file),
str(embeddings_file),
chunking_strategy='paragraph',
model_path=local_model_path
chunking_strategy='paragraph'
)
if embedding_data: