feat: 客户端不使用cookie存储改为localstore,优化认证代码

This commit is contained in:
zhangshaohu 2024-03-14 05:43:01 +08:00
parent 21a557ef43
commit 0fbd5873f7
20 changed files with 326 additions and 191 deletions

View File

@ -68,6 +68,8 @@ class IChatStep(IBaseChatPipelineStep):
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.base("补全问题")) padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.base("补全问题"))
# 是否使用流的形式输出 # 是否使用流的形式输出
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base("流式输出")) stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base("流式输出"))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
def is_valid(self, *, raise_exception=False): def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True) super().is_valid(raise_exception=True)
@ -90,5 +92,5 @@ class IChatStep(IBaseChatPipelineStep):
chat_model: BaseChatModel = None, chat_model: BaseChatModel = None,
paragraph_list=None, paragraph_list=None,
manage: PiplineManage = None, manage: PiplineManage = None,
padding_problem_text: str = None, stream: bool = True, **kwargs): padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, **kwargs):
pass pass

View File

@ -13,6 +13,7 @@ import traceback
import uuid import uuid
from typing import List from typing import List
from django.db.models import QuerySet
from django.http import StreamingHttpResponse from django.http import StreamingHttpResponse
from langchain.chat_models.base import BaseChatModel from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage from langchain.schema import BaseMessage
@ -21,9 +22,20 @@ from langchain.schema.messages import BaseMessageChunk, HumanMessage, AIMessage
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PiplineManage from application.chat_pipeline.pipeline_manage import PiplineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from application.models.api_key_model import ApplicationPublicAccessClient
from common.constants.authentication_type import AuthenticationType
from common.response import result from common.response import result
def add_access_num(client_id=None, client_type=None):
if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
application_public_access_client = QuerySet(ApplicationPublicAccessClient).filter(id=client_id).first()
if application_public_access_client is not None:
application_public_access_client.access_num = application_public_access_client.access_num + 1
application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
application_public_access_client.save()
def event_content(response, def event_content(response,
chat_id, chat_id,
chat_record_id, chat_record_id,
@ -34,7 +46,8 @@ def event_content(response,
chat_model, chat_model,
message_list: List[BaseMessage], message_list: List[BaseMessage],
problem_text: str, problem_text: str,
padding_problem_text: str = None): padding_problem_text: str = None,
client_id=None, client_type=None):
all_text = '' all_text = ''
try: try:
for chunk in response: for chunk in response:
@ -57,6 +70,7 @@ def event_content(response,
all_text, manage, step, padding_problem_text) all_text, manage, step, padding_problem_text)
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': '', 'is_end': True}) + "\n\n" 'content': '', 'is_end': True}) + "\n\n"
add_access_num(client_id, client_type)
except Exception as e: except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
@ -73,15 +87,16 @@ class BaseChatStep(IChatStep):
manage: PiplineManage = None, manage: PiplineManage = None,
padding_problem_text: str = None, padding_problem_text: str = None,
stream: bool = True, stream: bool = True,
client_id=None, client_type=None,
**kwargs): **kwargs):
if stream: if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model, return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list, paragraph_list,
manage, padding_problem_text) manage, padding_problem_text, client_id, client_type)
else: else:
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model, return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list, paragraph_list,
manage, padding_problem_text) manage, padding_problem_text, client_id, client_type)
def get_details(self, manage, **kwargs): def get_details(self, manage, **kwargs):
return { return {
@ -111,7 +126,8 @@ class BaseChatStep(IChatStep):
chat_model: BaseChatModel = None, chat_model: BaseChatModel = None,
paragraph_list=None, paragraph_list=None,
manage: PiplineManage = None, manage: PiplineManage = None,
padding_problem_text: str = None): padding_problem_text: str = None,
client_id=None, client_type=None):
# 调用模型 # 调用模型
if chat_model is None: if chat_model is None:
chat_result = iter( chat_result = iter(
@ -123,7 +139,7 @@ class BaseChatStep(IChatStep):
r = StreamingHttpResponse( r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
post_response_handler, manage, self, chat_model, message_list, problem_text, post_response_handler, manage, self, chat_model, message_list, problem_text,
padding_problem_text), padding_problem_text, client_id, client_type),
content_type='text/event-stream;charset=utf-8') content_type='text/event-stream;charset=utf-8')
r['Cache-Control'] = 'no-cache' r['Cache-Control'] = 'no-cache'
@ -136,7 +152,8 @@ class BaseChatStep(IChatStep):
chat_model: BaseChatModel = None, chat_model: BaseChatModel = None,
paragraph_list=None, paragraph_list=None,
manage: PiplineManage = None, manage: PiplineManage = None,
padding_problem_text: str = None): padding_problem_text: str = None,
client_id=None, client_type=None):
# 调用模型 # 调用模型
if chat_model is None: if chat_model is None:
chat_result = AIMessage( chat_result = AIMessage(
@ -156,5 +173,6 @@ class BaseChatStep(IChatStep):
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
chat_result.content, manage, self, padding_problem_text) chat_result.content, manage, self, padding_problem_text)
add_access_num(client_id, client_type)
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': chat_result.content, 'is_end': True}) 'content': chat_result.content, 'is_end': True})

View File

@ -0,0 +1,28 @@
# Generated by Django 4.1.10 on 2024-03-14 05:03
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('application', '0008_applicationaccesstoken_access_num_and_more'),
]
operations = [
migrations.CreateModel(
name='ApplicationPublicAccessClient',
fields=[
('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')),
('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')),
('id', models.UUIDField(primary_key=True, serialize=False, verbose_name='公共访问链接客户端id')),
('access_num', models.IntegerField(default=0, verbose_name='访问总次数次数')),
('intraday_access_num', models.IntegerField(default=0, verbose_name='当日访问次数')),
('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application', verbose_name='应用id')),
],
options={
'db_table': 'application_public_access_client',
},
),
]

View File

@ -42,3 +42,13 @@ class ApplicationAccessToken(AppModelMixin):
class Meta: class Meta:
db_table = "application_access_token" db_table = "application_access_token"
class ApplicationPublicAccessClient(AppModelMixin):
id = models.UUIDField(max_length=128, primary_key=True, verbose_name="公共访问链接客户端id")
application = models.ForeignKey(Application, on_delete=models.CASCADE, verbose_name="应用id")
access_num = models.IntegerField(default=0, verbose_name="访问总次数次数")
intraday_access_num = models.IntegerField(default=0, verbose_name="当日访问次数")
class Meta:
db_table = "application_public_access_client"

View File

@ -28,10 +28,8 @@ from common.constants.authentication_type import AuthenticationType
from common.db.search import get_dynamics_model, native_search, native_page_search from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException, NotFound404 from common.exception.app_exception import AppApiException, NotFound404
from common.util.common import getRestSeconds, set_embed_identity_cookie
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from common.util.rsa_util import encrypt
from dataset.models import DataSet, Document from dataset.models import DataSet, Document
from dataset.serializers.common_serializers import list_paragraph from dataset.serializers.common_serializers import list_paragraph
from setting.models import AuthOperate from setting.models import AuthOperate
@ -39,7 +37,6 @@ from setting.models.model_management import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from setting.serializers.provider_serializers import ModelSerializer from setting.serializers.provider_serializers import ModelSerializer
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
from smartdoc.settings import JWT_AUTH
token_cache = cache.caches['token_cache'] token_cache = cache.caches['token_cache']
chat_cache = cache.caches['chat_cache'] chat_cache = cache.caches['chat_cache']
@ -114,7 +111,7 @@ class ApplicationSerializer(serializers.Serializer):
protocol = serializers.CharField(required=True, error_messages=ErrMessage.char("协议")) protocol = serializers.CharField(required=True, error_messages=ErrMessage.char("协议"))
token = serializers.CharField(required=True, error_messages=ErrMessage.char("token")) token = serializers.CharField(required=True, error_messages=ErrMessage.char("token"))
def get_embed(self, request, with_valid=True): def get_embed(self, with_valid=True):
if with_valid: if with_valid:
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
index_path = os.path.join(PROJECT_DIR, 'apps', "application", 'template', 'embed.js') index_path = os.path.join(PROJECT_DIR, 'apps', "application", 'template', 'embed.js')
@ -136,7 +133,6 @@ class ApplicationSerializer(serializers.Serializer):
application_access_token.white_list), application_access_token.white_list),
'white_active': 'true' if application_access_token.white_active else 'false'})) 'white_active': 'true' if application_access_token.white_active else 'false'}))
response = HttpResponse(s, status=200, headers={'Content-Type': 'text/javascript'}) response = HttpResponse(s, status=200, headers={'Content-Type': 'text/javascript'})
set_embed_identity_cookie(request, response)
return response return response
class AccessTokenSerializer(serializers.Serializer): class AccessTokenSerializer(serializers.Serializer):
@ -197,17 +193,27 @@ class ApplicationSerializer(serializers.Serializer):
class Authentication(serializers.Serializer): class Authentication(serializers.Serializer):
access_token = serializers.CharField(required=True, error_messages=ErrMessage.char("access_token")) access_token = serializers.CharField(required=True, error_messages=ErrMessage.char("access_token"))
def auth(self, with_valid=True): def auth(self, request, with_valid=True):
token = request.META.get('HTTP_AUTHORIZATION', None)
token_details = None
try:
# 校验token
if token is not None:
token_details = signing.loads(token)
except Exception as e:
token = None
if with_valid: if with_valid:
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
access_token = self.data.get("access_token") access_token = self.data.get("access_token")
application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first() application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first()
if application_access_token is not None and application_access_token.is_active: if application_access_token is not None and application_access_token.is_active:
if token is None or (token_details is not None and 'client_id' not in token_details):
client_id = str(uuid.uuid1())
token = signing.dumps({'application_id': str(application_access_token.application_id), token = signing.dumps({'application_id': str(application_access_token.application_id),
'user_id': str(application_access_token.application.user.id), 'user_id': str(application_access_token.application.user.id),
'access_token': application_access_token.access_token, 'access_token': application_access_token.access_token,
'type': AuthenticationType.APPLICATION_ACCESS_TOKEN.value}) 'type': AuthenticationType.APPLICATION_ACCESS_TOKEN.value,
token_cache.set(token, application_access_token, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA']) 'client_id': client_id})
return token return token
else: else:
raise NotFound404(404, "无效的access_token") raise NotFound404(404, "无效的access_token")

View File

@ -23,7 +23,9 @@ from application.chat_pipeline.step.generate_human_message_step.impl.base_genera
from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping
from common.exception.app_exception import AppApiException from application.models.api_key_model import ApplicationPublicAccessClient, ApplicationAccessToken
from common.constants.authentication_type import AuthenticationType
from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed
from common.util.field_message import ErrMessage from common.util.field_message import ErrMessage
from common.util.rsa_util import decrypt from common.util.rsa_util import decrypt
from common.util.split_model import flat_map from common.util.split_model import flat_map
@ -32,7 +34,6 @@ from setting.models import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
chat_cache = caches['model_cache'] chat_cache = caches['model_cache']
chat_embed_identity_cache = caches['chat_cache']
class ChatInfo: class ChatInfo:
@ -75,15 +76,16 @@ class ChatInfo:
'chat_model': self.chat_model, 'chat_model': self.chat_model,
'model_id': self.application.model.id if self.application.model is not None else None, 'model_id': self.application.model.id if self.application.model is not None else None,
'problem_optimization': self.application.problem_optimization, 'problem_optimization': self.application.problem_optimization,
'stream': True 'stream': True,
} }
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler, def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
exclude_paragraph_id_list, stream=True): exclude_paragraph_id_list, client_id: str, client_type, stream=True):
params = self.to_base_pipeline_manage_params() params = self.to_base_pipeline_manage_params()
return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler, return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream} 'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream, 'client_id': client_id,
'client_type': client_type}
def append_chat_record(self, chat_record: ChatRecord): def append_chat_record(self, chat_record: ChatRecord):
# 存入缓存中 # 存入缓存中
@ -127,9 +129,37 @@ def get_post_handler(chat_info: ChatInfo):
class ChatMessageSerializer(serializers.Serializer): class ChatMessageSerializer(serializers.Serializer):
chat_id = serializers.UUIDField(required=True) chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("对话id"))
message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题"))
stream = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否流式回答"))
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答"))
application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id"))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
def chat(self, message, re_chat: bool, stream: bool): def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if self.data.get('client_type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.data.get('client_id')).first()
if access_client is None:
access_client = ApplicationPublicAccessClient(id=self.data.get('client_id'),
application_id=self.data.get('application_id'),
access_num=0,
intraday_access_num=0)
access_client.save()
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=self.data.get('application_id')).first()
if application_access_token.access_num <= access_client.intraday_access_num:
raise AppChatNumOutOfBoundsFailed(1002, "访问次数超过今日访问量")
def chat(self):
self.is_valid(raise_exception=True)
message = self.data.get('message')
re_chat = self.data.get('re_chat')
stream = self.data.get('stream')
client_id = self.data.get('client_id')
client_type = self.data.get('client_type')
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
chat_id = self.data.get('chat_id') chat_id = self.data.get('chat_id')
chat_info: ChatInfo = chat_cache.get(chat_id) chat_info: ChatInfo = chat_cache.get(chat_id)
@ -156,7 +186,7 @@ class ChatMessageSerializer(serializers.Serializer):
exclude_paragraph_id_list = list(set(paragraph_id_list)) exclude_paragraph_id_list = list(set(paragraph_id_list))
# 构建运行参数 # 构建运行参数
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list, params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list,
stream) client_id, client_type, stream)
# 运行流水线作业 # 运行流水线作业
pipline_message.run(params) pipline_message.run(params)
return pipline_message.context['chat_result'] return pipline_message.context['chat_result']

View File

@ -2,7 +2,7 @@
""" """
@project: maxkb @project: maxkb
@Author @Author
@file application_api.py @file application_key.py
@date2023/11/7 10:50 @date2023/11/7 10:50
@desc: @desc:
""" """

View File

@ -37,7 +37,7 @@ class Application(APIView):
def get(self, request: Request): def get(self, request: Request):
return ApplicationSerializer.Embed( return ApplicationSerializer.Embed(
data={'protocol': request.query_params.get('protocol'), 'token': request.query_params.get('token'), data={'protocol': request.query_params.get('protocol'), 'token': request.query_params.get('token'),
'host': request.query_params.get('host'), }).get_embed(request) 'host': request.query_params.get('host'), }).get_embed()
class Model(APIView): class Model(APIView):
authentication_classes = [TokenAuth] authentication_classes = [TokenAuth]
@ -192,7 +192,8 @@ class Application(APIView):
security=[]) security=[])
def post(self, request: Request): def post(self, request: Request):
return result.success( return result.success(
ApplicationSerializer.Authentication(data={'access_token': request.data.get("access_token")}).auth(), ApplicationSerializer.Authentication(data={'access_token': request.data.get("access_token")}).auth(
request),
headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true", headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "POST", "Access-Control-Allow-Methods": "POST",
"Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"} "Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"}

View File

@ -15,6 +15,7 @@ from application.serializers.chat_message_serializers import ChatMessageSerializ
from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer
from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi, ChatRecordImproveApi from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi, ChatRecordImproveApi
from common.auth import TokenAuth, has_permissions from common.auth import TokenAuth, has_permissions
from common.constants.authentication_type import AuthenticationType
from common.constants.permission_constants import Permission, Group, Operate, \ from common.constants.permission_constants import Permission, Group, Operate, \
RoleConstants, ViewPermission, CompareConstants RoleConstants, ViewPermission, CompareConstants
from common.response import result from common.response import result
@ -71,11 +72,15 @@ class ChatView(APIView):
dynamic_tag=keywords.get('application_id'))]) dynamic_tag=keywords.get('application_id'))])
) )
def post(self, request: Request, chat_id: str): def post(self, request: Request, chat_id: str):
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'), return ChatMessageSerializer(data={'chat_id': chat_id, 'message': request.data.get('message'),
request.data.get( 're_chat': (request.data.get(
're_chat') if 're_chat' in request.data else False, 're_chat') if 're_chat' in request.data else False),
request.data.get( 'stream': (request.data.get(
'stream') if 'stream' in request.data else True) 'stream') if 'stream' in request.data else True),
'application_id': (request.auth.keywords.get(
'application_id') if request.auth.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value else None),
'client_id': request.auth.client_id,
'client_type': request.auth.client_type}).chat()
@action(methods=['GET'], detail=False) @action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取对话列表", @swagger_auto_schema(operation_summary="获取对话列表",

View File

@ -14,6 +14,9 @@ from django.db.models import QuerySet
from rest_framework.authentication import TokenAuthentication from rest_framework.authentication import TokenAuthentication
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
from common.auth.handle.impl.application_key import ApplicationKey
from common.auth.handle.impl.public_access_token import PublicAccessToken
from common.auth.handle.impl.user_token import UserToken
from common.constants.authentication_type import AuthenticationType from common.constants.authentication_type import AuthenticationType
from common.constants.permission_constants import Auth, get_permission_list_by_role, RoleConstants, Permission, Group, \ from common.constants.permission_constants import Auth, get_permission_list_by_role, RoleConstants, Permission, Group, \
Operate Operate
@ -29,6 +32,25 @@ class AnonymousAuthentication(TokenAuthentication):
return None, None return None, None
handles = [UserToken(), PublicAccessToken(), ApplicationKey()]
class TokenDetails:
token_details = None
is_load = False
def __init__(self, token: str):
self.token = token
def get_token_details(self):
if self.token_details is None and not self.is_load:
try:
self.token_details = signing.loads(self.token)
except Exception as e:
self.is_load = True
return self.token_details
class TokenAuth(TokenAuthentication): class TokenAuth(TokenAuthentication):
# 重新 authenticate 方法,自定义认证规则 # 重新 authenticate 方法,自定义认证规则
def authenticate(self, request): def authenticate(self, request):
@ -38,62 +60,11 @@ class TokenAuth(TokenAuthentication):
if auth is None: if auth is None:
raise AppAuthenticationFailed(1003, '未登录,请先登录') raise AppAuthenticationFailed(1003, '未登录,请先登录')
try: try:
if str(auth).startswith("application-"): token_details = TokenDetails(auth)
application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=auth).first() for handle in handles:
if application_api_key is None: if handle.support(request, auth, token_details.get_token_details):
raise AppAuthenticationFailed(500, "secret_key 无效") return handle.handle(request, auth, token_details.get_token_details)
if not application_api_key.is_active:
raise AppAuthenticationFailed(500, "secret_key 无效")
permission_list = [Permission(group=Group.APPLICATION,
operate=Operate.USE,
dynamic_tag=str(
application_api_key.application_id)),
Permission(group=Group.APPLICATION,
operate=Operate.MANAGE,
dynamic_tag=str(
application_api_key.application_id))
]
return application_api_key.user, Auth(role_list=[RoleConstants.APPLICATION_KEY],
permission_list=permission_list,
application_id=application_api_key.application_id)
# 解析 token
auth_details = signing.loads(auth)
cache_token = token_cache.get(auth)
if cache_token is None:
raise AppAuthenticationFailed(1002, "登录过期")
if 'id' in auth_details and auth_details.get('type') == AuthenticationType.USER.value:
user = QuerySet(User).get(id=auth_details['id'])
# 续期
token_cache.touch(auth, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'].total_seconds())
rule = RoleConstants[user.role]
permission_list = get_permission_list_by_role(RoleConstants[user.role])
# 获取用户的应用和知识库的权限
permission_list += get_user_dynamics_permission(str(user.id))
return user, Auth(role_list=[rule],
permission_list=permission_list)
if 'application_id' in auth_details and 'access_token' in auth_details and auth_details.get(
'type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=auth_details.get('application_id')).first()
if application_access_token is None:
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
if not application_access_token.is_active:
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
if not application_access_token.access_token == auth_details.get('access_token'):
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
return application_access_token.application.user, Auth(
role_list=[RoleConstants.APPLICATION_ACCESS_TOKEN],
permission_list=[
Permission(group=Group.APPLICATION,
operate=Operate.USE,
dynamic_tag=str(
application_access_token.application_id))],
application_id=application_access_token.application_id
)
else:
raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户") raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户")
except Exception as e: except Exception as e:
traceback.format_exc() traceback.format_exc()
if isinstance(e, AppEmbedIdentityFailed) or isinstance(e, AppChatNumOutOfBoundsFailed): if isinstance(e, AppEmbedIdentityFailed) or isinstance(e, AppChatNumOutOfBoundsFailed):

View File

@ -0,0 +1,19 @@
# coding=utf-8
"""
@project: qabot
@Author
@file authenticate.py
@date2024/3/14 03:02
@desc: 认证处理器
"""
from abc import ABC, abstractmethod
class AuthBaseHandle(ABC):
@abstractmethod
def support(self, request, token: str, get_token_details):
pass
@abstractmethod
def handle(self, request, token: str, get_token_details):
pass

View File

@ -0,0 +1,41 @@
# coding=utf-8
"""
@project: qabot
@Author
@file authenticate.py
@date2024/3/14 03:02
@desc: 应用api key认证
"""
from django.db.models import QuerySet
from application.models.api_key_model import ApplicationApiKey
from common.auth.handle.auth_base_handle import AuthBaseHandle
from common.constants.authentication_type import AuthenticationType
from common.constants.permission_constants import Permission, Group, Operate, RoleConstants, Auth
from common.exception.app_exception import AppAuthenticationFailed
class ApplicationKey(AuthBaseHandle):
def handle(self, request, token: str, get_token_details):
application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=token).first()
if application_api_key is None:
raise AppAuthenticationFailed(500, "secret_key 无效")
if not application_api_key.is_active:
raise AppAuthenticationFailed(500, "secret_key 无效")
permission_list = [Permission(group=Group.APPLICATION,
operate=Operate.USE,
dynamic_tag=str(
application_api_key.application_id)),
Permission(group=Group.APPLICATION,
operate=Operate.MANAGE,
dynamic_tag=str(
application_api_key.application_id))
]
return application_api_key.user, Auth(role_list=[RoleConstants.APPLICATION_KEY],
permission_list=permission_list,
application_id=application_api_key.application_id,
client_id=token,
client_type=AuthenticationType.API_KEY.value)
def support(self, request, token: str, get_token_details):
return str(token).startswith("application-")

View File

@ -0,0 +1,49 @@
# coding=utf-8
"""
@project: qabot
@Author
@file authenticate.py
@date2024/3/14 03:02
@desc: 公共访问连接认证
"""
from django.db.models import QuerySet
from application.models.api_key_model import ApplicationAccessToken
from common.auth.handle.auth_base_handle import AuthBaseHandle
from common.constants.authentication_type import AuthenticationType
from common.constants.permission_constants import RoleConstants, Permission, Group, Operate, Auth
from common.exception.app_exception import AppAuthenticationFailed
class PublicAccessToken(AuthBaseHandle):
def support(self, request, token: str, get_token_details):
token_details = get_token_details()
if token_details is None:
return False
return (
'application_id' in token_details and
'access_token' in token_details and
token_details.get('type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value)
def handle(self, request, token: str, get_token_details):
auth_details = get_token_details()
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=auth_details.get('application_id')).first()
if application_access_token is None:
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
if not application_access_token.is_active:
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
if not application_access_token.access_token == auth_details.get('access_token'):
raise AppAuthenticationFailed(1002, "身份验证信息不正确")
return application_access_token.application.user, Auth(
role_list=[RoleConstants.APPLICATION_ACCESS_TOKEN],
permission_list=[
Permission(group=Group.APPLICATION,
operate=Operate.USE,
dynamic_tag=str(
application_access_token.application_id))],
application_id=application_access_token.application_id,
client_id=auth_details.get('client_id'),
client_type=AuthenticationType.APPLICATION_ACCESS_TOKEN.value
)

View File

@ -0,0 +1,46 @@
# coding=utf-8
"""
@project: qabot
@Author
@file authenticate.py
@date2024/3/14 03:02
@desc: 用户认证
"""
from django.db.models import QuerySet
from common.auth.handle.auth_base_handle import AuthBaseHandle
from common.constants.authentication_type import AuthenticationType
from common.constants.permission_constants import RoleConstants, get_permission_list_by_role, Auth
from common.exception.app_exception import AppAuthenticationFailed
from smartdoc.settings import JWT_AUTH
from users.models import User
from django.core import cache
from users.models.user import get_user_dynamics_permission
token_cache = cache.caches['token_cache']
class UserToken(AuthBaseHandle):
def support(self, request, token: str, get_token_details):
auth_details = get_token_details()
if auth_details is None:
return False
return 'id' in auth_details and auth_details.get('type') == AuthenticationType.USER.value
def handle(self, request, token: str, get_token_details):
cache_token = token_cache.get(token)
if cache_token is None:
raise AppAuthenticationFailed(1002, "登录过期")
auth_details = get_token_details()
user = QuerySet(User).get(id=auth_details['id'])
# 续期
token_cache.touch(token, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'].total_seconds())
rule = RoleConstants[user.role]
permission_list = get_permission_list_by_role(RoleConstants[user.role])
# 获取用户的应用和知识库的权限
permission_list += get_user_dynamics_permission(str(user.id))
return user, Auth(role_list=[rule],
permission_list=permission_list,
client_id=str(user.id),
client_type=AuthenticationType.USER.value)

View File

@ -10,7 +10,9 @@ from enum import Enum
class AuthenticationType(Enum): class AuthenticationType(Enum):
# 或者 # 普通用户
USER = "USER" USER = "USER"
# 并且 # 公共访问链接
APPLICATION_ACCESS_TOKEN = "APPLICATION_ACCESS_TOKEN" APPLICATION_ACCESS_TOKEN = "APPLICATION_ACCESS_TOKEN"
# key API
API_KEY = "API_KEY"

View File

@ -151,10 +151,12 @@ class Auth:
用于存储当前用户的角色和权限 用于存储当前用户的角色和权限
""" """
def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants | Permission], def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants | Permission]
**keywords): , client_id, client_type, **keywords):
self.role_list = role_list self.role_list = role_list
self.permission_list = permission_list self.permission_list = permission_list
self.client_id = client_id
self.client_type = client_type
self.keywords = keywords self.keywords = keywords

View File

@ -1,66 +0,0 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file chat_cookie_middleware.py
@date2024/3/13 20:13
@desc:
"""
from django.core import cache
from django.core import signing
from django.db.models import QuerySet
from django.utils.deprecation import MiddlewareMixin
from application.models.api_key_model import ApplicationAccessToken
from common.exception.app_exception import AppEmbedIdentityFailed
from common.response import result
from common.util.common import set_embed_identity_cookie, getRestSeconds
from common.util.rsa_util import decrypt
chat_cache = cache.caches['chat_cache']
class ChatCookieMiddleware(MiddlewareMixin):
def process_response(self, request, response):
if request.path.startswith('/api/application/chat_message') or request.path.startswith(
'/api/application/authentication') or request.path.startswith('/api/application/profile'):
set_embed_identity_cookie(request, response)
if 'embed_identity' in request.COOKIES and request.path.__contains__('/api/application/chat_message/'):
embed_identity = request.COOKIES['embed_identity']
try:
# 如果无法解密 说明embed_identity并非系统颁发
value = decrypt(embed_identity)
except Exception as e:
raise AppEmbedIdentityFailed(1004, '嵌入cookie不正确')
# 对话次数+1
try:
if not chat_cache.incr(value):
# 如果修改失败则设置为1
chat_cache.set(value, 1,
timeout=getRestSeconds())
except Exception as e:
# 如果修改失败则设置为1 证明 key不存在
chat_cache.set(value, 1,
timeout=getRestSeconds())
return response
def process_request(self, request):
if 'embed_identity' in request.COOKIES and request.path.__contains__('/api/application/chat_message/'):
auth = request.META.get('HTTP_AUTHORIZATION', None
)
auth_details = signing.loads(auth)
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=auth_details.get('application_id')).first()
embed_identity = request.COOKIES['embed_identity']
try:
# 如果无法解密 说明embed_identity并非系统颁发
value = decrypt(embed_identity)
except Exception as e:
return result.Result(1003,
message='访问次数超过今日访问量', response_status=460)
embed_identity_number = chat_cache.get(value)
if embed_identity_number is not None:
if application_access_token.access_num <= embed_identity_number:
return result.Result(1003,
message='访问次数超过今日访问量', response_status=461)

View File

@ -6,38 +6,10 @@
@date2023/10/16 16:42 @date2023/10/16 16:42
@desc: @desc:
""" """
import datetime
import importlib import importlib
import uuid
from functools import reduce from functools import reduce
from typing import Dict, List from typing import Dict, List
from django.core import cache
from .rsa_util import encrypt
chat_cache = cache.caches['chat_cache']
def set_embed_identity_cookie(request, response):
if 'embed_identity' in request.COOKIES:
embed_identity = request.COOKIES['embed_identity']
else:
value = str(uuid.uuid1())
embed_identity = encrypt(value)
chat_cache.set(value, 0, timeout=getRestSeconds())
response.set_cookie("embed_identity", embed_identity, max_age=3600 * 24 * 100, samesite='None',
secure=True)
return response
def getRestSeconds():
now = datetime.datetime.now()
today_begin = datetime.datetime(now.year, now.month, now.day, 0, 0, 0)
tomorrow_begin = today_begin + datetime.timedelta(days=1)
rest_seconds = (tomorrow_begin - now).seconds
return rest_seconds
def sub_array(array: List, item_num=10): def sub_array(array: List, item_num=10):
result = [] result = []

View File

@ -46,8 +46,7 @@ MIDDLEWARE = [
'django.contrib.sessions.middleware.SessionMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware', 'django.middleware.common.CommonMiddleware',
'django.contrib.messages.middleware.MessageMiddleware', 'django.contrib.messages.middleware.MessageMiddleware',
'common.middleware.static_headers_middleware.StaticHeadersMiddleware', 'common.middleware.static_headers_middleware.StaticHeadersMiddleware'
'common.middleware.chat_cookie_middleware.ChatCookieMiddleware'
] ]

View File

@ -17,7 +17,7 @@ from common.db.sql_execute import select_list
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
__all__ = ["User", "password_encrypt"] __all__ = ["User", "password_encrypt", 'get_user_dynamics_permission']
def password_encrypt(raw_password): def password_encrypt(raw_password):