feat: ollama支持下载模型

This commit is contained in:
shaohuzhang1 2024-03-22 17:56:56 +08:00
parent bdf5edc203
commit d074424398
13 changed files with 363 additions and 53 deletions

View File

@ -13,13 +13,14 @@ from django.core.cache import caches
memory_cache = caches['default'] memory_cache = caches['default']
def try_lock(key: str): def try_lock(key: str, timeout=None):
""" """
获取锁 获取锁
:param key: 获取锁 key :param key: 获取锁 key
:param timeout 超时时间
:return: 是否获取到锁 :return: 是否获取到锁
""" """
return memory_cache.add(key, 'lock', timeout=timedelta(hours=1).total_seconds()) return memory_cache.add(key, 'lock', timeout=timedelta(hours=1).total_seconds() if timeout is not None else timeout)
def un_lock(key: str): def un_lock(key: str):

View File

@ -0,0 +1,23 @@
# Generated by Django 4.1.13 on 2024-03-22 17:51
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('setting', '0002_systemsetting'),
]
operations = [
migrations.AddField(
model_name='model',
name='meta',
field=models.JSONField(default=dict, verbose_name='模型元数据,用于存储下载,或者错误信息'),
),
migrations.AddField(
model_name='model',
name='status',
field=models.CharField(choices=[('SUCCESS', '成功'), ('ERROR', '失败'), ('DOWNLOAD', '下载中')], default='SUCCESS', max_length=20, verbose_name='设置类型'),
),
]

View File

@ -14,6 +14,15 @@ from common.mixins.app_model_mixin import AppModelMixin
from users.models import User from users.models import User
class Status(models.TextChoices):
"""系统设置类型"""
SUCCESS = "SUCCESS", '成功'
ERROR = "ERROR", "失败"
DOWNLOAD = "DOWNLOAD", '下载中'
class Model(AppModelMixin): class Model(AppModelMixin):
""" """
模型数据 模型数据
@ -22,6 +31,9 @@ class Model(AppModelMixin):
name = models.CharField(max_length=128, verbose_name="名称") name = models.CharField(max_length=128, verbose_name="名称")
status = models.CharField(max_length=20, verbose_name='设置类型', choices=Status.choices,
default=Status.SUCCESS)
model_type = models.CharField(max_length=128, verbose_name="模型类型") model_type = models.CharField(max_length=128, verbose_name="模型类型")
model_name = models.CharField(max_length=128, verbose_name="模型名称") model_name = models.CharField(max_length=128, verbose_name="模型名称")
@ -32,6 +44,8 @@ class Model(AppModelMixin):
credential = models.CharField(max_length=5120, verbose_name="模型认证信息") credential = models.CharField(max_length=5120, verbose_name="模型认证信息")
meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict)
class Meta: class Meta:
db_table = "model" db_table = "model"
unique_together = ['name', 'user_id'] unique_together = ['name', 'user_id']

View File

@ -9,10 +9,42 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from functools import reduce from functools import reduce
from typing import Dict from typing import Dict, Iterator
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from common.exception.app_exception import AppApiException
class DownModelChunkStatus(Enum):
success = "success"
error = "error"
pulling = "pulling"
unknown = 'unknown'
class ValidCode(Enum):
valid_error = 500
model_not_fount = 404
class DownModelChunk:
def __init__(self, status: DownModelChunkStatus, digest: str, progress: int, details: str, index: int):
self.details = details
self.status = status
self.digest = digest
self.progress = progress
self.index = index
def to_dict(self):
return {
"details": self.details,
"status": self.status.value,
"digest": self.digest,
"progress": self.progress,
"index": self.index
}
class IModelProvider(ABC): class IModelProvider(ABC):
@ -40,6 +72,9 @@ class IModelProvider(ABC):
def get_dialogue_number(self): def get_dialogue_number(self):
pass pass
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
raise AppApiException(500, "当前平台不支持下载模型")
class BaseModelCredential(ABC): class BaseModelCredential(ABC):

View File

@ -9,8 +9,8 @@
import os import os
from typing import Dict from typing import Dict
from langchain_community.chat_models import AzureChatOpenAI
from langchain.schema import HumanMessage from langchain.schema import HumanMessage
from langchain_community.chat_models import AzureChatOpenAI
from common import froms from common import froms
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
@ -18,7 +18,7 @@ from common.froms import BaseForm
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \ from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
ModelInfo, \ ModelInfo, \
ModelTypeConst ModelTypeConst, ValidCode
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
@ -27,15 +27,15 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = AzureModelProvider().get_model_type_list() model_type_list = AzureModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(500, f'{model_type} 模型类型不支持') raise AppApiException(ValidCode.valid_error, f'{model_type} 模型类型不支持')
if model_name not in model_dict: if model_name not in model_dict:
raise AppApiException(500, f'{model_name} 模型名称不支持') raise AppApiException(ValidCode.valid_error, f'{model_name} 模型名称不支持')
for key in ['api_base', 'api_key', 'deployment_name']: for key in ['api_base', 'api_key', 'deployment_name']:
if key not in model_credential: if key not in model_credential:
if raise_exception: if raise_exception:
raise AppApiException(500, f'{key} 字段为必填字段') raise AppApiException(ValidCode.valid_error, f'{key} 字段为必填字段')
else: else:
return False return False
try: try:
@ -45,7 +45,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
if isinstance(e, AppApiException): if isinstance(e, AppApiException):
raise e raise e
if raise_exception: if raise_exception:
raise AppApiException(500, '校验失败,请检查参数是否正确') raise AppApiException(ValidCode.valid_error, '校验失败,请检查参数是否正确')
else: else:
return False return False

View File

@ -6,9 +6,14 @@
@date2024/3/5 17:23 @date2024/3/5 17:23
@desc: @desc:
""" """
import json
import os import os
from typing import Dict from typing import Dict, Iterator
from urllib.parse import urlparse, ParseResult
import aiohttp
import requests
from django.http import StreamingHttpResponse
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage from langchain.schema import HumanMessage
@ -17,29 +22,26 @@ from common.exception.app_exception import AppApiException
from common.froms import BaseForm from common.froms import BaseForm
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
BaseModelCredential BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode
from setting.models_provider.impl.ollama_model_provider.model.ollama_chat_model import OllamaChatModel from setting.models_provider.impl.ollama_model_provider.model.ollama_chat_model import OllamaChatModel
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
""
class OllamaLLMModelCredential(BaseForm, BaseModelCredential): class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = OllamaModelProvider().get_model_type_list() model_type_list = OllamaModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(500, f'{model_type} 模型类型不支持') raise AppApiException(ValidCode.valid_error, f'{model_type} 模型类型不支持')
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(500, f'{key} 字段为必填字段')
else:
return False
try: try:
OllamaModelProvider().get_model(model_type, model_name, model_credential).invoke( model_list = OllamaModelProvider.get_base_model_list(model_credential.get('api_base'))
[HumanMessage(content='valid')])
except Exception as e: except Exception as e:
if raise_exception: raise AppApiException(ValidCode.valid_error, "API 域名无效")
raise AppApiException(500, "校验失败,请检查 api_key secret_key 是否正确") exist = [model for model in model_list.get('models') if
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
if len(exist) == 0:
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
return True return True
def encryption_dict(self, model_info: Dict[str, object]): def encryption_dict(self, model_info: Dict[str, object]):
@ -86,6 +88,52 @@ model_dict = {
} }
def get_base_url(url: str):
parse = urlparse(url)
return ParseResult(scheme=parse.scheme, netloc=parse.netloc, path='', params='',
query='',
fragment='').geturl()
def convert_to_down_model_chunk(row_str: str, chunk_index: int):
row = json.loads(row_str)
status = DownModelChunkStatus.unknown
digest = ""
progress = 100
if 'status' in row:
digest = row.get('status')
if row.get('status') == 'success':
status = DownModelChunkStatus.success
if row.get('status').__contains__("pulling"):
status = DownModelChunkStatus.pulling
if 'total' in row and 'completed' in row:
progress = (row.get('completed') / row.get('total') * 100)
elif 'error' in row:
status = DownModelChunkStatus.error
digest = row.get('error')
return DownModelChunk(status=status, digest=digest, progress=progress, details=row_str, index=chunk_index)
def convert(response_stream) -> Iterator[DownModelChunk]:
temp = ""
index = 0
for c in response_stream:
index += 1
row_content = c.decode()
temp += row_content
if row_content.endswith('}') or row_content.endswith('\n'):
rows = [t for t in temp.split("\n") if len(t) > 0]
for row in rows:
yield convert_to_down_model_chunk(row, index)
temp = ""
if len(temp) > 0:
print(temp)
rows = [t for t in temp.split("\n") if len(t) > 0]
for row in rows:
yield convert_to_down_model_chunk(row, index)
class OllamaModelProvider(IModelProvider): class OllamaModelProvider(IModelProvider):
def get_model_provide_info(self): def get_model_provide_info(self):
return ModelProvideInfo(provider='model_ollama_provider', name='Ollama', icon=get_file_content( return ModelProvideInfo(provider='model_ollama_provider', name='Ollama', icon=get_file_content(
@ -113,3 +161,21 @@ class OllamaModelProvider(IModelProvider):
def get_dialogue_number(self): def get_dialogue_number(self):
return 2 return 2
@staticmethod
def get_base_model_list(api_base):
base_url = get_base_url(api_base)
r = requests.request(method="GET", url=f"{base_url}/api/tags")
r.raise_for_status()
return r.json()
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
api_base = model_credential.get('api_base')
base_url = get_base_url(api_base)
r = requests.request(
method="POST",
url=f"{base_url}/api/pull",
data=json.dumps({"name": model_name}).encode(),
stream=True,
)
return convert(r)

View File

@ -7,6 +7,8 @@
@desc: @desc:
""" """
import json import json
import threading
import time
import uuid import uuid
from typing import Dict from typing import Dict
@ -17,10 +19,36 @@ from application.models import Application
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from common.util.rsa_util import encrypt, decrypt from common.util.rsa_util import encrypt, decrypt
from setting.models.model_management import Model from setting.models.model_management import Model, Status
from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
class ModelPullManage:
@staticmethod
def pull(model: Model, credential: Dict):
response = ModelProvideConstants[model.provider].value.down_model(model.model_type, model.model_name,
credential)
down_model_chunk = {}
timestamp = time.time()
for chunk in response:
down_model_chunk[chunk.digest] = chunk.to_dict()
if time.time() - timestamp > 5:
QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": list(down_model_chunk.values())})
timestamp = time.time()
status = Status.ERROR
message = ""
down_model_chunk_list = list(down_model_chunk.values())
for chunk in down_model_chunk_list:
if chunk.get('status') == DownModelChunkStatus.success.value:
status = Status.SUCCESS
if chunk.get('status') == DownModelChunkStatus.error.value:
message = chunk.get("digest")
QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": down_model_chunk_list, "message": message},
status=status)
class ModelSerializer(serializers.Serializer): class ModelSerializer(serializers.Serializer):
class Query(serializers.Serializer): class Query(serializers.Serializer):
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
@ -50,7 +78,10 @@ class ModelSerializer(serializers.Serializer):
if self.data.get('provider') is not None: if self.data.get('provider') is not None:
query_params['provider'] = self.data.get('provider') query_params['provider'] = self.data.get('provider')
return [ModelSerializer.model_to_dict(model) for model in model_query_set.filter(**query_params)] return [
{'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name, 'status': model.status, 'meta': model.meta} for model in
model_query_set.filter(**query_params)]
class Edit(serializers.Serializer): class Edit(serializers.Serializer):
user_id = serializers.CharField(required=False, error_messages=ErrMessage.uuid("用户id")) user_id = serializers.CharField(required=False, error_messages=ErrMessage.uuid("用户id"))
@ -88,13 +119,7 @@ class ModelSerializer(serializers.Serializer):
for k in source_encryption_model_credential.keys(): for k in source_encryption_model_credential.keys():
if credential[k] == source_encryption_model_credential[k]: if credential[k] == source_encryption_model_credential[k]:
credential[k] = source_model_credential[k] credential[k] = source_model_credential[k]
# 校验模型认证数据 return credential, model_credential
model_credential.is_valid(
model_type,
model_name,
credential,
raise_exception=True)
return credential
class Create(serializers.Serializer): class Create(serializers.Serializer):
user_id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户id")) user_id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户id"))
@ -124,18 +149,28 @@ class ModelSerializer(serializers.Serializer):
raise_exception=True) raise_exception=True)
def insert(self, user_id, with_valid=False): def insert(self, user_id, with_valid=False):
status = Status.SUCCESS
if with_valid: if with_valid:
self.is_valid(raise_exception=True) try:
self.is_valid(raise_exception=True)
except AppApiException as e:
if e.code == ValidCode.model_not_fount:
status = Status.DOWNLOAD
else:
raise e
credential = self.data.get('credential') credential = self.data.get('credential')
name = self.data.get('name') name = self.data.get('name')
provider = self.data.get('provider') provider = self.data.get('provider')
model_type = self.data.get('model_type') model_type = self.data.get('model_type')
model_name = self.data.get('model_name') model_name = self.data.get('model_name')
model_credential_str = json.dumps(credential) model_credential_str = json.dumps(credential)
model = Model(id=uuid.uuid1(), user_id=user_id, name=name, model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
credential=encrypt(model_credential_str), credential=encrypt(model_credential_str),
provider=provider, model_type=model_type, model_name=model_name) provider=provider, model_type=model_type, model_name=model_name)
model.save() model.save()
if status == Status.DOWNLOAD:
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
thread.start()
return ModelSerializer.Operate(data={'id': model.id, 'user_id': user_id}).one(with_valid=True) return ModelSerializer.Operate(data={'id': model.id, 'user_id': user_id}).one(with_valid=True)
@staticmethod @staticmethod
@ -143,6 +178,8 @@ class ModelSerializer(serializers.Serializer):
credential = json.loads(decrypt(model.credential)) credential = json.loads(decrypt(model.credential))
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name, 'model_name': model.model_name,
'status': model.status,
'meta': model.meta,
'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type, 'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
model.model_name).encryption_dict( model.model_name).encryption_dict(
credential)} credential)}
@ -164,6 +201,15 @@ class ModelSerializer(serializers.Serializer):
model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id')) model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id'))
return ModelSerializer.model_to_dict(model) return ModelSerializer.model_to_dict(model)
def one_meta(self, with_valid=False):
if with_valid:
self.is_valid(raise_exception=True)
model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id'))
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name,
'status': model.status,
'meta': model.meta, }
def delete(self, with_valid=True): def delete(self, with_valid=True):
if with_valid: if with_valid:
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
@ -181,7 +227,20 @@ class ModelSerializer(serializers.Serializer):
if model is None: if model is None:
raise AppApiException(500, '不存在的id') raise AppApiException(500, '不存在的id')
else: else:
credential = ModelSerializer.Edit(data={**instance, 'user_id': user_id}).is_valid(model=model) credential, model_credential = ModelSerializer.Edit(data={**instance, 'user_id': user_id}).is_valid(
model=model)
try:
# 校验模型认证数据
model_credential.is_valid(
model.model_type,
instance.get("model_name"),
credential,
raise_exception=True)
except AppApiException as e:
if e.code == ValidCode.model_not_fount:
model.status = Status.DOWNLOAD
else:
raise e
update_keys = ['credential', 'name', 'model_type', 'model_name'] update_keys = ['credential', 'name', 'model_type', 'model_name']
for update_key in update_keys: for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None: if update_key in instance and instance.get(update_key) is not None:
@ -191,6 +250,9 @@ class ModelSerializer(serializers.Serializer):
else: else:
model.__setattr__(update_key, instance.get(update_key)) model.__setattr__(update_key, instance.get(update_key))
model.save() model.save()
if model.status == Status.DOWNLOAD:
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
thread.start()
return self.one(with_valid=False) return self.one(with_valid=False)

View File

@ -16,6 +16,7 @@ urlpatterns = [
name="provider/model_form"), name="provider/model_form"),
path('model', views.Model.as_view(), name='model'), path('model', views.Model.as_view(), name='model'),
path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'), path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
path('model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'),
path('email_setting', views.SystemSetting.Email.as_view(), name='email_setting') path('email_setting', views.SystemSetting.Email.as_view(), name='email_setting')
] ]

View File

@ -34,6 +34,17 @@ class Model(APIView):
ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id, ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id,
with_valid=True)) with_valid=True))
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="下载模型,只试用与Ollama平台",
operation_id="下载模型,只试用与Ollama平台",
request_body=ModelCreateApi.get_request_body_api()
, tags=["模型"])
@has_permissions(PermissionConstants.MODEL_CREATE)
def put(self, request: Request):
return result.success(
ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id,
with_valid=True))
@action(methods=['GET'], detail=False) @action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取模型列表", @swagger_auto_schema(operation_summary="获取模型列表",
operation_id="获取模型列表", operation_id="获取模型列表",
@ -46,6 +57,18 @@ class Model(APIView):
data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list( data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list(
with_valid=True)) with_valid=True))
class ModelMeta(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="查询模型meta信息,该接口不携带认证信息",
operation_id="查询模型meta信息,该接口不携带认证信息",
tags=["模型"])
@has_permissions(PermissionConstants.MODEL_READ)
def get(self, request: Request, model_id: str):
return result.success(
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one_meta(with_valid=True))
class Operate(APIView): class Operate(APIView):
authentication_classes = [TokenAuth] authentication_classes = [TokenAuth]

View File

@ -106,6 +106,31 @@ const updateModel: (
return put(`${prefix}/${model_id}`, request, {}, loading) return put(`${prefix}/${model_id}`, request, {}, loading)
} }
/**
* id
* @param model_id id
* @param loading
* @returns
*/
const getModelById: (model_id: string, loading?: Ref<boolean>) => Promise<Result<Model>> = (
model_id,
loading
) => {
return get(`${prefix}/${model_id}`, {}, loading)
}
/**
* id
* @param model_id id
* @param loading
* @returns
*/
const getModelMetaById: (model_id: string, loading?: Ref<boolean>) => Promise<Result<Model>> = (
model_id,
loading
) => {
return get(`${prefix}/${model_id}/meta`, {}, loading)
}
const deleteModel: (model_id: string, loading?: Ref<boolean>) => Promise<Result<boolean>> = ( const deleteModel: (model_id: string, loading?: Ref<boolean>) => Promise<Result<boolean>> = (
model_id, model_id,
loading loading
@ -120,5 +145,7 @@ export default {
listBaseModel, listBaseModel,
createModel, createModel,
updateModel, updateModel,
deleteModel deleteModel,
getModelById,
getModelMetaById
} }

View File

@ -1,4 +1,5 @@
import { store } from '@/stores' import { store } from '@/stores'
import { Dict } from './common'
interface modelRequest { interface modelRequest {
name: string name: string
model_type: string model_type: string
@ -64,6 +65,14 @@ interface Model {
* *
*/ */
provider: string provider: string
/**
*
*/
status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR'
/**
*
*/
meta: Dict<any>
} }
interface CreateModelRequest { interface CreateModelRequest {
/** /**

View File

@ -18,6 +18,7 @@
</template> </template>
<DynamicsForm <DynamicsForm
v-loading="formLoading"
v-model="form_data" v-model="form_data"
:render_data="model_form_field" :render_data="model_form_field"
:model="form_data" :model="form_data"
@ -56,7 +57,7 @@
@change="getModelForm($event)" @change="getModelForm($event)"
v-loading="base_model_loading" v-loading="base_model_loading"
style="width: 100%" style="width: 100%"
v-model="form_data.model_name" v-model="base_form_data.model_name"
class="m-2" class="m-2"
placeholder="请选择基础模型" placeholder="请选择基础模型"
filterable filterable
@ -90,10 +91,12 @@ import type { FormField } from '@/components/dynamics-form/type'
import DynamicsForm from '@/components/dynamics-form/index.vue' import DynamicsForm from '@/components/dynamics-form/index.vue'
import type { FormRules } from 'element-plus' import type { FormRules } from 'element-plus'
import { MsgSuccess } from '@/utils/message' import { MsgSuccess } from '@/utils/message'
const providerValue = ref<Provider>() const providerValue = ref<Provider>()
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>() const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
const emit = defineEmits(['change', 'submit']) const emit = defineEmits(['change', 'submit'])
const loading = ref<boolean>(false) const loading = ref<boolean>(false)
const formLoading = ref<boolean>(false)
const model_type_loading = ref<boolean>(false) const model_type_loading = ref<boolean>(false)
const base_model_loading = ref<boolean>(false) const base_model_loading = ref<boolean>(false)
const model_type_list = ref<Array<KeyValue<string, string>>>([]) const model_type_list = ref<Array<KeyValue<string, string>>>([])
@ -152,21 +155,22 @@ const list_base_model = (model_type: any) => {
} }
} }
const open = (provider: Provider, model: Model) => { const open = (provider: Provider, model: Model) => {
modelValue.value = model ModelApi.getModelById(model.id, formLoading).then((ok) => {
ModelApi.listModelType(model.provider, model_type_loading).then((ok) => { modelValue.value = ok.data
model_type_list.value = ok.data ModelApi.listModelType(model.provider, model_type_loading).then((ok) => {
list_base_model(model.model_type) model_type_list.value = ok.data
list_base_model(model.model_type)
})
providerValue.value = provider
base_form_data.value = {
name: model.name,
model_type: model.model_type,
model_name: model.model_name
}
form_data.value = model.credential
getModelForm(model.model_name)
}) })
providerValue.value = provider
base_form_data.value = {
name: model.name,
model_type: model.model_type,
model_name: model.model_name
}
form_data.value = model.credential
getModelForm(model.model_name)
dialogVisible.value = true dialogVisible.value = true
} }

View File

@ -37,15 +37,32 @@
<script setup lang="ts"> <script setup lang="ts">
import type { Provider, Model } from '@/api/type/model' import type { Provider, Model } from '@/api/type/model'
import ModelApi from '@/api/model' import ModelApi from '@/api/model'
import { computed, ref } from 'vue' import { computed, ref, onMounted, onBeforeUnmount } from 'vue'
import EditModel from '@/views/template/component/EditModel.vue' import EditModel from '@/views/template/component/EditModel.vue'
import { MsgSuccess, MsgConfirm } from '@/utils/message' import { MsgSuccess, MsgConfirm } from '@/utils/message'
const props = defineProps<{ const props = defineProps<{
model: Model model: Model
provider_list: Array<Provider> provider_list: Array<Provider>
}>() }>()
const downModel = ref<Model>()
const progress = computed(() => {
if (downModel.value) {
const down_model_chunk = downModel.value.meta['down_model_chunk']
if (down_model_chunk) {
const maxObj = down_model_chunk.reduce((prev: any, current: any) => {
return (prev.index || 0) > (current.index || 0) ? prev : current
})
return maxObj.progress
}
return 0
}
return 0
})
const emit = defineEmits(['change']) const emit = defineEmits(['change'])
const eidtModelRef = ref<InstanceType<typeof EditModel>>() const eidtModelRef = ref<InstanceType<typeof EditModel>>()
let interval: any
const deleteModel = () => { const deleteModel = () => {
MsgConfirm(`删除模型 `, `是否删除模型:${props.model.name} ?`, { MsgConfirm(`删除模型 `, `是否删除模型:${props.model.name} ?`, {
confirmButtonText: '删除', confirmButtonText: '删除',
@ -67,6 +84,34 @@ const openEditModel = () => {
const icon = computed(() => { const icon = computed(() => {
return props.provider_list.find((p) => p.provider === props.model.provider)?.icon return props.provider_list.find((p) => p.provider === props.model.provider)?.icon
}) })
/**
* 初始化轮询
*/
const initInterval = () => {
interval = setInterval(() => {
if (props.model.status === 'DOWNLOAD') {
ModelApi.getModelMetaById(props.model.id).then((ok) => {
downModel.value = ok.data
})
}
}, 6000)
}
/**
* 关闭轮询
*/
const closeInterval = () => {
if (interval) {
clearInterval(interval)
}
}
onMounted(() => {
initInterval()
})
onBeforeUnmount(() => {
//
closeInterval()
})
</script> </script>
<style lang="scss" scoped> <style lang="scss" scoped>
.model-card { .model-card {