modify env

This commit is contained in:
朱潮 2025-10-22 10:44:22 +08:00
parent b9973abdbd
commit 75c5f0aa80

View File

@ -31,13 +31,19 @@ def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-
# 优先使用本地模型路径 # 优先使用本地模型路径
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2" local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
# 从环境变量获取设备配置,默认为 CPU
device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu')
if device not in ['cpu', 'cuda', 'mps']:
print(f"警告: 不支持的设备类型 '{device}',使用默认 CPU")
device = 'cpu'
# 检查本地模型是否存在 # 检查本地模型是否存在
if os.path.exists(local_model_path): if os.path.exists(local_model_path):
print(f"使用本地模型: {local_model_path}") print(f"使用本地模型: {local_model_path}")
embedder = SentenceTransformer(local_model_path, device='cpu') embedder = SentenceTransformer(local_model_path, device=device)
else: else:
print(f"本地模型不存在使用HuggingFace模型: {model_name_or_path}") print(f"本地模型不存在使用HuggingFace模型: {model_name_or_path}")
embedder = SentenceTransformer(model_name_or_path, device='cpu') embedder = SentenceTransformer(model_name_or_path, device=device)
return embedder return embedder