refactor: 自动填充api_url
This commit is contained in:
parent
b8ba2458c0
commit
fcbfd8a07c
@ -9,7 +9,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val
|
|||||||
|
|
||||||
|
|
||||||
class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
|
class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
volcanic_api_url = forms.TextInputField('API 域名', required=True)
|
volcanic_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://openspeech.bytedance.com/api/v2/asr')
|
||||||
volcanic_app_id = forms.TextInputField('App ID', required=True)
|
volcanic_app_id = forms.TextInputField('App ID', required=True)
|
||||||
volcanic_token = forms.PasswordInputField('Token', required=True)
|
volcanic_token = forms.PasswordInputField('Token', required=True)
|
||||||
volcanic_cluster = forms.TextInputField('Cluster', required=True)
|
volcanic_cluster = forms.TextInputField('Cluster', required=True)
|
||||||
|
|||||||
@ -0,0 +1,45 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class VolcanicEngineTTSModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
volcanic_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://openspeech.bytedance.com/api/v1/tts/ws_binary')
|
||||||
|
volcanic_app_id = forms.TextInputField('App ID', required=True)
|
||||||
|
volcanic_token = forms.PasswordInputField('Token', required=True)
|
||||||
|
volcanic_cluster = forms.TextInputField('Cluster', required=True)
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['volcanic_api_url', 'volcanic_app_id', 'volcanic_token', 'volcanic_cluster']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.check_auth()
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))}
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
pass
|
||||||
@ -14,6 +14,7 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro
|
|||||||
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
|
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
|
||||||
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
||||||
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
||||||
|
from setting.models_provider.impl.volcanic_engine_model_provider.credential.tts import VolcanicEngineTTSModelCredential
|
||||||
from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel
|
from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel
|
||||||
from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential
|
from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential
|
||||||
from setting.models_provider.impl.volcanic_engine_model_provider.model.stt import VolcanicEngineSpeechToText
|
from setting.models_provider.impl.volcanic_engine_model_provider.model.stt import VolcanicEngineSpeechToText
|
||||||
@ -23,6 +24,7 @@ from smartdoc.conf import PROJECT_DIR
|
|||||||
|
|
||||||
volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
|
volcanic_engine_llm_model_credential = OpenAILLMModelCredential()
|
||||||
volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential()
|
volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential()
|
||||||
|
volcanic_engine_tts_model_credential = VolcanicEngineTTSModelCredential()
|
||||||
|
|
||||||
model_info_list = [
|
model_info_list = [
|
||||||
ModelInfo('ep-xxxxxxxxxx-yyyy',
|
ModelInfo('ep-xxxxxxxxxx-yyyy',
|
||||||
@ -38,7 +40,7 @@ model_info_list = [
|
|||||||
ModelInfo('tts',
|
ModelInfo('tts',
|
||||||
'',
|
'',
|
||||||
ModelTypeConst.TTS,
|
ModelTypeConst.TTS,
|
||||||
volcanic_engine_stt_model_credential, VolcanicEngineTextToSpeech
|
volcanic_engine_tts_model_credential, VolcanicEngineTextToSpeech
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val
|
|||||||
|
|
||||||
|
|
||||||
class XunFeiSTTModelCredential(BaseForm, BaseModelCredential):
|
class XunFeiSTTModelCredential(BaseForm, BaseModelCredential):
|
||||||
spark_api_url = forms.TextInputField('API 域名', required=True)
|
spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://iat-api.xfyun.cn/v2/iat')
|
||||||
spark_app_id = forms.TextInputField('APP ID', required=True)
|
spark_app_id = forms.TextInputField('APP ID', required=True)
|
||||||
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||||
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||||
|
|||||||
@ -0,0 +1,46 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class XunFeiTTSModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://tts-api.xfyun.cn/v2/tts')
|
||||||
|
spark_app_id = forms.TextInputField('APP ID', required=True)
|
||||||
|
spark_api_key = forms.PasswordInputField("API Key", required=True)
|
||||||
|
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
|
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.check_auth()
|
||||||
|
except Exception as e:
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
pass
|
||||||
@ -14,6 +14,7 @@ from setting.models_provider.base_model_provider import ModelProvideInfo, ModelT
|
|||||||
ModelInfoManage
|
ModelInfoManage
|
||||||
from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
|
from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
|
||||||
from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
|
from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
|
||||||
|
from setting.models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
|
||||||
from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
|
from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
|
||||||
from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
|
from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
|
||||||
from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
|
from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
|
||||||
@ -23,13 +24,13 @@ ssl._create_default_https_context = ssl.create_default_context()
|
|||||||
|
|
||||||
qwen_model_credential = XunFeiLLMModelCredential()
|
qwen_model_credential = XunFeiLLMModelCredential()
|
||||||
stt_model_credential = XunFeiSTTModelCredential()
|
stt_model_credential = XunFeiSTTModelCredential()
|
||||||
|
tts_model_credential = XunFeiTTSModelCredential()
|
||||||
model_info_list = [
|
model_info_list = [
|
||||||
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||||
ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||||
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
|
||||||
ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
||||||
ModelInfo('iat-niche', '小语种识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
|
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
|
||||||
ModelInfo('tts', '', ModelTypeConst.TTS, stt_model_credential, XFSparkTextToSpeech),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user