feat: application chat (#3213)

This commit is contained in:
shaohuzhang1 2025-06-09 16:18:43 +08:00 committed by GitHub
parent 5f10b70e24
commit 3807cf1960
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 835 additions and 138 deletions

View File

@ -12,17 +12,17 @@ from typing import Type
from rest_framework import serializers from rest_framework import serializers
from dataset.models import Paragraph from knowledge.models import Paragraph
class ParagraphPipelineModel: class ParagraphPipelineModel:
def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str, def __init__(self, _id: str, document_id: str, knowledge_id: str, content: str, title: str, status: str,
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str, is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str,
hit_handling_method: str, directly_return_similarity: float, meta: dict = None): hit_handling_method: str, directly_return_similarity: float, meta: dict = None):
self.id = _id self.id = _id
self.document_id = document_id self.document_id = document_id
self.dataset_id = dataset_id self.knowledge_id = knowledge_id
self.content = content self.content = content
self.title = title self.title = title
self.status = status, self.status = status,
@ -39,7 +39,7 @@ class ParagraphPipelineModel:
return { return {
'id': self.id, 'id': self.id,
'document_id': self.document_id, 'document_id': self.document_id,
'dataset_id': self.dataset_id, 'knowledge_id': self.knowledge_id,
'content': self.content, 'content': self.content,
'title': self.title, 'title': self.title,
'status': self.status, 'status': self.status,
@ -66,7 +66,7 @@ class ParagraphPipelineModel:
if isinstance(paragraph, Paragraph): if isinstance(paragraph, Paragraph):
self.paragraph = {'id': paragraph.id, self.paragraph = {'id': paragraph.id,
'document_id': paragraph.document_id, 'document_id': paragraph.document_id,
'dataset_id': paragraph.dataset_id, 'knowledge_id': paragraph.knowledge_id,
'content': paragraph.content, 'content': paragraph.content,
'title': paragraph.title, 'title': paragraph.title,
'status': paragraph.status, 'status': paragraph.status,
@ -106,7 +106,7 @@ class ParagraphPipelineModel:
def build(self): def build(self):
return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')), return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')),
str(self.paragraph.get('dataset_id')), str(self.paragraph.get('knowledge_id')),
self.paragraph.get('content'), self.paragraph.get('title'), self.paragraph.get('content'), self.paragraph.get('title'),
self.paragraph.get('status'), self.paragraph.get('status'),
self.paragraph.get('is_active'), self.paragraph.get('is_active'),

View File

@ -44,7 +44,7 @@ class PostResponseHandler:
@abstractmethod @abstractmethod
def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str, def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str,
answer_text, answer_text,
manage, step, padding_problem_text: str = None, client_id=None, **kwargs): manage, step, padding_problem_text: str = None, **kwargs):
pass pass
@ -68,8 +68,9 @@ class IChatStep(IBaseChatPipelineStep):
label=_("Completion Question")) label=_("Completion Question"))
# 是否使用流的形式输出 # 是否使用流的形式输出
stream = serializers.BooleanField(required=False, label=_("Streaming Output")) stream = serializers.BooleanField(required=False, label=_("Streaming Output"))
client_id = serializers.CharField(required=True, label=_("Client id")) chat_user_id = serializers.CharField(required=True, label=_("Chat user id"))
client_type = serializers.CharField(required=True, label=_("Client Type"))
chat_user_type = serializers.CharField(required=True, label=_("Chat user Type"))
# 未查询到引用分段 # 未查询到引用分段
no_references_setting = NoReferencesSetting(required=True, no_references_setting = NoReferencesSetting(required=True,
label=_("No reference segment settings")) label=_("No reference segment settings"))
@ -104,6 +105,6 @@ class IChatStep(IBaseChatPipelineStep):
user_id: str = None, user_id: str = None,
paragraph_list=None, paragraph_list=None,
manage: PipelineManage = None, manage: PipelineManage = None,
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None,
no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs): no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
pass pass

View File

@ -25,15 +25,16 @@ from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineMode
from application.chat_pipeline.pipeline_manage import PipelineManage from application.chat_pipeline.pipeline_manage import PipelineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from application.flow.tools import Reasoning from application.flow.tools import Reasoning
from application.models.application_api_key import ApplicationPublicAccessClient from application.models import ApplicationChatUserStats, ChatUserType
from common.constants.authentication_type import AuthenticationType
from models_provider.tools import get_model_instance_by_model_user_id from models_provider.tools import get_model_instance_by_model_user_id
def add_access_num(client_id=None, client_type=None, application_id=None): def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None):
if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value and application_id is not None: if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
application_public_access_client = (QuerySet(ApplicationPublicAccessClient).filter(client_id=client_id, chat_user_type) and application_id is not None:
application_id=application_id) application_public_access_client = (QuerySet(ApplicationChatUserStats).filter(chat_user_id=chat_user_id,
chat_user_type=chat_user_type,
application_id=application_id)
.first()) .first())
if application_public_access_client is not None: if application_public_access_client is not None:
application_public_access_client.access_num = application_public_access_client.access_num + 1 application_public_access_client.access_num = application_public_access_client.access_num + 1
@ -124,11 +125,9 @@ def event_content(response,
request_token = 0 request_token = 0
response_token = 0 response_token = 0
write_context(step, manage, request_token, response_token, all_text) write_context(step, manage, request_token, response_token, all_text)
asker = manage.context.get('form_data', {}).get('asker', None)
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, step, padding_problem_text, client_id, all_text, manage, step, padding_problem_text,
reasoning_content=reasoning_content if reasoning_content_enable else '' reasoning_content=reasoning_content if reasoning_content_enable else '')
, asker=asker)
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node', yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
[], '', True, [], '', True,
request_token, response_token, request_token, response_token,
@ -139,10 +138,8 @@ def event_content(response,
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
all_text = 'Exception:' + str(e) all_text = 'Exception:' + str(e)
write_context(step, manage, 0, 0, all_text) write_context(step, manage, 0, 0, all_text)
asker = manage.context.get('form_data', {}).get('asker', None)
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, step, padding_problem_text, client_id, reasoning_content='', all_text, manage, step, padding_problem_text, reasoning_content='')
asker=asker)
add_access_num(client_id, client_type, manage.context.get('application_id')) add_access_num(client_id, client_type, manage.context.get('application_id'))
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node', yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
[], all_text, [], all_text,
@ -165,7 +162,7 @@ class BaseChatStep(IChatStep):
manage: PipelineManage = None, manage: PipelineManage = None,
padding_problem_text: str = None, padding_problem_text: str = None,
stream: bool = True, stream: bool = True,
client_id=None, client_type=None, chat_user_id=None, chat_user_type=None,
no_references_setting=None, no_references_setting=None,
model_params_setting=None, model_params_setting=None,
model_setting=None, model_setting=None,
@ -175,12 +172,13 @@ class BaseChatStep(IChatStep):
if stream: if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model, return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list, paragraph_list,
manage, padding_problem_text, client_id, client_type, no_references_setting, manage, padding_problem_text, chat_user_id, chat_user_type,
no_references_setting,
model_setting) model_setting)
else: else:
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model, return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list, paragraph_list,
manage, padding_problem_text, client_id, client_type, no_references_setting, manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting,
model_setting) model_setting)
def get_details(self, manage, **kwargs): def get_details(self, manage, **kwargs):
@ -235,7 +233,7 @@ class BaseChatStep(IChatStep):
paragraph_list=None, paragraph_list=None,
manage: PipelineManage = None, manage: PipelineManage = None,
padding_problem_text: str = None, padding_problem_text: str = None,
client_id=None, client_type=None, chat_user_id=None, chat_user_type=None,
no_references_setting=None, no_references_setting=None,
model_setting=None): model_setting=None):
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list, chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
@ -244,7 +242,8 @@ class BaseChatStep(IChatStep):
r = StreamingHttpResponse( r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
post_response_handler, manage, self, chat_model, message_list, problem_text, post_response_handler, manage, self, chat_model, message_list, problem_text,
padding_problem_text, client_id, client_type, is_ai_chat, model_setting), padding_problem_text, chat_user_id, chat_user_type, is_ai_chat,
model_setting),
content_type='text/event-stream;charset=utf-8') content_type='text/event-stream;charset=utf-8')
r['Cache-Control'] = 'no-cache' r['Cache-Control'] = 'no-cache'

View File

@ -15,7 +15,7 @@ from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineMode
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \ from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
IGenerateHumanMessageStep IGenerateHumanMessageStep
from application.models import ChatRecord from application.models import ChatRecord
from common.util.split_model import flat_map from common.utils.common import flat_map
class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep): class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):

View File

@ -26,8 +26,8 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
padding_problem_text = serializers.CharField(required=False, padding_problem_text = serializers.CharField(required=False,
label=_("System completes question text")) label=_("System completes question text"))
# 需要查询的数据集id列表 # 需要查询的数据集id列表
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), knowledge_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
label=_("Dataset id list")) label=_("Dataset id list"))
# 需要排除的文档id # 需要排除的文档id
exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
label=_("List of document ids to exclude")) label=_("List of document ids to exclude"))
@ -55,7 +55,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
self.context['paragraph_list'] = paragraph_list self.context['paragraph_list'] = paragraph_list
@abstractmethod @abstractmethod
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None, search_mode: str = None,
user_id=None, user_id=None,
@ -65,7 +65,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
:param similarity: 相关性 :param similarity: 相关性
:param top_n: 查询多少条 :param top_n: 查询多少条
:param problem_text: 用户问题 :param problem_text: 用户问题
:param dataset_id_list: 需要查询的数据集id列表 :param knowledge_id_list: 需要查询的数据集id列表
:param exclude_document_id_list: 需要排除的文档id :param exclude_document_id_list: 需要排除的文档id
:param exclude_paragraph_id_list: 需要排除段落id :param exclude_paragraph_id_list: 需要排除段落id
:param padding_problem_text 补全问题 :param padding_problem_text 补全问题

View File

@ -35,42 +35,33 @@ def get_model_by_id(_id, user_id):
return model return model
def get_embedding_id(dataset_id_list): def get_embedding_id(knowledge_id_list):
<<<<<<< Updated upstream:apps/chat/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_model_id for dataset in dataset_list])) > 1:
raise Exception(_("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled."))
if len(dataset_list) == 0:
raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base"))
return dataset_list[0].embedding_model_id
=======
knowledge_list = QuerySet(Knowledge).filter(id__in=dataset_id_list)
if len(set([knowledge.embedding_mode_id for knowledge in knowledge_list])) > 1: if len(set([knowledge.embedding_mode_id for knowledge in knowledge_list])) > 1:
raise Exception( raise Exception(
_("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled.")) _("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled."))
if len(knowledge_list) == 0: if len(knowledge_list) == 0:
raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base")) raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base"))
return knowledge_list[0].embedding_mode_id return knowledge_list[0].embedding_mode_id
>>>>>>> Stashed changes:apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py
class BaseSearchDatasetStep(ISearchDatasetStep): class BaseSearchDatasetStep(ISearchDatasetStep):
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None, search_mode: str = None,
user_id=None, user_id=None,
**kwargs) -> List[ParagraphPipelineModel]: **kwargs) -> List[ParagraphPipelineModel]:
if len(dataset_id_list) == 0: if len(knowledge_id_list) == 0:
return [] return []
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
model_id = get_embedding_id(dataset_id_list) model_id = get_embedding_id(knowledge_id_list)
model = get_model_by_id(model_id, user_id) model = get_model_by_id(model_id, user_id)
self.context['model_name'] = model.name self.context['model_name'] = model.name
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model)) embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(exec_problem_text) embedding_value = embedding_model.embed_query(exec_problem_text)
vector = VectorStore.get_embedding_vector() vector = VectorStore.get_embedding_vector()
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list, embedding_list = vector.query(exec_problem_text, embedding_value, knowledge_id_list, exclude_document_id_list,
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode)) exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
if embedding_list is None: if embedding_list is None:
return [] return []

View File

@ -18,8 +18,8 @@ from rest_framework import serializers
from rest_framework.exceptions import ValidationError, ErrorDetail from rest_framework.exceptions import ValidationError, ErrorDetail
from application.flow.common import Answer, NodeChunk from application.flow.common import Answer, NodeChunk
from application.models import ChatRecord from application.models import ChatRecord, ChatUserType
from application.models import ApplicationChatClientStats from application.models import ApplicationChatUserStats
from common.constants.authentication_type import AuthenticationType from common.constants.authentication_type import AuthenticationType
from common.field.common import InstanceField from common.field.common import InstanceField
@ -45,10 +45,10 @@ def is_interrupt(node, step_variable: Dict, global_variable: Dict):
class WorkFlowPostHandler: class WorkFlowPostHandler:
def __init__(self, chat_info, client_id, client_type): def __init__(self, chat_info, chat_user_id, chat_user_type):
self.chat_info = chat_info self.chat_info = chat_info
self.client_id = client_id self.chat_user_id = chat_user_id
self.client_type = client_type self.chat_user_type = chat_user_type
def handler(self, chat_id, def handler(self, chat_id,
chat_record_id, chat_record_id,
@ -84,13 +84,13 @@ class WorkFlowPostHandler:
run_time=time.time() - workflow.context['start_time'], run_time=time.time() - workflow.context['start_time'],
index=0) index=0)
asker = workflow.context.get('asker', None) asker = workflow.context.get('asker', None)
self.chat_info.append_chat_record(chat_record, self.client_id, asker) self.chat_info.append_chat_record(chat_record)
# 重新设置缓存 self.chat_info.set_cahce()
chat_cache.set(chat_id, if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
self.chat_info, timeout=60 * 30) self.chat_user_type):
if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: application_public_access_client = (QuerySet(ApplicationChatUserStats)
application_public_access_client = (QuerySet(ApplicationChatClientStats) .filter(chat_user_id=self.chat_user_id,
.filter(client_id=self.client_id, chat_user_type=self.chat_user_type,
application_id=self.chat_info.application.id).first()) application_id=self.chat_info.application.id).first())
if application_public_access_client is not None: if application_public_access_client is not None:
application_public_access_client.access_num = application_public_access_client.access_num + 1 application_public_access_client.access_num = application_public_access_client.access_num + 1

View File

@ -0,0 +1,56 @@
# Generated by Django 5.2 on 2025-06-09 05:55
import uuid
import uuid_utils.compat
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0003_applicationaccesstoken_show_exec_chat_client_type_and_more'),
]
operations = [
migrations.RemoveIndex(
model_name='applicationchatclientstats',
name='application_applica_f89647_idx',
),
migrations.RenameField(
model_name='chat',
old_name='client_id',
new_name='chat_user_id',
),
migrations.RenameField(
model_name='chat',
old_name='client_type',
new_name='chat_user_type',
),
migrations.RemoveField(
model_name='applicationchatclientstats',
name='client_id',
),
migrations.RemoveField(
model_name='applicationchatclientstats',
name='client_type',
),
migrations.AddField(
model_name='applicationchatclientstats',
name='chat_user_id',
field=models.UUIDField(default=uuid_utils.compat.uuid7, verbose_name='对话用户id'),
),
migrations.AddField(
model_name='applicationchatclientstats',
name='chat_user_type',
field=models.CharField(choices=[('ANONYMOUS_USER', '匿名用户'), ('CHAT_USER', '对话用户'), ('SYSTEM_API_KEY', '系统API_KEY'), ('APPLICATION_API_KEY', '应用API_KEY')], default='ANONYMOUS_USER', max_length=64, verbose_name='对话用户类型'),
),
migrations.AlterField(
model_name='chat',
name='id',
field=models.UUIDField(default=uuid.UUID('01975341-b4e8-7d52-913b-1bb67d7d8107'), editable=False, primary_key=True, serialize=False, verbose_name='主键id'),
),
migrations.AddIndex(
model_name='applicationchatclientstats',
index=models.Index(fields=['application_id', 'chat_user_id'], name='application_applica_23b4d2_idx'),
),
]

View File

@ -0,0 +1,32 @@
# Generated by Django 5.2 on 2025-06-09 07:31
import uuid
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0004_remove_applicationchatclientstats_application_applica_f89647_idx_and_more'),
]
operations = [
migrations.RenameModel(
old_name='ApplicationChatClientStats',
new_name='ApplicationChatUserStats',
),
migrations.RenameIndex(
model_name='applicationchatuserstats',
new_name='application_applica_1652ba_idx',
old_name='application_applica_23b4d2_idx',
),
migrations.AlterField(
model_name='chat',
name='id',
field=models.UUIDField(default=uuid.UUID('01975399-efa5-7dc3-8f97-edc67332ed24'), editable=False, primary_key=True, serialize=False, verbose_name='主键id'),
),
migrations.AlterModelTable(
name='applicationchatuserstats',
table='application_chat_user_stats',
),
]

View File

@ -17,7 +17,7 @@ from common.encoder.encoder import SystemEncoder
from common.mixins.app_model_mixin import AppModelMixin from common.mixins.app_model_mixin import AppModelMixin
class ClientType(models.TextChoices): class ChatUserType(models.TextChoices):
ANONYMOUS_USER = "ANONYMOUS_USER", '匿名用户' ANONYMOUS_USER = "ANONYMOUS_USER", '匿名用户'
CHAT_USER = "CHAT_USER", "对话用户" CHAT_USER = "CHAT_USER", "对话用户"
SYSTEM_API_KEY = "SYSTEM_API_KEY", "系统API_KEY" SYSTEM_API_KEY = "SYSTEM_API_KEY", "系统API_KEY"
@ -28,9 +28,9 @@ class Chat(AppModelMixin):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7(), editable=False, verbose_name="主键id") id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7(), editable=False, verbose_name="主键id")
application = models.ForeignKey(Application, on_delete=models.CASCADE) application = models.ForeignKey(Application, on_delete=models.CASCADE)
abstract = models.CharField(max_length=1024, verbose_name="摘要") abstract = models.CharField(max_length=1024, verbose_name="摘要")
client_id = models.UUIDField(verbose_name="客户端id", default=None, null=True) chat_user_id = models.UUIDField(verbose_name="客户端id", default=None, null=True)
client_type = models.CharField(max_length=64, verbose_name="客户端类型", choices=ClientType.choices, chat_user_type = models.CharField(max_length=64, verbose_name="客户端类型", choices=ChatUserType.choices,
default=ClientType.ANONYMOUS_USER) default=ChatUserType.ANONYMOUS_USER)
is_deleted = models.BooleanField(verbose_name="逻辑删除", default=False) is_deleted = models.BooleanField(verbose_name="逻辑删除", default=False)
class Meta: class Meta:
@ -86,17 +86,17 @@ class ChatRecord(AppModelMixin):
db_table = "application_chat_record" db_table = "application_chat_record"
class ApplicationChatClientStats(AppModelMixin): class ApplicationChatUserStats(AppModelMixin):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id") id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
client_id = models.UUIDField(max_length=128, default=uuid.uuid7, verbose_name="公共访问链接客户端id") chat_user_id = models.UUIDField(max_length=128, default=uuid.uuid7, verbose_name="对话用户id")
client_type = models.CharField(max_length=64, verbose_name="客户端类型", choices=ClientType.choices, chat_user_type = models.CharField(max_length=64, verbose_name="对话用户类型", choices=ChatUserType.choices,
default=ClientType.ANONYMOUS_USER) default=ChatUserType.ANONYMOUS_USER)
application = models.ForeignKey(Application, on_delete=models.CASCADE, verbose_name="应用id") application = models.ForeignKey(Application, on_delete=models.CASCADE, verbose_name="应用id")
access_num = models.IntegerField(default=0, verbose_name="访问总次数次数") access_num = models.IntegerField(default=0, verbose_name="访问总次数次数")
intraday_access_num = models.IntegerField(default=0, verbose_name="当日访问次数") intraday_access_num = models.IntegerField(default=0, verbose_name="当日访问次数")
class Meta: class Meta:
db_table = "application_chat_client_stats" db_table = "application_chat_user_stats"
indexes = [ indexes = [
models.Index(fields=['application_id', 'client_id']), models.Index(fields=['application_id', 'chat_user_id']),
] ]

View File

@ -0,0 +1,137 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file common.py
@date2025/6/9 13:42
@desc:
"""
from datetime import datetime
from typing import List
from django.core.cache import cache
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler
from application.models import Application, WorkFlowVersion, ChatRecord, Chat
from common.constants.cache_version import Cache_Version
from models_provider.models import Model
from models_provider.tools import get_model_credential
class ChatInfo:
def __init__(self,
chat_id: str,
chat_user_id: str,
chat_user_type: str,
knowledge_id_list: List[str],
exclude_document_id_list: list[str],
application: Application,
work_flow_version: WorkFlowVersion = None):
"""
:param chat_id: 对话id
:param chat_user_id 对话用户id
:param chat_user_type 对话用户类型
:param knowledge_id_list: 知识库列表
:param exclude_document_id_list: 排除的文档
:param application: 应用信息
"""
self.chat_id = chat_id
self.chat_user_id = chat_user_id
self.chat_user_type = chat_user_type
self.application = application
self.knowledge_id_list = knowledge_id_list
self.exclude_document_id_list = exclude_document_id_list
self.chat_record_list: List[ChatRecord] = []
self.work_flow_version = work_flow_version
@staticmethod
def get_no_references_setting(knowledge_setting, model_setting):
no_references_setting = knowledge_setting.get(
'no_references_setting', {
'status': 'ai_questioning',
'value': '{question}'})
if no_references_setting.get('status') == 'ai_questioning':
no_references_prompt = model_setting.get('no_references_prompt', '{question}')
no_references_setting['value'] = no_references_prompt if len(no_references_prompt) > 0 else "{question}"
return no_references_setting
def to_base_pipeline_manage_params(self):
knowledge_setting = self.application.knowledge_setting
model_setting = self.application.model_setting
model_id = self.application.model.id if self.application.model is not None else None
model_params_setting = None
if model_id is not None:
model = QuerySet(Model).filter(id=model_id).first()
credential = get_model_credential(model.provider, model.model_type, model.model_name)
model_params_setting = credential.get_model_params_setting_form(model.model_name).get_default_form_data()
return {
'knowledge_id_list': self.knowledge_id_list,
'exclude_document_id_list': self.exclude_document_id_list,
'exclude_paragraph_id_list': [],
'top_n': knowledge_setting.get('top_n') or 3,
'similarity': knowledge_setting.get('similarity') or 0.6,
'max_paragraph_char_number': knowledge_setting.get('max_paragraph_char_number') or 5000,
'history_chat_record': self.chat_record_list,
'chat_id': self.chat_id,
'dialogue_number': self.application.dialogue_number,
'problem_optimization_prompt': self.application.problem_optimization_prompt if self.application.problem_optimization_prompt is not None and len(
self.application.problem_optimization_prompt) > 0 else _(
"() contains the user's question. Answer the guessed user's question based on the context ({question}) Requirement: Output a complete question and put it in the <data></data> tag"),
'prompt': model_setting.get(
'prompt') if 'prompt' in model_setting and len(model_setting.get(
'prompt')) > 0 else Application.get_default_model_prompt(),
'system': model_setting.get(
'system', None),
'model_id': model_id,
'problem_optimization': self.application.problem_optimization,
'stream': True,
'model_setting': model_setting,
'model_params_setting': model_params_setting if self.application.model_params_setting is None or len(
self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting,
'search_mode': self.application.knowledge_setting.get('search_mode') or 'embedding',
'no_references_setting': self.get_no_references_setting(self.application.knowledge_setting, model_setting),
'user_id': self.application.user_id,
'application_id': self.application.id
}
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
exclude_paragraph_id_list, chat_user_id: str, chat_user_type, stream=True,
form_data=None):
if form_data is None:
form_data = {}
params = self.to_base_pipeline_manage_params()
return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream, 'chat_user_id': chat_user_id,
'chat_user_type': chat_user_type, 'form_data': form_data}
def append_chat_record(self, chat_record: ChatRecord):
chat_record.problem_text = chat_record.problem_text[0:10240] if chat_record.problem_text is not None else ""
chat_record.answer_text = chat_record.answer_text[0:40960] if chat_record.problem_text is not None else ""
is_save = True
# 存入缓存中
for index in range(len(self.chat_record_list)):
record = self.chat_record_list[index]
if record.id == chat_record.id:
self.chat_record_list[index] = chat_record
is_save = False
if is_save:
self.chat_record_list.append(chat_record)
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(),
timeout=60 * 30)
if self.application.id is not None:
Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text[0:1024],
chat_user_id=self.chat_user_id, chat_user_type=self.chat_user_type).save()
else:
QuerySet(Chat).filter(id=self.chat_id).update(update_time=datetime.now())
# 插入会话记录
chat_record.save()
def set_cache(self):
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(),
timeout=60 * 30)
@staticmethod
def get_cache(chat_id):
return cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT.get_version())

29
apps/chat/api/chat_api.py Normal file
View File

@ -0,0 +1,29 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file chat_api.py
@date2025/6/9 15:23
@desc:
"""
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter
from chat.serializers.chat import ChatMessageSerializers
from common.mixins.api_mixin import APIMixin
class ChatAPI(APIMixin):
@staticmethod
def get_parameters():
return [OpenApiParameter(
name="chat_id",
description="对话id",
type=OpenApiTypes.STR,
location='path',
required=True,
)]
@staticmethod
def get_request():
return ChatMessageSerializers

View File

@ -6,14 +6,19 @@
@date2025/6/6 19:59 @date2025/6/6 19:59
@desc: @desc:
""" """
from chat.serializers.chat_authentication import AuthenticationSerializer
from django.utils.translation import gettext_lazy as _
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter
from chat.serializers.chat_authentication import AnonymousAuthenticationSerializer
from common.mixins.api_mixin import APIMixin from common.mixins.api_mixin import APIMixin
class ChatAuthenticationAPI(APIMixin): class ChatAuthenticationAPI(APIMixin):
@staticmethod @staticmethod
def get_request(): def get_request():
return AuthenticationSerializer() return AnonymousAuthenticationSerializer
@staticmethod @staticmethod
def get_parameters(): def get_parameters():
@ -22,3 +27,35 @@ class ChatAuthenticationAPI(APIMixin):
@staticmethod @staticmethod
def get_response(): def get_response():
pass pass
class ChatAuthenticationProfileAPI(APIMixin):
@staticmethod
def get_parameters():
return [OpenApiParameter(
name="access_token",
description=_("access_token"),
type=OpenApiTypes.STR,
location='query',
required=True,
)]
class ChatOpenAPI(APIMixin):
@staticmethod
def get_parameters():
return [OpenApiParameter(
name="workspace_id",
description="工作空间id",
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="application_id",
description="应用id",
type=OpenApiTypes.STR,
location='path',
required=True,
)]

View File

@ -0,0 +1,346 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file chat.py
@date2025/6/9 11:23
@desc:
"""
from gettext import gettext
from typing import List
import uuid_utils.compat as uuid
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler
from application.chat_pipeline.step.chat_step.impl.base_chat_step import BaseChatStep
from application.chat_pipeline.step.generate_human_message_step.impl.base_generate_human_message_step import \
BaseGenerateHumanMessageStep
from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
from application.flow.common import Answer
from application.flow.i_step_node import WorkFlowPostHandler
from application.flow.workflow_manage import WorkflowManage, Flow
from application.models import Application, ApplicationTypeChoices, WorkFlowVersion, ApplicationKnowledgeMapping, \
ChatUserType, ApplicationChatUserStats, ApplicationAccessToken, ChatRecord, Chat
from application.serializers.common import ChatInfo
from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed, ChatException
from common.handle.base_to_response import BaseToResponse
from common.handle.impl.response.system_to_response import SystemToResponse
from common.utils.common import flat_map
from knowledge.models import Document, Paragraph
from models_provider.models import Model, Status
class ChatMessageSerializers(serializers.Serializer):
message = serializers.CharField(required=True, label=_("User Questions"))
stream = serializers.BooleanField(required=True,
label=_("Is the answer in streaming mode"))
re_chat = serializers.BooleanField(required=True, label=_("Do you want to reply again"))
chat_record_id = serializers.UUIDField(required=False, allow_null=True,
label=_("Conversation record id"))
node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
label=_("Node id"))
runtime_node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
label=_("Runtime node id"))
node_data = serializers.DictField(required=False, allow_null=True,
label=_("Node parameters"))
form_data = serializers.DictField(required=False, label=_("Global variables"))
image_list = serializers.ListField(required=False, label=_("picture"))
document_list = serializers.ListField(required=False, label=_("document"))
audio_list = serializers.ListField(required=False, label=_("Audio"))
other_list = serializers.ListField(required=False, label=_("Other"))
child_node = serializers.DictField(required=False, allow_null=True,
label=_("Child Nodes"))
def get_post_handler(chat_info: ChatInfo):
class PostHandler(PostResponseHandler):
def handler(self,
chat_id,
chat_record_id,
paragraph_list: List[Paragraph],
problem_text: str,
answer_text,
manage: PipelineManage,
step: BaseChatStep,
padding_problem_text: str = None,
**kwargs):
answer_list = [[Answer(answer_text, 'ai-chat-node', 'ai-chat-node', 'ai-chat-node', {}, 'ai-chat-node',
kwargs.get('reasoning_content', '')).to_dict()]]
chat_record = ChatRecord(id=chat_record_id,
chat_id=chat_id,
problem_text=problem_text,
answer_text=answer_text,
details=manage.get_details(),
message_tokens=manage.context['message_tokens'],
answer_tokens=manage.context['answer_tokens'],
answer_text_list=answer_list,
run_time=manage.context['run_time'],
index=len(chat_info.chat_record_list) + 1)
chat_info.append_chat_record(chat_record)
# 重新设置缓存
chat_info.set_cache()
return PostHandler()
class ChatSerializers(serializers.Serializer):
chat_id = serializers.UUIDField(required=True, label=_("Conversation ID"))
chat_user_id = serializers.CharField(required=True, label=_("Client id"))
chat_user_type = serializers.CharField(required=True, label=_("Client Type"))
application_id = serializers.UUIDField(required=True, allow_null=True,
label=_("Application ID"))
def is_valid_application_workflow(self, *, raise_exception=False):
self.is_valid_intraday_access_num()
def is_valid_chat_id(self, chat_info: ChatInfo):
if self.data.get('application_id') is not None and self.data.get('application_id') != str(
chat_info.application.id):
raise ChatException(500, _("Conversation does not exist"))
def is_valid_intraday_access_num(self):
if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
self.data.get('chat_user_type')):
access_client = QuerySet(ApplicationChatUserStats).filter(chat_user_id=self.data.get('chat_user_id'),
application_id=self.data.get(
'application_id')).first()
if access_client is None:
access_client = ApplicationChatUserStats(chat_user_id=self.data.get('chat_user_id'),
chat_user_type=self.data.get('chat_user_type'),
application_id=self.data.get('application_id'),
access_num=0,
intraday_access_num=0)
access_client.save()
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=self.data.get('application_id')).first()
if application_access_token.access_num <= access_client.intraday_access_num:
raise AppChatNumOutOfBoundsFailed(1002, _("The number of visits exceeds today's visits"))
def is_valid_application_simple(self, *, chat_info: ChatInfo, raise_exception=False):
self.is_valid_intraday_access_num()
model = chat_info.application.model
if model is None:
return chat_info
model = QuerySet(Model).filter(id=model.id).first()
if model is None:
return chat_info
if model.status == Status.ERROR:
raise ChatException(500, _("The current model is not available"))
if model.status == Status.DOWNLOAD:
raise ChatException(500, _("The model is downloading, please try again later"))
return chat_info
def chat_simple(self, chat_info: ChatInfo, instance, base_to_response):
message = instance.get('message')
re_chat = instance.get('re_chat')
stream = instance.get('stream')
chat_user_id = self.data.get('chat_user_id')
chat_user_type = self.data.get('chat_user_type')
form_data = instance.get("form_data")
pipeline_manage_builder = PipelineManage.builder()
# 如果开启了问题优化,则添加上问题优化步骤
if chat_info.application.problem_optimization:
pipeline_manage_builder.append_step(BaseResetProblemStep)
# 构建流水线管理器
pipeline_message = (pipeline_manage_builder.append_step(BaseSearchDatasetStep)
.append_step(BaseGenerateHumanMessageStep)
.append_step(BaseChatStep)
.add_base_to_response(base_to_response)
.build())
exclude_paragraph_id_list = []
# 相同问题是否需要排除已经查询到的段落
if re_chat:
paragraph_id_list = flat_map(
[[paragraph.get('id') for paragraph in chat_record.details['search_step']['paragraph_list']] for
chat_record in chat_info.chat_record_list if
chat_record.problem_text == message and 'search_step' in chat_record.details and 'paragraph_list' in
chat_record.details['search_step']])
exclude_paragraph_id_list = list(set(paragraph_id_list))
# 构建运行参数
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list,
chat_user_id, chat_user_type, stream, form_data)
# 运行流水线作业
pipeline_message.run(params)
return pipeline_message.context['chat_result']
@staticmethod
def get_chat_record(chat_info, chat_record_id):
if chat_info is not None:
chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if
str(chat_record.id) == str(chat_record_id)]
if chat_record_list is not None and len(chat_record_list):
return chat_record_list[-1]
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_info.chat_id).first()
if chat_record is None:
raise ChatException(500, _("Conversation record does not exist"))
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id).first()
return chat_record
def chat_work_flow(self, chat_info: ChatInfo, instance: dict, base_to_response):
message = self.data.get('message')
re_chat = self.data.get('re_chat')
stream = self.data.get('stream')
chat_user_id = instance.get('chat_user_id')
chat_user_type = instance.get('chat_user_type')
form_data = self.data.get('form_data')
image_list = self.data.get('image_list')
document_list = self.data.get('document_list')
audio_list = self.data.get('audio_list')
other_list = self.data.get('other_list')
user_id = chat_info.application.user_id
chat_record_id = self.data.get('chat_record_id')
chat_record = None
history_chat_record = chat_info.chat_record_list
if chat_record_id is not None:
chat_record = self.get_chat_record(chat_info, chat_record_id)
history_chat_record = [r for r in chat_info.chat_record_list if str(r.id) != chat_record_id]
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
{'history_chat_record': history_chat_record, 'question': message,
'chat_id': chat_info.chat_id, 'chat_record_id': str(
uuid.uuid1()) if chat_record is None else chat_record.id,
'stream': stream,
're_chat': re_chat,
'chat_user_id': chat_user_id,
'chat_user_type': chat_user_type,
'user_id': user_id},
WorkFlowPostHandler(chat_info, chat_user_id, chat_user_type),
base_to_response, form_data, image_list, document_list, audio_list,
other_list,
self.data.get('runtime_node_id'),
self.data.get('node_data'), chat_record, self.data.get('child_node'))
r = work_flow_manage.run()
return r
def chat(self, instance: dict, base_to_response: BaseToResponse = SystemToResponse()):
super().is_valid(raise_exception=True)
ChatMessageSerializers(data=instance).is_valid(raise_exception=True)
chat_info = self.get_chat_info()
self.is_valid_chat_id(chat_info)
if chat_info.application.type == ApplicationTypeChoices.SIMPLE:
self.is_valid_application_simple(raise_exception=True, chat_info=chat_info),
return self.chat_simple(chat_info, instance, base_to_response)
else:
self.is_valid_application_workflow(raise_exception=True)
return self.chat_work_flow(chat_info, instance, base_to_response)
def get_chat_info(self):
self.is_valid(raise_exception=True)
chat_id = self.data.get('chat_id')
chat_info: ChatInfo = ChatInfo.get_cache(chat_id)
if chat_info is None:
chat_info: ChatInfo = self.re_open_chat(chat_id)
chat_info.set_cache()
return chat_info
def re_open_chat(self, chat_id: str):
chat = QuerySet(Chat).filter(id=chat_id).first()
if chat is None:
raise ChatException(500, _("Conversation does not exist"))
application = QuerySet(Application).filter(id=chat.application_id).first()
if application is None:
raise ChatException(500, _("Application does not exist"))
if application.type == ApplicationTypeChoices.SIMPLE:
return self.re_open_chat_simple(chat_id, application)
else:
return self.re_open_chat_work_flow(chat_id, application)
def re_open_chat_simple(self, chat_id, application):
# 数据集id列表
knowledge_id_list = [str(row.dataset_id) for row in
QuerySet(ApplicationKnowledgeMapping).filter(
application_id=application.id)]
# 需要排除的文档
exclude_document_id_list = [str(document.id) for document in
QuerySet(Document).filter(
knowledge_id__in=knowledge_id_list,
is_active=False)]
chat_info = ChatInfo(chat_id, self.data.get('chat_user_id'), self.data.get('chat_user_type'), knowledge_id_list,
exclude_document_id_list, application)
chat_record_list = list(QuerySet(ChatRecord).filter(chat_id=chat_id).order_by('-create_time')[0:5])
chat_record_list.sort(key=lambda r: r.create_time)
for chat_record in chat_record_list:
chat_info.chat_record_list.append(chat_record)
return chat_info
def re_open_chat_work_flow(self, chat_id, application):
work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application.id).order_by(
'-create_time')[0:1].first()
if work_flow_version is None:
raise ChatException(500, _("The application has not been published. Please use it after publishing."))
chat_info = ChatInfo(chat_id, self.data.get('chat_user_id'), self.data.get('chat_user_type'), [], [],
application, work_flow_version)
chat_record_list = list(QuerySet(ChatRecord).filter(chat_id=chat_id).order_by('-create_time')[0:5])
chat_record_list.sort(key=lambda r: r.create_time)
for chat_record in chat_record_list:
chat_info.chat_record_list.append(chat_record)
return chat_info
class OpenChatSerializers(serializers.Serializer):
workspace_id = serializers.CharField(required=True)
application_id = serializers.UUIDField(required=True)
chat_user_id = serializers.CharField(required=True, label=_("Client id"))
chat_user_type = serializers.CharField(required=True, label=_("Client Type"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
workspace_id = self.data.get('workspace_id')
application_id = self.data.get('application_id')
if not QuerySet(Application).filter(id=application_id, workspace_id=workspace_id).exists():
raise AppApiException(500, gettext('Application does not exist'))
def open(self):
self.is_valid(raise_exception=True)
application_id = self.data.get('application_id')
application = QuerySet(Application).get(id=application_id)
if application.type == ApplicationTypeChoices.SIMPLE:
return self.open_simple(application)
else:
return self.open_work_flow(application)
def open_work_flow(self, application):
self.is_valid(raise_exception=True)
application_id = self.data.get('application_id')
chat_user_id = self.data.get("chat_user_id")
chat_user_type = self.data.get("chat_user_type")
chat_id = str(uuid.uuid7())
work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application_id).order_by(
'-create_time')[0:1].first()
if work_flow_version is None:
raise AppApiException(500,
gettext(
"The application has not been published. Please use it after publishing."))
ChatInfo(chat_id, chat_user_id, chat_user_type, [],
[],
application, work_flow_version).set_cache()
return chat_id
def open_simple(self, application):
application_id = self.data.get('application_id')
chat_user_id = self.data.get("chat_user_id")
chat_user_type = self.data.get("chat_user_type")
knowledge_id_list = [str(row.dataset_id) for row in
QuerySet(ApplicationKnowledgeMapping).filter(
application_id=application_id)]
chat_id = str(uuid.uuid7())
ChatInfo(chat_id, chat_user_id, chat_user_type, knowledge_id_list,
[str(document.id) for document in
QuerySet(Document).filter(
knowledge_id__in=knowledge_id_list,
is_active=False)],
application).set_cache()
return chat_id

View File

@ -14,43 +14,18 @@ from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import serializers from rest_framework import serializers
from application.models import ApplicationAccessToken, ClientType, Application, ApplicationTypeChoices, WorkFlowVersion from application.models import ApplicationAccessToken, ChatUserType, Application, ApplicationTypeChoices, \
WorkFlowVersion
from application.serializers.application import ApplicationSerializerModel from application.serializers.application import ApplicationSerializerModel
from common.auth.common import ChatUserToken, ChatAuthentication from common.auth.common import ChatUserToken, ChatAuthentication
from common.constants.authentication_type import AuthenticationType from common.constants.authentication_type import AuthenticationType
from common.constants.cache_version import Cache_Version from common.constants.cache_version import Cache_Version
from common.database_model_manage.database_model_manage import DatabaseModelManage from common.database_model_manage.database_model_manage import DatabaseModelManage
from common.exception.app_exception import NotFound404, AppApiException, AppUnauthorizedFailed from common.exception.app_exception import NotFound404, AppUnauthorizedFailed
def auth(application_id, access_token, authentication_value, token_details): class AnonymousAuthenticationSerializer(serializers.Serializer):
client_id = token_details.get('client_id')
if client_id is None:
client_id = str(uuid.uuid1())
_type = AuthenticationType.CHAT_ANONYMOUS_USER
if authentication_value is not None:
application_setting_model = DatabaseModelManage.get_model('application_setting')
if application_setting_model is not None:
application_setting = QuerySet(application_setting_model).filter(application_id=application_id).first()
if application_setting.authentication:
auth_type = application_setting.authentication_value.get('type')
auth_value = authentication_value.get(auth_type + '_value')
if auth_type == 'password':
if authentication_value.get('type') == 'password':
if auth_value == authentication_value.get(auth_type + '_value'):
return ChatUserToken(application_id, None, access_token, _type, ClientType.ANONYMOUS_USER,
client_id, ChatAuthentication(auth_type, True, True))
else:
raise AppApiException(500, '认证方式不匹配')
return ChatUserToken(application_id, None, access_token, _type, ClientType.ANONYMOUS_USER,
client_id, ChatAuthentication(None, False, False))
class AuthenticationSerializer(serializers.Serializer):
access_token = serializers.CharField(required=True, label=_("access_token")) access_token = serializers.CharField(required=True, label=_("access_token"))
authentication_value = serializers.JSONField(required=False, allow_null=True,
label=_("Certification Information"))
def auth(self, request, with_valid=True): def auth(self, request, with_valid=True):
token = request.META.get('HTTP_AUTHORIZATION') token = request.META.get('HTTP_AUTHORIZATION')
@ -65,16 +40,42 @@ class AuthenticationSerializer(serializers.Serializer):
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
access_token = self.data.get("access_token") access_token = self.data.get("access_token")
application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first() application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first()
authentication_value = self.data.get('authentication_value', None)
if application_access_token is not None and application_access_token.is_active: if application_access_token is not None and application_access_token.is_active:
chat_user_token = auth(application_access_token.application_id, access_token, authentication_value, chat_user_id = token_details.get('chat_user_id') or str(uuid.uuid1())
token_details) _type = AuthenticationType.CHAT_ANONYMOUS_USER
return ChatUserToken(application_access_token.application_id, None, access_token, _type,
return chat_user_token.to_token() ChatUserType.ANONYMOUS_USER,
chat_user_id, ChatAuthentication(None, False, False)).to_token()
else: else:
raise NotFound404(404, _("Invalid access_token")) raise NotFound404(404, _("Invalid access_token"))
class AuthProfileSerializer(serializers.Serializer):
access_token = serializers.CharField(required=True, label=_("access_token"))
def profile(self):
self.is_valid(raise_exception=True)
access_token = self.data.get("access_token")
application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first()
application_id = application_access_token.application_id
profile = {
'authentication': False
}
application_setting_model = DatabaseModelManage.get_model('application_setting')
if application_setting_model:
application_setting = QuerySet(application_setting_model).filter(application_id=application_id).first()
profile = {
'icon': application_setting.application.icon,
'application_name': application_setting.application.name,
'bg_icon': application_setting.chat_background,
'authentication': application_setting.authentication,
'authentication_type': application_setting.authentication_value.get(
'type', 'password'),
'login_value': application_setting.authentication_value.get('login_value', [])
}
return profile
class ApplicationProfileSerializer(serializers.Serializer): class ApplicationProfileSerializer(serializers.Serializer):
application_id = serializers.UUIDField(required=True, label=_("Application ID")) application_id = serializers.UUIDField(required=True, label=_("Application ID"))
@ -119,11 +120,6 @@ class ApplicationProfileSerializer(serializers.Serializer):
'avatar': application_setting.avatar, 'avatar': application_setting.avatar,
'show_avatar': application_setting.show_avatar, 'show_avatar': application_setting.show_avatar,
'float_icon': application_setting.float_icon, 'float_icon': application_setting.float_icon,
'authentication': application_setting.authentication,
'authentication_type': application_setting.authentication_value.get(
'type', 'password'),
'login_value': application_setting.authentication_value.get(
'login_value', []),
'disclaimer': application_setting.disclaimer, 'disclaimer': application_setting.disclaimer,
'disclaimer_value': application_setting.disclaimer_value, 'disclaimer_value': application_setting.disclaimer_value,
'custom_theme': application_setting.custom_theme, 'custom_theme': application_setting.custom_theme,

View File

@ -6,6 +6,9 @@ app_name = 'chat'
urlpatterns = [ urlpatterns = [
path('chat/embed', views.ChatEmbedView.as_view()), path('chat/embed', views.ChatEmbedView.as_view()),
path('application/authentication', views.Authentication.as_view()), path('application/anonymous_authentication', views.AnonymousAuthentication.as_view()),
path('profile', views.ApplicationProfile.as_view()) path('auth/profile', views.AuthProfile.as_view()),
path('profile', views.ApplicationProfile.as_view()),
path('chat_message/<str:chat_id>', views.ChatView.as_view()),
path('workspace/<str:workspace_id>/application/<str:application_id>/open', views.OpenView.as_view())
] ]

View File

@ -12,14 +12,18 @@ from drf_spectacular.utils import extend_schema
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.views import APIView from rest_framework.views import APIView
from chat.api.chat_authentication_api import ChatAuthenticationAPI from chat.api.chat_api import ChatAPI
from chat.serializers.chat_authentication import AuthenticationSerializer, ApplicationProfileSerializer from chat.api.chat_authentication_api import ChatAuthenticationAPI, ChatAuthenticationProfileAPI, ChatOpenAPI
from chat.serializers.chat import OpenChatSerializers, ChatSerializers
from chat.serializers.chat_authentication import AnonymousAuthenticationSerializer, ApplicationProfileSerializer, \
AuthProfileSerializer
from common.auth import TokenAuth from common.auth import TokenAuth
from common.constants.permission_constants import ChatAuth
from common.exception.app_exception import AppAuthenticationFailed from common.exception.app_exception import AppAuthenticationFailed
from common.result import result from common.result import result
class Authentication(APIView): class AnonymousAuthentication(APIView):
def options(self, request, *args, **kwargs): def options(self, request, *args, **kwargs):
return HttpResponse( return HttpResponse(
headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true", headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true",
@ -28,18 +32,16 @@ class Authentication(APIView):
@extend_schema( @extend_schema(
methods=['POST'], methods=['POST'],
description=_('Application Certification'), description=_('Application Anonymous Certification'),
summary=_('Application Certification'), summary=_('Application Anonymous Certification'),
operation_id=_('Application Certification'), # type: ignore operation_id=_('Application Anonymous Certification'), # type: ignore
request=ChatAuthenticationAPI.get_request(), request=ChatAuthenticationAPI.get_request(),
responses=None, responses=None,
tags=[_('Chat')] # type: ignore tags=[_('Chat')] # type: ignore
) )
def post(self, request: Request): def post(self, request: Request):
return result.success( return result.success(
AuthenticationSerializer(data={'access_token': request.data.get("access_token"), AnonymousAuthenticationSerializer(data={'access_token': request.data.get("access_token")}).auth(
'authentication_value': request.data.get(
'authentication_value')}).auth(
request), request),
headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true", headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "POST", "Access-Control-Allow-Methods": "POST",
@ -60,11 +62,61 @@ class ApplicationProfile(APIView):
tags=[_('Chat')] # type: ignore tags=[_('Chat')] # type: ignore
) )
def get(self, request: Request): def get(self, request: Request):
if 'application_id' in request.auth.keywords: if isinstance(request.auth, ChatAuth):
return result.success(ApplicationProfileSerializer( return result.success(ApplicationProfileSerializer(
data={'application_id': request.auth.keywords.get('application_id')}).profile()) data={'application_id': request.auth.application_id}).profile())
raise AppAuthenticationFailed(401, "身份异常") raise AppAuthenticationFailed(401, "身份异常")
class AuthProfile(APIView):
@extend_schema(
methods=['GET'],
description=_("Get application authentication information"),
summary=_("Get application authentication information"),
operation_id=_("Get application authentication information"), # type: ignore
parameters=ChatAuthenticationProfileAPI.get_parameters(),
responses=None,
tags=[_('Chat')] # type: ignore
)
def get(self, request: Request):
return result.success(
AuthProfileSerializer(data={'access_token': request.query_params.get("access_token")}).profile())
class ChatView(APIView): class ChatView(APIView):
pass authentication_classes = [TokenAuth]
@extend_schema(
methods=['POST'],
description=_("dialogue"),
summary=_("dialogue"),
operation_id=_("dialogue"), # type: ignore
request=ChatAPI.get_request(),
parameters=ChatAPI.get_parameters(),
responses=None,
tags=[_('Chat')] # type: ignore
)
def post(self, request: Request, chat_id: str):
return ChatSerializers(data={'chat_id': chat_id,
'chat_user_id': request.auth.chat_user_id,
'chat_user_type': request.auth.chat_user_type,
'application_id': request.auth.application_id}
).chat(request.data)
class OpenView(APIView):
authentication_classes = [TokenAuth]
@extend_schema(
methods=['GET'],
description=_("Get the session id according to the application id"),
summary=_("Get the session id according to the application id"),
operation_id=_("Get the session id according to the application id"), # type: ignore
parameters=ChatOpenAPI.get_parameters(),
responses=None,
tags=[_('Chat')] # type: ignore
)
def get(self, request: Request, workspace_id: str, application_id: str):
return result.success(OpenChatSerializers(
data={'workspace_id': workspace_id, 'application_id': application_id,
'chat_user_id': request.auth.chat_user_id, 'chat_user_type': request.auth.chat_user_type}).open())

View File

@ -32,14 +32,14 @@ class ChatAuthentication:
class ChatUserToken: class ChatUserToken:
def __init__(self, application_id, user_id, access_token, _type, client_type, client_id, def __init__(self, application_id, user_id, access_token, _type, chat_user_type, chat_user_id,
authentication: ChatAuthentication): authentication: ChatAuthentication):
self.application_id = application_id self.application_id = application_id
self.user_id = user_id, self.user_id = user_id,
self.access_token = access_token self.access_token = access_token
self.type = _type self.type = _type
self.client_type = client_type self.chat_user_type = chat_user_type
self.client_id = client_id self.chat_user_id = chat_user_id
self.authentication = authentication self.authentication = authentication
def to_dict(self): def to_dict(self):
@ -48,8 +48,8 @@ class ChatUserToken:
'user_id': str(self.user_id), 'user_id': str(self.user_id),
'access_token': self.access_token, 'access_token': self.access_token,
'type': str(self.type.value), 'type': str(self.type.value),
'client_type': str(self.client_type), 'chat_user_type': str(self.chat_user_type),
'client_id': str(self.client_id), 'chat_user_id': str(self.chat_user_id),
'authentication': self.authentication.to_string() 'authentication': self.authentication.to_string()
} }
@ -59,6 +59,6 @@ class ChatUserToken:
@staticmethod @staticmethod
def new_instance(token_dict): def new_instance(token_dict):
return ChatUserToken(token_dict.get('application_id'), token_dict.get('user_id'), return ChatUserToken(token_dict.get('application_id'), token_dict.get('user_id'),
token_dict.get('access_token'), token_dict.get('type'), token_dict.get('client_type'), token_dict.get('access_token'), token_dict.get('type'), token_dict.get('chat_user_type'),
token_dict.get('client_id'), token_dict.get('chat_user_id'),
ChatAuthentication.new_instance(token_dict.get('authentication'))) ChatAuthentication.new_instance(token_dict.get('authentication')))

View File

@ -9,12 +9,11 @@
from django.db.models import QuerySet from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from application.models import ApplicationAccessToken, ClientType from application.models import ApplicationAccessToken
from common.auth.common import ChatUserToken from common.auth.common import ChatUserToken
from common.auth.handle.auth_base_handle import AuthBaseHandle from common.auth.handle.auth_base_handle import AuthBaseHandle
from common.constants.authentication_type import AuthenticationType from common.constants.authentication_type import AuthenticationType
from common.constants.permission_constants import RoleConstants, Permission, Group, Operate, Auth from common.constants.permission_constants import RoleConstants, Permission, Group, Operate, ChatAuth
from common.exception.app_exception import AppAuthenticationFailed, ChatException from common.exception.app_exception import AppAuthenticationFailed, ChatException
@ -45,11 +44,11 @@ class ChatAnonymousUserToken(AuthBaseHandle):
if request.path != '/api/application/profile': if request.path != '/api/application/profile':
if chat_user_token.authentication.is_auth and not chat_user_token.authentication.auth_passed: if chat_user_token.authentication.is_auth and not chat_user_token.authentication.auth_passed:
raise ChatException(1002, _('Authentication information is incorrect')) raise ChatException(1002, _('Authentication information is incorrect'))
return None, Auth( return None, ChatAuth(
current_role_list=[RoleConstants.CHAT_ANONYMOUS_USER], current_role_list=[RoleConstants.CHAT_ANONYMOUS_USER],
permission_list=[ permission_list=[
Permission(group=Group.APPLICATION, Permission(group=Group.APPLICATION,
operate=Operate.USE)], operate=Operate.USE)],
application_id=application_access_token.application_id, application_id=application_access_token.application_id,
client_id=auth_details.get('client_id'), chat_user_id=chat_user_token.chat_user_id,
client_type=ClientType.ANONYMOUS_USER) chat_user_type=chat_user_token.chat_user_type)

View File

@ -27,6 +27,9 @@ class Cache_Version(Enum):
# 应用对接三方应用的缓存 # 应用对接三方应用的缓存
APPLICATION_THIRD_PARTY = "APPLICATION:THIRD_PARTY", lambda key: key APPLICATION_THIRD_PARTY = "APPLICATION:THIRD_PARTY", lambda key: key
# 对话
CHAT = "CHAT", lambda key: key
def get_version(self): def get_version(self):
return self.value[0] return self.value[0]

View File

@ -925,6 +925,22 @@ def get_permission_list_by_resource_group(resource_group: ResourcePermissionGrou
PermissionConstants[k].value.resource_permission_group_list.__contains__(resource_group)] PermissionConstants[k].value.resource_permission_group_list.__contains__(resource_group)]
class ChatAuth:
def __init__(self,
current_role_list: List[RoleConstants | Role],
permission_list: List[PermissionConstants | Permission],
chat_user_id,
chat_user_type,
application_id):
# 权限列表
self.permission_list = permission_list
# 角色列表
self.role_list = current_role_list
self.chat_user_id = chat_user_id
self.chat_user_type = chat_user_type
self.application_id = application_id
class Auth: class Auth:
""" """
用于存储当前用户的角色和权限 用于存储当前用户的角色和权限