refactor: model

This commit is contained in:
wxg0103 2025-06-26 17:26:50 +08:00
parent d49f448a5f
commit e8f80094ce
3 changed files with 20 additions and 23 deletions

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import json import json
import os
import threading import threading
import time import time
from typing import Dict from typing import Dict
@ -11,8 +12,11 @@ from rest_framework import serializers
from django.db.models.query_utils import Q from django.db.models.query_utils import Q
from common.config.embedding_config import ModelManage from common.config.embedding_config import ModelManage
from common.database_model_manage.database_model_manage import DatabaseModelManage from common.database_model_manage.database_model_manage import DatabaseModelManage
from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.utils.common import get_file_content
from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt from common.utils.rsa_util import rsa_long_encrypt, rsa_long_decrypt
from maxkb.conf import PROJECT_DIR
from models_provider.base_model_provider import ValidCode, DownModelChunkStatus from models_provider.base_model_provider import ValidCode, DownModelChunkStatus
from models_provider.constants.model_provider_constants import ModelProvideConstants from models_provider.constants.model_provider_constants import ModelProvideConstants
from models_provider.models import Model, Status from models_provider.models import Model, Status
@ -412,27 +416,13 @@ class ModelSerializer(serializers.Serializer):
return True return True
def get_authorized_tool(tool_query_set, workspace_id, model_workspace_authorization): def get_authorized_tool(tool_query_set, workspace_id):
# 对所有工作空间拉黑的工具 model_id_list = select_list(get_file_content(
non_auths = QuerySet(model_workspace_authorization).filter( os.path.join(PROJECT_DIR, "apps", "models_provider", 'sql',
Q(workspace_id='None') & Q(authentication_type='WHITE_LIST') 'list_share_authorized_model.sql'
).values_list('model_id', flat=True) )), [workspace_id, workspace_id])
# 授权给所有工作空间的工具
all_auths = QuerySet(model_workspace_authorization).filter(
Q(workspace_id='None') & Q(authentication_type='BLACK_LIST')
).values_list('model_id', flat=True)
# 查询白名单授权的工具
white_authorized_tool_ids = QuerySet(model_workspace_authorization).filter(
workspace_id=workspace_id, authentication_type='WHITE_LIST'
).values_list('model_id', flat=True)
# 查询黑名单授权的工具
black_authorized_tool_ids = QuerySet(model_workspace_authorization).filter(
workspace_id=workspace_id, authentication_type='BLACK_LIST'
).values_list('model_id', flat=True)
tool_query_set = tool_query_set.filter( tool_query_set = tool_query_set.filter(
id__in=list(white_authorized_tool_ids) + list(all_auths) id__in=[k.get('model_id') for k in model_id_list]
).exclude(
id__in=list(black_authorized_tool_ids) + list(non_auths)
) )
return tool_query_set return tool_query_set
@ -471,8 +461,7 @@ class WorkspaceSharedModelSerializer(serializers.Serializer):
if workspace_id: if workspace_id:
model_workspace_authorization = DatabaseModelManage.get_model("model_workspace_authorization") model_workspace_authorization = DatabaseModelManage.get_model("model_workspace_authorization")
if model_workspace_authorization is not None: if model_workspace_authorization is not None:
queryset = get_authorized_tool(queryset, workspace_id, queryset = get_authorized_tool(queryset, workspace_id)
model_workspace_authorization=model_workspace_authorization)
for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']: for field in ['name', 'model_type', 'model_name', 'provider', 'create_user']:
value = self.data.get(field) value = self.data.get(field)

View File

@ -0,0 +1,8 @@
select model_id
from model_workspace_authorization
where case
when authentication_type = 'WHITE_LIST' then
%s = any (workspace_id_list)
else
not %s = any(workspace_id_list)
end

View File

@ -9,7 +9,7 @@ export default {
}, },
tip: { tip: {
createSuccessMessage: '创建模型成功', createSuccessMessage: '创建模型成功',
createErrorMessage: '基础信息填写错误', createErrorMessage: '基础信息填写错误',
errorMessage: '变量已存在: ', errorMessage: '变量已存在: ',
emptyMessage1: '请先选择基础信息的模型类型和基础模型', emptyMessage1: '请先选择基础信息的模型类型和基础模型',
emptyMessage2: '所选模型不支持参数设置', emptyMessage2: '所选模型不支持参数设置',