fix: 修复模型没权限使用时报错 (#825)

This commit is contained in:
shaohuzhang1 2024-07-22 11:44:38 +08:00 committed by GitHub
parent fb5ad9a06b
commit a404a5c6e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 64 additions and 75 deletions

View File

@ -53,8 +53,7 @@ class IChatStep(IBaseChatPipelineStep):
# 对话列表 # 对话列表
message_list = serializers.ListField(required=True, child=MessageField(required=True), message_list = serializers.ListField(required=True, child=MessageField(required=True),
error_messages=ErrMessage.list("对话列表")) error_messages=ErrMessage.list("对话列表"))
# 大语言模型 model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.list("大语言模型"))
# 段落列表 # 段落列表
paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表")) paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表"))
# 对话id # 对话id
@ -73,6 +72,8 @@ class IChatStep(IBaseChatPipelineStep):
# 未查询到引用分段 # 未查询到引用分段
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置")) no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
def is_valid(self, *, raise_exception=False): def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True) super().is_valid(raise_exception=True)
message_list: List = self.initial_data.get('message_list') message_list: List = self.initial_data.get('message_list')
@ -91,7 +92,8 @@ class IChatStep(IBaseChatPipelineStep):
def execute(self, message_list: List[BaseMessage], def execute(self, message_list: List[BaseMessage],
chat_id, problem_text, chat_id, problem_text,
post_response_handler: PostResponseHandler, post_response_handler: PostResponseHandler,
chat_model: BaseChatModel = None, model_id: str = None,
user_id: str = None,
paragraph_list=None, paragraph_list=None,
manage: PipelineManage = None, manage: PipelineManage = None,
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,

View File

@ -26,6 +26,7 @@ from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, Post
from application.models.api_key_model import ApplicationPublicAccessClient from application.models.api_key_model import ApplicationPublicAccessClient
from common.constants.authentication_type import AuthenticationType from common.constants.authentication_type import AuthenticationType
from common.response import result from common.response import result
from setting.models_provider.tools import get_model_instance_by_model_user_id
def add_access_num(client_id=None, client_type=None): def add_access_num(client_id=None, client_type=None):
@ -101,7 +102,8 @@ class BaseChatStep(IChatStep):
chat_id, chat_id,
problem_text, problem_text,
post_response_handler: PostResponseHandler, post_response_handler: PostResponseHandler,
chat_model: BaseChatModel = None, model_id: str = None,
user_id: str = None,
paragraph_list=None, paragraph_list=None,
manage: PipelineManage = None, manage: PipelineManage = None,
padding_problem_text: str = None, padding_problem_text: str = None,
@ -109,6 +111,7 @@ class BaseChatStep(IChatStep):
client_id=None, client_type=None, client_id=None, client_type=None,
no_references_setting=None, no_references_setting=None,
**kwargs): **kwargs):
chat_model = get_model_instance_by_model_user_id(model_id, user_id)
if stream: if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model, return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list, paragraph_list,

View File

@ -111,7 +111,7 @@ class FlowParamsSerializer(serializers.Serializer):
client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型")) client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id")) user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案")) re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案"))

View File

@ -32,6 +32,7 @@ class IChatNode(INode):
def _run(self): def _run(self):
return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
chat_record_id,
**kwargs) -> NodeResult: **kwargs) -> NodeResult:
pass pass

View File

@ -6,21 +6,17 @@
@date2024/6/4 14:30 @date2024/6/4 14:30
@desc: @desc:
""" """
import json
import time import time
from functools import reduce from functools import reduce
from typing import List, Dict from typing import List, Dict
from django.db.models import QuerySet
from langchain.schema import HumanMessage, SystemMessage from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from application.flow import tools from application.flow import tools
from application.flow.i_step_node import NodeResult, INode from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from common.util.rsa_util import rsa_long_decrypt from setting.models_provider.tools import get_model_instance_by_model_user_id
from setting.models import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
@ -125,13 +121,7 @@ def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable:
class BaseChatNode(IChatNode): class BaseChatNode(IChatNode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
**kwargs) -> NodeResult: **kwargs) -> NodeResult:
model = QuerySet(Model).filter(id=model_id).first() chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
if model is None:
raise Exception("模型不存在")
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(model.credential)),
streaming=True)
history_message = self.get_history_message(history_chat_record, dialogue_number) history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt) question = self.generate_prompt_question(prompt)

View File

@ -6,21 +6,17 @@
@date2024/6/4 14:30 @date2024/6/4 14:30
@desc: @desc:
""" """
import json
import time import time
from functools import reduce from functools import reduce
from typing import List, Dict from typing import List, Dict
from django.db.models import QuerySet
from langchain.schema import HumanMessage, SystemMessage from langchain.schema import HumanMessage, SystemMessage
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from application.flow import tools from application.flow import tools
from application.flow.i_step_node import NodeResult, INode from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.question_node.i_question_node import IQuestionNode from application.flow.step_node.question_node.i_question_node import IQuestionNode
from common.util.rsa_util import rsa_long_decrypt from setting.models_provider.tools import get_model_instance_by_model_user_id
from setting.models import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow):
@ -125,13 +121,7 @@ def to_response(chat_id, chat_record_id, node_variable: Dict, workflow_variable:
class BaseQuestionNode(IQuestionNode): class BaseQuestionNode(IQuestionNode):
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
**kwargs) -> NodeResult: **kwargs) -> NodeResult:
model = QuerySet(Model).filter(id=model_id).first() chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
if model is None:
raise Exception("模型不存在")
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(model.credential)),
streaming=True)
history_message = self.get_history_message(history_chat_record, dialogue_number) history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt) question = self.generate_prompt_question(prompt)

View File

@ -13,25 +13,15 @@ from django.db.models import QuerySet
from application.flow.i_step_node import NodeResult from application.flow.i_step_node import NodeResult
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
from common.config.embedding_config import VectorStore, ModelManage from common.config.embedding_config import VectorStore
from common.db.search import native_search from common.db.search import native_search
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from dataset.models import Document, Paragraph, DataSet from dataset.models import Document, Paragraph, DataSet
from embedding.models import SearchMode from embedding.models import SearchMode
from setting.models import Model from setting.models_provider.tools import get_model_instance_by_model_user_id
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
def get_model_by_id(_id, user_id):
model = QuerySet(Model).filter(id=_id).first()
if model is None:
raise Exception("模型不存在")
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
raise Exception(f"无权限使用此模型:{model.name}")
return model
def get_embedding_id(dataset_id_list): def get_embedding_id(dataset_id_list):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list) dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1: if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
@ -55,8 +45,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
if len(dataset_id_list) == 0: if len(dataset_id_list) == 0:
return get_none_result(question) return get_none_result(question)
model_id = get_embedding_id(dataset_id_list) model_id = get_embedding_id(dataset_id_list)
model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id')) embedding_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(question) embedding_value = embedding_model.embed_query(question)
vector = VectorStore.get_embedding_vector() vector = VectorStore.get_embedding_vector()
exclude_document_id_list = [str(document.id) for document in exclude_document_id_list = [str(document.id) for document in

View File

@ -639,10 +639,11 @@ class ApplicationSerializer(serializers.Serializer):
application.model_id = None application.model_id = None
else: else:
model = QuerySet(Model).filter( model = QuerySet(Model).filter(
id=instance.get('model_id'), id=instance.get('model_id')).first()
user_id=application.user_id).first()
if model is None: if model is None:
raise AppApiException(500, "模型不存在") raise AppApiException(500, "模型不存在")
if not model.is_permission(application.user_id):
raise AppApiException(500, f"沒有权限使用该模型:{model.name}")
if 'work_flow' in instance: if 'work_flow' in instance:
# 当前用户可修改关联的知识库列表 # 当前用户可修改关联的知识库列表
application_dataset_id_list = [str(dataset_dict.get('id')) for dataset_dict in application_dataset_id_list = [str(dataset_dict.get('id')) for dataset_dict in

View File

@ -267,14 +267,6 @@ class ChatMessageSerializer(serializers.Serializer):
@staticmethod @staticmethod
def re_open_chat_simple(chat_id, application): def re_open_chat_simple(chat_id, application):
model = QuerySet(Model).filter(id=application.model_id).first()
chat_model = None
if model is not None:
# 对话模型
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(model.credential)),
streaming=True)
# 数据集id列表 # 数据集id列表
dataset_id_list = [str(row.dataset_id) for row in dataset_id_list = [str(row.dataset_id) for row in
QuerySet(ApplicationDatasetMapping).filter( QuerySet(ApplicationDatasetMapping).filter(
@ -285,7 +277,7 @@ class ChatMessageSerializer(serializers.Serializer):
QuerySet(Document).filter( QuerySet(Document).filter(
dataset_id__in=dataset_id_list, dataset_id__in=dataset_id_list,
is_active=False)] is_active=False)]
return ChatInfo(chat_id, chat_model, dataset_id_list, exclude_document_id_list, application) return ChatInfo(chat_id, None, dataset_id_list, exclude_document_id_list, application)
@staticmethod @staticmethod
def re_open_chat_work_flow(chat_id, application): def re_open_chat_work_flow(chat_id, application):

View File

@ -7,7 +7,6 @@
@desc: @desc:
""" """
import datetime import datetime
import json
import os import os
import re import re
import uuid import uuid
@ -38,13 +37,11 @@ from common.util.common import post
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock from common.util.lock import try_lock, un_lock
from common.util.rsa_util import rsa_long_decrypt
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers from dataset.serializers.paragraph_serializers import ParagraphSerializers
from setting.models import Model from setting.models import Model
from setting.models_provider import get_model from setting.models_provider import get_model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
chat_cache = caches['model_cache'] chat_cache = caches['model_cache']
@ -238,16 +235,12 @@ class ChatSerializers(serializers.Serializer):
def open_simple(self, application): def open_simple(self, application):
application_id = self.data.get('application_id') application_id = self.data.get('application_id')
model = QuerySet(Model).filter(id=application.model_id).first()
dataset_id_list = [str(row.dataset_id) for row in dataset_id_list = [str(row.dataset_id) for row in
QuerySet(ApplicationDatasetMapping).filter( QuerySet(ApplicationDatasetMapping).filter(
application_id=application_id)] application_id=application_id)]
chat_model = None
if model is not None:
chat_model = ModelManage.get_model(str(model.id), lambda _id: get_model(model))
chat_id = str(uuid.uuid1()) chat_id = str(uuid.uuid1())
chat_cache.set(chat_id, chat_cache.set(chat_id,
ChatInfo(chat_id, chat_model, dataset_id_list, ChatInfo(chat_id, None, dataset_id_list,
[str(document.id) for document in [str(document.id) for document in
QuerySet(Document).filter( QuerySet(Document).filter(
dataset_id__in=dataset_id_list, dataset_id__in=dataset_id_list,
@ -318,24 +311,14 @@ class ChatSerializers(serializers.Serializer):
user_id = self.is_valid(raise_exception=True) user_id = self.is_valid(raise_exception=True)
chat_id = str(uuid.uuid1()) chat_id = str(uuid.uuid1())
model_id = self.data.get('model_id') model_id = self.data.get('model_id')
if model_id is not None and len(model_id) > 0:
model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first()
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(
model.credential)),
streaming=True)
else:
model = None
chat_model = None
dataset_id_list = self.data.get('dataset_id_list') dataset_id_list = self.data.get('dataset_id_list')
application = Application(id=None, dialogue_number=3, model=model, application = Application(id=None, dialogue_number=3, model_id=model_id,
dataset_setting=self.data.get('dataset_setting'), dataset_setting=self.data.get('dataset_setting'),
model_setting=self.data.get('model_setting'), model_setting=self.data.get('model_setting'),
problem_optimization=self.data.get('problem_optimization'), problem_optimization=self.data.get('problem_optimization'),
user_id=user_id) user_id=user_id)
chat_cache.set(chat_id, chat_cache.set(chat_id,
ChatInfo(chat_id, chat_model, dataset_id_list, ChatInfo(chat_id, None, dataset_id_list,
[str(document.id) for document in [str(document.id) for document in
QuerySet(Document).filter( QuerySet(Document).filter(
dataset_id__in=dataset_id_list, dataset_id__in=dataset_id_list,

View File

@ -56,6 +56,11 @@ class Model(AppModelMixin):
permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices, permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices,
default=PermissionType.PRIVATE) default=PermissionType.PRIVATE)
def is_permission(self, user_id):
if self.permission_type == PermissionType.PRIVATE and str(user_id) == str(self.user_id):
return False
return True
class Meta: class Meta:
db_table = "model" db_table = "model"
unique_together = ['name', 'user_id'] unique_together = ['name', 'user_id']

View File

@ -0,0 +1,33 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file tools.py
@date2024/7/22 11:18
@desc:
"""
from django.db.models import QuerySet
from common.config.embedding_config import ModelManage
from setting.models import Model
from setting.models_provider import get_model
def get_model_by_id(_id, user_id):
model = QuerySet(Model).filter(id=_id).first()
if model is None:
raise Exception("模型不存在")
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
raise Exception(f"无权限使用此模型:{model.name}")
return model
def get_model_instance_by_model_user_id(model_id, user_id):
"""
获取模型实例,根据模型相关数据
@param model_id: 模型id
@param user_id: 用户id
@return: 模型实例
"""
model = get_model_by_id(model_id, user_id)
return ModelManage.get_model(model_id, lambda _id: get_model(model))