feat: Create model and configure advanced parameters

This commit is contained in:
CaptainB 2024-12-10 17:23:34 +08:00 committed by 刘瑞斌
parent 8b33c99235
commit 4d977fd765
15 changed files with 384 additions and 193 deletions

View File

@ -21,6 +21,8 @@ class ImageGenerateNodeSerializer(serializers.Serializer):
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
model_params_setting = serializers.JSONField(required=False, default=dict, error_messages=ErrMessage.json("模型参数设置"))
class IImageGenerateNode(INode): class IImageGenerateNode(INode):
type = 'image-generate-node' type = 'image-generate-node'
@ -32,6 +34,7 @@ class IImageGenerateNode(INode):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
model_params_setting,
chat_record_id, chat_record_id,
**kwargs) -> NodeResult: **kwargs) -> NodeResult:
pass pass

View File

@ -2,10 +2,13 @@
from functools import reduce from functools import reduce
from typing import List from typing import List
import requests
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
from application.flow.i_step_node import NodeResult from application.flow.i_step_node import NodeResult
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
from common.util.common import bytes_to_uploaded_file
from dataset.serializers.file_serializers import FileSerializer
from setting.models_provider.tools import get_model_instance_by_model_user_id from setting.models_provider.tools import get_model_instance_by_model_user_id
@ -16,10 +19,12 @@ class BaseImageGenerateNode(IImageGenerateNode):
self.answer_text = details.get('answer') self.answer_text = details.get('answer')
def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id, def execute(self, model_id, prompt, negative_prompt, dialogue_number, dialogue_type, history_chat_record, chat_id,
model_params_setting,
chat_record_id, chat_record_id,
**kwargs) -> NodeResult: **kwargs) -> NodeResult:
print(model_params_setting)
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id')) application = self.workflow_manage.work_flow_post_handler.chat_info.application
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number) history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt) question = self.generate_prompt_question(prompt)
@ -28,10 +33,21 @@ class BaseImageGenerateNode(IImageGenerateNode):
self.context['message_list'] = message_list self.context['message_list'] = message_list
self.context['dialogue_type'] = dialogue_type self.context['dialogue_type'] = dialogue_type
print(message_list) print(message_list)
print(negative_prompt)
image_urls = tti_model.generate_image(question, negative_prompt) image_urls = tti_model.generate_image(question, negative_prompt)
self.context['image_list'] = image_urls # 保存图片
answer = '\n'.join([f"![Image]({path})" for path in image_urls]) file_urls = []
for image_url in image_urls:
file_name = 'generated_image.png'
file = bytes_to_uploaded_file(requests.get(image_url).content, file_name)
meta = {
'debug': False if application.id else True,
'chat_id': chat_id,
'application_id': str(application.id) if application.id else None,
}
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
file_urls.append(file_url)
self.context['image_list'] = file_urls
answer = '\n'.join([f"![Image]({path})" for path in file_urls])
return NodeResult({'answer': answer, 'chat_model': tti_model, 'message_list': message_list, return NodeResult({'answer': answer, 'chat_model': tti_model, 'message_list': message_list,
'image': [{'file_id': path.split('/')[-1], 'file_url': path} for path in file_urls], 'image': [{'file_id': path.split('/')[-1], 'file_url': path} for path in file_urls],
'history_message': history_message, 'question': question}, {}) 'history_message': history_message, 'question': question}, {})

View File

@ -10,14 +10,32 @@ from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel from common.forms import BaseForm, TooltipLabel
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class OpenAITTIModelParams(BaseForm):
size = forms.TextInputField(
TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'),
required=True, default_value='1024x1024')
quality = forms.TextInputField( class OpenAITTIModelParams(BaseForm):
size = forms.SingleSelect(
TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'),
required=True,
default_value='1024x1024',
option_list=[
{'value': '1024x1024', 'label': '1024x1024'},
{'value': '1024x1792', 'label': '1024x1792'},
{'value': '1792x1024', 'label': '1792x1024'},
],
text_field='label',
value_field='value'
)
quality = forms.SingleSelect(
TooltipLabel('图片质量', ''), TooltipLabel('图片质量', ''),
required=True, default_value='standard') required=True,
default_value='standard',
option_list=[
{'value': 'standard', 'label': 'standard'},
{'value': 'hd', 'label': 'hd'},
],
text_field='label',
value_field='value'
)
n = forms.SliderField( n = forms.SliderField(
TooltipLabel('图片数量', '指定生成图片的数量'), TooltipLabel('图片数量', '指定生成图片的数量'),

View File

@ -1,13 +1,8 @@
from typing import Dict from typing import Dict
import requests
from langchain_core.messages import HumanMessage
from langchain_openai import ChatOpenAI
from openai import OpenAI from openai import OpenAI
from common.config.tokenizer_manage_config import TokenizerManage from common.config.tokenizer_manage_config import TokenizerManage
from common.util.common import bytes_to_uploaded_file
from dataset.serializers.file_serializers import FileSerializer
from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_tti import BaseTextToImage from setting.models_provider.impl.base_tti import BaseTextToImage
@ -32,7 +27,7 @@ class OpenAITextToImage(MaxKBBaseModel, BaseTextToImage):
@staticmethod @staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'params': {}} optional_params = {'params': {'size': '1024x1024', 'quality': 'standard', 'n': 1}}
for key, value in model_kwargs.items(): for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']: if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value optional_params['params'][key] = value
@ -43,6 +38,9 @@ class OpenAITextToImage(MaxKBBaseModel, BaseTextToImage):
**optional_params, **optional_params,
) )
def is_cache_model(self):
return False
def check_auth(self): def check_auth(self):
chat = OpenAI(api_key=self.api_key, base_url=self.api_base) chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
response_list = chat.models.with_raw_response.list() response_list = chat.models.with_raw_response.list()
@ -50,18 +48,11 @@ class OpenAITextToImage(MaxKBBaseModel, BaseTextToImage):
# self.generate_image('生成一个小猫图片') # self.generate_image('生成一个小猫图片')
def generate_image(self, prompt: str, negative_prompt: str = None): def generate_image(self, prompt: str, negative_prompt: str = None):
chat = OpenAI(api_key=self.api_key, base_url=self.api_base) chat = OpenAI(api_key=self.api_key, base_url=self.api_base)
res = chat.images.generate(model=self.model, prompt=prompt, **self.params) res = chat.images.generate(model=self.model, prompt=prompt, **self.params)
file_urls = [] file_urls = []
for content in res.data: for content in res.data:
url = content.url url = content.url
print(url) file_urls.append(url)
file_name = 'generated_image.png'
file = bytes_to_uploaded_file(requests.get(url).content, file_name)
meta = {'debug': True}
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
file_urls.append(file_url)
return file_urls return file_urls

View File

@ -19,9 +19,18 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val
class QwenModelParams(BaseForm): class QwenModelParams(BaseForm):
size = forms.TextInputField( size = forms.SingleSelect(
TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'), TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'),
required=True, default_value='1024x1024') required=True,
default_value='1024*1024',
option_list=[
{'value': '1024*1024', 'label': '1024*1024'},
{'value': '720*1280', 'label': '720*1280'},
{'value': '768*1152', 'label': '768*1152'},
{'value': '1280*720', 'label': '1280*720'},
],
text_field='label',
value_field='value')
n = forms.SliderField( n = forms.SliderField(
TooltipLabel('图片数量', '指定生成图片的数量'), TooltipLabel('图片数量', '指定生成图片的数量'),
required=True, default_value=1, required=True, default_value=1,
@ -29,9 +38,25 @@ class QwenModelParams(BaseForm):
_max=4, _max=4,
_step=1, _step=1,
precision=0) precision=0)
style = forms.TextInputField( style = forms.SingleSelect(
TooltipLabel('风格', '指定生成图片的风格'), TooltipLabel('风格', '指定生成图片的风格'),
required=True, default_value='<auto>') required=True,
default_value='<auto>',
option_list=[
{'value': '<auto>', 'label': '默认值,由模型随机输出图像风格'},
{'value': '<photography>', 'label': '摄影'},
{'value': '<portrait>', 'label': '人像写真'},
{'value': '<3d cartoon>', 'label': '3D卡通'},
{'value': '<anime>', 'label': '动画'},
{'value': '<oil painting>', 'label': '油画'},
{'value': '<watercolor>', 'label': '水彩'},
{'value': '<sketch>', 'label': '素描'},
{'value': '<chinese painting>', 'label': '中国画'},
{'value': '<flat illustration>', 'label': '扁平插画'},
],
text_field='label',
value_field='value'
)
class QwenTextToImageModelCredential(BaseForm, BaseModelCredential): class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):

View File

@ -1,16 +1,11 @@
# coding=utf-8 # coding=utf-8
from http import HTTPStatus from http import HTTPStatus
from pathlib import PurePosixPath
from typing import Dict from typing import Dict
from urllib.parse import unquote, urlparse
import requests
from dashscope import ImageSynthesis from dashscope import ImageSynthesis
from langchain_community.chat_models import ChatTongyi from langchain_community.chat_models import ChatTongyi
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from common.util.common import bytes_to_uploaded_file
from dataset.serializers.file_serializers import FileSerializer
from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_tti import BaseTextToImage from setting.models_provider.impl.base_tti import BaseTextToImage
@ -28,7 +23,7 @@ class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage):
@staticmethod @staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'params': {}} optional_params = {'params': {'size': '1024*1024', 'style': '<auto>', 'n': 1}}
for key, value in model_kwargs.items(): for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']: if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value optional_params['params'][key] = value
@ -39,6 +34,9 @@ class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage):
) )
return chat_tong_yi return chat_tong_yi
def is_cache_model(self):
return False
def check_auth(self): def check_auth(self):
chat = ChatTongyi(api_key=self.api_key, model_name='qwen-max') chat = ChatTongyi(api_key=self.api_key, model_name='qwen-max')
chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])]) chat.invoke([HumanMessage([{"type": "text", "text": "你好"}])])
@ -53,11 +51,7 @@ class QwenTextToImageModel(MaxKBBaseModel, BaseTextToImage):
file_urls = [] file_urls = []
if rsp.status_code == HTTPStatus.OK: if rsp.status_code == HTTPStatus.OK:
for result in rsp.output.results: for result in rsp.output.results:
file_name = PurePosixPath(unquote(urlparse(result.url).path)).parts[-1] file_urls.append(result.url)
file = bytes_to_uploaded_file(requests.get(result.url).content, file_name)
meta = {'debug': True}
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
file_urls.append(file_url)
else: else:
print('sync_call Failed, status_code: %s, code: %s, message: %s' % print('sync_call Failed, status_code: %s, code: %s, message: %s' %
(rsp.status_code, rsp.code, rsp.message)) (rsp.status_code, rsp.code, rsp.message))

View File

@ -3,15 +3,12 @@
import json import json
from typing import Dict from typing import Dict
import requests
from tencentcloud.common import credential from tencentcloud.common import credential
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
from tencentcloud.common.profile.client_profile import ClientProfile from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models from tencentcloud.hunyuan.v20230901 import hunyuan_client, models
from common.util.common import bytes_to_uploaded_file
from dataset.serializers.file_serializers import FileSerializer
from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_tti import BaseTextToImage from setting.models_provider.impl.base_tti import BaseTextToImage
from setting.models_provider.impl.tencent_model_provider.model.hunyuan import ChatHunyuan from setting.models_provider.impl.tencent_model_provider.model.hunyuan import ChatHunyuan
@ -87,12 +84,8 @@ class TencentTextToImageModel(MaxKBBaseModel, BaseTextToImage):
# 输出json格式的字符串回包 # 输出json格式的字符串回包
print(resp.to_json_string()) print(resp.to_json_string())
file_urls = [] file_urls = []
file_name = 'generated_image.png'
file = bytes_to_uploaded_file(requests.get(resp.ResultImage).content, file_name) file_urls.append(resp.ResultImage)
meta = {'debug': True}
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
file_urls.append(file_url)
return file_urls return file_urls
except TencentCloudSDKException as err: except TencentCloudSDKException as err:
print(err) print(err)

View File

@ -8,10 +8,22 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val
class ZhiPuTTIModelParams(BaseForm): class ZhiPuTTIModelParams(BaseForm):
size = forms.TextInputField( size = forms.SingleSelect(
TooltipLabel('图片尺寸', TooltipLabel('图片尺寸',
'图片尺寸,仅 cogview-3-plus 支持该参数。可选范围:[1024x1024,768x1344,864x1152,1344x768,1152x864,1440x720,720x1440]默认是1024x1024。'), '图片尺寸,仅 cogview-3-plus 支持该参数。可选范围:[1024x1024,768x1344,864x1152,1344x768,1152x864,1440x720,720x1440]默认是1024x1024。'),
required=True, default_value='1024x1024') required=True,
default_value='1024x1024',
option_list=[
{'value': '1024x1024', 'label': '1024x1024'},
{'value': '768x1344', 'label': '768x1344'},
{'value': '864x1152', 'label': '864x1152'},
{'value': '1344x768', 'label': '1344x768'},
{'value': '1152x864', 'label': '1152x864'},
{'value': '1440x720', 'label': '1440x720'},
{'value': '720x1440', 'label': '720x1440'},
],
text_field='label',
value_field='value')
class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential): class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential):

View File

@ -1,13 +1,10 @@
from typing import Dict from typing import Dict
import requests
from langchain_community.chat_models import ChatZhipuAI from langchain_community.chat_models import ChatZhipuAI
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from zhipuai import ZhipuAI from zhipuai import ZhipuAI
from common.config.tokenizer_manage_config import TokenizerManage from common.config.tokenizer_manage_config import TokenizerManage
from common.util.common import bytes_to_uploaded_file
from dataset.serializers.file_serializers import FileSerializer
from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_tti import BaseTextToImage from setting.models_provider.impl.base_tti import BaseTextToImage
@ -30,7 +27,7 @@ class ZhiPuTextToImage(MaxKBBaseModel, BaseTextToImage):
@staticmethod @staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'params': {}} optional_params = {'params': {'size': '1024x1024'}}
for key, value in model_kwargs.items(): for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']: if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value optional_params['params'][key] = value
@ -40,6 +37,9 @@ class ZhiPuTextToImage(MaxKBBaseModel, BaseTextToImage):
**optional_params, **optional_params,
) )
def is_cache_model(self):
return False
def check_auth(self): def check_auth(self):
chat = ChatZhipuAI( chat = ChatZhipuAI(
zhipuai_api_key=self.api_key, zhipuai_api_key=self.api_key,
@ -58,16 +58,11 @@ class ZhiPuTextToImage(MaxKBBaseModel, BaseTextToImage):
response = chat.images.generations( response = chat.images.generations(
model=self.model, # 填写需要调用的模型编码 model=self.model, # 填写需要调用的模型编码
prompt=prompt, # 填写需要生成图片的文本 prompt=prompt, # 填写需要生成图片的文本
**self.params # 填写额外参数 **self.params # 填写额外参数
) )
file_urls = [] file_urls = []
for content in response.data: for content in response.data:
url = content['url'] url = content.url
print(url) file_urls.append(url)
file_name = url.split('/')[-1]
file = bytes_to_uploaded_file(requests.get(url).content, file_name)
meta = {'debug': True}
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
file_urls.append(file_url)
return file_urls return file_urls

View File

@ -30,8 +30,10 @@ from setting.models_provider.constants.model_provider_constants import ModelProv
def get_default_model_params_setting(provider, model_type, model_name): def get_default_model_params_setting(provider, model_type, model_name):
credential = get_model_credential(provider, model_type, model_name) credential = get_model_credential(provider, model_type, model_name)
model_params_setting = credential.get_model_params_setting_form(model_name).to_form_list() setting_form = credential.get_model_params_setting_form(model_name)
return model_params_setting if setting_form is not None:
return setting_form.to_form_list()
return []
class ModelPullManage: class ModelPullManage:
@ -178,6 +180,8 @@ class ModelSerializer(serializers.Serializer):
model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型")) model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型"))
model_params_form = serializers.ListField(required=False, default=list, error_messages=ErrMessage.char("参数配置"))
credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息")) credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息"))
def is_valid(self, *, raise_exception=False): def is_valid(self, *, raise_exception=False):
@ -207,11 +211,12 @@ class ModelSerializer(serializers.Serializer):
model_type = self.data.get('model_type') model_type = self.data.get('model_type')
model_name = self.data.get('model_name') model_name = self.data.get('model_name')
permission_type = self.data.get('permission_type') permission_type = self.data.get('permission_type')
model_params_form = self.data.get('model_params_form')
model_credential_str = json.dumps(credential) model_credential_str = json.dumps(credential)
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name, model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
credential=rsa_long_encrypt(model_credential_str), credential=rsa_long_encrypt(model_credential_str),
provider=provider, model_type=model_type, model_name=model_name, provider=provider, model_type=model_type, model_name=model_name,
model_params_form=get_default_model_params_setting(provider, model_type, model_name), model_params_form=model_params_form,
permission_type=permission_type) permission_type=permission_type)
model.save() model.save()
if status == Status.DOWNLOAD: if status == Status.DOWNLOAD:

View File

@ -14,6 +14,8 @@ urlpatterns = [
path('provider/model_type_list', views.Provide.ModelTypeList.as_view(), name="provider/model_type_list"), path('provider/model_type_list', views.Provide.ModelTypeList.as_view(), name="provider/model_type_list"),
path('provider/model_list', views.Provide.ModelList.as_view(), path('provider/model_list', views.Provide.ModelList.as_view(),
name="provider/model_name_list"), name="provider/model_name_list"),
path('provider/model_params_form', views.Provide.ModelParamsForm.as_view(),
name="provider/model_params_form"),
path('provider/model_form', views.Provide.ModelForm.as_view(), path('provider/model_form', views.Provide.ModelForm.as_view(),
name="provider/model_form"), name="provider/model_form"),
path('model', views.Model.as_view(), name='model'), path('model', views.Model.as_view(), name='model'),

View File

@ -16,7 +16,7 @@ from common.constants.permission_constants import PermissionConstants
from common.response import result from common.response import result
from common.util.common import query_params_to_single_dict from common.util.common import query_params_to_single_dict
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from setting.serializers.provider_serializers import ProviderSerializer, ModelSerializer from setting.serializers.provider_serializers import ProviderSerializer, ModelSerializer, get_default_model_params_setting
from setting.swagger_api.provide_api import ProvideApi, ModelCreateApi, ModelQueryApi, ModelEditApi from setting.swagger_api.provide_api import ProvideApi, ModelCreateApi, ModelQueryApi, ModelEditApi
@ -207,6 +207,24 @@ class Provide(APIView):
ModelProvideConstants[provider].value.get_model_list( ModelProvideConstants[provider].value.get_model_list(
model_type)) model_type))
class ModelParamsForm(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取模型默认参数",
operation_id="获取模型创建表单",
manual_parameters=ProvideApi.ModelList.get_request_params_api(),
responses=result.get_api_array_response(ProvideApi.ModelList.get_response_body_api())
, tags=["模型"]
)
@has_permissions(PermissionConstants.MODEL_READ)
def get(self, request: Request):
provider = request.query_params.get('provider')
model_type = request.query_params.get('model_type')
model_name = request.query_params.get('model_name')
return result.success(get_default_model_params_setting(provider, model_type, model_name))
class ModelForm(APIView): class ModelForm(APIView):
authentication_classes = [TokenAuth] authentication_classes = [TokenAuth]

View File

@ -98,6 +98,15 @@ const listBaseModel: (
return get(`${prefix_provider}/model_list`, { provider, model_type }, loading) return get(`${prefix_provider}/model_list`, { provider, model_type }, loading)
} }
const listBaseModelParamsForm: (
provider: string,
model_type: string,
model_name: string,
loading?: Ref<boolean>
) => Promise<Result<Array<BaseModel>>> = (provider, model_type, model_name, loading) => {
return get(`${prefix_provider}/model_params_form`, { provider, model_type, model_name}, loading)
}
/** /**
* *
* @param request * @param request
@ -187,6 +196,7 @@ export default {
getModelCreateForm, getModelCreateForm,
listModelType, listModelType,
listBaseModel, listBaseModel,
listBaseModelParamsForm,
createModel, createModel,
updateModel, updateModel,
deleteModel, deleteModel,

View File

@ -22,136 +22,192 @@
> >
</el-breadcrumb> </el-breadcrumb>
</template> </template>
<el-tabs v-model="activeName">
<DynamicsForm <el-tab-pane label="基础信息" name="base-info">
v-model="form_data" <DynamicsForm
:render_data="model_form_field" v-model="form_data"
:model="form_data" :render_data="model_form_field"
ref="dynamicsFormRef" :model="form_data"
label-position="top" ref="dynamicsFormRef"
require-asterisk-position="right" label-position="top"
class="mb-24" require-asterisk-position="right"
label-width="auto" class="mb-24"
> label-width="auto"
<template #default> >
<el-form-item prop="name" :rules="base_form_data_rule.name"> <template #default>
<template #label> <el-form-item prop="name" :rules="base_form_data_rule.name">
<div class="flex align-center" style="display: inline-flex"> <template #label>
<div class="mr-4">
<span>模型名称 </span>
</div>
<el-tooltip effect="dark" placement="right">
<template #content>
<p>MaxKB 中自定义的模型名称</p>
</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<el-input
v-model="base_form_data.name"
maxlength="64"
show-word-limit
placeholder="请给基础模型设置一个名称"
/>
</el-form-item>
<el-form-item prop="permission_type" :rules="base_form_data_rule.permission_type">
<template #label>
<span>权限</span>
</template>
<el-radio-group v-model="base_form_data.permission_type" class="card__radio">
<el-row :gutter="16">
<template v-for="(value, key) of PermissionType" :key="key">
<el-col :span="12">
<el-card
shadow="never"
class="mb-16"
:class="base_form_data.permission_type === key ? 'active' : ''"
>
<el-radio :value="key" size="large">
<p class="mb-4">{{ value }}</p>
<el-text type="info">
{{ PermissionDesc[key] }}
</el-text>
</el-radio>
</el-card>
</el-col>
</template>
</el-row>
</el-radio-group>
</el-form-item>
<el-form-item prop="model_type" :rules="base_form_data_rule.model_type">
<template #label>
<div class="flex align-center" style="display: inline-flex">
<span class="mr-4">模型类型 </span>
<el-tooltip effect="dark" placement="right">
<template #content>
<p>大语言模型在应用中与AI对话的推理模型</p>
<p>向量模型在知识库中导入文档进行向量化和向量检索召回分段时使用的向量模型</p>
<p>
重排模型在二次召回中根据召回的候选分段和用户问题的匹配度重新排序从而得到更精确的结果
</p>
<p>语音识别在应用中开启语音识别后用于语音转文字的模型</p>
<p>语音合成在应用中开启语音播放后用于文字转语音的模型</p>
</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<el-select
v-loading="model_type_loading"
@change="list_base_model($event, true)"
v-model="base_form_data.model_type"
class="w-full m-2"
placeholder="请选择模型类型"
>
<el-option
v-for="item in model_type_list"
:key="item.value"
:label="item.key"
:value="item.value"
></el-option>
</el-select>
</el-form-item>
<el-form-item prop="model_name" :rules="base_form_data_rule.model_name">
<template #label>
<div class="flex align-center" style="display: inline-flex">
<div class="mr-4">
<span>基础模型 </span>
<span class="danger">列表中未列出的模型直接输入模型名称回车即可添加</span>
</div>
</div>
</template>
<el-select
@change="getModelForm($event)"
v-loading="base_model_loading"
v-model="base_form_data.model_name"
class="w-full m-2"
placeholder="自定义输入基础模型后回车即可"
filterable
allow-create
default-first-option
>
<el-option v-for="item in base_model_list" :key="item.name" :value="item.name">
<template #default>
<div class="flex align-center" style="display: inline-flex"> <div class="flex align-center" style="display: inline-flex">
<div class="flex-between mr-4"> <div class="mr-4">
<span>{{ item.name }} </span> <span>模型名称 </span>
</div> </div>
<el-tooltip effect="dark" placement="right" v-if="item.desc"> <el-tooltip effect="dark" placement="right">
<template #content> <template #content>
<p class="w-280">{{ item.desc }}</p> <p>MaxKB 中自定义的模型名称</p>
</template> </template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon> <AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip> </el-tooltip>
</div> </div>
</template> </template>
</el-option> <el-input
</el-select> v-model="base_form_data.name"
</el-form-item> maxlength="64"
</template> show-word-limit
</DynamicsForm> placeholder="请给基础模型设置一个名称"
/>
</el-form-item>
<el-form-item prop="permission_type" :rules="base_form_data_rule.permission_type">
<template #label>
<span>权限</span>
</template>
<el-radio-group v-model="base_form_data.permission_type" class="card__radio">
<el-row :gutter="16">
<template v-for="(value, key) of PermissionType" :key="key">
<el-col :span="12">
<el-card
shadow="never"
class="mb-16"
:class="base_form_data.permission_type === key ? 'active' : ''"
>
<el-radio :value="key" size="large">
<p class="mb-4">{{ value }}</p>
<el-text type="info">
{{ PermissionDesc[key] }}
</el-text>
</el-radio>
</el-card>
</el-col>
</template>
</el-row>
</el-radio-group>
</el-form-item>
<el-form-item prop="model_type" :rules="base_form_data_rule.model_type">
<template #label>
<div class="flex align-center" style="display: inline-flex">
<span class="mr-4">模型类型 </span>
<el-tooltip effect="dark" placement="right">
<template #content>
<p>大语言模型在应用中与AI对话的推理模型</p>
<p>向量模型在知识库中导入文档进行向量化和向量检索召回分段时使用的向量模型</p>
<p>
重排模型在二次召回中根据召回的候选分段和用户问题的匹配度重新排序从而得到更精确的结果
</p>
<p>语音识别在应用中开启语音识别后用于语音转文字的模型</p>
<p>语音合成在应用中开启语音播放后用于文字转语音的模型</p>
</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<el-select
v-loading="model_type_loading"
@change="list_base_model($event, true)"
v-model="base_form_data.model_type"
class="w-full m-2"
placeholder="请选择模型类型"
>
<el-option
v-for="item in model_type_list"
:key="item.value"
:label="item.key"
:value="item.value"
></el-option>
</el-select>
</el-form-item>
<el-form-item prop="model_name" :rules="base_form_data_rule.model_name">
<template #label>
<div class="flex align-center" style="display: inline-flex">
<div class="mr-4">
<span>基础模型 </span>
<span class="danger">列表中未列出的模型直接输入模型名称回车即可添加</span>
</div>
</div>
</template>
<el-select
@change="getModelForm($event)"
v-loading="base_model_loading"
v-model="base_form_data.model_name"
class="w-full m-2"
placeholder="自定义输入基础模型后回车即可"
filterable
allow-create
default-first-option
>
<el-option v-for="item in base_model_list" :key="item.name" :value="item.name">
<template #default>
<div class="flex align-center" style="display: inline-flex">
<div class="flex-between mr-4">
<span>{{ item.name }} </span>
</div>
<el-tooltip effect="dark" placement="right" v-if="item.desc">
<template #content>
<p class="w-280">{{ item.desc }}</p>
</template>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
</el-option>
</el-select>
</el-form-item>
</template>
</DynamicsForm>
</el-tab-pane>
<el-tab-pane label="高级设置" name="advanced-info">
<div class="flex-between">
<h5>模型参数</h5>
<el-button type="text" @click="openAddDrawer()" :disabled="form_data.model_type !== 'LLM' && form_data.model_type !== 'IMAGE' && form_data.model_type !== 'TTS' && form_data.model_type !== 'TTI'">
<AppIcon iconName="Plus" class="add-icon" />添加
</el-button>
</div>
<el-table :data="base_form_data.model_params_form" v-if="base_form_data.model_params_form?.length > 0" class="mb-16">
<el-table-column prop="label" label="显示名称" show-overflow-tooltip>
<template #default="{ row }">
<span v-if="row.label && row.label.input_type === 'TooltipLabel'">{{
row.label.label
}}</span>
<span v-else>{{ row.label }}</span>
</template>
</el-table-column>
<el-table-column prop="field" label="参数" show-overflow-tooltip />
<el-table-column label="组件类型" width="110px">
<template #default="{ row }">
<el-tag type="info" class="info-tag">{{
input_type_list.find((item) => item.value === row.input_type)?.label
}}</el-tag>
</template>
</el-table-column>
<el-table-column prop="default_value" label="默认值" show-overflow-tooltip />
<el-table-column label="必填">
<template #default="{ row }">
<div @click.stop>
<el-switch disabled size="small" v-model="row.required" />
</div>
</template>
</el-table-column>
<el-table-column label="操作" align="left" width="80">
<template #default="{ row, $index }">
<span class="mr-4">
<el-tooltip effect="dark" content="修改" placement="top">
<el-button type="primary" text @click.stop="openAddDrawer(row, $index)">
<el-icon><EditPen /></el-icon>
</el-button>
</el-tooltip>
</span>
<el-tooltip effect="dark" content="删除" placement="top">
<el-button type="primary" text @click="deleteParam($index)">
<el-icon>
<Delete />
</el-icon>
</el-button>
</el-tooltip>
</template>
</el-table-column>
</el-table>
</el-tab-pane>
</el-tabs>
<template #footer> <template #footer>
<span class="dialog-footer"> <span class="dialog-footer">
<el-button @click="close">取消</el-button> <el-button @click="close">取消</el-button>
@ -159,6 +215,8 @@
</span> </span>
</template> </template>
</el-dialog> </el-dialog>
<AddParamDrawer ref="AddParamRef" @refresh="refresh" />
</template> </template>
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed } from 'vue' import { ref, computed } from 'vue'
@ -168,8 +226,10 @@ import ModelApi from '@/api/model'
import type { FormField } from '@/components/dynamics-form/type' import type { FormField } from '@/components/dynamics-form/type'
import DynamicsForm from '@/components/dynamics-form/index.vue' import DynamicsForm from '@/components/dynamics-form/index.vue'
import type { FormRules } from 'element-plus' import type { FormRules } from 'element-plus'
import { MsgSuccess, MsgWarning } from '@/utils/message' import { MsgError, MsgSuccess, MsgWarning } from '@/utils/message'
import { PermissionType, PermissionDesc } from '@/enums/model' import { PermissionType, PermissionDesc } from '@/enums/model'
import { input_type_list } from '@/components/dynamics-form/constructor/data'
import AddParamDrawer from '@/views/template/component/AddParamDrawer.vue'
const providerValue = ref<Provider>() const providerValue = ref<Provider>()
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>() const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
@ -182,6 +242,9 @@ const model_type_list = ref<Array<KeyValue<string, string>>>([])
const base_model_list = ref<Array<BaseModel>>() const base_model_list = ref<Array<BaseModel>>()
const model_form_field = ref<Array<FormField>>([]) const model_form_field = ref<Array<FormField>>([])
const dialogVisible = ref<boolean>(false) const dialogVisible = ref<boolean>(false)
const activeName = ref('base-info')
const AddParamRef = ref()
const base_form_data_rule = ref<FormRules>({ const base_form_data_rule = ref<FormRules>({
name: { required: true, trigger: 'blur', message: '模型名称不能为空' }, name: { required: true, trigger: 'blur', message: '模型名称不能为空' },
@ -195,7 +258,8 @@ const base_form_data = ref<{
permission_type: string permission_type: string
model_type: string model_type: string
model_name: string model_name: string
}>({ name: '', model_type: '', model_name: '', permission_type: 'PRIVATE' }) model_params_form: any
}>({ name: '', model_type: '', model_name: '', permission_type: 'PRIVATE', model_params_form: [] })
const credential_form_data = ref<Dict<any>>({}) const credential_form_data = ref<Dict<any>>({})
@ -206,7 +270,8 @@ const form_data = computed({
name: base_form_data.value.name, name: base_form_data.value.name,
model_type: base_form_data.value.model_type, model_type: base_form_data.value.model_type,
model_name: base_form_data.value.model_name, model_name: base_form_data.value.model_name,
permission_type: base_form_data.value.permission_type permission_type: base_form_data.value.permission_type,
model_params_form: base_form_data.value.model_params_form
} }
}, },
set: (event: any) => { set: (event: any) => {
@ -230,6 +295,11 @@ const getModelForm = (model_name: string) => {
// //
dynamicsFormRef.value?.render(model_form_field.value, undefined) dynamicsFormRef.value?.render(model_form_field.value, undefined)
}) })
ModelApi.listBaseModelParamsForm(providerValue.value.provider, form_data.value.model_type, model_name, base_model_loading)
.then((ok) => {
base_form_data.value.model_params_form = ok.data
})
} }
} }
@ -255,7 +325,7 @@ const list_base_model = (model_type: any, change?: boolean) => {
} }
const close = () => { const close = () => {
base_form_data.value = { name: '', model_type: '', model_name: '', permission_type: 'PRIVATE' } base_form_data.value = { name: '', model_type: '', model_name: '', permission_type: 'PRIVATE', model_params_form: [] }
credential_form_data.value = {} credential_form_data.value = {}
model_form_field.value = [] model_form_field.value = []
base_model_list.value = [] base_model_list.value = []
@ -279,6 +349,44 @@ const submit = () => {
} }
}) })
} }
function openAddDrawer(data?: any, index?: any) {
AddParamRef.value?.open(data, index)
}
function deleteParam(index: any) {
base_form_data.value.model_params_form.splice(index, 1)
}
function refresh(data: any, index: any) {
for (let i = 0; i < base_form_data.value.model_params_form.length; i++) {
let field = base_form_data.value.model_params_form[i].field
let label = base_form_data.value.model_params_form[i].label
if (label && label.input_type === 'TooltipLabel') {
label = label.label
}
let label2 = data.label
if (label2 && label2.input_type === 'TooltipLabel') {
label2 = label2.label
}
if (field === data.field && index !== i) {
MsgError('变量已存在: ' + data.field)
return
}
if (label === label2 && index !== i) {
MsgError('变量已存在: ' + label)
return
}
}
if (index !== null) {
base_form_data.value.model_params_form.splice(index, 1, data)
} else {
base_form_data.value.model_params_form.push(data)
}
}
const toSelectProvider = () => { const toSelectProvider = () => {
close() close()
emit('change') emit('change')

View File

@ -39,6 +39,7 @@ export const baseNode = {
node_data: { node_data: {
name: '', name: '',
desc: '', desc: '',
// @ts-ignore
prologue: t('views.application.prompt.defaultPrologue') prologue: t('views.application.prompt.defaultPrologue')
}, },
config: {} config: {}