feat: 客户端不使用cookie存储改为localstore,优化认证代码
This commit is contained in:
parent
21a557ef43
commit
0fbd5873f7
@ -68,6 +68,8 @@ class IChatStep(IBaseChatPipelineStep):
|
|||||||
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.base("补全问题"))
|
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.base("补全问题"))
|
||||||
# 是否使用流的形式输出
|
# 是否使用流的形式输出
|
||||||
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base("流式输出"))
|
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base("流式输出"))
|
||||||
|
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
|
||||||
|
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
|
||||||
|
|
||||||
def is_valid(self, *, raise_exception=False):
|
def is_valid(self, *, raise_exception=False):
|
||||||
super().is_valid(raise_exception=True)
|
super().is_valid(raise_exception=True)
|
||||||
@ -90,5 +92,5 @@ class IChatStep(IBaseChatPipelineStep):
|
|||||||
chat_model: BaseChatModel = None,
|
chat_model: BaseChatModel = None,
|
||||||
paragraph_list=None,
|
paragraph_list=None,
|
||||||
manage: PiplineManage = None,
|
manage: PiplineManage = None,
|
||||||
padding_problem_text: str = None, stream: bool = True, **kwargs):
|
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -13,6 +13,7 @@ import traceback
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from django.db.models import QuerySet
|
||||||
from django.http import StreamingHttpResponse
|
from django.http import StreamingHttpResponse
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.schema import BaseMessage
|
from langchain.schema import BaseMessage
|
||||||
@ -21,9 +22,20 @@ from langchain.schema.messages import BaseMessageChunk, HumanMessage, AIMessage
|
|||||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||||
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
|
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
|
||||||
|
from application.models.api_key_model import ApplicationPublicAccessClient
|
||||||
|
from common.constants.authentication_type import AuthenticationType
|
||||||
from common.response import result
|
from common.response import result
|
||||||
|
|
||||||
|
|
||||||
|
def add_access_num(client_id=None, client_type=None):
|
||||||
|
if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
|
||||||
|
application_public_access_client = QuerySet(ApplicationPublicAccessClient).filter(id=client_id).first()
|
||||||
|
if application_public_access_client is not None:
|
||||||
|
application_public_access_client.access_num = application_public_access_client.access_num + 1
|
||||||
|
application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
|
||||||
|
application_public_access_client.save()
|
||||||
|
|
||||||
|
|
||||||
def event_content(response,
|
def event_content(response,
|
||||||
chat_id,
|
chat_id,
|
||||||
chat_record_id,
|
chat_record_id,
|
||||||
@ -34,7 +46,8 @@ def event_content(response,
|
|||||||
chat_model,
|
chat_model,
|
||||||
message_list: List[BaseMessage],
|
message_list: List[BaseMessage],
|
||||||
problem_text: str,
|
problem_text: str,
|
||||||
padding_problem_text: str = None):
|
padding_problem_text: str = None,
|
||||||
|
client_id=None, client_type=None):
|
||||||
all_text = ''
|
all_text = ''
|
||||||
try:
|
try:
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
@ -57,6 +70,7 @@ def event_content(response,
|
|||||||
all_text, manage, step, padding_problem_text)
|
all_text, manage, step, padding_problem_text)
|
||||||
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||||
'content': '', 'is_end': True}) + "\n\n"
|
'content': '', 'is_end': True}) + "\n\n"
|
||||||
|
add_access_num(client_id, client_type)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||||
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||||
@ -73,15 +87,16 @@ class BaseChatStep(IChatStep):
|
|||||||
manage: PiplineManage = None,
|
manage: PiplineManage = None,
|
||||||
padding_problem_text: str = None,
|
padding_problem_text: str = None,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
|
client_id=None, client_type=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
if stream:
|
if stream:
|
||||||
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||||
paragraph_list,
|
paragraph_list,
|
||||||
manage, padding_problem_text)
|
manage, padding_problem_text, client_id, client_type)
|
||||||
else:
|
else:
|
||||||
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||||
paragraph_list,
|
paragraph_list,
|
||||||
manage, padding_problem_text)
|
manage, padding_problem_text, client_id, client_type)
|
||||||
|
|
||||||
def get_details(self, manage, **kwargs):
|
def get_details(self, manage, **kwargs):
|
||||||
return {
|
return {
|
||||||
@ -111,7 +126,8 @@ class BaseChatStep(IChatStep):
|
|||||||
chat_model: BaseChatModel = None,
|
chat_model: BaseChatModel = None,
|
||||||
paragraph_list=None,
|
paragraph_list=None,
|
||||||
manage: PiplineManage = None,
|
manage: PiplineManage = None,
|
||||||
padding_problem_text: str = None):
|
padding_problem_text: str = None,
|
||||||
|
client_id=None, client_type=None):
|
||||||
# 调用模型
|
# 调用模型
|
||||||
if chat_model is None:
|
if chat_model is None:
|
||||||
chat_result = iter(
|
chat_result = iter(
|
||||||
@ -123,7 +139,7 @@ class BaseChatStep(IChatStep):
|
|||||||
r = StreamingHttpResponse(
|
r = StreamingHttpResponse(
|
||||||
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
|
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
|
||||||
post_response_handler, manage, self, chat_model, message_list, problem_text,
|
post_response_handler, manage, self, chat_model, message_list, problem_text,
|
||||||
padding_problem_text),
|
padding_problem_text, client_id, client_type),
|
||||||
content_type='text/event-stream;charset=utf-8')
|
content_type='text/event-stream;charset=utf-8')
|
||||||
|
|
||||||
r['Cache-Control'] = 'no-cache'
|
r['Cache-Control'] = 'no-cache'
|
||||||
@ -136,7 +152,8 @@ class BaseChatStep(IChatStep):
|
|||||||
chat_model: BaseChatModel = None,
|
chat_model: BaseChatModel = None,
|
||||||
paragraph_list=None,
|
paragraph_list=None,
|
||||||
manage: PiplineManage = None,
|
manage: PiplineManage = None,
|
||||||
padding_problem_text: str = None):
|
padding_problem_text: str = None,
|
||||||
|
client_id=None, client_type=None):
|
||||||
# 调用模型
|
# 调用模型
|
||||||
if chat_model is None:
|
if chat_model is None:
|
||||||
chat_result = AIMessage(
|
chat_result = AIMessage(
|
||||||
@ -156,5 +173,6 @@ class BaseChatStep(IChatStep):
|
|||||||
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
|
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
|
||||||
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||||
chat_result.content, manage, self, padding_problem_text)
|
chat_result.content, manage, self, padding_problem_text)
|
||||||
|
add_access_num(client_id, client_type)
|
||||||
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||||
'content': chat_result.content, 'is_end': True})
|
'content': chat_result.content, 'is_end': True})
|
||||||
|
|||||||
@ -0,0 +1,28 @@
|
|||||||
|
# Generated by Django 4.1.10 on 2024-03-14 05:03
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
import django.db.models.deletion
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('application', '0008_applicationaccesstoken_access_num_and_more'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name='ApplicationPublicAccessClient',
|
||||||
|
fields=[
|
||||||
|
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
|
||||||
|
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
|
||||||
|
('id', models.UUIDField(primary_key=True, serialize=False, verbose_name='公共访问链接客户端id')),
|
||||||
|
('access_num', models.IntegerField(default=0, verbose_name='访问总次数次数')),
|
||||||
|
('intraday_access_num', models.IntegerField(default=0, verbose_name='当日访问次数')),
|
||||||
|
('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application', verbose_name='应用id')),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
'db_table': 'application_public_access_client',
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
@ -42,3 +42,13 @@ class ApplicationAccessToken(AppModelMixin):
|
|||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
db_table = "application_access_token"
|
db_table = "application_access_token"
|
||||||
|
|
||||||
|
|
||||||
|
class ApplicationPublicAccessClient(AppModelMixin):
|
||||||
|
id = models.UUIDField(max_length=128, primary_key=True, verbose_name="公共访问链接客户端id")
|
||||||
|
application = models.ForeignKey(Application, on_delete=models.CASCADE, verbose_name="应用id")
|
||||||
|
access_num = models.IntegerField(default=0, verbose_name="访问总次数次数")
|
||||||
|
intraday_access_num = models.IntegerField(default=0, verbose_name="当日访问次数")
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
db_table = "application_public_access_client"
|
||||||
|
|||||||
@ -28,10 +28,8 @@ from common.constants.authentication_type import AuthenticationType
|
|||||||
from common.db.search import get_dynamics_model, native_search, native_page_search
|
from common.db.search import get_dynamics_model, native_search, native_page_search
|
||||||
from common.db.sql_execute import select_list
|
from common.db.sql_execute import select_list
|
||||||
from common.exception.app_exception import AppApiException, NotFound404
|
from common.exception.app_exception import AppApiException, NotFound404
|
||||||
from common.util.common import getRestSeconds, set_embed_identity_cookie
|
|
||||||
from common.util.field_message import ErrMessage
|
from common.util.field_message import ErrMessage
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from common.util.rsa_util import encrypt
|
|
||||||
from dataset.models import DataSet, Document
|
from dataset.models import DataSet, Document
|
||||||
from dataset.serializers.common_serializers import list_paragraph
|
from dataset.serializers.common_serializers import list_paragraph
|
||||||
from setting.models import AuthOperate
|
from setting.models import AuthOperate
|
||||||
@ -39,7 +37,6 @@ from setting.models.model_management import Model
|
|||||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||||
from setting.serializers.provider_serializers import ModelSerializer
|
from setting.serializers.provider_serializers import ModelSerializer
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
from smartdoc.settings import JWT_AUTH
|
|
||||||
|
|
||||||
token_cache = cache.caches['token_cache']
|
token_cache = cache.caches['token_cache']
|
||||||
chat_cache = cache.caches['chat_cache']
|
chat_cache = cache.caches['chat_cache']
|
||||||
@ -114,7 +111,7 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
protocol = serializers.CharField(required=True, error_messages=ErrMessage.char("协议"))
|
protocol = serializers.CharField(required=True, error_messages=ErrMessage.char("协议"))
|
||||||
token = serializers.CharField(required=True, error_messages=ErrMessage.char("token"))
|
token = serializers.CharField(required=True, error_messages=ErrMessage.char("token"))
|
||||||
|
|
||||||
def get_embed(self, request, with_valid=True):
|
def get_embed(self, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
index_path = os.path.join(PROJECT_DIR, 'apps', "application", 'template', 'embed.js')
|
index_path = os.path.join(PROJECT_DIR, 'apps', "application", 'template', 'embed.js')
|
||||||
@ -136,7 +133,6 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
application_access_token.white_list),
|
application_access_token.white_list),
|
||||||
'white_active': 'true' if application_access_token.white_active else 'false'}))
|
'white_active': 'true' if application_access_token.white_active else 'false'}))
|
||||||
response = HttpResponse(s, status=200, headers={'Content-Type': 'text/javascript'})
|
response = HttpResponse(s, status=200, headers={'Content-Type': 'text/javascript'})
|
||||||
set_embed_identity_cookie(request, response)
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
class AccessTokenSerializer(serializers.Serializer):
|
class AccessTokenSerializer(serializers.Serializer):
|
||||||
@ -197,17 +193,27 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
class Authentication(serializers.Serializer):
|
class Authentication(serializers.Serializer):
|
||||||
access_token = serializers.CharField(required=True, error_messages=ErrMessage.char("access_token"))
|
access_token = serializers.CharField(required=True, error_messages=ErrMessage.char("access_token"))
|
||||||
|
|
||||||
def auth(self, with_valid=True):
|
def auth(self, request, with_valid=True):
|
||||||
|
token = request.META.get('HTTP_AUTHORIZATION', None)
|
||||||
|
token_details = None
|
||||||
|
try:
|
||||||
|
# 校验token
|
||||||
|
if token is not None:
|
||||||
|
token_details = signing.loads(token)
|
||||||
|
except Exception as e:
|
||||||
|
token = None
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
access_token = self.data.get("access_token")
|
access_token = self.data.get("access_token")
|
||||||
application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first()
|
application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first()
|
||||||
if application_access_token is not None and application_access_token.is_active:
|
if application_access_token is not None and application_access_token.is_active:
|
||||||
token = signing.dumps({'application_id': str(application_access_token.application_id),
|
if token is None or (token_details is not None and 'client_id' not in token_details):
|
||||||
'user_id': str(application_access_token.application.user.id),
|
client_id = str(uuid.uuid1())
|
||||||
'access_token': application_access_token.access_token,
|
token = signing.dumps({'application_id': str(application_access_token.application_id),
|
||||||
'type': AuthenticationType.APPLICATION_ACCESS_TOKEN.value})
|
'user_id': str(application_access_token.application.user.id),
|
||||||
token_cache.set(token, application_access_token, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'])
|
'access_token': application_access_token.access_token,
|
||||||
|
'type': AuthenticationType.APPLICATION_ACCESS_TOKEN.value,
|
||||||
|
'client_id': client_id})
|
||||||
return token
|
return token
|
||||||
else:
|
else:
|
||||||
raise NotFound404(404, "无效的access_token")
|
raise NotFound404(404, "无效的access_token")
|
||||||
|
|||||||
@ -23,7 +23,9 @@ from application.chat_pipeline.step.generate_human_message_step.impl.base_genera
|
|||||||
from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
|
from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
|
||||||
from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
|
from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
|
||||||
from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping
|
from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping
|
||||||
from common.exception.app_exception import AppApiException
|
from application.models.api_key_model import ApplicationPublicAccessClient, ApplicationAccessToken
|
||||||
|
from common.constants.authentication_type import AuthenticationType
|
||||||
|
from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed
|
||||||
from common.util.field_message import ErrMessage
|
from common.util.field_message import ErrMessage
|
||||||
from common.util.rsa_util import decrypt
|
from common.util.rsa_util import decrypt
|
||||||
from common.util.split_model import flat_map
|
from common.util.split_model import flat_map
|
||||||
@ -32,7 +34,6 @@ from setting.models import Model
|
|||||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||||
|
|
||||||
chat_cache = caches['model_cache']
|
chat_cache = caches['model_cache']
|
||||||
chat_embed_identity_cache = caches['chat_cache']
|
|
||||||
|
|
||||||
|
|
||||||
class ChatInfo:
|
class ChatInfo:
|
||||||
@ -75,15 +76,16 @@ class ChatInfo:
|
|||||||
'chat_model': self.chat_model,
|
'chat_model': self.chat_model,
|
||||||
'model_id': self.application.model.id if self.application.model is not None else None,
|
'model_id': self.application.model.id if self.application.model is not None else None,
|
||||||
'problem_optimization': self.application.problem_optimization,
|
'problem_optimization': self.application.problem_optimization,
|
||||||
'stream': True
|
'stream': True,
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
|
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
|
||||||
exclude_paragraph_id_list, stream=True):
|
exclude_paragraph_id_list, client_id: str, client_type, stream=True):
|
||||||
params = self.to_base_pipeline_manage_params()
|
params = self.to_base_pipeline_manage_params()
|
||||||
return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
|
return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
|
||||||
'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream}
|
'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream, 'client_id': client_id,
|
||||||
|
'client_type': client_type}
|
||||||
|
|
||||||
def append_chat_record(self, chat_record: ChatRecord):
|
def append_chat_record(self, chat_record: ChatRecord):
|
||||||
# 存入缓存中
|
# 存入缓存中
|
||||||
@ -127,9 +129,37 @@ def get_post_handler(chat_info: ChatInfo):
|
|||||||
|
|
||||||
|
|
||||||
class ChatMessageSerializer(serializers.Serializer):
|
class ChatMessageSerializer(serializers.Serializer):
|
||||||
chat_id = serializers.UUIDField(required=True)
|
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("对话id"))
|
||||||
|
message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题"))
|
||||||
|
stream = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否流式回答"))
|
||||||
|
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答"))
|
||||||
|
application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id"))
|
||||||
|
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
|
||||||
|
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
|
||||||
|
|
||||||
def chat(self, message, re_chat: bool, stream: bool):
|
def is_valid(self, *, raise_exception=False):
|
||||||
|
super().is_valid(raise_exception=True)
|
||||||
|
if self.data.get('client_type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
|
||||||
|
access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.data.get('client_id')).first()
|
||||||
|
if access_client is None:
|
||||||
|
access_client = ApplicationPublicAccessClient(id=self.data.get('client_id'),
|
||||||
|
application_id=self.data.get('application_id'),
|
||||||
|
access_num=0,
|
||||||
|
intraday_access_num=0)
|
||||||
|
access_client.save()
|
||||||
|
|
||||||
|
application_access_token = QuerySet(ApplicationAccessToken).filter(
|
||||||
|
application_id=self.data.get('application_id')).first()
|
||||||
|
if application_access_token.access_num <= access_client.intraday_access_num:
|
||||||
|
raise AppChatNumOutOfBoundsFailed(1002, "访问次数超过今日访问量")
|
||||||
|
|
||||||
|
def chat(self):
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
message = self.data.get('message')
|
||||||
|
re_chat = self.data.get('re_chat')
|
||||||
|
stream = self.data.get('stream')
|
||||||
|
client_id = self.data.get('client_id')
|
||||||
|
client_type = self.data.get('client_type')
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
chat_id = self.data.get('chat_id')
|
chat_id = self.data.get('chat_id')
|
||||||
chat_info: ChatInfo = chat_cache.get(chat_id)
|
chat_info: ChatInfo = chat_cache.get(chat_id)
|
||||||
@ -156,7 +186,7 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||||||
exclude_paragraph_id_list = list(set(paragraph_id_list))
|
exclude_paragraph_id_list = list(set(paragraph_id_list))
|
||||||
# 构建运行参数
|
# 构建运行参数
|
||||||
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list,
|
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list,
|
||||||
stream)
|
client_id, client_type, stream)
|
||||||
# 运行流水线作业
|
# 运行流水线作业
|
||||||
pipline_message.run(params)
|
pipline_message.run(params)
|
||||||
return pipline_message.context['chat_result']
|
return pipline_message.context['chat_result']
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
"""
|
"""
|
||||||
@project: maxkb
|
@project: maxkb
|
||||||
@Author:虎
|
@Author:虎
|
||||||
@file: application_api.py
|
@file: application_key.py
|
||||||
@date:2023/11/7 10:50
|
@date:2023/11/7 10:50
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -37,7 +37,7 @@ class Application(APIView):
|
|||||||
def get(self, request: Request):
|
def get(self, request: Request):
|
||||||
return ApplicationSerializer.Embed(
|
return ApplicationSerializer.Embed(
|
||||||
data={'protocol': request.query_params.get('protocol'), 'token': request.query_params.get('token'),
|
data={'protocol': request.query_params.get('protocol'), 'token': request.query_params.get('token'),
|
||||||
'host': request.query_params.get('host'), }).get_embed(request)
|
'host': request.query_params.get('host'), }).get_embed()
|
||||||
|
|
||||||
class Model(APIView):
|
class Model(APIView):
|
||||||
authentication_classes = [TokenAuth]
|
authentication_classes = [TokenAuth]
|
||||||
@ -192,7 +192,8 @@ class Application(APIView):
|
|||||||
security=[])
|
security=[])
|
||||||
def post(self, request: Request):
|
def post(self, request: Request):
|
||||||
return result.success(
|
return result.success(
|
||||||
ApplicationSerializer.Authentication(data={'access_token': request.data.get("access_token")}).auth(),
|
ApplicationSerializer.Authentication(data={'access_token': request.data.get("access_token")}).auth(
|
||||||
|
request),
|
||||||
headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true",
|
headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true",
|
||||||
"Access-Control-Allow-Methods": "POST",
|
"Access-Control-Allow-Methods": "POST",
|
||||||
"Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"}
|
"Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"}
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from application.serializers.chat_message_serializers import ChatMessageSerializ
|
|||||||
from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer
|
from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer
|
||||||
from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi, ChatRecordImproveApi
|
from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi, ChatRecordImproveApi
|
||||||
from common.auth import TokenAuth, has_permissions
|
from common.auth import TokenAuth, has_permissions
|
||||||
|
from common.constants.authentication_type import AuthenticationType
|
||||||
from common.constants.permission_constants import Permission, Group, Operate, \
|
from common.constants.permission_constants import Permission, Group, Operate, \
|
||||||
RoleConstants, ViewPermission, CompareConstants
|
RoleConstants, ViewPermission, CompareConstants
|
||||||
from common.response import result
|
from common.response import result
|
||||||
@ -71,11 +72,15 @@ class ChatView(APIView):
|
|||||||
dynamic_tag=keywords.get('application_id'))])
|
dynamic_tag=keywords.get('application_id'))])
|
||||||
)
|
)
|
||||||
def post(self, request: Request, chat_id: str):
|
def post(self, request: Request, chat_id: str):
|
||||||
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'),
|
return ChatMessageSerializer(data={'chat_id': chat_id, 'message': request.data.get('message'),
|
||||||
request.data.get(
|
're_chat': (request.data.get(
|
||||||
're_chat') if 're_chat' in request.data else False,
|
're_chat') if 're_chat' in request.data else False),
|
||||||
request.data.get(
|
'stream': (request.data.get(
|
||||||
'stream') if 'stream' in request.data else True)
|
'stream') if 'stream' in request.data else True),
|
||||||
|
'application_id': (request.auth.keywords.get(
|
||||||
|
'application_id') if request.auth.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value else None),
|
||||||
|
'client_id': request.auth.client_id,
|
||||||
|
'client_type': request.auth.client_type}).chat()
|
||||||
|
|
||||||
@action(methods=['GET'], detail=False)
|
@action(methods=['GET'], detail=False)
|
||||||
@swagger_auto_schema(operation_summary="获取对话列表",
|
@swagger_auto_schema(operation_summary="获取对话列表",
|
||||||
|
|||||||
@ -14,6 +14,9 @@ from django.db.models import QuerySet
|
|||||||
from rest_framework.authentication import TokenAuthentication
|
from rest_framework.authentication import TokenAuthentication
|
||||||
|
|
||||||
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
|
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
|
||||||
|
from common.auth.handle.impl.application_key import ApplicationKey
|
||||||
|
from common.auth.handle.impl.public_access_token import PublicAccessToken
|
||||||
|
from common.auth.handle.impl.user_token import UserToken
|
||||||
from common.constants.authentication_type import AuthenticationType
|
from common.constants.authentication_type import AuthenticationType
|
||||||
from common.constants.permission_constants import Auth, get_permission_list_by_role, RoleConstants, Permission, Group, \
|
from common.constants.permission_constants import Auth, get_permission_list_by_role, RoleConstants, Permission, Group, \
|
||||||
Operate
|
Operate
|
||||||
@ -29,6 +32,25 @@ class AnonymousAuthentication(TokenAuthentication):
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
handles = [UserToken(), PublicAccessToken(), ApplicationKey()]
|
||||||
|
|
||||||
|
|
||||||
|
class TokenDetails:
|
||||||
|
token_details = None
|
||||||
|
is_load = False
|
||||||
|
|
||||||
|
def __init__(self, token: str):
|
||||||
|
self.token = token
|
||||||
|
|
||||||
|
def get_token_details(self):
|
||||||
|
if self.token_details is None and not self.is_load:
|
||||||
|
try:
|
||||||
|
self.token_details = signing.loads(self.token)
|
||||||
|
except Exception as e:
|
||||||
|
self.is_load = True
|
||||||
|
return self.token_details
|
||||||
|
|
||||||
|
|
||||||
class TokenAuth(TokenAuthentication):
|
class TokenAuth(TokenAuthentication):
|
||||||
# 重新 authenticate 方法,自定义认证规则
|
# 重新 authenticate 方法,自定义认证规则
|
||||||
def authenticate(self, request):
|
def authenticate(self, request):
|
||||||
@ -38,62 +60,11 @@ class TokenAuth(TokenAuthentication):
|
|||||||
if auth is None:
|
if auth is None:
|
||||||
raise AppAuthenticationFailed(1003, '未登录,请先登录')
|
raise AppAuthenticationFailed(1003, '未登录,请先登录')
|
||||||
try:
|
try:
|
||||||
if str(auth).startswith("application-"):
|
token_details = TokenDetails(auth)
|
||||||
application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=auth).first()
|
for handle in handles:
|
||||||
if application_api_key is None:
|
if handle.support(request, auth, token_details.get_token_details):
|
||||||
raise AppAuthenticationFailed(500, "secret_key 无效")
|
return handle.handle(request, auth, token_details.get_token_details)
|
||||||
if not application_api_key.is_active:
|
raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户")
|
||||||
raise AppAuthenticationFailed(500, "secret_key 无效")
|
|
||||||
permission_list = [Permission(group=Group.APPLICATION,
|
|
||||||
operate=Operate.USE,
|
|
||||||
dynamic_tag=str(
|
|
||||||
application_api_key.application_id)),
|
|
||||||
Permission(group=Group.APPLICATION,
|
|
||||||
operate=Operate.MANAGE,
|
|
||||||
dynamic_tag=str(
|
|
||||||
application_api_key.application_id))
|
|
||||||
]
|
|
||||||
return application_api_key.user, Auth(role_list=[RoleConstants.APPLICATION_KEY],
|
|
||||||
permission_list=permission_list,
|
|
||||||
application_id=application_api_key.application_id)
|
|
||||||
# 解析 token
|
|
||||||
auth_details = signing.loads(auth)
|
|
||||||
cache_token = token_cache.get(auth)
|
|
||||||
if cache_token is None:
|
|
||||||
raise AppAuthenticationFailed(1002, "登录过期")
|
|
||||||
if 'id' in auth_details and auth_details.get('type') == AuthenticationType.USER.value:
|
|
||||||
user = QuerySet(User).get(id=auth_details['id'])
|
|
||||||
# 续期
|
|
||||||
token_cache.touch(auth, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'].total_seconds())
|
|
||||||
rule = RoleConstants[user.role]
|
|
||||||
permission_list = get_permission_list_by_role(RoleConstants[user.role])
|
|
||||||
# 获取用户的应用和知识库的权限
|
|
||||||
permission_list += get_user_dynamics_permission(str(user.id))
|
|
||||||
return user, Auth(role_list=[rule],
|
|
||||||
permission_list=permission_list)
|
|
||||||
if 'application_id' in auth_details and 'access_token' in auth_details and auth_details.get(
|
|
||||||
'type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
|
|
||||||
application_access_token = QuerySet(ApplicationAccessToken).filter(
|
|
||||||
application_id=auth_details.get('application_id')).first()
|
|
||||||
if application_access_token is None:
|
|
||||||
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
|
|
||||||
if not application_access_token.is_active:
|
|
||||||
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
|
|
||||||
if not application_access_token.access_token == auth_details.get('access_token'):
|
|
||||||
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
|
|
||||||
return application_access_token.application.user, Auth(
|
|
||||||
role_list=[RoleConstants.APPLICATION_ACCESS_TOKEN],
|
|
||||||
permission_list=[
|
|
||||||
Permission(group=Group.APPLICATION,
|
|
||||||
operate=Operate.USE,
|
|
||||||
dynamic_tag=str(
|
|
||||||
application_access_token.application_id))],
|
|
||||||
application_id=application_access_token.application_id
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.format_exc()
|
traceback.format_exc()
|
||||||
if isinstance(e, AppEmbedIdentityFailed) or isinstance(e, AppChatNumOutOfBoundsFailed):
|
if isinstance(e, AppEmbedIdentityFailed) or isinstance(e, AppChatNumOutOfBoundsFailed):
|
||||||
|
|||||||
19
apps/common/auth/handle/auth_base_handle.py
Normal file
19
apps/common/auth/handle/auth_base_handle.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: qabot
|
||||||
|
@Author:虎
|
||||||
|
@file: authenticate.py
|
||||||
|
@date:2024/3/14 03:02
|
||||||
|
@desc: 认证处理器
|
||||||
|
"""
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class AuthBaseHandle(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def support(self, request, token: str, get_token_details):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def handle(self, request, token: str, get_token_details):
|
||||||
|
pass
|
||||||
41
apps/common/auth/handle/impl/application_key.py
Normal file
41
apps/common/auth/handle/impl/application_key.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: qabot
|
||||||
|
@Author:虎
|
||||||
|
@file: authenticate.py
|
||||||
|
@date:2024/3/14 03:02
|
||||||
|
@desc: 应用api key认证
|
||||||
|
"""
|
||||||
|
from django.db.models import QuerySet
|
||||||
|
|
||||||
|
from application.models.api_key_model import ApplicationApiKey
|
||||||
|
from common.auth.handle.auth_base_handle import AuthBaseHandle
|
||||||
|
from common.constants.authentication_type import AuthenticationType
|
||||||
|
from common.constants.permission_constants import Permission, Group, Operate, RoleConstants, Auth
|
||||||
|
from common.exception.app_exception import AppAuthenticationFailed
|
||||||
|
|
||||||
|
|
||||||
|
class ApplicationKey(AuthBaseHandle):
|
||||||
|
def handle(self, request, token: str, get_token_details):
|
||||||
|
application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=token).first()
|
||||||
|
if application_api_key is None:
|
||||||
|
raise AppAuthenticationFailed(500, "secret_key 无效")
|
||||||
|
if not application_api_key.is_active:
|
||||||
|
raise AppAuthenticationFailed(500, "secret_key 无效")
|
||||||
|
permission_list = [Permission(group=Group.APPLICATION,
|
||||||
|
operate=Operate.USE,
|
||||||
|
dynamic_tag=str(
|
||||||
|
application_api_key.application_id)),
|
||||||
|
Permission(group=Group.APPLICATION,
|
||||||
|
operate=Operate.MANAGE,
|
||||||
|
dynamic_tag=str(
|
||||||
|
application_api_key.application_id))
|
||||||
|
]
|
||||||
|
return application_api_key.user, Auth(role_list=[RoleConstants.APPLICATION_KEY],
|
||||||
|
permission_list=permission_list,
|
||||||
|
application_id=application_api_key.application_id,
|
||||||
|
client_id=token,
|
||||||
|
client_type=AuthenticationType.API_KEY.value)
|
||||||
|
|
||||||
|
def support(self, request, token: str, get_token_details):
|
||||||
|
return str(token).startswith("application-")
|
||||||
49
apps/common/auth/handle/impl/public_access_token.py
Normal file
49
apps/common/auth/handle/impl/public_access_token.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: qabot
|
||||||
|
@Author:虎
|
||||||
|
@file: authenticate.py
|
||||||
|
@date:2024/3/14 03:02
|
||||||
|
@desc: 公共访问连接认证
|
||||||
|
"""
|
||||||
|
from django.db.models import QuerySet
|
||||||
|
|
||||||
|
from application.models.api_key_model import ApplicationAccessToken
|
||||||
|
from common.auth.handle.auth_base_handle import AuthBaseHandle
|
||||||
|
from common.constants.authentication_type import AuthenticationType
|
||||||
|
from common.constants.permission_constants import RoleConstants, Permission, Group, Operate, Auth
|
||||||
|
from common.exception.app_exception import AppAuthenticationFailed
|
||||||
|
|
||||||
|
|
||||||
|
class PublicAccessToken(AuthBaseHandle):
|
||||||
|
def support(self, request, token: str, get_token_details):
|
||||||
|
token_details = get_token_details()
|
||||||
|
if token_details is None:
|
||||||
|
return False
|
||||||
|
return (
|
||||||
|
'application_id' in token_details and
|
||||||
|
'access_token' in token_details and
|
||||||
|
token_details.get('type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value)
|
||||||
|
|
||||||
|
def handle(self, request, token: str, get_token_details):
|
||||||
|
auth_details = get_token_details()
|
||||||
|
application_access_token = QuerySet(ApplicationAccessToken).filter(
|
||||||
|
application_id=auth_details.get('application_id')).first()
|
||||||
|
if application_access_token is None:
|
||||||
|
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
|
||||||
|
if not application_access_token.is_active:
|
||||||
|
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
|
||||||
|
if not application_access_token.access_token == auth_details.get('access_token'):
|
||||||
|
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
|
||||||
|
|
||||||
|
return application_access_token.application.user, Auth(
|
||||||
|
role_list=[RoleConstants.APPLICATION_ACCESS_TOKEN],
|
||||||
|
permission_list=[
|
||||||
|
Permission(group=Group.APPLICATION,
|
||||||
|
operate=Operate.USE,
|
||||||
|
dynamic_tag=str(
|
||||||
|
application_access_token.application_id))],
|
||||||
|
application_id=application_access_token.application_id,
|
||||||
|
client_id=auth_details.get('client_id'),
|
||||||
|
client_type=AuthenticationType.APPLICATION_ACCESS_TOKEN.value
|
||||||
|
)
|
||||||
46
apps/common/auth/handle/impl/user_token.py
Normal file
46
apps/common/auth/handle/impl/user_token.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: qabot
|
||||||
|
@Author:虎
|
||||||
|
@file: authenticate.py
|
||||||
|
@date:2024/3/14 03:02
|
||||||
|
@desc: 用户认证
|
||||||
|
"""
|
||||||
|
from django.db.models import QuerySet
|
||||||
|
|
||||||
|
from common.auth.handle.auth_base_handle import AuthBaseHandle
|
||||||
|
from common.constants.authentication_type import AuthenticationType
|
||||||
|
from common.constants.permission_constants import RoleConstants, get_permission_list_by_role, Auth
|
||||||
|
from common.exception.app_exception import AppAuthenticationFailed
|
||||||
|
from smartdoc.settings import JWT_AUTH
|
||||||
|
from users.models import User
|
||||||
|
from django.core import cache
|
||||||
|
|
||||||
|
from users.models.user import get_user_dynamics_permission
|
||||||
|
|
||||||
|
token_cache = cache.caches['token_cache']
|
||||||
|
|
||||||
|
|
||||||
|
class UserToken(AuthBaseHandle):
|
||||||
|
def support(self, request, token: str, get_token_details):
|
||||||
|
auth_details = get_token_details()
|
||||||
|
if auth_details is None:
|
||||||
|
return False
|
||||||
|
return 'id' in auth_details and auth_details.get('type') == AuthenticationType.USER.value
|
||||||
|
|
||||||
|
def handle(self, request, token: str, get_token_details):
|
||||||
|
cache_token = token_cache.get(token)
|
||||||
|
if cache_token is None:
|
||||||
|
raise AppAuthenticationFailed(1002, "登录过期")
|
||||||
|
auth_details = get_token_details()
|
||||||
|
user = QuerySet(User).get(id=auth_details['id'])
|
||||||
|
# 续期
|
||||||
|
token_cache.touch(token, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'].total_seconds())
|
||||||
|
rule = RoleConstants[user.role]
|
||||||
|
permission_list = get_permission_list_by_role(RoleConstants[user.role])
|
||||||
|
# 获取用户的应用和知识库的权限
|
||||||
|
permission_list += get_user_dynamics_permission(str(user.id))
|
||||||
|
return user, Auth(role_list=[rule],
|
||||||
|
permission_list=permission_list,
|
||||||
|
client_id=str(user.id),
|
||||||
|
client_type=AuthenticationType.USER.value)
|
||||||
@ -10,7 +10,9 @@ from enum import Enum
|
|||||||
|
|
||||||
|
|
||||||
class AuthenticationType(Enum):
|
class AuthenticationType(Enum):
|
||||||
# 或者
|
# 普通用户
|
||||||
USER = "USER"
|
USER = "USER"
|
||||||
# 并且
|
# 公共访问链接
|
||||||
APPLICATION_ACCESS_TOKEN = "APPLICATION_ACCESS_TOKEN"
|
APPLICATION_ACCESS_TOKEN = "APPLICATION_ACCESS_TOKEN"
|
||||||
|
# key API
|
||||||
|
API_KEY = "API_KEY"
|
||||||
|
|||||||
@ -151,10 +151,12 @@ class Auth:
|
|||||||
用于存储当前用户的角色和权限
|
用于存储当前用户的角色和权限
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants | Permission],
|
def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants | Permission]
|
||||||
**keywords):
|
, client_id, client_type, **keywords):
|
||||||
self.role_list = role_list
|
self.role_list = role_list
|
||||||
self.permission_list = permission_list
|
self.permission_list = permission_list
|
||||||
|
self.client_id = client_id
|
||||||
|
self.client_type = client_type
|
||||||
self.keywords = keywords
|
self.keywords = keywords
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,66 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: maxkb
|
|
||||||
@Author:虎
|
|
||||||
@file: chat_cookie_middleware.py
|
|
||||||
@date:2024/3/13 20:13
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
from django.core import cache
|
|
||||||
from django.core import signing
|
|
||||||
from django.db.models import QuerySet
|
|
||||||
from django.utils.deprecation import MiddlewareMixin
|
|
||||||
|
|
||||||
from application.models.api_key_model import ApplicationAccessToken
|
|
||||||
from common.exception.app_exception import AppEmbedIdentityFailed
|
|
||||||
from common.response import result
|
|
||||||
from common.util.common import set_embed_identity_cookie, getRestSeconds
|
|
||||||
from common.util.rsa_util import decrypt
|
|
||||||
|
|
||||||
chat_cache = cache.caches['chat_cache']
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCookieMiddleware(MiddlewareMixin):
|
|
||||||
|
|
||||||
def process_response(self, request, response):
|
|
||||||
if request.path.startswith('/api/application/chat_message') or request.path.startswith(
|
|
||||||
'/api/application/authentication') or request.path.startswith('/api/application/profile'):
|
|
||||||
set_embed_identity_cookie(request, response)
|
|
||||||
if 'embed_identity' in request.COOKIES and request.path.__contains__('/api/application/chat_message/'):
|
|
||||||
embed_identity = request.COOKIES['embed_identity']
|
|
||||||
try:
|
|
||||||
# 如果无法解密 说明embed_identity并非系统颁发
|
|
||||||
value = decrypt(embed_identity)
|
|
||||||
except Exception as e:
|
|
||||||
raise AppEmbedIdentityFailed(1004, '嵌入cookie不正确')
|
|
||||||
# 对话次数+1
|
|
||||||
try:
|
|
||||||
if not chat_cache.incr(value):
|
|
||||||
# 如果修改失败则设置为1
|
|
||||||
chat_cache.set(value, 1,
|
|
||||||
timeout=getRestSeconds())
|
|
||||||
except Exception as e:
|
|
||||||
# 如果修改失败则设置为1 证明 key不存在
|
|
||||||
chat_cache.set(value, 1,
|
|
||||||
timeout=getRestSeconds())
|
|
||||||
return response
|
|
||||||
|
|
||||||
def process_request(self, request):
|
|
||||||
if 'embed_identity' in request.COOKIES and request.path.__contains__('/api/application/chat_message/'):
|
|
||||||
auth = request.META.get('HTTP_AUTHORIZATION', None
|
|
||||||
)
|
|
||||||
auth_details = signing.loads(auth)
|
|
||||||
application_access_token = QuerySet(ApplicationAccessToken).filter(
|
|
||||||
application_id=auth_details.get('application_id')).first()
|
|
||||||
embed_identity = request.COOKIES['embed_identity']
|
|
||||||
try:
|
|
||||||
# 如果无法解密 说明embed_identity并非系统颁发
|
|
||||||
value = decrypt(embed_identity)
|
|
||||||
except Exception as e:
|
|
||||||
return result.Result(1003,
|
|
||||||
message='访问次数超过今日访问量', response_status=460)
|
|
||||||
embed_identity_number = chat_cache.get(value)
|
|
||||||
if embed_identity_number is not None:
|
|
||||||
if application_access_token.access_num <= embed_identity_number:
|
|
||||||
return result.Result(1003,
|
|
||||||
message='访问次数超过今日访问量', response_status=461)
|
|
||||||
@ -6,38 +6,10 @@
|
|||||||
@date:2023/10/16 16:42
|
@date:2023/10/16 16:42
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
import datetime
|
|
||||||
import importlib
|
import importlib
|
||||||
import uuid
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from django.core import cache
|
|
||||||
|
|
||||||
from .rsa_util import encrypt
|
|
||||||
|
|
||||||
chat_cache = cache.caches['chat_cache']
|
|
||||||
|
|
||||||
|
|
||||||
def set_embed_identity_cookie(request, response):
|
|
||||||
if 'embed_identity' in request.COOKIES:
|
|
||||||
embed_identity = request.COOKIES['embed_identity']
|
|
||||||
else:
|
|
||||||
value = str(uuid.uuid1())
|
|
||||||
embed_identity = encrypt(value)
|
|
||||||
chat_cache.set(value, 0, timeout=getRestSeconds())
|
|
||||||
response.set_cookie("embed_identity", embed_identity, max_age=3600 * 24 * 100, samesite='None',
|
|
||||||
secure=True)
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def getRestSeconds():
|
|
||||||
now = datetime.datetime.now()
|
|
||||||
today_begin = datetime.datetime(now.year, now.month, now.day, 0, 0, 0)
|
|
||||||
tomorrow_begin = today_begin + datetime.timedelta(days=1)
|
|
||||||
rest_seconds = (tomorrow_begin - now).seconds
|
|
||||||
return rest_seconds
|
|
||||||
|
|
||||||
|
|
||||||
def sub_array(array: List, item_num=10):
|
def sub_array(array: List, item_num=10):
|
||||||
result = []
|
result = []
|
||||||
|
|||||||
@ -46,8 +46,7 @@ MIDDLEWARE = [
|
|||||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||||
'django.middleware.common.CommonMiddleware',
|
'django.middleware.common.CommonMiddleware',
|
||||||
'django.contrib.messages.middleware.MessageMiddleware',
|
'django.contrib.messages.middleware.MessageMiddleware',
|
||||||
'common.middleware.static_headers_middleware.StaticHeadersMiddleware',
|
'common.middleware.static_headers_middleware.StaticHeadersMiddleware'
|
||||||
'common.middleware.chat_cookie_middleware.ChatCookieMiddleware'
|
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from common.db.sql_execute import select_list
|
|||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
|
|
||||||
__all__ = ["User", "password_encrypt"]
|
__all__ = ["User", "password_encrypt", 'get_user_dynamics_permission']
|
||||||
|
|
||||||
|
|
||||||
def password_encrypt(raw_password):
|
def password_encrypt(raw_password):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user