feat: add model setting
This commit is contained in:
parent
9b0f9b04b7
commit
6f6b163416
@ -47,20 +47,20 @@ class ModelManage:
|
|||||||
ModelManage.cache.delete(_id)
|
ModelManage.cache.delete(_id)
|
||||||
|
|
||||||
|
|
||||||
class VectorStore:
|
# class VectorStore:
|
||||||
from embedding.vector.pg_vector import PGVector
|
# from embedding.vector.pg_vector import PGVector
|
||||||
from embedding.vector.base_vector import BaseVectorStore
|
# from embedding.vector.base_vector import BaseVectorStore
|
||||||
instance_map = {
|
# instance_map = {
|
||||||
'pg_vector': PGVector,
|
# 'pg_vector': PGVector,
|
||||||
}
|
# }
|
||||||
instance = None
|
# instance = None
|
||||||
|
#
|
||||||
@staticmethod
|
# @staticmethod
|
||||||
def get_embedding_vector() -> BaseVectorStore:
|
# def get_embedding_vector() -> BaseVectorStore:
|
||||||
from embedding.vector.pg_vector import PGVector
|
# from embedding.vector.pg_vector import PGVector
|
||||||
if VectorStore.instance is None:
|
# if VectorStore.instance is None:
|
||||||
from maxkb.const import CONFIG
|
# from maxkb.const import CONFIG
|
||||||
vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"),
|
# vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"),
|
||||||
PGVector)
|
# PGVector)
|
||||||
VectorStore.instance = vector_store_class()
|
# VectorStore.instance = vector_store_class()
|
||||||
return VectorStore.instance
|
# return VectorStore.instance
|
||||||
|
|||||||
@ -13,7 +13,8 @@ import io
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
from typing import List
|
from functools import reduce
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
from django.core.files.uploadedfile import InMemoryUploadedFile
|
from django.core.files.uploadedfile import InMemoryUploadedFile
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
@ -50,13 +51,13 @@ def group_by(list_source: List, key):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
CHAR_SET = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
|
CHAR_SET = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
|
||||||
|
|
||||||
|
|
||||||
def get_random_chars(number=6):
|
def get_random_chars(number=6):
|
||||||
return "".join([CHAR_SET[random.randint(0, len(CHAR_SET) - 1)] for index in range(number)])
|
return "".join([CHAR_SET[random.randint(0, len(CHAR_SET) - 1)] for index in range(number)])
|
||||||
|
|
||||||
|
|
||||||
def encryption(message: str):
|
def encryption(message: str):
|
||||||
"""
|
"""
|
||||||
加密敏感字段数据 加密方式是 如果密码是 1234567890 那么给前端则是 123******890
|
加密敏感字段数据 加密方式是 如果密码是 1234567890 那么给前端则是 123******890
|
||||||
@ -122,7 +123,6 @@ def get_file_content(path):
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
|
def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
|
||||||
content_type, _ = mimetypes.guess_type(file_name)
|
content_type, _ = mimetypes.guess_type(file_name)
|
||||||
if content_type is None:
|
if content_type is None:
|
||||||
@ -205,3 +205,9 @@ def split_and_transcribe(file_path, model, max_segment_length_ms=59000, audio_fo
|
|||||||
full_text.append(text)
|
full_text.append(text)
|
||||||
return ' '.join(full_text)
|
return ' '.join(full_text)
|
||||||
|
|
||||||
|
|
||||||
|
def query_params_to_single_dict(query_params: Dict):
|
||||||
|
return reduce(lambda x, y: {**x, **y}, list(
|
||||||
|
filter(lambda item: item is not None, [({key: value} if value is not None and len(value) > 0 else None) for
|
||||||
|
key, value in
|
||||||
|
query_params.items()])), {})
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,8 +1,12 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
|
from drf_spectacular.types import OpenApiTypes
|
||||||
|
from drf_spectacular.utils import OpenApiParameter
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
from common.mixins.api_mixin import APIMixin
|
from common.mixins.api_mixin import APIMixin
|
||||||
from common.result import ResultSerializer
|
from common.result import ResultSerializer
|
||||||
from models_provider.serializers.model import ModelCreateRequest, ModelModelSerializer
|
from models_provider.serializers.model_serializer import ModelModelSerializer, ModelCreateRequest
|
||||||
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
|
|
||||||
class ModelCreateResponse(ResultSerializer):
|
class ModelCreateResponse(ResultSerializer):
|
||||||
@ -10,6 +14,12 @@ class ModelCreateResponse(ResultSerializer):
|
|||||||
return ModelModelSerializer()
|
return ModelModelSerializer()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelListResponse(APIMixin):
|
||||||
|
@staticmethod
|
||||||
|
def get_response():
|
||||||
|
return serializers.ListSerializer(child=ModelModelSerializer())
|
||||||
|
|
||||||
|
|
||||||
class ModelCreateAPI(APIMixin):
|
class ModelCreateAPI(APIMixin):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_request():
|
def get_request():
|
||||||
@ -18,3 +28,47 @@ class ModelCreateAPI(APIMixin):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_response():
|
def get_response():
|
||||||
return ModelCreateResponse
|
return ModelCreateResponse
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_query_params_api(cls):
|
||||||
|
return [OpenApiParameter(
|
||||||
|
name="workspace_id",
|
||||||
|
description=_("workspace id"),
|
||||||
|
type=OpenApiTypes.STR,
|
||||||
|
location=OpenApiParameter.PATH,
|
||||||
|
required=True,
|
||||||
|
)]
|
||||||
|
|
||||||
|
|
||||||
|
class GetModelApi(APIMixin):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_query_params_api():
|
||||||
|
return [OpenApiParameter(
|
||||||
|
name="workspace_id",
|
||||||
|
description=_("workspace id"),
|
||||||
|
type=OpenApiTypes.STR,
|
||||||
|
location=OpenApiParameter.PATH,
|
||||||
|
required=True,
|
||||||
|
), OpenApiParameter(
|
||||||
|
name="model_id",
|
||||||
|
description=_("model id"),
|
||||||
|
type=OpenApiTypes.STR,
|
||||||
|
location=OpenApiParameter.PATH,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_response():
|
||||||
|
return ModelModelSerializer
|
||||||
|
|
||||||
|
|
||||||
|
class ModelEditApi(APIMixin):
|
||||||
|
@staticmethod
|
||||||
|
def get_request():
|
||||||
|
return ModelCreateRequest
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_response():
|
||||||
|
return ModelModelSerializer
|
||||||
|
|||||||
@ -30,29 +30,65 @@ class ModelListSerializer(serializers.Serializer):
|
|||||||
desc = serializers.CharField(required=True, label=_("model name"))
|
desc = serializers.CharField(required=True, label=_("model name"))
|
||||||
|
|
||||||
|
|
||||||
|
class ModelParamsFormSerializer(serializers.Serializer):
|
||||||
|
input_type = serializers.CharField(required=False, label=_("input type"))
|
||||||
|
label = serializers.CharField(required=False, label=_("label"))
|
||||||
|
text_field = serializers.CharField(required=False, label=_("text field"))
|
||||||
|
value_field = serializers.CharField(required=False, label=_("value field"))
|
||||||
|
provider = serializers.CharField(required=False, label=_("provider"))
|
||||||
|
method = serializers.CharField(required=False, label=_("method"))
|
||||||
|
required = serializers.BooleanField(required=False, label=_("required"))
|
||||||
|
default_value = serializers.CharField(required=False, label=_("default value"))
|
||||||
|
relation_show_field_dict = serializers.DictField(required=False, label=_("relation show field dict"))
|
||||||
|
relation_trigger_field_dict = serializers.DictField(required=False, label=_("relation trigger field dict"))
|
||||||
|
trigger_type = serializers.CharField(required=False, label=_("trigger type"))
|
||||||
|
attrs = serializers.DictField(required=False, label=_("attrs"))
|
||||||
|
props_info = serializers.DictField(required=False, label=_("props info"))
|
||||||
|
|
||||||
|
|
||||||
class ProvideApi(APIMixin):
|
class ProvideApi(APIMixin):
|
||||||
|
class ModelParamsForm(APIMixin):
|
||||||
|
@staticmethod
|
||||||
|
def get_query_params_api():
|
||||||
|
return [OpenApiParameter(
|
||||||
|
name="model_type",
|
||||||
|
description=_("model type"),
|
||||||
|
type=OpenApiTypes.STR,
|
||||||
|
location=OpenApiParameter.QUERY,
|
||||||
|
required=True,
|
||||||
|
), OpenApiParameter(
|
||||||
|
name="provider",
|
||||||
|
description=_("provider"),
|
||||||
|
type=OpenApiTypes.STR,
|
||||||
|
location=OpenApiParameter.QUERY,
|
||||||
|
required=True,
|
||||||
|
), OpenApiParameter(
|
||||||
|
name="model_name",
|
||||||
|
description=_("model name"),
|
||||||
|
type=OpenApiTypes.STR,
|
||||||
|
location=OpenApiParameter.QUERY,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_response():
|
||||||
|
return serializers.ListSerializer(child=ModelParamsFormSerializer())
|
||||||
|
|
||||||
class ModelList(APIMixin):
|
class ModelList(APIMixin):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_query_params_api():
|
def get_query_params_api():
|
||||||
return [OpenApiParameter(
|
return [OpenApiParameter(
|
||||||
# 参数的名称是done
|
|
||||||
name="model_type",
|
name="model_type",
|
||||||
# 对参数的备注
|
description=_("model type"),
|
||||||
description="model_type",
|
|
||||||
# 指定参数的类型
|
|
||||||
type=OpenApiTypes.STR,
|
type=OpenApiTypes.STR,
|
||||||
location=OpenApiParameter.QUERY,
|
location=OpenApiParameter.QUERY,
|
||||||
# 指定必须给
|
required=True,
|
||||||
required=False,
|
|
||||||
), OpenApiParameter(
|
), OpenApiParameter(
|
||||||
# 参数的名称是done
|
|
||||||
name="provider",
|
name="provider",
|
||||||
# 对参数的备注
|
description=_("provider"),
|
||||||
description="provider",
|
|
||||||
# 指定参数的类型
|
|
||||||
type=OpenApiTypes.STR,
|
type=OpenApiTypes.STR,
|
||||||
location=OpenApiParameter.QUERY,
|
location=OpenApiParameter.QUERY,
|
||||||
# 指定必须给
|
|
||||||
required=True,
|
required=True,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
@ -72,7 +108,7 @@ class ProvideApi(APIMixin):
|
|||||||
# 参数的名称是done
|
# 参数的名称是done
|
||||||
name="provider",
|
name="provider",
|
||||||
# 对参数的备注
|
# 对参数的备注
|
||||||
description="provider",
|
description=_("provider"),
|
||||||
# 指定参数的类型
|
# 指定参数的类型
|
||||||
type=OpenApiTypes.STR,
|
type=OpenApiTypes.STR,
|
||||||
location=OpenApiParameter.QUERY,
|
location=OpenApiParameter.QUERY,
|
||||||
|
|||||||
@ -15,7 +15,7 @@ class BaiLianLLMModelParams(BaseForm):
|
|||||||
temperature = forms.SliderField(
|
temperature = forms.SliderField(
|
||||||
TooltipLabel(
|
TooltipLabel(
|
||||||
_('Temperature'),
|
_('Temperature'),
|
||||||
_('Higher values make the output more random, while lower values make it more focused and deterministic.')
|
_('Higher values make the output more random, while lower values make it more focused and deterministic')
|
||||||
),
|
),
|
||||||
required=True,
|
required=True,
|
||||||
default_value=0.7,
|
default_value=0.7,
|
||||||
|
|||||||
@ -1,181 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
import json
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
import uuid_utils.compat as uuid
|
|
||||||
from django.db.models import QuerySet
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
|
||||||
from rest_framework import serializers
|
|
||||||
|
|
||||||
from common.exception.app_exception import AppApiException
|
|
||||||
from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt
|
|
||||||
from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
|
|
||||||
from models_provider.constants.model_provider_constants import ModelProvideConstants
|
|
||||||
from models_provider.models import Model, Status
|
|
||||||
|
|
||||||
|
|
||||||
class ModelModelSerializer(serializers.ModelSerializer):
|
|
||||||
class Meta:
|
|
||||||
model = Model
|
|
||||||
fields = [
|
|
||||||
'id', 'name', 'status', 'model_type', 'model_name',
|
|
||||||
'user', 'provider', 'credential', 'meta',
|
|
||||||
'model_params_form', 'workspace_id'
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ModelCreateRequest(serializers.Serializer):
|
|
||||||
name = serializers.CharField(required=True, max_length=64, label=_("model name"))
|
|
||||||
provider = serializers.CharField(required=True, label=_("provider"))
|
|
||||||
model_type = serializers.CharField(required=True, label=_("model type"))
|
|
||||||
model_name = serializers.CharField(required=True, label=_("model name"))
|
|
||||||
model_params_form = serializers.ListField(required=False, default=list, label=_("parameter configuration"))
|
|
||||||
credential = serializers.DictField(required=True, label=_("certification information"))
|
|
||||||
|
|
||||||
|
|
||||||
class ModelPullManage:
|
|
||||||
@staticmethod
|
|
||||||
def pull(model: Model, credential: Dict):
|
|
||||||
try:
|
|
||||||
response = ModelProvideConstants[model.provider].value.down_model(
|
|
||||||
model.model_type, model.model_name, credential
|
|
||||||
)
|
|
||||||
down_model_chunk = {}
|
|
||||||
last_update_time = time.time()
|
|
||||||
|
|
||||||
for chunk in response:
|
|
||||||
down_model_chunk[chunk.digest] = chunk.to_dict()
|
|
||||||
if time.time() - last_update_time > 5:
|
|
||||||
current_model = QuerySet(Model).filter(id=model.id).first()
|
|
||||||
if current_model and current_model.status == Status.PAUSE_DOWNLOAD:
|
|
||||||
return
|
|
||||||
QuerySet(Model).filter(id=model.id).update(
|
|
||||||
meta={"down_model_chunk": list(down_model_chunk.values())}
|
|
||||||
)
|
|
||||||
last_update_time = time.time()
|
|
||||||
|
|
||||||
status = Status.ERROR
|
|
||||||
message = ""
|
|
||||||
for chunk in down_model_chunk.values():
|
|
||||||
if chunk.get('status') == DownModelChunkStatus.success.value:
|
|
||||||
status = Status.SUCCESS
|
|
||||||
elif chunk.get('status') == DownModelChunkStatus.error.value:
|
|
||||||
message = chunk.get("digest")
|
|
||||||
|
|
||||||
QuerySet(Model).filter(id=model.id).update(
|
|
||||||
meta={"down_model_chunk": [], "message": message},
|
|
||||||
status=status
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
QuerySet(Model).filter(id=model.id).update(
|
|
||||||
meta={"down_model_chunk": [], "message": str(e)},
|
|
||||||
status=Status.ERROR
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelSerializer(serializers.Serializer):
|
|
||||||
@staticmethod
|
|
||||||
def model_to_dict(model: Model):
|
|
||||||
credential = json.loads(rsa_long_decrypt(model.credential))
|
|
||||||
return {
|
|
||||||
'id': str(model.id),
|
|
||||||
'provider': model.provider,
|
|
||||||
'name': model.name,
|
|
||||||
'model_type': model.model_type,
|
|
||||||
'model_name': model.model_name,
|
|
||||||
'status': model.status,
|
|
||||||
'meta': model.meta,
|
|
||||||
'credential': ModelProvideConstants[model.provider].value.get_model_credential(
|
|
||||||
model.model_type, model.model_name
|
|
||||||
).encryption_dict(credential),
|
|
||||||
'workspace_id': model.workspace_id
|
|
||||||
}
|
|
||||||
|
|
||||||
class Operate(serializers.Serializer):
|
|
||||||
id = serializers.UUIDField(required=True, label=_("模型id"))
|
|
||||||
user_id = serializers.UUIDField(required=True, label=_("user id"))
|
|
||||||
|
|
||||||
def is_valid(self, *, raise_exception=False):
|
|
||||||
super().is_valid(raise_exception=True)
|
|
||||||
model = QuerySet(Model).filter(
|
|
||||||
id=self.data.get("id"), user_id=self.data.get("user_id")
|
|
||||||
).first()
|
|
||||||
if model is None:
|
|
||||||
raise AppApiException(500, _('模型不存在'))
|
|
||||||
|
|
||||||
def one(self, with_valid=False):
|
|
||||||
if with_valid:
|
|
||||||
self.is_valid(raise_exception=True)
|
|
||||||
model = QuerySet(Model).get(
|
|
||||||
id=self.data.get('id'), user_id=self.data.get('user_id')
|
|
||||||
)
|
|
||||||
return ModelSerializer.model_to_dict(model)
|
|
||||||
|
|
||||||
class Create(serializers.Serializer):
|
|
||||||
user_id = serializers.UUIDField(required=True, label=_('user id'))
|
|
||||||
name = serializers.CharField(required=True, max_length=64, label=_("model name"))
|
|
||||||
provider = serializers.CharField(required=True, label=_("provider"))
|
|
||||||
model_type = serializers.CharField(required=True, label=_("model type"))
|
|
||||||
model_name = serializers.CharField(required=True, label=_("model name"))
|
|
||||||
model_params_form = serializers.ListField(required=False, default=list, label=_("parameter configuration"))
|
|
||||||
credential = serializers.DictField(required=True, label=_("certification information"))
|
|
||||||
workspace_id = serializers.CharField(required=False, label=_("workspace id"), max_length=128)
|
|
||||||
|
|
||||||
def is_valid(self, *, raise_exception=False):
|
|
||||||
super().is_valid(raise_exception=True)
|
|
||||||
if QuerySet(Model).filter(
|
|
||||||
user_id=self.data.get('user_id'),
|
|
||||||
name=self.data.get('name'),
|
|
||||||
workspace_id=self.data.get('workspace_id')
|
|
||||||
).exists():
|
|
||||||
raise AppApiException(
|
|
||||||
500,
|
|
||||||
_('Model name【{model_name}】already exists').format(model_name=self.data.get("name"))
|
|
||||||
)
|
|
||||||
default_params = {item['field']: item['default_value'] for item in self.data.get('model_params_form')}
|
|
||||||
ModelProvideConstants[self.data.get('provider')].value.is_valid_credential(
|
|
||||||
self.data.get('model_type'),
|
|
||||||
self.data.get('model_name'),
|
|
||||||
self.data.get('credential'),
|
|
||||||
default_params,
|
|
||||||
raise_exception=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def insert(self, workspace_id, with_valid=True):
|
|
||||||
status = Status.SUCCESS
|
|
||||||
if with_valid:
|
|
||||||
try:
|
|
||||||
self.is_valid(raise_exception=True)
|
|
||||||
except AppApiException as e:
|
|
||||||
if e.code == ValidCode.model_not_fount:
|
|
||||||
status = Status.DOWNLOAD
|
|
||||||
else:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
credential = self.data.get('credential')
|
|
||||||
model_data = {
|
|
||||||
'id': uuid.uuid1(),
|
|
||||||
'status': status,
|
|
||||||
'user_id': self.data.get('user_id'),
|
|
||||||
'name': self.data.get('name'),
|
|
||||||
'credential': rsa_long_encrypt(json.dumps(credential)),
|
|
||||||
'provider': self.data.get('provider'),
|
|
||||||
'model_type': self.data.get('model_type'),
|
|
||||||
'model_name': self.data.get('model_name'),
|
|
||||||
'model_params_form': self.data.get('model_params_form'),
|
|
||||||
'workspace_id': workspace_id
|
|
||||||
}
|
|
||||||
model = Model(**model_data)
|
|
||||||
try:
|
|
||||||
model.save()
|
|
||||||
except Exception as save_error:
|
|
||||||
# 可添加日志记录
|
|
||||||
raise AppApiException(500, _('模型保存失败')) from save_error
|
|
||||||
|
|
||||||
if status == Status.DOWNLOAD:
|
|
||||||
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
|
|
||||||
thread.start()
|
|
||||||
|
|
||||||
return ModelModelSerializer(model).data
|
|
||||||
389
apps/models_provider/serializers/model_serializer.py
Normal file
389
apps/models_provider/serializers/model_serializer.py
Normal file
@ -0,0 +1,389 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import uuid_utils.compat as uuid
|
||||||
|
from django.db.models import QuerySet
|
||||||
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
from common.config.embedding_config import ModelManage
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt
|
||||||
|
from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
|
||||||
|
from models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||||
|
from models_provider.models import Model, Status
|
||||||
|
from models_provider.tools import get_model_credential
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_model_params_setting(provider, model_type, model_name):
|
||||||
|
credential = get_model_credential(provider, model_type, model_name)
|
||||||
|
setting_form = credential.get_model_params_setting_form(model_name)
|
||||||
|
if setting_form is not None:
|
||||||
|
return setting_form.to_form_list()
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class ModelModelSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = Model
|
||||||
|
fields = [
|
||||||
|
'id', 'name', 'status', 'model_type', 'model_name',
|
||||||
|
'user', 'provider', 'credential', 'meta',
|
||||||
|
'model_params_form', 'workspace_id'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCreateRequest(serializers.Serializer):
|
||||||
|
name = serializers.CharField(required=True, max_length=64, label=_("model name"))
|
||||||
|
provider = serializers.CharField(required=True, label=_("provider"))
|
||||||
|
model_type = serializers.CharField(required=True, label=_("model type"))
|
||||||
|
model_name = serializers.CharField(required=True, label=_("base model"))
|
||||||
|
model_params_form = serializers.ListField(required=False, default=list, label=_("parameter configuration"))
|
||||||
|
credential = serializers.DictField(required=True, label=_("certification information"))
|
||||||
|
|
||||||
|
|
||||||
|
class ModelPullManage:
|
||||||
|
@staticmethod
|
||||||
|
def pull(model: Model, credential: Dict):
|
||||||
|
try:
|
||||||
|
response = ModelProvideConstants[model.provider].value.down_model(
|
||||||
|
model.model_type, model.model_name, credential
|
||||||
|
)
|
||||||
|
down_model_chunk = {}
|
||||||
|
last_update_time = time.time()
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
down_model_chunk[chunk.digest] = chunk.to_dict()
|
||||||
|
if time.time() - last_update_time > 5:
|
||||||
|
current_model = QuerySet(Model).filter(id=model.id).first()
|
||||||
|
if current_model and current_model.status == Status.PAUSE_DOWNLOAD:
|
||||||
|
return
|
||||||
|
QuerySet(Model).filter(id=model.id).update(
|
||||||
|
meta={"down_model_chunk": list(down_model_chunk.values())}
|
||||||
|
)
|
||||||
|
last_update_time = time.time()
|
||||||
|
|
||||||
|
status = Status.ERROR
|
||||||
|
message = ""
|
||||||
|
for chunk in down_model_chunk.values():
|
||||||
|
if chunk.get('status') == DownModelChunkStatus.success.value:
|
||||||
|
status = Status.SUCCESS
|
||||||
|
elif chunk.get('status') == DownModelChunkStatus.error.value:
|
||||||
|
message = chunk.get("digest")
|
||||||
|
|
||||||
|
QuerySet(Model).filter(id=model.id).update(
|
||||||
|
meta={"down_model_chunk": [], "message": message},
|
||||||
|
status=status
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
QuerySet(Model).filter(id=model.id).update(
|
||||||
|
meta={"down_model_chunk": [], "message": str(e)},
|
||||||
|
status=Status.ERROR
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSerializer(serializers.Serializer):
|
||||||
|
@staticmethod
|
||||||
|
def model_to_dict(model: Model):
|
||||||
|
credential = json.loads(rsa_long_decrypt(model.credential))
|
||||||
|
return {
|
||||||
|
'id': str(model.id),
|
||||||
|
'provider': model.provider,
|
||||||
|
'name': model.name,
|
||||||
|
'model_type': model.model_type,
|
||||||
|
'model_name': model.model_name,
|
||||||
|
'status': model.status,
|
||||||
|
'meta': model.meta,
|
||||||
|
'credential': ModelProvideConstants[model.provider].value.get_model_credential(
|
||||||
|
model.model_type, model.model_name
|
||||||
|
).encryption_dict(credential),
|
||||||
|
'workspace_id': model.workspace_id
|
||||||
|
}
|
||||||
|
|
||||||
|
class Operate(serializers.Serializer):
|
||||||
|
id = serializers.UUIDField(required=True, label=_("model id"))
|
||||||
|
user_id = serializers.UUIDField(required=True, label=_("user id"))
|
||||||
|
|
||||||
|
def is_valid(self, *, raise_exception=False):
|
||||||
|
super().is_valid(raise_exception=True)
|
||||||
|
model = QuerySet(Model).filter(
|
||||||
|
id=self.data.get("id")
|
||||||
|
).first()
|
||||||
|
if model is None:
|
||||||
|
raise AppApiException(500, _('Model does not exist'))
|
||||||
|
|
||||||
|
def one(self, with_valid=False):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
model = QuerySet(Model).get(
|
||||||
|
id=self.data.get('id')
|
||||||
|
)
|
||||||
|
return ModelSerializer.model_to_dict(model)
|
||||||
|
|
||||||
|
def one_meta(self, with_valid=False):
|
||||||
|
model = None
|
||||||
|
if with_valid:
|
||||||
|
super().is_valid(raise_exception=True)
|
||||||
|
model = QuerySet(Model).filter(id=self.data.get("id")).first()
|
||||||
|
if model is None:
|
||||||
|
raise AppApiException(500, _('Model does not exist'))
|
||||||
|
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
||||||
|
'model_name': model.model_name,
|
||||||
|
'status': model.status,
|
||||||
|
'meta': model.meta
|
||||||
|
}
|
||||||
|
|
||||||
|
def pause_download(self, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
QuerySet(Model).filter(id=self.data.get('id')).update(status=Status.PAUSE_DOWNLOAD)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def delete(self, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
model_id = self.data.get('id')
|
||||||
|
model = Model.objects.filter(id=model_id).first()
|
||||||
|
if not model:
|
||||||
|
raise AppApiException(500, _("Model does not exist"))
|
||||||
|
# TODO : 这里可以添加模型删除的逻辑,需要注意删除模型时的权限和关联关系
|
||||||
|
# if model.model_type == 'LLM':
|
||||||
|
# application_count = Application.objects.filter(model_id=model_id).count()
|
||||||
|
# if application_count > 0:
|
||||||
|
# raise AppApiException(500, f"该模型关联了{application_count} 个应用,无法删除该模型。")
|
||||||
|
# elif model.model_type == 'EMBEDDING':
|
||||||
|
# dataset_count = DataSet.objects.filter(embedding_mode_id=model_id).count()
|
||||||
|
# if dataset_count > 0:
|
||||||
|
# raise AppApiException(500, f"该模型关联了{dataset_count} 个知识库,无法删除该模型。")
|
||||||
|
# elif model.model_type == 'TTS':
|
||||||
|
# dataset_count = Application.objects.filter(tts_model_id=model_id).count()
|
||||||
|
# if dataset_count > 0:
|
||||||
|
# raise AppApiException(500, f"该模型关联了{dataset_count} 个应用,无法删除该模型。")
|
||||||
|
# elif model.model_type == 'STT':
|
||||||
|
# dataset_count = Application.objects.filter(stt_model_id=model_id).count()
|
||||||
|
# if dataset_count > 0:
|
||||||
|
# raise AppApiException(500, f"该模型关联了{dataset_count} 个应用,无法删除该模型。")
|
||||||
|
model.delete()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def edit(self, instance: Dict, user_id: str, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
model = QuerySet(Model).filter(id=self.data.get('id')).first()
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
raise AppApiException(500, _('Model does not exist'))
|
||||||
|
else:
|
||||||
|
credential, model_credential, provider_handler = ModelSerializer.Edit(
|
||||||
|
data={**instance}).is_valid(
|
||||||
|
model=model)
|
||||||
|
try:
|
||||||
|
model.status = Status.SUCCESS
|
||||||
|
default_params = {item['field']: item['default_value'] for item in model.model_params_form}
|
||||||
|
# 校验模型认证数据
|
||||||
|
provider_handler.is_valid_credential(model.model_type,
|
||||||
|
instance.get("model_name"),
|
||||||
|
credential,
|
||||||
|
default_params,
|
||||||
|
raise_exception=True)
|
||||||
|
|
||||||
|
except AppApiException as e:
|
||||||
|
if e.code == ValidCode.model_not_fount:
|
||||||
|
model.status = Status.DOWNLOAD
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
update_keys = ['credential', 'name', 'model_type', 'model_name']
|
||||||
|
for update_key in update_keys:
|
||||||
|
if update_key in instance and instance.get(update_key) is not None:
|
||||||
|
if update_key == 'credential':
|
||||||
|
model_credential_str = json.dumps(credential)
|
||||||
|
model.__setattr__(update_key, rsa_long_encrypt(model_credential_str))
|
||||||
|
else:
|
||||||
|
model.__setattr__(update_key, instance.get(update_key))
|
||||||
|
|
||||||
|
ModelManage.delete_key(str(model.id))
|
||||||
|
model.save()
|
||||||
|
if model.status == Status.DOWNLOAD:
|
||||||
|
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
|
||||||
|
thread.start()
|
||||||
|
return self.one(with_valid=False)
|
||||||
|
|
||||||
|
class Edit(serializers.Serializer):
|
||||||
|
user_id = serializers.CharField(required=False, label=(_('user id')))
|
||||||
|
|
||||||
|
name = serializers.CharField(required=False, max_length=64,
|
||||||
|
label=(_("model name")))
|
||||||
|
|
||||||
|
model_type = serializers.CharField(required=False, label=(_("model type")))
|
||||||
|
|
||||||
|
model_name = serializers.CharField(required=False, label=(_("base model")))
|
||||||
|
|
||||||
|
credential = serializers.DictField(required=False,
|
||||||
|
label=(_("certification information")))
|
||||||
|
|
||||||
|
def is_valid(self, model=None, raise_exception=False):
|
||||||
|
super().is_valid(raise_exception=True)
|
||||||
|
filter_params = {'workspace_id': self.data.get('workspace_id')}
|
||||||
|
if 'name' in self.data and self.data.get('name') is not None:
|
||||||
|
filter_params['name'] = self.data.get('name')
|
||||||
|
if QuerySet(Model).exclude(id=model.id).filter(**filter_params).exists():
|
||||||
|
raise AppApiException(500, _('base model【{model_name}】already exists').format(
|
||||||
|
model_name=self.data.get("name")))
|
||||||
|
|
||||||
|
ModelSerializer.model_to_dict(model)
|
||||||
|
|
||||||
|
provider = model.provider
|
||||||
|
model_type = self.data.get('model_type')
|
||||||
|
model_name = self.data.get(
|
||||||
|
'model_name')
|
||||||
|
credential = self.data.get('credential')
|
||||||
|
provider_handler = ModelProvideConstants[provider].value
|
||||||
|
model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type,
|
||||||
|
model_name)
|
||||||
|
source_model_credential = json.loads(rsa_long_decrypt(model.credential))
|
||||||
|
source_encryption_model_credential = model_credential.encryption_dict(source_model_credential)
|
||||||
|
if credential is not None:
|
||||||
|
for k in source_encryption_model_credential.keys():
|
||||||
|
if k in credential and credential[k] == source_encryption_model_credential[k]:
|
||||||
|
credential[k] = source_model_credential[k]
|
||||||
|
return credential, model_credential, provider_handler
|
||||||
|
|
||||||
|
class Create(serializers.Serializer):
|
||||||
|
user_id = serializers.UUIDField(required=True, label=_('user id'))
|
||||||
|
name = serializers.CharField(required=True, max_length=64, label=_("model name"))
|
||||||
|
provider = serializers.CharField(required=True, label=_("provider"))
|
||||||
|
model_type = serializers.CharField(required=True, label=_("model type"))
|
||||||
|
model_name = serializers.CharField(required=True, label=_("base model"))
|
||||||
|
model_params_form = serializers.ListField(required=False, default=list, label=_("parameter configuration"))
|
||||||
|
credential = serializers.DictField(required=True, label=_("certification information"))
|
||||||
|
workspace_id = serializers.CharField(required=False, label=_("workspace id"), max_length=128)
|
||||||
|
|
||||||
|
def is_valid(self, *, raise_exception=False):
|
||||||
|
super().is_valid(raise_exception=True)
|
||||||
|
if QuerySet(Model).filter(
|
||||||
|
name=self.data.get('name'),
|
||||||
|
workspace_id=self.data.get('workspace_id')
|
||||||
|
).exists():
|
||||||
|
raise AppApiException(
|
||||||
|
500,
|
||||||
|
_('base model【{model_name}】already exists').format(model_name=self.data.get("name"))
|
||||||
|
)
|
||||||
|
default_params = {item['field']: item['default_value'] for item in self.data.get('model_params_form')}
|
||||||
|
ModelProvideConstants[self.data.get('provider')].value.is_valid_credential(
|
||||||
|
self.data.get('model_type'),
|
||||||
|
self.data.get('model_name'),
|
||||||
|
self.data.get('credential'),
|
||||||
|
default_params,
|
||||||
|
raise_exception=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def insert(self, workspace_id, with_valid=True):
|
||||||
|
status = Status.SUCCESS
|
||||||
|
if with_valid:
|
||||||
|
try:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
except AppApiException as e:
|
||||||
|
if e.code == ValidCode.model_not_fount:
|
||||||
|
status = Status.DOWNLOAD
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
credential = self.data.get('credential')
|
||||||
|
model_data = {
|
||||||
|
'id': uuid.uuid1(),
|
||||||
|
'status': status,
|
||||||
|
'user_id': self.data.get('user_id'),
|
||||||
|
'name': self.data.get('name'),
|
||||||
|
'credential': rsa_long_encrypt(json.dumps(credential)),
|
||||||
|
'provider': self.data.get('provider'),
|
||||||
|
'model_type': self.data.get('model_type'),
|
||||||
|
'model_name': self.data.get('model_name'),
|
||||||
|
'model_params_form': self.data.get('model_params_form'),
|
||||||
|
'workspace_id': workspace_id
|
||||||
|
}
|
||||||
|
model = Model(**model_data)
|
||||||
|
try:
|
||||||
|
model.save()
|
||||||
|
except Exception as save_error:
|
||||||
|
# 可添加日志记录
|
||||||
|
raise AppApiException(500, _("Model saving failed")) from save_error
|
||||||
|
|
||||||
|
if status == Status.DOWNLOAD:
|
||||||
|
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
return ModelModelSerializer(model).data
|
||||||
|
|
||||||
|
class Query(serializers.Serializer):
|
||||||
|
name = serializers.CharField(required=False, max_length=64, label=_('model name'))
|
||||||
|
model_type = serializers.CharField(required=False, label=_('model type'))
|
||||||
|
model_name = serializers.CharField(required=False, label=_('base model'))
|
||||||
|
provider = serializers.CharField(required=False, label=_('provider'))
|
||||||
|
create_user = serializers.CharField(required=False, label=_('create user'))
|
||||||
|
workspace_id = serializers.CharField(required=False, label=_('workspace id'))
|
||||||
|
|
||||||
|
def list(self, with_valid):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
|
||||||
|
query_params = self._build_query_params()
|
||||||
|
return self._fetch_models(query_params)
|
||||||
|
|
||||||
|
def _build_query_params(self):
|
||||||
|
query_params = {}
|
||||||
|
for field in ['name', 'model_type', 'model_name', 'provider', 'create_user', 'workspace_id']:
|
||||||
|
value = self.data.get(field)
|
||||||
|
if value is not None:
|
||||||
|
if field == 'name':
|
||||||
|
query_params[f'{field}__icontains'] = value
|
||||||
|
elif field == 'create_user':
|
||||||
|
query_params['user_id'] = value
|
||||||
|
else:
|
||||||
|
query_params[field] = value
|
||||||
|
return query_params
|
||||||
|
|
||||||
|
def _fetch_models(self, query_params):
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'id': str(model.id),
|
||||||
|
'provider': model.provider,
|
||||||
|
'name': model.name,
|
||||||
|
'model_type': model.model_type,
|
||||||
|
'model_name': model.model_name,
|
||||||
|
'status': model.status,
|
||||||
|
'meta': model.meta,
|
||||||
|
'user_id': model.user_id,
|
||||||
|
'username': model.user.username
|
||||||
|
}
|
||||||
|
for model in Model.objects.filter(**query_params).order_by("-create_time")
|
||||||
|
]
|
||||||
|
|
||||||
|
class ModelParams(serializers.Serializer):
|
||||||
|
id = serializers.UUIDField(required=True, label=_('model id'))
|
||||||
|
|
||||||
|
def is_valid(self, *, raise_exception=False):
|
||||||
|
super().is_valid(raise_exception=True)
|
||||||
|
model = QuerySet(Model).filter(id=self.data.get("id")).first()
|
||||||
|
if model is None:
|
||||||
|
raise AppApiException(500, _("Model does not exist"))
|
||||||
|
|
||||||
|
def get_model_params(self, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
model_id = self.data.get('id')
|
||||||
|
model = QuerySet(Model).filter(id=model_id).first()
|
||||||
|
return model.model_params_form
|
||||||
|
|
||||||
|
def save_model_params_form(self, model_params_form, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
if model_params_form is None:
|
||||||
|
model_params_form = []
|
||||||
|
model_id = self.data.get('id')
|
||||||
|
model = QuerySet(Model).filter(id=model_id).first()
|
||||||
|
model.model_params_form = model_params_form
|
||||||
|
model.save()
|
||||||
|
return True
|
||||||
@ -109,8 +109,6 @@ def get_model_by_id(_id, user_id):
|
|||||||
connection.close()
|
connection.close()
|
||||||
if model is None:
|
if model is None:
|
||||||
raise Exception(_('Model does not exist'))
|
raise Exception(_('Model does not exist'))
|
||||||
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
|
|
||||||
raise Exception(_('No permission to use this model') + f"{model.name}")
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,18 +4,19 @@ from . import views
|
|||||||
|
|
||||||
app_name = "models_provider"
|
app_name = "models_provider"
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
# path('provider/<str:provider>/<str:method>', views.Provide.Exec.as_view(), name='provide_exec'),
|
|
||||||
path('provider', views.Provide.as_view(), name='provide'),
|
path('provider', views.Provide.as_view(), name='provide'),
|
||||||
path('provider/model_type_list', views.Provide.ModelTypeList.as_view(), name="provider/model_type_list"),
|
path('provider/model_type_list', views.Provide.ModelTypeList.as_view(), name="provider/model_type_list"),
|
||||||
path('provider/model_list', views.Provide.ModelList.as_view(), name="provider/model_name_list"),
|
path('provider/model_list', views.Provide.ModelList.as_view(), name="provider/model_name_list"),
|
||||||
# path('provider/model_params_form', views.Provide.ModelParamsForm.as_view(),
|
path('provider/model_params_form', views.Provide.ModelParamsForm.as_view(),
|
||||||
# name="provider/model_params_form"),
|
name="provider/model_params_form"),
|
||||||
# path('provider/model_form', views.Provide.ModelForm.as_view(),
|
path('provider/model_form', views.Provide.ModelForm.as_view(),
|
||||||
# name="provider/model_form"),
|
name="provider/model_form"),
|
||||||
path('workspace/<str:workspace_id>/model', views.Model.as_view(), name='model'),
|
path('workspace/<str:workspace_id>/model', views.Model.as_view(), name='model'),
|
||||||
# path('workspace/<str:workspace_id>/model/<str:model_id>/model_params_form', views.Model.ModelParamsForm.as_view(),
|
path('workspace/<str:workspace_id>/model/<str:model_id>/model_params_form', views.Model.ModelParamsForm.as_view(),
|
||||||
# name='model/model_params_form'),
|
name='model/model_params_form'),
|
||||||
# path('workspace/<str:workspace_id>/model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
|
path('workspace/<str:workspace_id>/model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
|
||||||
# path('workspace/<str:workspace_id>/model/<str:model_id>/pause_download', views.Model.PauseDownload.as_view(), name='model/operate'),
|
path('workspace/<str:workspace_id>/model/<str:model_id>/pause_download', views.Model.PauseDownload.as_view(),
|
||||||
# path('workspace/<str:workspace_id>/model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'),
|
name='model/operate'),
|
||||||
|
path('workspace/<str:workspace_id>/model/<str:model_id>/meta', views.Model.ModelMeta.as_view(),
|
||||||
|
name='model/operate/meta'),
|
||||||
]
|
]
|
||||||
|
|||||||
@ -15,8 +15,10 @@ from common.auth import TokenAuth
|
|||||||
from common.auth.authentication import has_permissions
|
from common.auth.authentication import has_permissions
|
||||||
from common.constants.permission_constants import PermissionConstants
|
from common.constants.permission_constants import PermissionConstants
|
||||||
from common.result import result
|
from common.result import result
|
||||||
from models_provider.api.model import ModelCreateAPI
|
from common.utils.common import query_params_to_single_dict
|
||||||
from models_provider.serializers.model import ModelSerializer
|
from models_provider.api.model import ModelCreateAPI, GetModelApi, ModelEditApi, ModelListResponse
|
||||||
|
from models_provider.api.provide import ProvideApi
|
||||||
|
from models_provider.serializers.model_serializer import ModelSerializer
|
||||||
|
|
||||||
|
|
||||||
class Model(APIView):
|
class Model(APIView):
|
||||||
@ -26,10 +28,127 @@ class Model(APIView):
|
|||||||
description=_("Create model"),
|
description=_("Create model"),
|
||||||
operation_id=_("Create model"),
|
operation_id=_("Create model"),
|
||||||
tags=[_("Model")],
|
tags=[_("Model")],
|
||||||
|
parameters=ModelCreateAPI.get_query_params_api(),
|
||||||
request=ModelCreateAPI.get_request(),
|
request=ModelCreateAPI.get_request(),
|
||||||
responses=ModelCreateAPI.get_response())
|
responses=ModelCreateAPI.get_response())
|
||||||
@has_permissions(PermissionConstants.MODEL_CREATE)
|
@has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission())
|
||||||
def post(self, request: Request, workspace_id: str):
|
def post(self, request: Request, workspace_id: str):
|
||||||
return result.success(
|
return result.success(
|
||||||
ModelSerializer.Create(data={**request.data, 'user_id': request.user.id}).insert(workspace_id,
|
ModelSerializer.Create(data={**request.data, 'user_id': request.user.id}).insert(workspace_id,
|
||||||
with_valid=True))
|
with_valid=True))
|
||||||
|
|
||||||
|
# @extend_schema(methods=['PUT'],
|
||||||
|
# description=_('Update model'),
|
||||||
|
# operation_id=_('Update model'),
|
||||||
|
# request=ModelEditApi.get_request(),
|
||||||
|
# responses=ModelCreateApi.get_response(),
|
||||||
|
# tags=[_('Model')])
|
||||||
|
# @has_permissions(PermissionConstants.MODEL_CREATE)
|
||||||
|
# def put(self, request: Request):
|
||||||
|
# return result.success(
|
||||||
|
# ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id,
|
||||||
|
# with_valid=True))
|
||||||
|
|
||||||
|
@extend_schema(methods=['GET'],
|
||||||
|
description=_('Query model list'),
|
||||||
|
operation_id=_('Query model list'),
|
||||||
|
parameters=ModelCreateAPI.get_query_params_api(),
|
||||||
|
responses=ModelListResponse.get_response(),
|
||||||
|
tags=[_('Model')])
|
||||||
|
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
|
||||||
|
def get(self, request: Request):
|
||||||
|
return result.success(
|
||||||
|
ModelSerializer.Query(
|
||||||
|
data={**query_params_to_single_dict(request.query_params)}).list(
|
||||||
|
with_valid=True))
|
||||||
|
|
||||||
|
class Operate(APIView):
|
||||||
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@extend_schema(methods=['PUT'],
|
||||||
|
description=_('Update model'),
|
||||||
|
operation_id=_('Update model'),
|
||||||
|
request=ModelEditApi.get_request(),
|
||||||
|
parameters=GetModelApi.get_query_params_api(),
|
||||||
|
responses=ModelEditApi.get_response(),
|
||||||
|
tags=[_('Model')])
|
||||||
|
@has_permissions(PermissionConstants.MODEL_EDIT.get_workspace_permission())
|
||||||
|
def put(self, request: Request, workspace_id, model_id: str):
|
||||||
|
return result.success(
|
||||||
|
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).edit(request.data,
|
||||||
|
str(request.user.id)))
|
||||||
|
|
||||||
|
@extend_schema(methods=['DELETE'],
|
||||||
|
description=_('Delete model'),
|
||||||
|
operation_id=_('Delete model'),
|
||||||
|
parameters=GetModelApi.get_query_params_api(),
|
||||||
|
tags=[_('Model')])
|
||||||
|
@has_permissions(PermissionConstants.MODEL_DELETE.get_workspace_permission())
|
||||||
|
def delete(self, request: Request, workspace_id: str, model_id: str):
|
||||||
|
return result.success(
|
||||||
|
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).delete())
|
||||||
|
|
||||||
|
@extend_schema(methods=['GET'],
|
||||||
|
description=_('Query model details'),
|
||||||
|
operation_id=_('Query model details'),
|
||||||
|
parameters=GetModelApi.get_query_params_api(),
|
||||||
|
responses=GetModelApi.get_response(),
|
||||||
|
tags=[_('Model')])
|
||||||
|
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
|
||||||
|
def get(self, request: Request, workspace_id: str, model_id: str):
|
||||||
|
return result.success(
|
||||||
|
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one(with_valid=True))
|
||||||
|
|
||||||
|
class ModelParamsForm(APIView):
|
||||||
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@extend_schema(methods=['GET'],
|
||||||
|
description=_('Get model parameter form'),
|
||||||
|
operation_id=_('Get model parameter form'),
|
||||||
|
parameters=GetModelApi.get_query_params_api(),
|
||||||
|
responses=ProvideApi.ModelParamsForm.get_response(),
|
||||||
|
tags=[_('Model')])
|
||||||
|
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
|
||||||
|
def get(self, request: Request, workspace_id: str, model_id: str):
|
||||||
|
return result.success(
|
||||||
|
ModelSerializer.ModelParams(data={'id': model_id}).get_model_params())
|
||||||
|
|
||||||
|
@extend_schema(methods=['PUT'],
|
||||||
|
description=_('Save model parameter form'),
|
||||||
|
operation_id=_('Save model parameter form'),
|
||||||
|
parameters=GetModelApi.get_query_params_api(),
|
||||||
|
responses=ProvideApi.ModelParamsForm.get_response(),
|
||||||
|
tags=[_('Model')])
|
||||||
|
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
|
||||||
|
def put(self, request: Request, workspace_id: str, model_id: str):
|
||||||
|
return result.success(
|
||||||
|
ModelSerializer.ModelParams(data={'id': model_id}).save_model_params_form(request.data))
|
||||||
|
|
||||||
|
class ModelMeta(APIView):
|
||||||
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@extend_schema(methods=['GET'],
|
||||||
|
description=_(
|
||||||
|
'Query model meta information, this interface does not carry authentication information'),
|
||||||
|
operation_id=_(
|
||||||
|
'Query model meta information, this interface does not carry authentication information'),
|
||||||
|
parameters=GetModelApi.get_query_params_api(),
|
||||||
|
responses=GetModelApi.get_response(),
|
||||||
|
tags=[_('Model')])
|
||||||
|
@has_permissions(PermissionConstants.MODEL_READ.get_workspace_permission())
|
||||||
|
def get(self, request: Request, workspace_id: str, model_id: str):
|
||||||
|
return result.success(
|
||||||
|
ModelSerializer.Operate(data={'id': model_id}).one_meta(with_valid=True))
|
||||||
|
|
||||||
|
class PauseDownload(APIView):
|
||||||
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@extend_schema(methods=['PUT'],
|
||||||
|
description=_('Pause model download'),
|
||||||
|
operation_id=_('Pause model download'),
|
||||||
|
parameters=GetModelApi.get_query_params_api(),
|
||||||
|
tags=[_('Model')])
|
||||||
|
@has_permissions(PermissionConstants.MODEL_CREATE.get_workspace_permission())
|
||||||
|
def put(self, request: Request, workspace_id: str, model_id: str):
|
||||||
|
return result.success(
|
||||||
|
ModelSerializer.Operate(data={'id': model_id}).pause_download())
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from common.auth.authentication import has_permissions
|
|||||||
from common.constants.permission_constants import PermissionConstants
|
from common.constants.permission_constants import PermissionConstants
|
||||||
from models_provider.api.provide import ProvideApi
|
from models_provider.api.provide import ProvideApi
|
||||||
from models_provider.constants.model_provider_constants import ModelProvideConstants
|
from models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||||
|
from models_provider.serializers.model_serializer import get_default_model_params_setting
|
||||||
|
|
||||||
|
|
||||||
class Provide(APIView):
|
class Provide(APIView):
|
||||||
@ -66,3 +67,37 @@ class Provide(APIView):
|
|||||||
return result.success(
|
return result.success(
|
||||||
ModelProvideConstants[provider].value.get_model_list(
|
ModelProvideConstants[provider].value.get_model_list(
|
||||||
model_type))
|
model_type))
|
||||||
|
|
||||||
|
class ModelParamsForm(APIView):
|
||||||
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@extend_schema(methods=['GET'],
|
||||||
|
description=_('Get model default parameters'),
|
||||||
|
operation_id=_('Get the model creation form'),
|
||||||
|
parameters=ProvideApi.ModelParamsForm.get_query_params_api(),
|
||||||
|
responses=ProvideApi.ModelParamsForm.get_response(),
|
||||||
|
tags=[_('Model')])
|
||||||
|
@has_permissions(PermissionConstants.MODEL_READ)
|
||||||
|
def get(self, request: Request):
|
||||||
|
provider = request.query_params.get('provider')
|
||||||
|
model_type = request.query_params.get('model_type')
|
||||||
|
model_name = request.query_params.get('model_name')
|
||||||
|
|
||||||
|
return result.success(get_default_model_params_setting(provider, model_type, model_name))
|
||||||
|
|
||||||
|
class ModelForm(APIView):
|
||||||
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@extend_schema(methods=['GET'],
|
||||||
|
description=_('Get the model creation form'),
|
||||||
|
operation_id=_('Get the model creation form'),
|
||||||
|
parameters=ProvideApi.ModelParamsForm.get_query_params_api(),
|
||||||
|
responses=ProvideApi.ModelParamsForm.get_response(),
|
||||||
|
tags=[_('Model')])
|
||||||
|
@has_permissions(PermissionConstants.MODEL_READ)
|
||||||
|
def get(self, request: Request):
|
||||||
|
provider = request.query_params.get('provider')
|
||||||
|
model_type = request.query_params.get('model_type')
|
||||||
|
model_name = request.query_params.get('model_name')
|
||||||
|
return result.success(
|
||||||
|
ModelProvideConstants[provider].value.get_model_credential(model_type, model_name).to_form_list())
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user