fix: chat bugs (#3308)

This commit is contained in:
shaohuzhang1 2025-06-19 14:53:24 +08:00 committed by GitHub
parent 03ec0f3fdf
commit 598b72fd12
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 83 additions and 68 deletions

View File

@ -75,7 +75,7 @@ class IChatStep(IBaseChatPipelineStep):
no_references_setting = NoReferencesSetting(required=True, no_references_setting = NoReferencesSetting(required=True,
label=_("No reference segment settings")) label=_("No reference segment settings"))
user_id = serializers.UUIDField(required=True, label=_("User ID")) workspace_id = serializers.CharField(required=True, label=_("Workspace ID"))
model_setting = serializers.DictField(required=True, allow_null=True, model_setting = serializers.DictField(required=True, allow_null=True,
label=_("Model settings")) label=_("Model settings"))
@ -102,7 +102,7 @@ class IChatStep(IBaseChatPipelineStep):
chat_id, problem_text, chat_id, problem_text,
post_response_handler: PostResponseHandler, post_response_handler: PostResponseHandler,
model_id: str = None, model_id: str = None,
user_id: str = None, workspace_id: str = None,
paragraph_list=None, paragraph_list=None,
manage: PipelineManage = None, manage: PipelineManage = None,
padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None, padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None,

View File

@ -26,7 +26,7 @@ from application.chat_pipeline.pipeline_manage import PipelineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from application.flow.tools import Reasoning from application.flow.tools import Reasoning
from application.models import ApplicationChatUserStats, ChatUserType from application.models import ApplicationChatUserStats, ChatUserType
from models_provider.tools import get_model_instance_by_model_user_id from models_provider.tools import get_model_instance_by_model_workspace_id
def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None): def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None):
@ -157,7 +157,7 @@ class BaseChatStep(IChatStep):
problem_text, problem_text,
post_response_handler: PostResponseHandler, post_response_handler: PostResponseHandler,
model_id: str = None, model_id: str = None,
user_id: str = None, workspace_id: str = None,
paragraph_list=None, paragraph_list=None,
manage: PipelineManage = None, manage: PipelineManage = None,
padding_problem_text: str = None, padding_problem_text: str = None,
@ -167,8 +167,8 @@ class BaseChatStep(IChatStep):
model_params_setting=None, model_params_setting=None,
model_setting=None, model_setting=None,
**kwargs): **kwargs):
chat_model = get_model_instance_by_model_user_id(model_id, user_id, chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
**model_params_setting) if model_id is not None else None **model_params_setting) if model_id is not None else None
if stream: if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model, return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list, paragraph_list,

View File

@ -27,7 +27,7 @@ class IResetProblemStep(IBaseChatPipelineStep):
label=_("History Questions")) label=_("History Questions"))
# 大语言模型 # 大语言模型
model_id = serializers.UUIDField(required=False, allow_null=True, label=_("Model id")) model_id = serializers.UUIDField(required=False, allow_null=True, label=_("Model id"))
user_id = serializers.UUIDField(required=True, label=_("User ID")) workspace_id = serializers.CharField(required=True, label=_("User ID"))
problem_optimization_prompt = serializers.CharField(required=False, max_length=102400, problem_optimization_prompt = serializers.CharField(required=False, max_length=102400,
label=_("Question completion prompt")) label=_("Question completion prompt"))
@ -50,6 +50,6 @@ class IResetProblemStep(IBaseChatPipelineStep):
@abstractmethod @abstractmethod
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None, def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
problem_optimization_prompt=None, problem_optimization_prompt=None,
user_id=None, workspace_id=None,
**kwargs): **kwargs):
pass pass

View File

@ -14,7 +14,7 @@ from langchain.schema import HumanMessage
from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep
from application.models import ChatRecord from application.models import ChatRecord
from common.utils.split_model import flat_map from common.utils.split_model import flat_map
from models_provider.tools import get_model_instance_by_model_user_id from models_provider.tools import get_model_instance_by_model_workspace_id
prompt = _( prompt = _(
"() contains the user's question. Answer the guessed user's question based on the context ({question}) Requirement: Output a complete question and put it in the <data></data> tag") "() contains the user's question. Answer the guessed user's question based on the context ({question}) Requirement: Output a complete question and put it in the <data></data> tag")
@ -23,9 +23,9 @@ prompt = _(
class BaseResetProblemStep(IResetProblemStep): class BaseResetProblemStep(IResetProblemStep):
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None, def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
problem_optimization_prompt=None, problem_optimization_prompt=None,
user_id=None, workspace_id=None,
**kwargs) -> str: **kwargs) -> str:
chat_model = get_model_instance_by_model_user_id(model_id, user_id) if model_id is not None else None chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id) if model_id is not None else None
if chat_model is None: if chat_model is None:
return problem_text return problem_text
start_index = len(history_chat_record) - 3 start_index = len(history_chat_record) - 3

View File

@ -44,7 +44,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message=_("The type only supports embedding|keywords|blend"), code=500) message=_("The type only supports embedding|keywords|blend"), code=500)
], label=_("Retrieval Mode")) ], label=_("Retrieval Mode"))
user_id = serializers.UUIDField(required=True, label=_("User ID")) workspace_id = serializers.CharField(required=True, label=_("Workspace ID"))
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]: def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
return self.InstanceSerializer return self.InstanceSerializer
@ -58,19 +58,19 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str], def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None, search_mode: str = None,
user_id=None, workspace_id=None,
**kwargs) -> List[ParagraphPipelineModel]: **kwargs) -> List[ParagraphPipelineModel]:
""" """
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询 关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
:param similarity: 相关性 :param similarity: 相关性
:param top_n: 查询多少条 :param top_n: 查询多少条
:param problem_text: 用户问题 :param problem_text: 用户问题
:param knowledge_id_list: 需要查询的数据集id列表 :param knowledge_id_list: 需要查询的数据集id列表
:param exclude_document_id_list: 需要排除的文档id :param exclude_document_id_list: 需要排除的文档id
:param exclude_paragraph_id_list: 需要排除段落id :param exclude_paragraph_id_list: 需要排除段落id
:param padding_problem_text 补全问题 :param padding_problem_text 补全问题
:param search_mode 检索模式 :param search_mode 检索模式
:param user_id 用户id :param workspace_id 工作空间id
:return: 段落列表 :return: 段落列表
""" """
pass pass

View File

@ -25,13 +25,13 @@ from models_provider.models import Model
from models_provider.tools import get_model from models_provider.tools import get_model
def get_model_by_id(_id, user_id): def get_model_by_id(_id, workspace_id):
model = QuerySet(Model).filter(id=_id).first() model = QuerySet(Model).filter(id=_id, model_type="EMBEDDING").first()
if model is None: if model is None:
raise Exception(_("Model does not exist")) raise Exception(_("Model does not exist"))
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id): if model.workspace_id is not None:
message = lazy_format(_('No permission to use this model {model_name}'), model_name=model.name) if model.workspace_id != workspace_id:
raise Exception(message) raise Exception(_("Model does not exist"))
return model return model
@ -50,13 +50,13 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str], def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None, search_mode: str = None,
user_id=None, workspace_id=None,
**kwargs) -> List[ParagraphPipelineModel]: **kwargs) -> List[ParagraphPipelineModel]:
if len(knowledge_id_list) == 0: if len(knowledge_id_list) == 0:
return [] return []
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
model_id = get_embedding_id(knowledge_id_list) model_id = get_embedding_id(knowledge_id_list)
model = get_model_by_id(model_id, user_id) model = get_model_by_id(model_id, workspace_id)
self.context['model_name'] = model.name self.context['model_name'] = model.name
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model)) embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(exec_problem_text) embedding_value = embedding_model.embed_query(exec_problem_text)

View File

@ -11,7 +11,6 @@ import json
import re import re
import time import time
from functools import reduce from functools import reduce
from types import AsyncGeneratorType
from typing import List, Dict from typing import List, Dict
from django.db.models import QuerySet from django.db.models import QuerySet
@ -24,7 +23,7 @@ from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from application.flow.tools import Reasoning from application.flow.tools import Reasoning
from models_provider.models import Model from models_provider.models import Model
from models_provider.tools import get_model_credential, get_model_instance_by_model_user_id from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
tool_message_template = """ tool_message_template = """
<details> <details>
@ -206,8 +205,9 @@ class BaseChatNode(IChatNode):
model_setting = {'reasoning_content_enable': False, 'reasoning_content_end': '</think>', model_setting = {'reasoning_content_enable': False, 'reasoning_content_end': '</think>',
'reasoning_content_start': '<think>'} 'reasoning_content_start': '<think>'}
self.context['model_setting'] = model_setting self.context['model_setting'] = model_setting
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), workspace_id = self.workflow_manage.get_body().get('workspace_id')
**model_params_setting) chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
**model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type, history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type,
self.runtime_node_id) self.runtime_node_id)
self.context['history_message'] = history_message self.context['history_message'] = history_message

View File

@ -9,7 +9,7 @@ from application.flow.i_step_node import NodeResult
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
from common.utils.common import bytes_to_uploaded_file from common.utils.common import bytes_to_uploaded_file
from oss.serializers.file import FileSerializer from oss.serializers.file import FileSerializer
from models_provider.tools import get_model_instance_by_model_user_id from models_provider.tools import get_model_instance_by_model_workspace_id
class BaseImageGenerateNode(IImageGenerateNode): class BaseImageGenerateNode(IImageGenerateNode):
@ -25,8 +25,9 @@ class BaseImageGenerateNode(IImageGenerateNode):
**kwargs) -> NodeResult: **kwargs) -> NodeResult:
print(model_params_setting) print(model_params_setting)
application = self.workflow_manage.work_flow_post_handler.chat_info.application application = self.workflow_manage.work_flow_post_handler.chat_info.application
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), workspace_id = self.workflow_manage.get_body().get('workspace_id')
**model_params_setting) tti_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
**model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number) history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt) question = self.generate_prompt_question(prompt)

View File

@ -11,7 +11,7 @@ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AI
from application.flow.i_step_node import NodeResult, INode from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
from knowledge.models import File from knowledge.models import File
from models_provider.tools import get_model_instance_by_model_user_id from models_provider.tools import get_model_instance_by_model_workspace_id
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
@ -79,9 +79,9 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
# 处理不正确的参数 # 处理不正确的参数
if image is None or not isinstance(image, list): if image is None or not isinstance(image, list):
image = [] image = []
print(model_params_setting) workspace_id = self.workflow_manage.get_body().get('workspace_id')
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), image_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
**model_params_setting) **model_params_setting)
# 执行详情中的历史消息不需要图片内容 # 执行详情中的历史消息不需要图片内容
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number) history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
self.context['history_message'] = history_message self.context['history_message'] = history_message

View File

@ -18,7 +18,7 @@ from langchain_core.messages import BaseMessage
from application.flow.i_step_node import NodeResult, INode from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.question_node.i_question_node import IQuestionNode from application.flow.step_node.question_node.i_question_node import IQuestionNode
from models_provider.models import Model from models_provider.models import Model
from models_provider.tools import get_model_instance_by_model_user_id, get_model_credential from models_provider.tools import get_model_instance_by_model_workspace_id, get_model_credential
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
@ -87,8 +87,9 @@ class BaseQuestionNode(IQuestionNode):
**kwargs) -> NodeResult: **kwargs) -> NodeResult:
if model_params_setting is None: if model_params_setting is None:
model_params_setting = get_default_model_params_setting(model_id) model_params_setting = get_default_model_params_setting(model_id)
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), workspace_id = self.workflow_manage.get_body().get('workspace_id')
**model_params_setting) chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
**model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number) history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt) question = self.generate_prompt_question(prompt)

View File

@ -12,7 +12,7 @@ from langchain_core.documents import Document
from application.flow.i_step_node import NodeResult from application.flow.i_step_node import NodeResult
from application.flow.step_node.reranker_node.i_reranker_node import IRerankerNode from application.flow.step_node.reranker_node.i_reranker_node import IRerankerNode
from models_provider.tools import get_model_instance_by_model_user_id from models_provider.tools import get_model_instance_by_model_workspace_id
def merge_reranker_list(reranker_list, result=None): def merge_reranker_list(reranker_list, result=None):
@ -78,8 +78,9 @@ class BaseRerankerNode(IRerankerNode):
self.context['document_list'] = [{'page_content': document.page_content, 'metadata': document.metadata} for self.context['document_list'] = [{'page_content': document.page_content, 'metadata': document.metadata} for
document in documents] document in documents]
self.context['question'] = question self.context['question'] = question
reranker_model = get_model_instance_by_model_user_id(reranker_model_id, workspace_id = self.workflow_manage.get_body().get('workspace_id')
self.flow_params_serializer.data.get('user_id'), reranker_model = get_model_instance_by_model_workspace_id(reranker_model_id,
workspace_id,
top_n=top_n) top_n=top_n)
result = reranker_model.compress_documents( result = reranker_model.compress_documents(
documents, documents,

View File

@ -19,7 +19,7 @@ from common.db.search import native_search
from common.utils.common import get_file_content from common.utils.common import get_file_content
from knowledge.models import Document, Paragraph, Knowledge, SearchMode from knowledge.models import Document, Paragraph, Knowledge, SearchMode
from maxkb.conf import PROJECT_DIR from maxkb.conf import PROJECT_DIR
from models_provider.tools import get_model_instance_by_model_user_id from models_provider.tools import get_model_instance_by_model_workspace_id
def get_embedding_id(dataset_id_list): def get_embedding_id(dataset_id_list):
@ -67,7 +67,8 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
if len(dataset_id_list) == 0: if len(dataset_id_list) == 0:
return get_none_result(question) return get_none_result(question)
model_id = get_embedding_id(dataset_id_list) model_id = get_embedding_id(dataset_id_list)
embedding_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id')) workspace_id = self.workflow_manage.get_body().get('workspace_id')
embedding_model = get_model_instance_by_model_workspace_id(model_id, workspace_id)
embedding_value = embedding_model.embed_query(question) embedding_value = embedding_model.embed_query(question)
vector = VectorStore.get_embedding_vector() vector = VectorStore.get_embedding_vector()
exclude_document_id_list = [str(document.id) for document in exclude_document_id_list = [str(document.id) for document in

View File

@ -9,7 +9,7 @@ from application.flow.i_step_node import NodeResult
from application.flow.step_node.speech_to_text_step_node.i_speech_to_text_node import ISpeechToTextNode from application.flow.step_node.speech_to_text_step_node.i_speech_to_text_node import ISpeechToTextNode
from common.utils.common import split_and_transcribe, any_to_mp3 from common.utils.common import split_and_transcribe, any_to_mp3
from knowledge.models import File from knowledge.models import File
from models_provider.tools import get_model_instance_by_model_user_id from models_provider.tools import get_model_instance_by_model_workspace_id
class BaseSpeechToTextNode(ISpeechToTextNode): class BaseSpeechToTextNode(ISpeechToTextNode):
@ -20,7 +20,8 @@ class BaseSpeechToTextNode(ISpeechToTextNode):
self.answer_text = details.get('answer') self.answer_text = details.get('answer')
def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult: def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult:
stt_model = get_model_instance_by_model_user_id(stt_model_id, self.flow_params_serializer.data.get('user_id')) workspace_id = self.workflow_manage.get_body().get('workspace_id')
stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id)
audio_list = audio audio_list = audio
self.context['audio_list'] = audio self.context['audio_list'] = audio

View File

@ -6,8 +6,8 @@ from django.core.files.uploadedfile import InMemoryUploadedFile
from application.flow.i_step_node import NodeResult from application.flow.i_step_node import NodeResult
from application.flow.step_node.text_to_speech_step_node.i_text_to_speech_node import ITextToSpeechNode from application.flow.step_node.text_to_speech_step_node.i_text_to_speech_node import ITextToSpeechNode
from models_provider.tools import get_model_instance_by_model_workspace_id
from oss.serializers.file import FileSerializer from oss.serializers.file import FileSerializer
from models_provider.tools import get_model_instance_by_model_user_id
def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"): def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"):
@ -42,8 +42,9 @@ class BaseTextToSpeechNode(ITextToSpeechNode):
content, model_params_setting=None, content, model_params_setting=None,
**kwargs) -> NodeResult: **kwargs) -> NodeResult:
self.context['content'] = content self.context['content'] = content
model = get_model_instance_by_model_user_id(tts_model_id, self.flow_params_serializer.data.get('user_id'), workspace_id = self.workflow_manage.get_body().get('workspace_id')
**model_params_setting) model = get_model_instance_by_model_workspace_id(tts_model_id, workspace_id,
**model_params_setting)
audio_byte = model.text_to_speech(content) audio_byte = model.text_to_speech(content)
# 需要把这个音频文件存储到数据库中 # 需要把这个音频文件存储到数据库中
file_name = 'generated_audio.mp3' file_name = 'generated_audio.mp3'

View File

@ -538,6 +538,7 @@ class ApplicationSerializer(serializers.Serializer):
class ApplicationOperateSerializer(serializers.Serializer): class ApplicationOperateSerializer(serializers.Serializer):
application_id = serializers.UUIDField(required=True, label=_("Application ID")) application_id = serializers.UUIDField(required=True, label=_("Application ID"))
user_id = serializers.UUIDField(required=True, label=_("User ID")) user_id = serializers.UUIDField(required=True, label=_("User ID"))
workspace_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_("Workspace ID"))
def is_valid(self, *, raise_exception=False): def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True) super().is_valid(raise_exception=True)
@ -682,7 +683,6 @@ class ApplicationOperateSerializer(serializers.Serializer):
for update_key in update_keys: for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None: if update_key in instance and instance.get(update_key) is not None:
application.__setattr__(update_key, instance.get(update_key)) application.__setattr__(update_key, instance.get(update_key))
print(application.name)
application.save() application.save()
if 'knowledge_id_list' in instance: if 'knowledge_id_list' in instance:
@ -690,11 +690,11 @@ class ApplicationOperateSerializer(serializers.Serializer):
# 当前用户可修改关联的知识库列表 # 当前用户可修改关联的知识库列表
application_knowledge_id_list = [str(knowledge.id) for knowledge in application_knowledge_id_list = [str(knowledge.id) for knowledge in
self.list_knowledge(with_valid=False)] self.list_knowledge(with_valid=False)]
for dataset_id in knowledge_id_list: for knowledge_id in knowledge_id_list:
if not application_knowledge_id_list.__contains__(dataset_id): if not application_knowledge_id_list.__contains__(knowledge_id):
message = lazy_format(_('Unknown knowledge base id {dataset_id}, unable to associate'), message = lazy_format(_('Unknown knowledge base id {dataset_id}, unable to associate'),
dataset_id=dataset_id) dataset_id=knowledge_id)
raise AppApiException(500, message) raise AppApiException(500, str(message))
self.save_application_knowledge_mapping(application_knowledge_id_list, knowledge_id_list, application_id) self.save_application_knowledge_mapping(application_knowledge_id_list, knowledge_id_list, application_id)
return self.one(with_valid=False) return self.one(with_valid=False)
@ -707,8 +707,8 @@ class ApplicationOperateSerializer(serializers.Serializer):
knowledge_list = self.list_knowledge(with_valid=False) knowledge_list = self.list_knowledge(with_valid=False)
mapping_knowledge_id_list = [akm.knowledge_id for akm in mapping_knowledge_id_list = [akm.knowledge_id for akm in
QuerySet(ApplicationKnowledgeMapping).filter(application_id=application_id)] QuerySet(ApplicationKnowledgeMapping).filter(application_id=application_id)]
knowledge_id_list = [d.get('id') for d in knowledge_id_list = [d.id for d in
list(filter(lambda row: mapping_knowledge_id_list.__contains__(row.get('id')), list(filter(lambda row: mapping_knowledge_id_list.__contains__(row.id),
knowledge_list))] knowledge_list))]
return {**ApplicationSerializerModel(application).data, return {**ApplicationSerializerModel(application).data,
'knowledge_id_list': knowledge_id_list} 'knowledge_id_list': knowledge_id_list}
@ -729,5 +729,5 @@ class ApplicationOperateSerializer(serializers.Serializer):
application_id=application_id).delete() application_id=application_id).delete()
# 插入 # 插入
QuerySet(ApplicationKnowledgeMapping).bulk_create( QuerySet(ApplicationKnowledgeMapping).bulk_create(
[ApplicationKnowledgeMapping(application_id=application_id, dataset_id=dataset_id) for dataset_id in [ApplicationKnowledgeMapping(application_id=application_id, knowledge_id=knowledge_id) for knowledge_id in
knowledge_id_list]) if len(knowledge_id_list) > 0 else None knowledge_id_list]) if len(knowledge_id_list) > 0 else None

View File

@ -98,7 +98,7 @@ class ChatInfo:
self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting, self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting,
'search_mode': self.application.knowledge_setting.get('search_mode') or 'embedding', 'search_mode': self.application.knowledge_setting.get('search_mode') or 'embedding',
'no_references_setting': self.get_no_references_setting(self.application.knowledge_setting, model_setting), 'no_references_setting': self.get_no_references_setting(self.application.knowledge_setting, model_setting),
'user_id': self.application.user_id, 'workspace_id': self.application.workspace_id,
'application_id': self.application.id 'application_id': self.application.id
} }

View File

@ -130,6 +130,7 @@ class ApplicationAPI(APIView):
def post(self, request: Request, workspace_id: str, application_id: str): def post(self, request: Request, workspace_id: str, application_id: str):
return ApplicationOperateSerializer( return ApplicationOperateSerializer(
data={'application_id': application_id, data={'application_id': application_id,
'workspace_id': workspace_id,
'user_id': request.user.id}).export(request.data) 'user_id': request.user.id}).export(request.data)
class Operate(APIView): class Operate(APIView):
@ -148,11 +149,12 @@ class ApplicationAPI(APIView):
RoleConstants.WORKSPACE_MANAGE.get_workspace_role()) RoleConstants.WORKSPACE_MANAGE.get_workspace_role())
@log(menu='Application', operate='Deleting application', @log(menu='Application', operate='Deleting application',
get_operation_object=lambda r, k: get_application_operation_object(k.get('application_id')), get_operation_object=lambda r, k: get_application_operation_object(k.get('application_id')),
) )
def delete(self, request: Request, workspace_id: str, application_id: str): def delete(self, request: Request, workspace_id: str, application_id: str):
return result.success(ApplicationOperateSerializer( return result.success(ApplicationOperateSerializer(
data={'application_id': application_id, 'user_id': request.user.id}).delete( data={'application_id': application_id, 'user_id': request.user.id,
'workspace_id': workspace_id, }).delete(
with_valid=True)) with_valid=True))
@extend_schema( @extend_schema(
@ -173,7 +175,8 @@ class ApplicationAPI(APIView):
def put(self, request: Request, workspace_id: str, application_id: str): def put(self, request: Request, workspace_id: str, application_id: str):
return result.success( return result.success(
ApplicationOperateSerializer( ApplicationOperateSerializer(
data={'application_id': application_id, 'user_id': request.user.id}).edit( data={'application_id': application_id, 'user_id': request.user.id,
'workspace_id': workspace_id, }).edit(
request.data)) request.data))
@extend_schema( @extend_schema(
@ -190,7 +193,8 @@ class ApplicationAPI(APIView):
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.ADMIN) RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.ADMIN)
def get(self, request: Request, workspace_id: str, application_id: str): def get(self, request: Request, workspace_id: str, application_id: str):
return result.success(ApplicationOperateSerializer( return result.success(ApplicationOperateSerializer(
data={'application_id': application_id, 'user_id': request.user.id}).one()) data={'application_id': application_id, 'user_id': request.user.id,
'workspace_id': workspace_id, }).one())
class Publish(APIView): class Publish(APIView):
authentication_classes = [TokenAuth] authentication_classes = [TokenAuth]
@ -207,9 +211,10 @@ class ApplicationAPI(APIView):
) )
@log(menu='Application', operate='Publishing an application', @log(menu='Application', operate='Publishing an application',
get_operation_object=lambda r, k: get_application_operation_object(k.get('application_id')), get_operation_object=lambda r, k: get_application_operation_object(k.get('application_id')),
) )
def put(self, request: Request, workspace_id: str, application_id: str): def put(self, request: Request, workspace_id: str, application_id: str):
return result.success( return result.success(
ApplicationOperateSerializer( ApplicationOperateSerializer(
data={'application_id': application_id, 'user_id': request.user.id}).publish(request.data)) data={'application_id': application_id, 'user_id': request.user.id,
'workspace_id': workspace_id, }).publish(request.data))

View File

@ -366,7 +366,7 @@ class OpenChatSerializers(serializers.Serializer):
chat_user_id = self.data.get("chat_user_id") chat_user_id = self.data.get("chat_user_id")
chat_user_type = self.data.get("chat_user_type") chat_user_type = self.data.get("chat_user_type")
debug = self.data.get("debug") debug = self.data.get("debug")
knowledge_id_list = [str(row.dataset_id) for row in knowledge_id_list = [str(row.knowledge_id) for row in
QuerySet(ApplicationKnowledgeMapping).filter( QuerySet(ApplicationKnowledgeMapping).filter(
application_id=application_id)] application_id=application_id)]
chat_id = str(uuid.uuid7()) chat_id = str(uuid.uuid7())

View File

@ -103,21 +103,25 @@ def is_valid_credential(provider, model_type, model_name, model_credential: Dict
raise_exception) raise_exception)
def get_model_by_id(_id, user_id): def get_model_by_id(_id, workspace_id):
model = QuerySet(Model).filter(id=_id).first() model = QuerySet(Model).filter(id=_id).first()
# 手动关闭数据库连接 # 手动关闭数据库连接
connection.close() connection.close()
if model is None: if model is None:
raise Exception(_('Model does not exist')) raise Exception(_('Model does not exist'))
if model.workspace_id:
if model.workspace_id != workspace_id:
raise Exception(_('Model does not exist'))
return model return model
def get_model_instance_by_model_user_id(model_id, user_id, **kwargs): def get_model_instance_by_model_workspace_id(model_id, workspace_id, **kwargs):
""" """
获取模型实例,根据模型相关数据 获取模型实例,根据模型相关数据
@param model_id: 模型id @param model_id: 模型id
@param user_id: 用户id @param workspace_id: 工作空间id
@return: 模型实例 @return: 模型实例
""" """
model = get_model_by_id(model_id, user_id) model = get_model_by_id(model_id, workspace_id)
return ModelManage.get_model(model_id, lambda _id: get_model(model, **kwargs)) return ModelManage.get_model(model_id, lambda _id: get_model(model, **kwargs))