feat: enhance model credential validation and support for multiple API versions
This commit is contained in:
parent
795db14c75
commit
044465fcc6
@ -60,7 +60,10 @@ class Reasoning:
|
|||||||
if not self.reasoning_content_is_end:
|
if not self.reasoning_content_is_end:
|
||||||
self.reasoning_content_is_end = True
|
self.reasoning_content_is_end = True
|
||||||
self.content += self.all_content
|
self.content += self.all_content
|
||||||
return {'content': self.all_content, 'reasoning_content': ''}
|
return {'content': self.all_content,
|
||||||
|
'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
|
||||||
|
'') if chunk.additional_kwargs else ''
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
if self.reasoning_content_is_start:
|
if self.reasoning_content_is_start:
|
||||||
self.reasoning_content_chunk += chunk.content
|
self.reasoning_content_chunk += chunk.content
|
||||||
@ -68,7 +71,9 @@ class Reasoning:
|
|||||||
self.reasoning_content_end_tag_prefix)
|
self.reasoning_content_end_tag_prefix)
|
||||||
if self.reasoning_content_is_end:
|
if self.reasoning_content_is_end:
|
||||||
self.content += chunk.content
|
self.content += chunk.content
|
||||||
return {'content': chunk.content, 'reasoning_content': ''}
|
return {'content': chunk.content, 'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
|
||||||
|
'') if chunk.additional_kwargs else ''
|
||||||
|
}
|
||||||
# 是否包含结束
|
# 是否包含结束
|
||||||
if reasoning_content_end_tag_prefix_index > -1:
|
if reasoning_content_end_tag_prefix_index > -1:
|
||||||
if len(self.reasoning_content_chunk) - reasoning_content_end_tag_prefix_index >= self.reasoning_content_end_tag_len:
|
if len(self.reasoning_content_chunk) - reasoning_content_end_tag_prefix_index >= self.reasoning_content_end_tag_len:
|
||||||
@ -93,7 +98,9 @@ class Reasoning:
|
|||||||
else:
|
else:
|
||||||
if self.reasoning_content_is_end:
|
if self.reasoning_content_is_end:
|
||||||
self.content += chunk.content
|
self.content += chunk.content
|
||||||
return {'content': chunk.content, 'reasoning_content': ''}
|
return {'content': chunk.content, 'reasoning_content': chunk.additional_kwargs.get('reasoning_content',
|
||||||
|
'') if chunk.additional_kwargs else ''
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
# aaa
|
# aaa
|
||||||
result = {'content': '', 'reasoning_content': self.reasoning_content_chunk}
|
result = {'content': '', 'reasoning_content': self.reasoning_content_chunk}
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from typing import List, Dict
|
|||||||
from common.forms.base_field import BaseExecField, TriggerType
|
from common.forms.base_field import BaseExecField, TriggerType
|
||||||
|
|
||||||
|
|
||||||
class Radio(BaseExecField):
|
class RadioButton(BaseExecField):
|
||||||
"""
|
"""
|
||||||
下拉单选
|
下拉单选
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from typing import List, Dict
|
|||||||
from common.forms.base_field import BaseExecField, TriggerType
|
from common.forms.base_field import BaseExecField, TriggerType
|
||||||
|
|
||||||
|
|
||||||
class Radio(BaseExecField):
|
class RadioCard(BaseExecField):
|
||||||
"""
|
"""
|
||||||
下拉单选
|
下拉单选
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -40,16 +40,23 @@ class WenxinLLMModelParams(BaseForm):
|
|||||||
class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
raise_exception=False):
|
raise_exception=False):
|
||||||
|
# 根据api_version检查必需字段
|
||||||
|
api_version = model_credential.get('api_version', 'v1')
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
|
if api_version == 'v1':
|
||||||
model_type_list = provider.get_model_type_list()
|
model_type_list = provider.get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
raise AppApiException(ValidCode.valid_error.value,
|
raise AppApiException(ValidCode.valid_error.value,
|
||||||
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
|
||||||
model_info = [model.lower() for model in model.client.models()]
|
model_info = [model.lower() for model in model.client.models()]
|
||||||
if not model_info.__contains__(model_name.lower()):
|
if not model_info.__contains__(model_name.lower()):
|
||||||
raise AppApiException(ValidCode.valid_error.value,
|
raise AppApiException(ValidCode.valid_error.value,
|
||||||
gettext('{model_name} The model does not support').format(model_name=model_name))
|
gettext('{model_name} The model does not support').format(model_name=model_name))
|
||||||
for key in ['api_key', 'secret_key']:
|
required_keys = ['api_key', 'secret_key']
|
||||||
|
if api_version == 'v2':
|
||||||
|
required_keys = ['api_base', 'api_key']
|
||||||
|
|
||||||
|
for key in required_keys:
|
||||||
if key not in model_credential:
|
if key not in model_credential:
|
||||||
if raise_exception:
|
if raise_exception:
|
||||||
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||||
@ -64,19 +71,47 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def encryption_dict(self, model_info: Dict[str, object]):
|
def encryption_dict(self, model_info: Dict[str, object]):
|
||||||
|
# 根据api_version加密不同字段
|
||||||
|
api_version = model_info.get('api_version', 'v1')
|
||||||
|
if api_version == 'v1':
|
||||||
return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))}
|
return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))}
|
||||||
|
else: # v2
|
||||||
|
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
|
||||||
|
|
||||||
def build_model(self, model_info: Dict[str, object]):
|
def build_model(self, model_info: Dict[str, object]):
|
||||||
for key in ['api_key', 'secret_key', 'model']:
|
api_version = model_info.get('api_version', 'v1')
|
||||||
|
# 根据api_version检查必需字段
|
||||||
|
if api_version == 'v1':
|
||||||
|
for key in ['api_version', 'api_key', 'secret_key', 'model']:
|
||||||
if key not in model_info:
|
if key not in model_info:
|
||||||
raise AppApiException(500, gettext('{key} is required').format(key=key))
|
raise AppApiException(500, gettext('{key} is required').format(key=key))
|
||||||
self.api_key = model_info.get('api_key')
|
self.api_key = model_info.get('api_key')
|
||||||
self.secret_key = model_info.get('secret_key')
|
self.secret_key = model_info.get('secret_key')
|
||||||
|
else: # v2
|
||||||
|
for key in ['api_version', 'api_base', 'api_key', 'model', ]:
|
||||||
|
if key not in model_info:
|
||||||
|
raise AppApiException(500, gettext('{key} is required').format(key=key))
|
||||||
|
self.api_base = model_info.get('api_base')
|
||||||
|
self.api_key = model_info.get('api_key')
|
||||||
return self
|
return self
|
||||||
|
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
# 动态字段定义 - 根据api_version显示不同字段
|
||||||
|
api_version = forms.Radio('API Version', required=True, text_field='label', value_field='value',
|
||||||
|
option_list=[
|
||||||
|
{'label': 'v1', 'value': 'v1'},
|
||||||
|
{'label': 'v2', 'value': 'v2'}
|
||||||
|
],
|
||||||
|
default_value='v1',
|
||||||
|
provider='',
|
||||||
|
method='', )
|
||||||
|
|
||||||
secret_key = forms.PasswordInputField("Secret Key", required=True)
|
# v2版本字段
|
||||||
|
api_base = forms.TextInputField("API Base", required=False, relation_show_field_dict={"api_version": ["v2"]})
|
||||||
|
|
||||||
|
# v1版本字段
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=False)
|
||||||
|
secret_key = forms.PasswordInputField("Secret Key", required=False,
|
||||||
|
relation_show_field_dict={"api_version": ["v1"]})
|
||||||
|
|
||||||
def get_model_params_setting_form(self, model_name):
|
def get_model_params_setting_form(self, model_name):
|
||||||
return WenxinLLMModelParams()
|
return WenxinLLMModelParams()
|
||||||
|
|||||||
@ -17,9 +17,10 @@ from langchain_core.messages import (
|
|||||||
from langchain_core.outputs import ChatGenerationChunk
|
from langchain_core.outputs import ChatGenerationChunk
|
||||||
|
|
||||||
from models_provider.base_model_provider import MaxKBBaseModel
|
from models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
|
class QianfanChatModelQianfan(MaxKBBaseModel, QianfanChatEndpoint):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_cache_model():
|
def is_cache_model():
|
||||||
return False
|
return False
|
||||||
@ -27,7 +28,7 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||||
return QianfanChatModel(model=model_name,
|
return QianfanChatModelQianfan(model=model_name,
|
||||||
qianfan_ak=model_credential.get('api_key'),
|
qianfan_ak=model_credential.get('api_key'),
|
||||||
qianfan_sk=model_credential.get('secret_key'),
|
qianfan_sk=model_credential.get('secret_key'),
|
||||||
streaming=model_kwargs.get('streaming', False),
|
streaming=model_kwargs.get('streaming', False),
|
||||||
@ -74,3 +75,30 @@ class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint):
|
|||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
class QianfanChatModelOpenai(MaxKBBaseModel, BaseChatOpenAI):
|
||||||
|
@staticmethod
|
||||||
|
def is_cache_model():
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||||
|
return QianfanChatModelOpenai(
|
||||||
|
model=model_name,
|
||||||
|
openai_api_base=model_credential.get('api_base'),
|
||||||
|
openai_api_key=model_credential.get('api_key'),
|
||||||
|
extra_body=optional_params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QianfanChatModel(MaxKBBaseModel):
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
api_version = model_credential.get('api_version', 'v1')
|
||||||
|
|
||||||
|
if api_version == "v1":
|
||||||
|
return QianfanChatModelQianfan.new_instance(model_type, model_name, model_credential, **model_kwargs)
|
||||||
|
elif api_version == "v2":
|
||||||
|
return QianfanChatModelOpenai.new_instance(model_type, model_name, model_credential, **model_kwargs)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user