feat: ollama支持下载模型
This commit is contained in:
parent
bdf5edc203
commit
d074424398
@ -13,13 +13,14 @@ from django.core.cache import caches
|
|||||||
memory_cache = caches['default']
|
memory_cache = caches['default']
|
||||||
|
|
||||||
|
|
||||||
def try_lock(key: str):
|
def try_lock(key: str, timeout=None):
|
||||||
"""
|
"""
|
||||||
获取锁
|
获取锁
|
||||||
:param key: 获取锁 key
|
:param key: 获取锁 key
|
||||||
|
:param timeout 超时时间
|
||||||
:return: 是否获取到锁
|
:return: 是否获取到锁
|
||||||
"""
|
"""
|
||||||
return memory_cache.add(key, 'lock', timeout=timedelta(hours=1).total_seconds())
|
return memory_cache.add(key, 'lock', timeout=timedelta(hours=1).total_seconds() if timeout is not None else timeout)
|
||||||
|
|
||||||
|
|
||||||
def un_lock(key: str):
|
def un_lock(key: str):
|
||||||
|
|||||||
23
apps/setting/migrations/0003_model_meta_model_status.py
Normal file
23
apps/setting/migrations/0003_model_meta_model_status.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# Generated by Django 4.1.13 on 2024-03-22 17:51
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('setting', '0002_systemsetting'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='model',
|
||||||
|
name='meta',
|
||||||
|
field=models.JSONField(default=dict, verbose_name='模型元数据,用于存储下载,或者错误信息'),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='model',
|
||||||
|
name='status',
|
||||||
|
field=models.CharField(choices=[('SUCCESS', '成功'), ('ERROR', '失败'), ('DOWNLOAD', '下载中')], default='SUCCESS', max_length=20, verbose_name='设置类型'),
|
||||||
|
),
|
||||||
|
]
|
||||||
@ -14,6 +14,15 @@ from common.mixins.app_model_mixin import AppModelMixin
|
|||||||
from users.models import User
|
from users.models import User
|
||||||
|
|
||||||
|
|
||||||
|
class Status(models.TextChoices):
|
||||||
|
"""系统设置类型"""
|
||||||
|
SUCCESS = "SUCCESS", '成功'
|
||||||
|
|
||||||
|
ERROR = "ERROR", "失败"
|
||||||
|
|
||||||
|
DOWNLOAD = "DOWNLOAD", '下载中'
|
||||||
|
|
||||||
|
|
||||||
class Model(AppModelMixin):
|
class Model(AppModelMixin):
|
||||||
"""
|
"""
|
||||||
模型数据
|
模型数据
|
||||||
@ -22,6 +31,9 @@ class Model(AppModelMixin):
|
|||||||
|
|
||||||
name = models.CharField(max_length=128, verbose_name="名称")
|
name = models.CharField(max_length=128, verbose_name="名称")
|
||||||
|
|
||||||
|
status = models.CharField(max_length=20, verbose_name='设置类型', choices=Status.choices,
|
||||||
|
default=Status.SUCCESS)
|
||||||
|
|
||||||
model_type = models.CharField(max_length=128, verbose_name="模型类型")
|
model_type = models.CharField(max_length=128, verbose_name="模型类型")
|
||||||
|
|
||||||
model_name = models.CharField(max_length=128, verbose_name="模型名称")
|
model_name = models.CharField(max_length=128, verbose_name="模型名称")
|
||||||
@ -32,6 +44,8 @@ class Model(AppModelMixin):
|
|||||||
|
|
||||||
credential = models.CharField(max_length=5120, verbose_name="模型认证信息")
|
credential = models.CharField(max_length=5120, verbose_name="模型认证信息")
|
||||||
|
|
||||||
|
meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
db_table = "model"
|
db_table = "model"
|
||||||
unique_together = ['name', 'user_id']
|
unique_together = ['name', 'user_id']
|
||||||
|
|||||||
@ -9,10 +9,42 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Dict
|
from typing import Dict, Iterator
|
||||||
|
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
|
||||||
|
|
||||||
|
class DownModelChunkStatus(Enum):
|
||||||
|
success = "success"
|
||||||
|
error = "error"
|
||||||
|
pulling = "pulling"
|
||||||
|
unknown = 'unknown'
|
||||||
|
|
||||||
|
|
||||||
|
class ValidCode(Enum):
|
||||||
|
valid_error = 500
|
||||||
|
model_not_fount = 404
|
||||||
|
|
||||||
|
|
||||||
|
class DownModelChunk:
|
||||||
|
def __init__(self, status: DownModelChunkStatus, digest: str, progress: int, details: str, index: int):
|
||||||
|
self.details = details
|
||||||
|
self.status = status
|
||||||
|
self.digest = digest
|
||||||
|
self.progress = progress
|
||||||
|
self.index = index
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"details": self.details,
|
||||||
|
"status": self.status.value,
|
||||||
|
"digest": self.digest,
|
||||||
|
"progress": self.progress,
|
||||||
|
"index": self.index
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class IModelProvider(ABC):
|
class IModelProvider(ABC):
|
||||||
|
|
||||||
@ -40,6 +72,9 @@ class IModelProvider(ABC):
|
|||||||
def get_dialogue_number(self):
|
def get_dialogue_number(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
|
||||||
|
raise AppApiException(500, "当前平台不支持下载模型")
|
||||||
|
|
||||||
|
|
||||||
class BaseModelCredential(ABC):
|
class BaseModelCredential(ABC):
|
||||||
|
|
||||||
|
|||||||
@ -9,8 +9,8 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from langchain_community.chat_models import AzureChatOpenAI
|
|
||||||
from langchain.schema import HumanMessage
|
from langchain.schema import HumanMessage
|
||||||
|
from langchain_community.chat_models import AzureChatOpenAI
|
||||||
|
|
||||||
from common import froms
|
from common import froms
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
@ -18,7 +18,7 @@ from common.froms import BaseForm
|
|||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
|
||||||
ModelInfo, \
|
ModelInfo, \
|
||||||
ModelTypeConst
|
ModelTypeConst, ValidCode
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
|
||||||
@ -27,15 +27,15 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||||
model_type_list = AzureModelProvider().get_model_type_list()
|
model_type_list = AzureModelProvider().get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
raise AppApiException(500, f'{model_type} 模型类型不支持')
|
raise AppApiException(ValidCode.valid_error, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
if model_name not in model_dict:
|
if model_name not in model_dict:
|
||||||
raise AppApiException(500, f'{model_name} 模型名称不支持')
|
raise AppApiException(ValidCode.valid_error, f'{model_name} 模型名称不支持')
|
||||||
|
|
||||||
for key in ['api_base', 'api_key', 'deployment_name']:
|
for key in ['api_base', 'api_key', 'deployment_name']:
|
||||||
if key not in model_credential:
|
if key not in model_credential:
|
||||||
if raise_exception:
|
if raise_exception:
|
||||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
raise AppApiException(ValidCode.valid_error, f'{key} 字段为必填字段')
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
@ -45,7 +45,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||||||
if isinstance(e, AppApiException):
|
if isinstance(e, AppApiException):
|
||||||
raise e
|
raise e
|
||||||
if raise_exception:
|
if raise_exception:
|
||||||
raise AppApiException(500, '校验失败,请检查参数是否正确')
|
raise AppApiException(ValidCode.valid_error, '校验失败,请检查参数是否正确')
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@ -6,9 +6,14 @@
|
|||||||
@date:2024/3/5 17:23
|
@date:2024/3/5 17:23
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Dict
|
from typing import Dict, Iterator
|
||||||
|
from urllib.parse import urlparse, ParseResult
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import requests
|
||||||
|
from django.http import StreamingHttpResponse
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.schema import HumanMessage
|
from langchain.schema import HumanMessage
|
||||||
|
|
||||||
@ -17,29 +22,26 @@ from common.exception.app_exception import AppApiException
|
|||||||
from common.froms import BaseForm
|
from common.froms import BaseForm
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
||||||
BaseModelCredential
|
BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode
|
||||||
from setting.models_provider.impl.ollama_model_provider.model.ollama_chat_model import OllamaChatModel
|
from setting.models_provider.impl.ollama_model_provider.model.ollama_chat_model import OllamaChatModel
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
|
""
|
||||||
|
|
||||||
|
|
||||||
class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
|
class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
|
||||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||||
model_type_list = OllamaModelProvider().get_model_type_list()
|
model_type_list = OllamaModelProvider().get_model_type_list()
|
||||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||||
raise AppApiException(500, f'{model_type} 模型类型不支持')
|
raise AppApiException(ValidCode.valid_error, f'{model_type} 模型类型不支持')
|
||||||
|
|
||||||
for key in ['api_key']:
|
|
||||||
if key not in model_credential:
|
|
||||||
if raise_exception:
|
|
||||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
try:
|
try:
|
||||||
OllamaModelProvider().get_model(model_type, model_name, model_credential).invoke(
|
model_list = OllamaModelProvider.get_base_model_list(model_credential.get('api_base'))
|
||||||
[HumanMessage(content='valid')])
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if raise_exception:
|
raise AppApiException(ValidCode.valid_error, "API 域名无效")
|
||||||
raise AppApiException(500, "校验失败,请检查 api_key secret_key 是否正确")
|
exist = [model for model in model_list.get('models') if
|
||||||
|
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
|
||||||
|
if len(exist) == 0:
|
||||||
|
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def encryption_dict(self, model_info: Dict[str, object]):
|
def encryption_dict(self, model_info: Dict[str, object]):
|
||||||
@ -86,6 +88,52 @@ model_dict = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_url(url: str):
|
||||||
|
parse = urlparse(url)
|
||||||
|
return ParseResult(scheme=parse.scheme, netloc=parse.netloc, path='', params='',
|
||||||
|
query='',
|
||||||
|
fragment='').geturl()
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_down_model_chunk(row_str: str, chunk_index: int):
|
||||||
|
row = json.loads(row_str)
|
||||||
|
status = DownModelChunkStatus.unknown
|
||||||
|
digest = ""
|
||||||
|
progress = 100
|
||||||
|
if 'status' in row:
|
||||||
|
digest = row.get('status')
|
||||||
|
if row.get('status') == 'success':
|
||||||
|
status = DownModelChunkStatus.success
|
||||||
|
if row.get('status').__contains__("pulling"):
|
||||||
|
status = DownModelChunkStatus.pulling
|
||||||
|
if 'total' in row and 'completed' in row:
|
||||||
|
progress = (row.get('completed') / row.get('total') * 100)
|
||||||
|
elif 'error' in row:
|
||||||
|
status = DownModelChunkStatus.error
|
||||||
|
digest = row.get('error')
|
||||||
|
return DownModelChunk(status=status, digest=digest, progress=progress, details=row_str, index=chunk_index)
|
||||||
|
|
||||||
|
|
||||||
|
def convert(response_stream) -> Iterator[DownModelChunk]:
|
||||||
|
temp = ""
|
||||||
|
index = 0
|
||||||
|
for c in response_stream:
|
||||||
|
index += 1
|
||||||
|
row_content = c.decode()
|
||||||
|
temp += row_content
|
||||||
|
if row_content.endswith('}') or row_content.endswith('\n'):
|
||||||
|
rows = [t for t in temp.split("\n") if len(t) > 0]
|
||||||
|
for row in rows:
|
||||||
|
yield convert_to_down_model_chunk(row, index)
|
||||||
|
temp = ""
|
||||||
|
|
||||||
|
if len(temp) > 0:
|
||||||
|
print(temp)
|
||||||
|
rows = [t for t in temp.split("\n") if len(t) > 0]
|
||||||
|
for row in rows:
|
||||||
|
yield convert_to_down_model_chunk(row, index)
|
||||||
|
|
||||||
|
|
||||||
class OllamaModelProvider(IModelProvider):
|
class OllamaModelProvider(IModelProvider):
|
||||||
def get_model_provide_info(self):
|
def get_model_provide_info(self):
|
||||||
return ModelProvideInfo(provider='model_ollama_provider', name='Ollama', icon=get_file_content(
|
return ModelProvideInfo(provider='model_ollama_provider', name='Ollama', icon=get_file_content(
|
||||||
@ -113,3 +161,21 @@ class OllamaModelProvider(IModelProvider):
|
|||||||
|
|
||||||
def get_dialogue_number(self):
|
def get_dialogue_number(self):
|
||||||
return 2
|
return 2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_base_model_list(api_base):
|
||||||
|
base_url = get_base_url(api_base)
|
||||||
|
r = requests.request(method="GET", url=f"{base_url}/api/tags")
|
||||||
|
r.raise_for_status()
|
||||||
|
return r.json()
|
||||||
|
|
||||||
|
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
|
||||||
|
api_base = model_credential.get('api_base')
|
||||||
|
base_url = get_base_url(api_base)
|
||||||
|
r = requests.request(
|
||||||
|
method="POST",
|
||||||
|
url=f"{base_url}/api/pull",
|
||||||
|
data=json.dumps({"name": model_name}).encode(),
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
return convert(r)
|
||||||
|
|||||||
@ -7,6 +7,8 @@
|
|||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
@ -17,10 +19,36 @@ from application.models import Application
|
|||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
from common.util.field_message import ErrMessage
|
from common.util.field_message import ErrMessage
|
||||||
from common.util.rsa_util import encrypt, decrypt
|
from common.util.rsa_util import encrypt, decrypt
|
||||||
from setting.models.model_management import Model
|
from setting.models.model_management import Model, Status
|
||||||
|
from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus
|
||||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||||
|
|
||||||
|
|
||||||
|
class ModelPullManage:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pull(model: Model, credential: Dict):
|
||||||
|
response = ModelProvideConstants[model.provider].value.down_model(model.model_type, model.model_name,
|
||||||
|
credential)
|
||||||
|
down_model_chunk = {}
|
||||||
|
timestamp = time.time()
|
||||||
|
for chunk in response:
|
||||||
|
down_model_chunk[chunk.digest] = chunk.to_dict()
|
||||||
|
if time.time() - timestamp > 5:
|
||||||
|
QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": list(down_model_chunk.values())})
|
||||||
|
timestamp = time.time()
|
||||||
|
status = Status.ERROR
|
||||||
|
message = ""
|
||||||
|
down_model_chunk_list = list(down_model_chunk.values())
|
||||||
|
for chunk in down_model_chunk_list:
|
||||||
|
if chunk.get('status') == DownModelChunkStatus.success.value:
|
||||||
|
status = Status.SUCCESS
|
||||||
|
if chunk.get('status') == DownModelChunkStatus.error.value:
|
||||||
|
message = chunk.get("digest")
|
||||||
|
QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": down_model_chunk_list, "message": message},
|
||||||
|
status=status)
|
||||||
|
|
||||||
|
|
||||||
class ModelSerializer(serializers.Serializer):
|
class ModelSerializer(serializers.Serializer):
|
||||||
class Query(serializers.Serializer):
|
class Query(serializers.Serializer):
|
||||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||||
@ -50,7 +78,10 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
if self.data.get('provider') is not None:
|
if self.data.get('provider') is not None:
|
||||||
query_params['provider'] = self.data.get('provider')
|
query_params['provider'] = self.data.get('provider')
|
||||||
|
|
||||||
return [ModelSerializer.model_to_dict(model) for model in model_query_set.filter(**query_params)]
|
return [
|
||||||
|
{'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
||||||
|
'model_name': model.model_name, 'status': model.status, 'meta': model.meta} for model in
|
||||||
|
model_query_set.filter(**query_params)]
|
||||||
|
|
||||||
class Edit(serializers.Serializer):
|
class Edit(serializers.Serializer):
|
||||||
user_id = serializers.CharField(required=False, error_messages=ErrMessage.uuid("用户id"))
|
user_id = serializers.CharField(required=False, error_messages=ErrMessage.uuid("用户id"))
|
||||||
@ -88,13 +119,7 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
for k in source_encryption_model_credential.keys():
|
for k in source_encryption_model_credential.keys():
|
||||||
if credential[k] == source_encryption_model_credential[k]:
|
if credential[k] == source_encryption_model_credential[k]:
|
||||||
credential[k] = source_model_credential[k]
|
credential[k] = source_model_credential[k]
|
||||||
# 校验模型认证数据
|
return credential, model_credential
|
||||||
model_credential.is_valid(
|
|
||||||
model_type,
|
|
||||||
model_name,
|
|
||||||
credential,
|
|
||||||
raise_exception=True)
|
|
||||||
return credential
|
|
||||||
|
|
||||||
class Create(serializers.Serializer):
|
class Create(serializers.Serializer):
|
||||||
user_id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
user_id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||||
@ -124,18 +149,28 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
raise_exception=True)
|
raise_exception=True)
|
||||||
|
|
||||||
def insert(self, user_id, with_valid=False):
|
def insert(self, user_id, with_valid=False):
|
||||||
|
status = Status.SUCCESS
|
||||||
if with_valid:
|
if with_valid:
|
||||||
|
try:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
|
except AppApiException as e:
|
||||||
|
if e.code == ValidCode.model_not_fount:
|
||||||
|
status = Status.DOWNLOAD
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
credential = self.data.get('credential')
|
credential = self.data.get('credential')
|
||||||
name = self.data.get('name')
|
name = self.data.get('name')
|
||||||
provider = self.data.get('provider')
|
provider = self.data.get('provider')
|
||||||
model_type = self.data.get('model_type')
|
model_type = self.data.get('model_type')
|
||||||
model_name = self.data.get('model_name')
|
model_name = self.data.get('model_name')
|
||||||
model_credential_str = json.dumps(credential)
|
model_credential_str = json.dumps(credential)
|
||||||
model = Model(id=uuid.uuid1(), user_id=user_id, name=name,
|
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
|
||||||
credential=encrypt(model_credential_str),
|
credential=encrypt(model_credential_str),
|
||||||
provider=provider, model_type=model_type, model_name=model_name)
|
provider=provider, model_type=model_type, model_name=model_name)
|
||||||
model.save()
|
model.save()
|
||||||
|
if status == Status.DOWNLOAD:
|
||||||
|
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
|
||||||
|
thread.start()
|
||||||
return ModelSerializer.Operate(data={'id': model.id, 'user_id': user_id}).one(with_valid=True)
|
return ModelSerializer.Operate(data={'id': model.id, 'user_id': user_id}).one(with_valid=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -143,6 +178,8 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
credential = json.loads(decrypt(model.credential))
|
credential = json.loads(decrypt(model.credential))
|
||||||
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
||||||
'model_name': model.model_name,
|
'model_name': model.model_name,
|
||||||
|
'status': model.status,
|
||||||
|
'meta': model.meta,
|
||||||
'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
|
'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
|
||||||
model.model_name).encryption_dict(
|
model.model_name).encryption_dict(
|
||||||
credential)}
|
credential)}
|
||||||
@ -164,6 +201,15 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id'))
|
model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id'))
|
||||||
return ModelSerializer.model_to_dict(model)
|
return ModelSerializer.model_to_dict(model)
|
||||||
|
|
||||||
|
def one_meta(self, with_valid=False):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id'))
|
||||||
|
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
||||||
|
'model_name': model.model_name,
|
||||||
|
'status': model.status,
|
||||||
|
'meta': model.meta, }
|
||||||
|
|
||||||
def delete(self, with_valid=True):
|
def delete(self, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
@ -181,7 +227,20 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
if model is None:
|
if model is None:
|
||||||
raise AppApiException(500, '不存在的id')
|
raise AppApiException(500, '不存在的id')
|
||||||
else:
|
else:
|
||||||
credential = ModelSerializer.Edit(data={**instance, 'user_id': user_id}).is_valid(model=model)
|
credential, model_credential = ModelSerializer.Edit(data={**instance, 'user_id': user_id}).is_valid(
|
||||||
|
model=model)
|
||||||
|
try:
|
||||||
|
# 校验模型认证数据
|
||||||
|
model_credential.is_valid(
|
||||||
|
model.model_type,
|
||||||
|
instance.get("model_name"),
|
||||||
|
credential,
|
||||||
|
raise_exception=True)
|
||||||
|
except AppApiException as e:
|
||||||
|
if e.code == ValidCode.model_not_fount:
|
||||||
|
model.status = Status.DOWNLOAD
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
update_keys = ['credential', 'name', 'model_type', 'model_name']
|
update_keys = ['credential', 'name', 'model_type', 'model_name']
|
||||||
for update_key in update_keys:
|
for update_key in update_keys:
|
||||||
if update_key in instance and instance.get(update_key) is not None:
|
if update_key in instance and instance.get(update_key) is not None:
|
||||||
@ -191,6 +250,9 @@ class ModelSerializer(serializers.Serializer):
|
|||||||
else:
|
else:
|
||||||
model.__setattr__(update_key, instance.get(update_key))
|
model.__setattr__(update_key, instance.get(update_key))
|
||||||
model.save()
|
model.save()
|
||||||
|
if model.status == Status.DOWNLOAD:
|
||||||
|
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
|
||||||
|
thread.start()
|
||||||
return self.one(with_valid=False)
|
return self.one(with_valid=False)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,7 @@ urlpatterns = [
|
|||||||
name="provider/model_form"),
|
name="provider/model_form"),
|
||||||
path('model', views.Model.as_view(), name='model'),
|
path('model', views.Model.as_view(), name='model'),
|
||||||
path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
|
path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
|
||||||
|
path('model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'),
|
||||||
path('email_setting', views.SystemSetting.Email.as_view(), name='email_setting')
|
path('email_setting', views.SystemSetting.Email.as_view(), name='email_setting')
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|||||||
@ -34,6 +34,17 @@ class Model(APIView):
|
|||||||
ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id,
|
ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id,
|
||||||
with_valid=True))
|
with_valid=True))
|
||||||
|
|
||||||
|
@action(methods=['PUT'], detail=False)
|
||||||
|
@swagger_auto_schema(operation_summary="下载模型,只试用与Ollama平台",
|
||||||
|
operation_id="下载模型,只试用与Ollama平台",
|
||||||
|
request_body=ModelCreateApi.get_request_body_api()
|
||||||
|
, tags=["模型"])
|
||||||
|
@has_permissions(PermissionConstants.MODEL_CREATE)
|
||||||
|
def put(self, request: Request):
|
||||||
|
return result.success(
|
||||||
|
ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id,
|
||||||
|
with_valid=True))
|
||||||
|
|
||||||
@action(methods=['GET'], detail=False)
|
@action(methods=['GET'], detail=False)
|
||||||
@swagger_auto_schema(operation_summary="获取模型列表",
|
@swagger_auto_schema(operation_summary="获取模型列表",
|
||||||
operation_id="获取模型列表",
|
operation_id="获取模型列表",
|
||||||
@ -46,6 +57,18 @@ class Model(APIView):
|
|||||||
data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list(
|
data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list(
|
||||||
with_valid=True))
|
with_valid=True))
|
||||||
|
|
||||||
|
class ModelMeta(APIView):
|
||||||
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@action(methods=['GET'], detail=False)
|
||||||
|
@swagger_auto_schema(operation_summary="查询模型meta信息,该接口不携带认证信息",
|
||||||
|
operation_id="查询模型meta信息,该接口不携带认证信息",
|
||||||
|
tags=["模型"])
|
||||||
|
@has_permissions(PermissionConstants.MODEL_READ)
|
||||||
|
def get(self, request: Request, model_id: str):
|
||||||
|
return result.success(
|
||||||
|
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one_meta(with_valid=True))
|
||||||
|
|
||||||
class Operate(APIView):
|
class Operate(APIView):
|
||||||
authentication_classes = [TokenAuth]
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
|||||||
@ -106,6 +106,31 @@ const updateModel: (
|
|||||||
return put(`${prefix}/${model_id}`, request, {}, loading)
|
return put(`${prefix}/${model_id}`, request, {}, loading)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取模型详情根据模型id 包括认证信息
|
||||||
|
* @param model_id 模型id
|
||||||
|
* @param loading 加载器
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
const getModelById: (model_id: string, loading?: Ref<boolean>) => Promise<Result<Model>> = (
|
||||||
|
model_id,
|
||||||
|
loading
|
||||||
|
) => {
|
||||||
|
return get(`${prefix}/${model_id}`, {}, loading)
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* 获取模型信息不包括认证信息根据模型id
|
||||||
|
* @param model_id 模型id
|
||||||
|
* @param loading 加载器
|
||||||
|
* @returns
|
||||||
|
*/
|
||||||
|
const getModelMetaById: (model_id: string, loading?: Ref<boolean>) => Promise<Result<Model>> = (
|
||||||
|
model_id,
|
||||||
|
loading
|
||||||
|
) => {
|
||||||
|
return get(`${prefix}/${model_id}/meta`, {}, loading)
|
||||||
|
}
|
||||||
|
|
||||||
const deleteModel: (model_id: string, loading?: Ref<boolean>) => Promise<Result<boolean>> = (
|
const deleteModel: (model_id: string, loading?: Ref<boolean>) => Promise<Result<boolean>> = (
|
||||||
model_id,
|
model_id,
|
||||||
loading
|
loading
|
||||||
@ -120,5 +145,7 @@ export default {
|
|||||||
listBaseModel,
|
listBaseModel,
|
||||||
createModel,
|
createModel,
|
||||||
updateModel,
|
updateModel,
|
||||||
deleteModel
|
deleteModel,
|
||||||
|
getModelById,
|
||||||
|
getModelMetaById
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import { store } from '@/stores'
|
import { store } from '@/stores'
|
||||||
|
import { Dict } from './common'
|
||||||
interface modelRequest {
|
interface modelRequest {
|
||||||
name: string
|
name: string
|
||||||
model_type: string
|
model_type: string
|
||||||
@ -64,6 +65,14 @@ interface Model {
|
|||||||
* 供应商
|
* 供应商
|
||||||
*/
|
*/
|
||||||
provider: string
|
provider: string
|
||||||
|
/**
|
||||||
|
* 状态
|
||||||
|
*/
|
||||||
|
status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR'
|
||||||
|
/**
|
||||||
|
* 元数据
|
||||||
|
*/
|
||||||
|
meta: Dict<any>
|
||||||
}
|
}
|
||||||
interface CreateModelRequest {
|
interface CreateModelRequest {
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -18,6 +18,7 @@
|
|||||||
</template>
|
</template>
|
||||||
|
|
||||||
<DynamicsForm
|
<DynamicsForm
|
||||||
|
v-loading="formLoading"
|
||||||
v-model="form_data"
|
v-model="form_data"
|
||||||
:render_data="model_form_field"
|
:render_data="model_form_field"
|
||||||
:model="form_data"
|
:model="form_data"
|
||||||
@ -56,7 +57,7 @@
|
|||||||
@change="getModelForm($event)"
|
@change="getModelForm($event)"
|
||||||
v-loading="base_model_loading"
|
v-loading="base_model_loading"
|
||||||
style="width: 100%"
|
style="width: 100%"
|
||||||
v-model="form_data.model_name"
|
v-model="base_form_data.model_name"
|
||||||
class="m-2"
|
class="m-2"
|
||||||
placeholder="请选择基础模型"
|
placeholder="请选择基础模型"
|
||||||
filterable
|
filterable
|
||||||
@ -90,10 +91,12 @@ import type { FormField } from '@/components/dynamics-form/type'
|
|||||||
import DynamicsForm from '@/components/dynamics-form/index.vue'
|
import DynamicsForm from '@/components/dynamics-form/index.vue'
|
||||||
import type { FormRules } from 'element-plus'
|
import type { FormRules } from 'element-plus'
|
||||||
import { MsgSuccess } from '@/utils/message'
|
import { MsgSuccess } from '@/utils/message'
|
||||||
|
|
||||||
const providerValue = ref<Provider>()
|
const providerValue = ref<Provider>()
|
||||||
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
|
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
|
||||||
const emit = defineEmits(['change', 'submit'])
|
const emit = defineEmits(['change', 'submit'])
|
||||||
const loading = ref<boolean>(false)
|
const loading = ref<boolean>(false)
|
||||||
|
const formLoading = ref<boolean>(false)
|
||||||
const model_type_loading = ref<boolean>(false)
|
const model_type_loading = ref<boolean>(false)
|
||||||
const base_model_loading = ref<boolean>(false)
|
const base_model_loading = ref<boolean>(false)
|
||||||
const model_type_list = ref<Array<KeyValue<string, string>>>([])
|
const model_type_list = ref<Array<KeyValue<string, string>>>([])
|
||||||
@ -152,12 +155,12 @@ const list_base_model = (model_type: any) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
const open = (provider: Provider, model: Model) => {
|
const open = (provider: Provider, model: Model) => {
|
||||||
modelValue.value = model
|
ModelApi.getModelById(model.id, formLoading).then((ok) => {
|
||||||
|
modelValue.value = ok.data
|
||||||
ModelApi.listModelType(model.provider, model_type_loading).then((ok) => {
|
ModelApi.listModelType(model.provider, model_type_loading).then((ok) => {
|
||||||
model_type_list.value = ok.data
|
model_type_list.value = ok.data
|
||||||
list_base_model(model.model_type)
|
list_base_model(model.model_type)
|
||||||
})
|
})
|
||||||
|
|
||||||
providerValue.value = provider
|
providerValue.value = provider
|
||||||
|
|
||||||
base_form_data.value = {
|
base_form_data.value = {
|
||||||
@ -167,6 +170,7 @@ const open = (provider: Provider, model: Model) => {
|
|||||||
}
|
}
|
||||||
form_data.value = model.credential
|
form_data.value = model.credential
|
||||||
getModelForm(model.model_name)
|
getModelForm(model.model_name)
|
||||||
|
})
|
||||||
dialogVisible.value = true
|
dialogVisible.value = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -37,15 +37,32 @@
|
|||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import type { Provider, Model } from '@/api/type/model'
|
import type { Provider, Model } from '@/api/type/model'
|
||||||
import ModelApi from '@/api/model'
|
import ModelApi from '@/api/model'
|
||||||
import { computed, ref } from 'vue'
|
import { computed, ref, onMounted, onBeforeUnmount } from 'vue'
|
||||||
import EditModel from '@/views/template/component/EditModel.vue'
|
import EditModel from '@/views/template/component/EditModel.vue'
|
||||||
import { MsgSuccess, MsgConfirm } from '@/utils/message'
|
import { MsgSuccess, MsgConfirm } from '@/utils/message'
|
||||||
|
|
||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
model: Model
|
model: Model
|
||||||
provider_list: Array<Provider>
|
provider_list: Array<Provider>
|
||||||
}>()
|
}>()
|
||||||
|
const downModel = ref<Model>()
|
||||||
|
|
||||||
|
const progress = computed(() => {
|
||||||
|
if (downModel.value) {
|
||||||
|
const down_model_chunk = downModel.value.meta['down_model_chunk']
|
||||||
|
if (down_model_chunk) {
|
||||||
|
const maxObj = down_model_chunk.reduce((prev: any, current: any) => {
|
||||||
|
return (prev.index || 0) > (current.index || 0) ? prev : current
|
||||||
|
})
|
||||||
|
return maxObj.progress
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
})
|
||||||
const emit = defineEmits(['change'])
|
const emit = defineEmits(['change'])
|
||||||
const eidtModelRef = ref<InstanceType<typeof EditModel>>()
|
const eidtModelRef = ref<InstanceType<typeof EditModel>>()
|
||||||
|
let interval: any
|
||||||
const deleteModel = () => {
|
const deleteModel = () => {
|
||||||
MsgConfirm(`删除模型 `, `是否删除模型:${props.model.name} ?`, {
|
MsgConfirm(`删除模型 `, `是否删除模型:${props.model.name} ?`, {
|
||||||
confirmButtonText: '删除',
|
confirmButtonText: '删除',
|
||||||
@ -67,6 +84,34 @@ const openEditModel = () => {
|
|||||||
const icon = computed(() => {
|
const icon = computed(() => {
|
||||||
return props.provider_list.find((p) => p.provider === props.model.provider)?.icon
|
return props.provider_list.find((p) => p.provider === props.model.provider)?.icon
|
||||||
})
|
})
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 初始化轮询
|
||||||
|
*/
|
||||||
|
const initInterval = () => {
|
||||||
|
interval = setInterval(() => {
|
||||||
|
if (props.model.status === 'DOWNLOAD') {
|
||||||
|
ModelApi.getModelMetaById(props.model.id).then((ok) => {
|
||||||
|
downModel.value = ok.data
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}, 6000)
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* 关闭轮询
|
||||||
|
*/
|
||||||
|
const closeInterval = () => {
|
||||||
|
if (interval) {
|
||||||
|
clearInterval(interval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
onMounted(() => {
|
||||||
|
initInterval()
|
||||||
|
})
|
||||||
|
onBeforeUnmount(() => {
|
||||||
|
// 清除定时任务
|
||||||
|
closeInterval()
|
||||||
|
})
|
||||||
</script>
|
</script>
|
||||||
<style lang="scss" scoped>
|
<style lang="scss" scoped>
|
||||||
.model-card {
|
.model-card {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user