refactor: model
This commit is contained in:
parent
173f7e8321
commit
594ca6cd89
@ -12,6 +12,7 @@ from models_provider.impl.kimi_model_provider.kimi_model_provider import KimiMod
|
|||||||
from models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider
|
from models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider
|
||||||
from models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider
|
from models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider
|
||||||
from models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider
|
from models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider
|
||||||
|
from models_provider.impl.regolo_model_provider.regolo_model_provider import RegoloModelProvider
|
||||||
from models_provider.impl.siliconCloud_model_provider.siliconCloud_model_provider import SiliconCloudModelProvider
|
from models_provider.impl.siliconCloud_model_provider.siliconCloud_model_provider import SiliconCloudModelProvider
|
||||||
from models_provider.impl.tencent_cloud_model_provider.tencent_cloud_model_provider import TencentCloudModelProvider
|
from models_provider.impl.tencent_cloud_model_provider.tencent_cloud_model_provider import TencentCloudModelProvider
|
||||||
from models_provider.impl.tencent_model_provider.tencent_model_provider import TencentModelProvider
|
from models_provider.impl.tencent_model_provider.tencent_model_provider import TencentModelProvider
|
||||||
@ -44,3 +45,4 @@ class ModelProvideConstants(Enum):
|
|||||||
aliyun_bai_lian_model_provider = AliyunBaiLianModelProvider()
|
aliyun_bai_lian_model_provider = AliyunBaiLianModelProvider()
|
||||||
model_anthropic_provider = AnthropicModelProvider()
|
model_anthropic_provider = AnthropicModelProvider()
|
||||||
model_siliconCloud_provider = SiliconCloudModelProvider()
|
model_siliconCloud_provider = SiliconCloudModelProvider()
|
||||||
|
model_regolo_provider = RegoloModelProvider()
|
||||||
|
|||||||
@ -14,11 +14,30 @@ from langchain_core.messages import HumanMessage
|
|||||||
|
|
||||||
from common import forms
|
from common import forms
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
from common.forms import BaseForm
|
from common.forms import BaseForm, TooltipLabel
|
||||||
from common.utils.logger import maxkb_logger
|
from common.utils.logger import maxkb_logger
|
||||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class QwenModelParams(BaseForm):
|
||||||
|
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||||
|
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||||
|
required=True, default_value=1.0,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.9,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel(_('Output the maximum Tokens'),
|
||||||
|
_('Specify the maximum number of tokens that the model can generate')),
|
||||||
|
required=True, default_value=800,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
class QwenVLModelCredential(BaseForm, BaseModelCredential):
|
class QwenVLModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(
|
def is_valid(
|
||||||
@ -70,3 +89,6 @@ class QwenVLModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
return QwenModelParams()
|
||||||
|
|||||||
@ -12,6 +12,25 @@ from common.utils.logger import maxkb_logger
|
|||||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicImageModelParams(BaseForm):
|
||||||
|
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||||
|
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||||
|
required=True, default_value=0.7,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.0,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel(_('Output the maximum Tokens'),
|
||||||
|
_('Specify the maximum number of tokens that the model can generate')),
|
||||||
|
required=True, default_value=800,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
class AnthropicImageModelCredential(BaseForm, BaseModelCredential):
|
class AnthropicImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
api_base = forms.TextInputField(_('API URL'), required=True)
|
api_base = forms.TextInputField(_('API URL'), required=True)
|
||||||
api_key = forms.PasswordInputField(_('API Key'), required=True)
|
api_key = forms.PasswordInputField(_('API Key'), required=True)
|
||||||
@ -51,4 +70,4 @@ class AnthropicImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
def get_model_params_setting_form(self, model_name):
|
def get_model_params_setting_form(self, model_name):
|
||||||
pass
|
return AnthropicImageModelParams()
|
||||||
|
|||||||
@ -15,7 +15,22 @@ from django.utils.translation import gettext_lazy as _, gettext
|
|||||||
|
|
||||||
|
|
||||||
class AzureOpenAIImageModelParams(BaseForm):
|
class AzureOpenAIImageModelParams(BaseForm):
|
||||||
pass
|
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||||
|
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||||
|
required=True, default_value=0.7,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.0,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel(_('Output the maximum Tokens'),
|
||||||
|
_('Specify the maximum number of tokens that the model can generate')),
|
||||||
|
required=True, default_value=800,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
class AzureOpenAIImageModelCredential(BaseForm, BaseModelCredential):
|
class AzureOpenAIImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|||||||
@ -12,6 +12,25 @@ from common.utils.logger import maxkb_logger
|
|||||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiImageModelParams(BaseForm):
|
||||||
|
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||||
|
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||||
|
required=True, default_value=0.7,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.0,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel(_('Output the maximum Tokens'),
|
||||||
|
_('Specify the maximum number of tokens that the model can generate')),
|
||||||
|
required=True, default_value=800,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageModelCredential(BaseForm, BaseModelCredential):
|
class GeminiImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
@ -50,4 +69,4 @@ class GeminiImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
def get_model_params_setting_form(self, model_name):
|
def get_model_params_setting_form(self, model_name):
|
||||||
pass
|
return GeminiImageModelParams()
|
||||||
|
|||||||
@ -7,6 +7,26 @@ from common.forms import BaseForm, TooltipLabel
|
|||||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
from django.utils.translation import gettext_lazy as _, gettext
|
from django.utils.translation import gettext_lazy as _, gettext
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaImageModelParams(BaseForm):
|
||||||
|
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||||
|
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||||
|
required=True, default_value=0.7,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.0,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel(_('Output the maximum Tokens'),
|
||||||
|
_('Specify the maximum number of tokens that the model can generate')),
|
||||||
|
required=True, default_value=800,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
class OllamaImageModelCredential(BaseForm, BaseModelCredential):
|
class OllamaImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
api_base = forms.TextInputField('API URL', required=True)
|
api_base = forms.TextInputField('API URL', required=True)
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
@ -33,4 +53,4 @@ class OllamaImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
def get_model_params_setting_form(self, model_name):
|
def get_model_params_setting_form(self, model_name):
|
||||||
pass
|
return OllamaImageModelParams()
|
||||||
|
|||||||
@ -14,6 +14,25 @@ from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
|||||||
from django.utils.translation import gettext_lazy as _, gettext
|
from django.utils.translation import gettext_lazy as _, gettext
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIImageModelParams(BaseForm):
|
||||||
|
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||||
|
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||||
|
required=True, default_value=0.7,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.0,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel(_('Output the maximum Tokens'),
|
||||||
|
_('Specify the maximum number of tokens that the model can generate')),
|
||||||
|
required=True, default_value=800,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIImageModelCredential(BaseForm, BaseModelCredential):
|
class OpenAIImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
api_base = forms.TextInputField('API URL', required=True)
|
api_base = forms.TextInputField('API URL', required=True)
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
@ -53,4 +72,4 @@ class OpenAIImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
def get_model_params_setting_form(self, model_name):
|
def get_model_params_setting_form(self, model_name):
|
||||||
pass
|
return OpenAIImageModelParams()
|
||||||
|
|||||||
@ -0,0 +1,8 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: __init__.py.py
|
||||||
|
@date:2024/3/28 16:25
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
@ -0,0 +1,52 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: embedding.py
|
||||||
|
@date:2024/7/12 16:45
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import traceback
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from django.utils.translation import gettext as _
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm
|
||||||
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class RegoloEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
|
raise_exception=True):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value,
|
||||||
|
_('{model_type} Model type is not supported').format(model_type=model_type))
|
||||||
|
|
||||||
|
for key in ['api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, _('{key} is required').format(key=key))
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential)
|
||||||
|
model.embed_query(_('Hello'))
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value,
|
||||||
|
_('Verification failed, please check whether the parameters are correct: {error}').format(
|
||||||
|
error=str(e)))
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
@ -0,0 +1,75 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
import traceback
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm, TooltipLabel
|
||||||
|
from django.utils.translation import gettext_lazy as _, gettext
|
||||||
|
|
||||||
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class RegoloImageModelParams(BaseForm):
|
||||||
|
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||||
|
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||||
|
required=True, default_value=0.7,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.0,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel(_('Output the maximum Tokens'),
|
||||||
|
_('Specify the maximum number of tokens that the model can generate')),
|
||||||
|
required=True, default_value=800,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
|
class RegoloImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
api_base = forms.TextInputField('API URL', required=True)
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value,
|
||||||
|
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||||
|
|
||||||
|
for key in ['api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
|
res = model.stream([HumanMessage(content=[{"type": "text", "text": gettext('Hello')}])])
|
||||||
|
for chunk in res:
|
||||||
|
print(chunk)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value,
|
||||||
|
gettext(
|
||||||
|
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||||
|
error=str(e)))
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
return RegoloImageModelParams()
|
||||||
@ -0,0 +1,78 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/7/11 18:32
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import traceback
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from django.utils.translation import gettext_lazy as _, gettext
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm, TooltipLabel
|
||||||
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class RegoloLLMModelParams(BaseForm):
|
||||||
|
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||||
|
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||||
|
required=True, default_value=0.7,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.0,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel(_('Output the maximum Tokens'),
|
||||||
|
_('Specify the maximum number of tokens that the model can generate')),
|
||||||
|
required=True, default_value=800,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
|
class RegoloLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value,
|
||||||
|
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||||
|
|
||||||
|
for key in ['api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
|
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value,
|
||||||
|
gettext(
|
||||||
|
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||||
|
error=str(e)))
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
return RegoloLLMModelParams()
|
||||||
@ -0,0 +1,89 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
import traceback
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from django.utils.translation import gettext_lazy as _, gettext
|
||||||
|
|
||||||
|
from common import forms
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.forms import BaseForm, TooltipLabel
|
||||||
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class RegoloTTIModelParams(BaseForm):
|
||||||
|
size = forms.SingleSelect(
|
||||||
|
TooltipLabel(_('Image size'),
|
||||||
|
_('The image generation endpoint allows you to create raw images based on text prompts. ')),
|
||||||
|
required=True,
|
||||||
|
default_value='1024x1024',
|
||||||
|
option_list=[
|
||||||
|
{'value': '1024x1024', 'label': '1024x1024'},
|
||||||
|
{'value': '1024x1792', 'label': '1024x1792'},
|
||||||
|
{'value': '1792x1024', 'label': '1792x1024'},
|
||||||
|
],
|
||||||
|
text_field='label',
|
||||||
|
value_field='value'
|
||||||
|
)
|
||||||
|
|
||||||
|
quality = forms.SingleSelect(
|
||||||
|
TooltipLabel(_('Picture quality'), _('''
|
||||||
|
By default, images are produced in standard quality.
|
||||||
|
''')),
|
||||||
|
required=True,
|
||||||
|
default_value='standard',
|
||||||
|
option_list=[
|
||||||
|
{'value': 'standard', 'label': 'standard'},
|
||||||
|
{'value': 'hd', 'label': 'hd'},
|
||||||
|
],
|
||||||
|
text_field='label',
|
||||||
|
value_field='value'
|
||||||
|
)
|
||||||
|
|
||||||
|
n = forms.SliderField(
|
||||||
|
TooltipLabel(_('Number of pictures'),
|
||||||
|
_('1 as default')),
|
||||||
|
required=True, default_value=1,
|
||||||
|
_min=1,
|
||||||
|
_max=10,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
|
class RegoloTextToImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
|
raise_exception=False):
|
||||||
|
model_type_list = provider.get_model_type_list()
|
||||||
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
|
raise AppApiException(ValidCode.valid_error.value,
|
||||||
|
gettext('{model_type} Model type is not supported').format(model_type=model_type))
|
||||||
|
|
||||||
|
for key in ['api_key']:
|
||||||
|
if key not in model_credential:
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value, gettext('{key} is required').format(key=key))
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||||
|
res = model.check_auth()
|
||||||
|
print(res)
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
if isinstance(e, AppApiException):
|
||||||
|
raise e
|
||||||
|
if raise_exception:
|
||||||
|
raise AppApiException(ValidCode.valid_error.value,
|
||||||
|
gettext(
|
||||||
|
'Verification failed, please check whether the parameters are correct: {error}').format(
|
||||||
|
error=str(e)))
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def encryption_dict(self, model: Dict[str, object]):
|
||||||
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
|
def get_model_params_setting_form(self, model_name):
|
||||||
|
return RegoloTTIModelParams()
|
||||||
@ -0,0 +1,64 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
|
||||||
|
<svg
|
||||||
|
id="Livello_2"
|
||||||
|
data-name="Livello 2"
|
||||||
|
viewBox="0 0 104.4 104.38"
|
||||||
|
version="1.1"
|
||||||
|
sodipodi:docname="Regolo_logo_positive.svg"
|
||||||
|
width="100%" height="100%"
|
||||||
|
inkscape:version="1.4 (e7c3feb100, 2024-10-09)"
|
||||||
|
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
|
||||||
|
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
xmlns:svg="http://www.w3.org/2000/svg">
|
||||||
|
<sodipodi:namedview
|
||||||
|
id="namedview13"
|
||||||
|
pagecolor="#ffffff"
|
||||||
|
bordercolor="#666666"
|
||||||
|
borderopacity="1.0"
|
||||||
|
inkscape:showpageshadow="2"
|
||||||
|
inkscape:pageopacity="0.0"
|
||||||
|
inkscape:pagecheckerboard="0"
|
||||||
|
inkscape:deskcolor="#d1d1d1"
|
||||||
|
inkscape:zoom="2.1335227"
|
||||||
|
inkscape:cx="119.05193"
|
||||||
|
inkscape:cy="48.511318"
|
||||||
|
inkscape:window-width="1920"
|
||||||
|
inkscape:window-height="1025"
|
||||||
|
inkscape:window-x="0"
|
||||||
|
inkscape:window-y="0"
|
||||||
|
inkscape:window-maximized="1"
|
||||||
|
inkscape:current-layer="g13" />
|
||||||
|
<defs
|
||||||
|
id="defs1">
|
||||||
|
<style
|
||||||
|
id="style1">
|
||||||
|
.cls-1 {
|
||||||
|
fill: #303030;
|
||||||
|
}
|
||||||
|
|
||||||
|
.cls-2 {
|
||||||
|
fill: #59e389;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</defs>
|
||||||
|
<g
|
||||||
|
id="Grafica"
|
||||||
|
transform="translate(0,-40.87)">
|
||||||
|
<g
|
||||||
|
id="g13">
|
||||||
|
<path
|
||||||
|
class="cls-1"
|
||||||
|
d="m 104.39,105.96 v 36.18 c 0,0.32 -0.05,0.62 -0.14,0.91 -0.39,1.27 -1.58,2.2 -2.99,2.2 H 65.08 c -1.73,0 -3.13,-1.41 -3.13,-3.13 V 113.4 c 0,-0.15 0,-0.29 0,-0.44 v -7 c 0,-1.73 1.4,-3.13 3.13,-3.13 h 36.19 c 1.5,0 2.77,1.07 3.06,2.5 0.05,0.21 0.07,0.41 0.07,0.63 z"
|
||||||
|
id="path1" />
|
||||||
|
<path
|
||||||
|
class="cls-1"
|
||||||
|
d="m 104.39,105.96 v 36.18 c 0,0.32 -0.05,0.62 -0.14,0.91 -0.39,1.27 -1.58,2.2 -2.99,2.2 H 65.08 c -1.73,0 -3.13,-1.41 -3.13,-3.13 V 113.4 c 0,-0.15 0,-0.29 0,-0.44 v -7 c 0,-1.73 1.4,-3.13 3.13,-3.13 h 36.19 c 1.5,0 2.77,1.07 3.06,2.5 0.05,0.21 0.07,0.41 0.07,0.63 z"
|
||||||
|
id="path2" />
|
||||||
|
<path
|
||||||
|
class="cls-2"
|
||||||
|
d="M 101.27,40.88 H 65.09 c -1.73,0 -3.13,1.4 -3.13,3.13 v 28.71 c 0,4.71 -1.88,9.23 -5.2,12.56 L 44.42,97.61 c -3.32,3.33 -7.85,5.2 -12.55,5.2 H 18.98 c -2.21,0 -3.99,-1.79 -3.99,-3.99 V 87.29 c 0,-2.21 1.79,-3.99 3.99,-3.99 h 20.34 c 1.41,0 2.59,-0.93 2.99,-2.2 0.09,-0.29 0.14,-0.59 0.14,-0.91 V 44 c 0,-0.22 -0.02,-0.42 -0.07,-0.63 -0.29,-1.43 -1.56,-2.5 -3.06,-2.5 H 3.13 C 1.4,40.87 0,42.27 0,44 v 7 c 0,0.15 0,0.29 0,0.44 v 28.72 c 0,1.72 1.41,3.13 3.13,3.13 h 3.16 c 2.21,0 3.99,1.79 3.99,3.99 v 11.53 c 0,2.21 -1.79,3.99 -3.99,3.99 H 3.15 c -1.73,0 -3.13,1.4 -3.13,3.13 v 36.19 c 0,1.72 1.41,3.13 3.13,3.13 h 36.19 c 1.73,0 3.13,-1.41 3.13,-3.13 V 113.4 c 0,-4.7 1.87,-9.23 5.2,-12.55 L 60,88.51 c 3.33,-3.32 7.85,-5.2 12.56,-5.2 h 28.71 c 1.73,0 3.13,-1.4 3.13,-3.13 V 44 c 0,-1.73 -1.4,-3.13 -3.13,-3.13 z"
|
||||||
|
id="path3" />
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 2.8 KiB |
@ -0,0 +1,23 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: MaxKB
|
||||||
|
@Author:虎
|
||||||
|
@file: embedding.py
|
||||||
|
@date:2024/7/12 17:44
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from langchain_community.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
|
from models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class RegoloEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings):
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
return RegoloEmbeddingModel(
|
||||||
|
api_key=model_credential.get('api_key'),
|
||||||
|
model=model_name,
|
||||||
|
openai_api_base="https://api.regolo.ai/v1",
|
||||||
|
)
|
||||||
@ -0,0 +1,19 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
|
class RegoloImage(MaxKBBaseModel, BaseChatOpenAI):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||||
|
return RegoloImage(
|
||||||
|
model_name=model_name,
|
||||||
|
openai_api_base="https://api.regolo.ai/v1",
|
||||||
|
openai_api_key=model_credential.get('api_key'),
|
||||||
|
streaming=True,
|
||||||
|
stream_usage=True,
|
||||||
|
extra_body=optional_params
|
||||||
|
)
|
||||||
38
apps/models_provider/impl/regolo_model_provider/model/llm.py
Normal file
38
apps/models_provider/impl/regolo_model_provider/model/llm.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: llm.py
|
||||||
|
@date:2024/4/18 15:28
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
from langchain_openai.chat_models import ChatOpenAI
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from models_provider.impl.base_chat_open_ai import BaseChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
|
def custom_get_token_ids(text: str):
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return tokenizer.encode(text)
|
||||||
|
|
||||||
|
|
||||||
|
class RegoloChatModel(MaxKBBaseModel, BaseChatOpenAI):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_cache_model():
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
|
||||||
|
return RegoloChatModel(
|
||||||
|
model=model_name,
|
||||||
|
openai_api_base="https://api.regolo.ai/v1",
|
||||||
|
openai_api_key=model_credential.get('api_key'),
|
||||||
|
extra_body=optional_params
|
||||||
|
)
|
||||||
58
apps/models_provider/impl/regolo_model_provider/model/tti.py
Normal file
58
apps/models_provider/impl/regolo_model_provider/model/tti.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
from common.config.tokenizer_manage_config import TokenizerManage
|
||||||
|
from models_provider.base_model_provider import MaxKBBaseModel
|
||||||
|
from models_provider.impl.base_tti import BaseTextToImage
|
||||||
|
|
||||||
|
|
||||||
|
def custom_get_token_ids(text: str):
|
||||||
|
tokenizer = TokenizerManage.get_tokenizer()
|
||||||
|
return tokenizer.encode(text)
|
||||||
|
|
||||||
|
|
||||||
|
class RegoloTextToImage(MaxKBBaseModel, BaseTextToImage):
|
||||||
|
api_base: str
|
||||||
|
api_key: str
|
||||||
|
model: str
|
||||||
|
params: dict
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.api_key = kwargs.get('api_key')
|
||||||
|
self.api_base = "https://api.regolo.ai/v1"
|
||||||
|
self.model = kwargs.get('model')
|
||||||
|
self.params = kwargs.get('params')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
|
optional_params = {'params': {'size': '1024x1024', 'quality': 'standard', 'n': 1}}
|
||||||
|
for key, value in model_kwargs.items():
|
||||||
|
if key not in ['model_id', 'use_local', 'streaming']:
|
||||||
|
optional_params['params'][key] = value
|
||||||
|
return RegoloTextToImage(
|
||||||
|
model=model_name,
|
||||||
|
api_base="https://api.regolo.ai/v1",
|
||||||
|
api_key=model_credential.get('api_key'),
|
||||||
|
**optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_cache_model(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def check_auth(self):
|
||||||
|
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
|
||||||
|
response_list = chat.models.with_raw_response.list()
|
||||||
|
|
||||||
|
# self.generate_image('生成一个小猫图片')
|
||||||
|
|
||||||
|
def generate_image(self, prompt: str, negative_prompt: str = None):
|
||||||
|
chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
|
||||||
|
res = chat.images.generate(model=self.model, prompt=prompt, **self.params)
|
||||||
|
file_urls = []
|
||||||
|
for content in res.data:
|
||||||
|
url = content.url
|
||||||
|
file_urls.append(url)
|
||||||
|
|
||||||
|
return file_urls
|
||||||
@ -0,0 +1,88 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: openai_model_provider.py
|
||||||
|
@date:2024/3/28 16:26
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from django.utils.translation import gettext as _
|
||||||
|
|
||||||
|
from common.utils.common import get_file_content
|
||||||
|
from maxkb.conf import PROJECT_DIR
|
||||||
|
from models_provider.base_model_provider import ModelInfo, ModelTypeConst, ModelInfoManage, IModelProvider, \
|
||||||
|
ModelProvideInfo
|
||||||
|
from models_provider.impl.regolo_model_provider.credential.embedding import RegoloEmbeddingCredential
|
||||||
|
from models_provider.impl.regolo_model_provider.credential.llm import RegoloLLMModelCredential
|
||||||
|
from models_provider.impl.regolo_model_provider.credential.tti import RegoloTextToImageModelCredential
|
||||||
|
from models_provider.impl.regolo_model_provider.model.embedding import RegoloEmbeddingModel
|
||||||
|
from models_provider.impl.regolo_model_provider.model.llm import RegoloChatModel
|
||||||
|
from models_provider.impl.regolo_model_provider.model.tti import RegoloTextToImage
|
||||||
|
|
||||||
|
openai_llm_model_credential = RegoloLLMModelCredential()
|
||||||
|
openai_tti_model_credential = RegoloTextToImageModelCredential()
|
||||||
|
model_info_list = [
|
||||||
|
ModelInfo('Phi-4', '', ModelTypeConst.LLM,
|
||||||
|
openai_llm_model_credential, RegoloChatModel
|
||||||
|
),
|
||||||
|
ModelInfo('DeepSeek-R1-Distill-Qwen-32B', '', ModelTypeConst.LLM,
|
||||||
|
openai_llm_model_credential,
|
||||||
|
RegoloChatModel),
|
||||||
|
ModelInfo('maestrale-chat-v0.4-beta', '',
|
||||||
|
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||||
|
RegoloChatModel),
|
||||||
|
ModelInfo('Llama-3.3-70B-Instruct',
|
||||||
|
'',
|
||||||
|
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||||
|
RegoloChatModel),
|
||||||
|
ModelInfo('Llama-3.1-8B-Instruct',
|
||||||
|
'',
|
||||||
|
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||||
|
RegoloChatModel),
|
||||||
|
ModelInfo('DeepSeek-Coder-6.7B-Instruct', '',
|
||||||
|
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||||
|
RegoloChatModel)
|
||||||
|
]
|
||||||
|
open_ai_embedding_credential = RegoloEmbeddingCredential()
|
||||||
|
model_info_embedding_list = [
|
||||||
|
ModelInfo('gte-Qwen2', '',
|
||||||
|
ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
|
||||||
|
RegoloEmbeddingModel),
|
||||||
|
]
|
||||||
|
|
||||||
|
model_info_tti_list = [
|
||||||
|
ModelInfo('FLUX.1-dev', '',
|
||||||
|
ModelTypeConst.TTI, openai_tti_model_credential,
|
||||||
|
RegoloTextToImage),
|
||||||
|
ModelInfo('sdxl-turbo', '',
|
||||||
|
ModelTypeConst.TTI, openai_tti_model_credential,
|
||||||
|
RegoloTextToImage),
|
||||||
|
]
|
||||||
|
model_info_manage = (
|
||||||
|
ModelInfoManage.builder()
|
||||||
|
.append_model_info_list(model_info_list)
|
||||||
|
.append_default_model_info(
|
||||||
|
ModelInfo('gpt-3.5-turbo', _('The latest gpt-3.5-turbo, updated with OpenAI adjustments'), ModelTypeConst.LLM,
|
||||||
|
openai_llm_model_credential, RegoloChatModel
|
||||||
|
))
|
||||||
|
.append_model_info_list(model_info_embedding_list)
|
||||||
|
.append_default_model_info(model_info_embedding_list[0])
|
||||||
|
.append_model_info_list(model_info_tti_list)
|
||||||
|
.append_default_model_info(model_info_tti_list[0])
|
||||||
|
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RegoloModelProvider(IModelProvider):
|
||||||
|
|
||||||
|
def get_model_info_manage(self):
|
||||||
|
return model_info_manage
|
||||||
|
|
||||||
|
def get_model_provide_info(self):
|
||||||
|
return ModelProvideInfo(provider='model_regolo_provider', name='Regolo', icon=get_file_content(
|
||||||
|
os.path.join(PROJECT_DIR, "apps", 'models_provider', 'impl', 'regolo_model_provider',
|
||||||
|
'icon',
|
||||||
|
'regolo_icon_svg')))
|
||||||
@ -14,6 +14,25 @@ from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
|||||||
from django.utils.translation import gettext_lazy as _, gettext
|
from django.utils.translation import gettext_lazy as _, gettext
|
||||||
|
|
||||||
|
|
||||||
|
class SiliconCloudImageModelParams(BaseForm):
|
||||||
|
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||||
|
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||||
|
required=True, default_value=0.7,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.0,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel(_('Output the maximum Tokens'),
|
||||||
|
_('Specify the maximum number of tokens that the model can generate')),
|
||||||
|
required=True, default_value=800,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
class SiliconCloudImageModelCredential(BaseForm, BaseModelCredential):
|
class SiliconCloudImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
api_base = forms.TextInputField('API URL', required=True)
|
api_base = forms.TextInputField('API URL', required=True)
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
@ -53,4 +72,4 @@ class SiliconCloudImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
def get_model_params_setting_form(self, model_name):
|
def get_model_params_setting_form(self, model_name):
|
||||||
pass
|
return SiliconCloudImageModelParams()
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
"""
|
"""
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
import requests
|
||||||
from langchain_community.embeddings import OpenAIEmbeddings
|
from langchain_community.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
from models_provider.base_model_provider import MaxKBBaseModel
|
from models_provider.base_model_provider import MaxKBBaseModel
|
||||||
@ -17,7 +18,26 @@ class SiliconCloudEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||||
return SiliconCloudEmbeddingModel(
|
return SiliconCloudEmbeddingModel(
|
||||||
api_key=model_credential.get('api_key'),
|
openai_api_key=model_credential.get('api_key'),
|
||||||
model=model_name,
|
model=model_name,
|
||||||
openai_api_base=model_credential.get('api_base'),
|
openai_api_base=model_credential.get('api_base'),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> list:
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"input": text
|
||||||
|
}
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.openai_api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(self.openai_api_base + '/embeddings', json=payload, headers=headers)
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
# 假设返回结构中有 'data[0].embedding'
|
||||||
|
return data["data"][0]["embedding"]
|
||||||
|
|
||||||
|
def embed_documents(self, texts: list) -> list:
|
||||||
|
return [self.embed_query(text) for text in texts]
|
||||||
|
|||||||
@ -19,6 +19,25 @@ from common.utils.logger import maxkb_logger
|
|||||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class TencentModelParams(BaseForm):
|
||||||
|
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||||
|
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||||
|
required=True, default_value=1.0,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.9,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel(_('Output the maximum Tokens'),
|
||||||
|
_('Specify the maximum number of tokens that the model can generate')),
|
||||||
|
required=True, default_value=800,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
class TencentVisionModelCredential(BaseForm, BaseModelCredential):
|
class TencentVisionModelCredential(BaseForm, BaseModelCredential):
|
||||||
|
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||||
@ -57,4 +76,4 @@ class TencentVisionModelCredential(BaseForm, BaseModelCredential):
|
|||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
def get_model_params_setting_form(self, model_name):
|
def get_model_params_setting_form(self, model_name):
|
||||||
pass
|
return TencentModelParams()
|
||||||
|
|||||||
@ -11,6 +11,26 @@ from common.forms import BaseForm, TooltipLabel
|
|||||||
from common.utils.logger import maxkb_logger
|
from common.utils.logger import maxkb_logger
|
||||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class VllmImageModelParams(BaseForm):
|
||||||
|
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||||
|
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||||
|
required=True, default_value=0.7,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.0,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel(_('Output the maximum Tokens'),
|
||||||
|
_('Specify the maximum number of tokens that the model can generate')),
|
||||||
|
required=True, default_value=800,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
class VllmImageModelCredential(BaseForm, BaseModelCredential):
|
class VllmImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
api_base = forms.TextInputField('API URL', required=True)
|
api_base = forms.TextInputField('API URL', required=True)
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
@ -50,4 +70,4 @@ class VllmImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
def get_model_params_setting_form(self, model_name):
|
def get_model_params_setting_form(self, model_name):
|
||||||
pass
|
return VllmImageModelParams()
|
||||||
|
|||||||
@ -11,6 +11,26 @@ from common.forms import BaseForm, TooltipLabel
|
|||||||
from common.utils.logger import maxkb_logger
|
from common.utils.logger import maxkb_logger
|
||||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class VolcanicEngineImageModelParams(BaseForm):
|
||||||
|
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||||
|
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||||
|
required=True, default_value=0.95,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.0,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel(_('Output the maximum Tokens'),
|
||||||
|
_('Specify the maximum number of tokens that the model can generate')),
|
||||||
|
required=True, default_value=1024,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
class VolcanicEngineImageModelCredential(BaseForm, BaseModelCredential):
|
class VolcanicEngineImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
api_base = forms.TextInputField('API URL', required=True)
|
api_base = forms.TextInputField('API URL', required=True)
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
@ -50,4 +70,4 @@ class VolcanicEngineImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
def get_model_params_setting_form(self, model_name):
|
def get_model_params_setting_form(self, model_name):
|
||||||
pass
|
return VolcanicEngineImageModelParams()
|
||||||
|
|||||||
@ -70,11 +70,6 @@ model_info_list = [
|
|||||||
ModelTypeConst.TTI,
|
ModelTypeConst.TTI,
|
||||||
volcanic_engine_tti_model_credential, VolcanicEngineTextToImage
|
volcanic_engine_tti_model_credential, VolcanicEngineTextToImage
|
||||||
),
|
),
|
||||||
ModelInfo('anime_v1.3',
|
|
||||||
_('Animation 1.3.0-Vincent Picture'),
|
|
||||||
ModelTypeConst.TTI,
|
|
||||||
volcanic_engine_tti_model_credential, VolcanicEngineTextToImage
|
|
||||||
),
|
|
||||||
ModelInfo('anime_v1.3.1',
|
ModelInfo('anime_v1.3.1',
|
||||||
_('Animation 1.3.1-Vincent Picture'),
|
_('Animation 1.3.1-Vincent Picture'),
|
||||||
ModelTypeConst.TTI,
|
ModelTypeConst.TTI,
|
||||||
|
|||||||
@ -11,6 +11,25 @@ from common.utils.logger import maxkb_logger
|
|||||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class XinferenceImageModelParams(BaseForm):
|
||||||
|
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||||
|
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||||
|
required=True, default_value=0.7,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.0,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel(_('Output the maximum Tokens'),
|
||||||
|
_('Specify the maximum number of tokens that the model can generate')),
|
||||||
|
required=True, default_value=800,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
class XinferenceImageModelCredential(BaseForm, BaseModelCredential):
|
class XinferenceImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
api_base = forms.TextInputField('API URL', required=True)
|
api_base = forms.TextInputField('API URL', required=True)
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
@ -49,4 +68,4 @@ class XinferenceImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
def get_model_params_setting_form(self, model_name):
|
def get_model_params_setting_form(self, model_name):
|
||||||
pass
|
return XinferenceImageModelParams()
|
||||||
|
|||||||
@ -11,6 +11,26 @@ from common.forms import BaseForm, TooltipLabel
|
|||||||
from common.utils.logger import maxkb_logger
|
from common.utils.logger import maxkb_logger
|
||||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||||
|
|
||||||
|
|
||||||
|
class ZhiPuImageModelParams(BaseForm):
|
||||||
|
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||||
|
_('Higher values make the output more random, while lower values make it more focused and deterministic')),
|
||||||
|
required=True, default_value=0.95,
|
||||||
|
_min=0.1,
|
||||||
|
_max=1.0,
|
||||||
|
_step=0.01,
|
||||||
|
precision=2)
|
||||||
|
|
||||||
|
max_tokens = forms.SliderField(
|
||||||
|
TooltipLabel(_('Output the maximum Tokens'),
|
||||||
|
_('Specify the maximum number of tokens that the model can generate')),
|
||||||
|
required=True, default_value=1024,
|
||||||
|
_min=1,
|
||||||
|
_max=100000,
|
||||||
|
_step=1,
|
||||||
|
precision=0)
|
||||||
|
|
||||||
|
|
||||||
class ZhiPuImageModelCredential(BaseForm, BaseModelCredential):
|
class ZhiPuImageModelCredential(BaseForm, BaseModelCredential):
|
||||||
api_key = forms.PasswordInputField('API Key', required=True)
|
api_key = forms.PasswordInputField('API Key', required=True)
|
||||||
|
|
||||||
@ -49,4 +69,4 @@ class ZhiPuImageModelCredential(BaseForm, BaseModelCredential):
|
|||||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||||
|
|
||||||
def get_model_params_setting_form(self, model_name):
|
def get_model_params_setting_form(self, model_name):
|
||||||
pass
|
return ZhiPuImageModelParams()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user