modify env
This commit is contained in:
parent
b9973abdbd
commit
75c5f0aa80
@ -31,13 +31,19 @@ def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-
|
||||
# 优先使用本地模型路径
|
||||
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
|
||||
|
||||
# 从环境变量获取设备配置,默认为 CPU
|
||||
device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu')
|
||||
if device not in ['cpu', 'cuda', 'mps']:
|
||||
print(f"警告: 不支持的设备类型 '{device}',使用默认 CPU")
|
||||
device = 'cpu'
|
||||
|
||||
# 检查本地模型是否存在
|
||||
if os.path.exists(local_model_path):
|
||||
print(f"使用本地模型: {local_model_path}")
|
||||
embedder = SentenceTransformer(local_model_path, device='cpu')
|
||||
embedder = SentenceTransformer(local_model_path, device=device)
|
||||
else:
|
||||
print(f"本地模型不存在,使用HuggingFace模型: {model_name_or_path}")
|
||||
embedder = SentenceTransformer(model_name_or_path, device='cpu')
|
||||
embedder = SentenceTransformer(model_name_or_path, device=device)
|
||||
|
||||
return embedder
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user