fix: 模型校验错误

This commit is contained in:
zhangshaohu 2024-03-23 00:38:52 +08:00
parent 31a441c1c0
commit 0c9a7c15b6
3 changed files with 12 additions and 15 deletions

View File

@ -27,15 +27,15 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = AzureModelProvider().get_model_type_list() model_type_list = AzureModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error, f'{model_type} 模型类型不支持') raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
if model_name not in model_dict: if model_name not in model_dict:
raise AppApiException(ValidCode.valid_error, f'{model_name} 模型名称不支持') raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型名称不支持')
for key in ['api_base', 'api_key', 'deployment_name']: for key in ['api_base', 'api_key', 'deployment_name']:
if key not in model_credential: if key not in model_credential:
if raise_exception: if raise_exception:
raise AppApiException(ValidCode.valid_error, f'{key} 字段为必填字段') raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else: else:
return False return False
try: try:
@ -45,7 +45,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
if isinstance(e, AppApiException): if isinstance(e, AppApiException):
raise e raise e
if raise_exception: if raise_exception:
raise AppApiException(ValidCode.valid_error, '校验失败,请检查参数是否正确') raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确')
else: else:
return False return False

View File

@ -11,11 +11,8 @@ import os
from typing import Dict, Iterator from typing import Dict, Iterator
from urllib.parse import urlparse, ParseResult from urllib.parse import urlparse, ParseResult
import aiohttp
import requests import requests
from django.http import StreamingHttpResponse
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage
from common import froms from common import froms
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
@ -33,11 +30,11 @@ class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = OllamaModelProvider().get_model_type_list() model_type_list = OllamaModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error, f'{model_type} 模型类型不支持') raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
try: try:
model_list = OllamaModelProvider.get_base_model_list(model_credential.get('api_base')) model_list = OllamaModelProvider.get_base_model_list(model_credential.get('api_base'))
except Exception as e: except Exception as e:
raise AppApiException(ValidCode.valid_error, "API 域名无效") raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
exist = [model for model in model_list.get('models') if exist = [model for model in model_list.get('models') if
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name] model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
if len(exist) == 0: if len(exist) == 0:
@ -159,7 +156,7 @@ class OllamaModelProvider(IModelProvider):
@staticmethod @staticmethod
def get_base_model_list(api_base): def get_base_model_list(api_base):
base_url = get_base_url(api_base) base_url = get_base_url(api_base)
r = requests.request(method="GET", url=f"{base_url}/api/tags") r = requests.request(method="GET", url=f"{base_url}/api/tags", timeout=5)
r.raise_for_status() r.raise_for_status()
return r.json() return r.json()

View File

@ -17,7 +17,7 @@ from common.exception.app_exception import AppApiException
from common.froms import BaseForm from common.froms import BaseForm
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, BaseModelCredential, \
ModelInfo, IModelProvider ModelInfo, IModelProvider, ValidCode
from setting.models_provider.impl.wenxin_model_provider.model.qian_fan_chat_model import QianfanChatModel from setting.models_provider.impl.wenxin_model_provider.model.qian_fan_chat_model import QianfanChatModel
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
@ -26,15 +26,15 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = WenxinModelProvider().get_model_type_list() model_type_list = WenxinModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(500, f'{model_type} 模型类型不支持') raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
if model_name not in model_dict: if model_name not in model_dict:
raise AppApiException(500, f'{model_name} 模型名称不支持') raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型名称不支持')
for key in ['api_key', 'secret_key']: for key in ['api_key', 'secret_key']:
if key not in model_credential: if key not in model_credential:
if raise_exception: if raise_exception:
raise AppApiException(500, f'{key} 字段为必填字段') raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else: else:
return False return False
try: try:
@ -42,7 +42,7 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
[HumanMessage(content='valid')]) [HumanMessage(content='valid')])
except Exception as e: except Exception as e:
if raise_exception: if raise_exception:
raise AppApiException(500, "校验失败,请检查 api_key secret_key 是否正确") raise AppApiException(ValidCode.valid_error.value, "校验失败,请检查 api_key secret_key 是否正确")
return True return True
def encryption_dict(self, model_info: Dict[str, object]): def encryption_dict(self, model_info: Dict[str, object]):