diff --git a/embedding/embedding.py b/embedding/embedding.py index ef308f7..ce7b3f3 100644 --- a/embedding/embedding.py +++ b/embedding/embedding.py @@ -23,10 +23,17 @@ def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual- import os if os.path.exists(model_name_or_path): print("使用本地模型") - embedder = SentenceTransformer(model_name_or_path, device='cpu') else: print("使用 HuggingFace 模型") - embedder = SentenceTransformer(model_name_or_path, device='cpu') + + # 从环境变量获取设备配置,默认为 CPU + device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu') + if device not in ['cpu', 'cuda', 'mps']: + print(f"警告: 不支持的设备类型 '{device}',使用默认 CPU") + device = 'cpu' + + print(f"使用设备: {device}") + embedder = SentenceTransformer(model_name_or_path, device=device) print("模型加载完成") return embedder @@ -91,7 +98,7 @@ def is_meaningful_line(text): return True def embed_document(input_file='document.txt', output_file='document_embeddings.pkl', - chunking_strategy='line', model_path=None, **chunking_params): + chunking_strategy='line', **chunking_params): """ 读取文档文件,使用指定分块策略进行embedding,保存为pickle文件 @@ -99,7 +106,6 @@ def embed_document(input_file='document.txt', output_file='document_embeddings.p input_file (str): 输入文档文件路径 output_file (str): 输出pickle文件路径 chunking_strategy (str): 分块策略,可选 'line', 'paragraph' - model_path (str): 模型路径,可以是本地路径或HuggingFace模型名称 **chunking_params: 分块参数 - 对于 'line' 策略:无额外参数 - 对于 'paragraph' 策略: @@ -186,8 +192,12 @@ def embed_document(input_file='document.txt', output_file='document_embeddings.p print(f"正在处理 {len(chunks)} 个内容块...") - # 设置默认模型路径 - if model_path is None: + # 在函数内部设置模型路径 + local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2" + import os + if os.path.exists(local_model_path): + model_path = local_model_path + else: model_path = 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2' model = get_model(model_path) @@ -251,7 +261,13 @@ def semantic_search(user_query, embeddings_file='document_embeddings.pkl', top_k cos_scores = util.cos_sim(query_embedding, chunk_embeddings)[0] - top_results = np.argsort(-cos_scores.cpu().numpy())[:top_k] + # 处理 GPU/CPU 环境下的 tensor 转换 + if cos_scores.is_cuda: + cos_scores_np = cos_scores.cpu().numpy() + else: + cos_scores_np = cos_scores.numpy() + + top_results = np.argsort(-cos_scores_np)[:top_k] results = [] print(f"\n与查询最相关的 {top_k} 个{content_type} (分块策略: {chunking_strategy}):") @@ -748,14 +764,10 @@ if __name__ == "__main__": #test_chunking_strategies() #demo_usage() - # 使用新的段落级分块示例: - # 可以指定本地模型路径,避免从 HuggingFace 下载 - local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2" - + # 使用新的智能分块示例: embed_document("./projects/test/dataset/all_hp_product_spec_book2506/document.txt", "./projects/test/dataset/all_hp_product_spec_book2506/smart_embeddings.pkl", chunking_strategy='smart', # 使用智能分块策略 - model_path=local_model_path, # 使用本地模型 max_chunk_size=800, # 较小的chunk大小 overlap=100) diff --git a/utils/dataset_manager.py b/utils/dataset_manager.py index dd15fa7..8d140d7 100644 --- a/utils/dataset_manager.py +++ b/utils/dataset_manager.py @@ -238,16 +238,12 @@ async def download_dataset_files(unique_id: str, files: Dict[str, List[str]]) -> # Generate embeddings print(f" Generating embeddings for {key}") - local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2" - if not os.path.exists(local_model_path): - local_model_path = None # Fallback to HuggingFace model # Use paragraph chunking strategy with default settings embedding_data = embed_document( str(document_file), str(embeddings_file), - chunking_strategy='paragraph', - model_path=local_model_path + chunking_strategy='paragraph' ) if embedding_data: diff --git a/utils/organize_dataset_files.py b/utils/organize_dataset_files.py index b2c8a8a..40098ce 100644 --- a/utils/organize_dataset_files.py +++ b/utils/organize_dataset_files.py @@ -100,17 +100,11 @@ def organize_single_project_files(unique_id: str, skip_processed=True): # Generate embeddings print(f" Generating embeddings for {document_file.name}") try: - # Set local model path for faster processing - local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2" - if not os.path.exists(local_model_path): - local_model_path = None # Fallback to HuggingFace model - # Use paragraph chunking strategy with default settings embedding_data = embed_document( str(document_file), str(embeddings_file), - chunking_strategy='paragraph', - model_path=local_model_path + chunking_strategy='paragraph' ) if embedding_data: