modify env
This commit is contained in:
parent
b9973abdbd
commit
75c5f0aa80
@ -31,13 +31,19 @@ def get_model(model_name_or_path='sentence-transformers/paraphrase-multilingual-
|
|||||||
# 优先使用本地模型路径
|
# 优先使用本地模型路径
|
||||||
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
|
local_model_path = "./models/paraphrase-multilingual-MiniLM-L12-v2"
|
||||||
|
|
||||||
|
# 从环境变量获取设备配置,默认为 CPU
|
||||||
|
device = os.environ.get('SENTENCE_TRANSFORMER_DEVICE', 'cpu')
|
||||||
|
if device not in ['cpu', 'cuda', 'mps']:
|
||||||
|
print(f"警告: 不支持的设备类型 '{device}',使用默认 CPU")
|
||||||
|
device = 'cpu'
|
||||||
|
|
||||||
# 检查本地模型是否存在
|
# 检查本地模型是否存在
|
||||||
if os.path.exists(local_model_path):
|
if os.path.exists(local_model_path):
|
||||||
print(f"使用本地模型: {local_model_path}")
|
print(f"使用本地模型: {local_model_path}")
|
||||||
embedder = SentenceTransformer(local_model_path, device='cpu')
|
embedder = SentenceTransformer(local_model_path, device=device)
|
||||||
else:
|
else:
|
||||||
print(f"本地模型不存在,使用HuggingFace模型: {model_name_or_path}")
|
print(f"本地模型不存在,使用HuggingFace模型: {model_name_or_path}")
|
||||||
embedder = SentenceTransformer(model_name_or_path, device='cpu')
|
embedder = SentenceTransformer(model_name_or_path, device=device)
|
||||||
|
|
||||||
return embedder
|
return embedder
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user