add SENTENCE_TRANSFORMER_DEVICE
This commit is contained in:
parent
bd4435c1ec
commit
607d20492c
@ -23,10 +23,17 @@ def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-
|
||||
import os
|
||||
if os.path.exists(model_name_or_path):
|
||||
print("使用本地模型")
|
||||
embedder = SentenceTransformer(model_name_or_path, device='cpu')
|
||||
else:
|
||||
print("使用 HuggingFace 模型")
|
||||
embedder = SentenceTransformer(model_name_or_path, device='cpu')
|
||||
|
||||
# 从环境变量获取设备配置,默认为 CPU
|
||||
device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu')
|
||||
if device not in ['cpu', 'cuda', 'mps']:
|
||||
print(f"警告: 不支持的设备类型 '{device}',使用默认 CPU")
|
||||
device = 'cpu'
|
||||
|
||||
print(f"使用设备: {device}")
|
||||
embedder = SentenceTransformer(model_name_or_path, device=device)
|
||||
|
||||
print("模型加载完成")
|
||||
return embedder
|
||||
@ -91,7 +98,7 @@ def is_meaningful_line(text):
|
||||
return True
|
||||
|
||||
def embed_document(input_file='document.txt', output_file='document_embeddings.pkl',
|
||||
chunking_strategy='line', model_path=None, **chunking_params):
|
||||
chunking_strategy='line', **chunking_params):
|
||||
"""
|
||||
读取文档文件,使用指定分块策略进行embedding,保存为pickle文件
|
||||
|
||||
@ -99,7 +106,6 @@ def embed_document(input_file='document.txt', output_file='document_embeddings.p
|
||||
input_file (str): 输入文档文件路径
|
||||
output_file (str): 输出pickle文件路径
|
||||
chunking_strategy (str): 分块策略,可选 'line', 'paragraph'
|
||||
model_path (str): 模型路径,可以是本地路径或HuggingFace模型名称
|
||||
**chunking_params: 分块参数
|
||||
- 对于 'line' 策略:无额外参数
|
||||
- 对于 'paragraph' 策略:
|
||||
@ -186,8 +192,12 @@ def embed_document(input_file='document.txt', output_file='document_embeddings.p
|
||||
|
||||
print(f"正在处理 {len(chunks)} 个内容块...")
|
||||
|
||||
# 设置默认模型路径
|
||||
if model_path is None:
|
||||
# 在函数内部设置模型路径
|
||||
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
import os
|
||||
if os.path.exists(local_model_path):
|
||||
model_path = local_model_path
|
||||
else:
|
||||
model_path = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2'
|
||||
|
||||
model = get_model(model_path)
|
||||
@ -251,7 +261,13 @@ def semantic_search(user_query, embeddings_file='document_embeddings.pkl', top_k
|
||||
|
||||
cos_scores = util.cos_sim(query_embedding, chunk_embeddings)[0]
|
||||
|
||||
top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k]
|
||||
# 处理 GPU/CPU 环境下的 tensor 转换
|
||||
if cos_scores.is_cuda:
|
||||
cos_scores_np = cos_scores.cpu().numpy()
|
||||
else:
|
||||
cos_scores_np = cos_scores.numpy()
|
||||
|
||||
top_results = np.argsort(-cos_scores_np)[:top_k]
|
||||
|
||||
results = []
|
||||
print(f"\n与查询最相关的 {top_k} 个{content_type} (分块策略: {chunking_strategy}):")
|
||||
@ -748,14 +764,10 @@ if __name__ == "__main__":
|
||||
#test_chunking_strategies()
|
||||
#demo_usage()
|
||||
|
||||
# 使用新的段落级分块示例:
|
||||
# 可以指定本地模型路径,避免从 HuggingFace 下载
|
||||
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
|
||||
# 使用新的智能分块示例:
|
||||
embed_document("./projects/test/dataset/all_hp_product_spec_book2506/document.txt",
|
||||
"./projects/test/dataset/all_hp_product_spec_book2506/smart_embeddings.pkl",
|
||||
chunking_strategy='smart', # 使用智能分块策略
|
||||
model_path=local_model_path, # 使用本地模型
|
||||
max_chunk_size=800, # 较小的chunk大小
|
||||
overlap=100)
|
||||
|
||||
|
||||
@ -238,16 +238,12 @@ async def download_dataset_files(unique_id: str, files: Dict[str, List[str]]) ->
|
||||
|
||||
# Generate embeddings
|
||||
print(f" Generating embeddings for {key}")
|
||||
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
if not os.path.exists(local_model_path):
|
||||
local_model_path = None # Fallback to HuggingFace model
|
||||
|
||||
# Use paragraph chunking strategy with default settings
|
||||
embedding_data = embed_document(
|
||||
str(document_file),
|
||||
str(embeddings_file),
|
||||
chunking_strategy='paragraph',
|
||||
model_path=local_model_path
|
||||
chunking_strategy='paragraph'
|
||||
)
|
||||
|
||||
if embedding_data:
|
||||
|
||||
@ -100,17 +100,11 @@ def organize_single_project_files(unique_id: str, skip_processed=True):
|
||||
# Generate embeddings
|
||||
print(f" Generating embeddings for {document_file.name}")
|
||||
try:
|
||||
# Set local model path for faster processing
|
||||
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
if not os.path.exists(local_model_path):
|
||||
local_model_path = None # Fallback to HuggingFace model
|
||||
|
||||
# Use paragraph chunking strategy with default settings
|
||||
embedding_data = embed_document(
|
||||
str(document_file),
|
||||
str(embeddings_file),
|
||||
chunking_strategy='paragraph',
|
||||
model_path=local_model_path
|
||||
chunking_strategy='paragraph'
|
||||
)
|
||||
|
||||
if embedding_data:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user