feat: 优化对话逻辑
This commit is contained in:
parent
7349f00c54
commit
3f87335c80
55
apps/application/chat_pipeline/I_base_chat_pipeline.py
Normal file
55
apps/application/chat_pipeline/I_base_chat_pipeline.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: I_base_chat_pipeline.py
|
||||||
|
@date:2024/1/9 17:25
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
|
||||||
|
class IBaseChatPipelineStep:
|
||||||
|
def __init__(self):
|
||||||
|
# 当前步骤上下文,用于存储当前步骤信息
|
||||||
|
self.context = {}
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_step_serializer(self, manage) -> Type[serializers.Serializer]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def valid_args(self, manage):
|
||||||
|
step_serializer_clazz = self.get_step_serializer(manage)
|
||||||
|
step_serializer = step_serializer_clazz(data=manage.context)
|
||||||
|
step_serializer.is_valid(raise_exception=True)
|
||||||
|
self.context['step_args'] = step_serializer.data
|
||||||
|
|
||||||
|
def run(self, manage):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param manage: 步骤管理器
|
||||||
|
:return: 执行结果
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
# 校验参数,
|
||||||
|
self.valid_args(manage)
|
||||||
|
self._run(manage)
|
||||||
|
self.context['start_time'] = start_time
|
||||||
|
self.context['run_time'] = time.time() - start_time
|
||||||
|
|
||||||
|
def _run(self, manage):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def execute(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_details(self, manage, **kwargs):
|
||||||
|
"""
|
||||||
|
运行详情
|
||||||
|
:return: 步骤详情
|
||||||
|
"""
|
||||||
|
return None
|
||||||
8
apps/application/chat_pipeline/__init__.py
Normal file
8
apps/application/chat_pipeline/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: __init__.py.py
|
||||||
|
@date:2024/1/9 17:23
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
45
apps/application/chat_pipeline/pipeline_manage.py
Normal file
45
apps/application/chat_pipeline/pipeline_manage.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: pipeline_manage.py
|
||||||
|
@date:2024/1/9 17:40
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
from functools import reduce
|
||||||
|
from typing import List, Type, Dict
|
||||||
|
|
||||||
|
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||||
|
|
||||||
|
|
||||||
|
class PiplineManage:
|
||||||
|
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]]):
|
||||||
|
# 步骤执行器
|
||||||
|
self.step_list = [step() for step in step_list]
|
||||||
|
# 上下文
|
||||||
|
self.context = {'message_tokens': 0, 'answer_tokens': 0}
|
||||||
|
|
||||||
|
def run(self, context: Dict = None):
|
||||||
|
self.context['start_time'] = time.time()
|
||||||
|
if context is not None:
|
||||||
|
for key, value in context.items():
|
||||||
|
self.context[key] = value
|
||||||
|
for step in self.step_list:
|
||||||
|
step.run(self)
|
||||||
|
|
||||||
|
def get_details(self):
|
||||||
|
return reduce(lambda x, y: {**x, **y}, [{item.get('step_type'): item} for item in
|
||||||
|
filter(lambda r: r is not None,
|
||||||
|
[row.get_details(self) for row in self.step_list])], {})
|
||||||
|
|
||||||
|
class builder:
|
||||||
|
def __init__(self):
|
||||||
|
self.step_list: List[Type[IBaseChatPipelineStep]] = []
|
||||||
|
|
||||||
|
def append_step(self, step: Type[IBaseChatPipelineStep]):
|
||||||
|
self.step_list.append(step)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
return PiplineManage(step_list=self.step_list)
|
||||||
8
apps/application/chat_pipeline/step/__init__.py
Normal file
8
apps/application/chat_pipeline/step/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: __init__.py.py
|
||||||
|
@date:2024/1/9 18:23
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
@ -0,0 +1,8 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: __init__.py.py
|
||||||
|
@date:2024/1/9 18:23
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
88
apps/application/chat_pipeline/step/chat_step/i_chat_step.py
Normal file
88
apps/application/chat_pipeline/step/chat_step/i_chat_step.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: i_chat_step.py
|
||||||
|
@date:2024/1/9 18:17
|
||||||
|
@desc: 对话
|
||||||
|
"""
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Type, List
|
||||||
|
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.schema import BaseMessage
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||||
|
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||||
|
from common.field.common import InstanceField
|
||||||
|
from dataset.models import Paragraph
|
||||||
|
|
||||||
|
|
||||||
|
class ModelField(serializers.Field):
|
||||||
|
def to_internal_value(self, data):
|
||||||
|
if not isinstance(data, BaseChatModel):
|
||||||
|
self.fail('模型类型错误', value=data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def to_representation(self, value):
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class MessageField(serializers.Field):
|
||||||
|
def to_internal_value(self, data):
|
||||||
|
if not isinstance(data, BaseMessage):
|
||||||
|
self.fail('message类型错误', value=data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def to_representation(self, value):
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class PostResponseHandler:
|
||||||
|
@abstractmethod
|
||||||
|
def handler(self, chat_id, chat_record_id, paragraph_list: List[Paragraph], problem_text: str, answer_text,
|
||||||
|
manage, step, padding_problem_text: str = None, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class IChatStep(IBaseChatPipelineStep):
|
||||||
|
class InstanceSerializer(serializers.Serializer):
|
||||||
|
# 对话列表
|
||||||
|
message_list = serializers.ListField(required=True, child=MessageField(required=True))
|
||||||
|
# 大语言模型
|
||||||
|
chat_model = ModelField()
|
||||||
|
# 段落列表
|
||||||
|
paragraph_list = serializers.ListField()
|
||||||
|
# 对话id
|
||||||
|
chat_id = serializers.UUIDField(required=True)
|
||||||
|
# 用户问题
|
||||||
|
problem_text = serializers.CharField(required=True)
|
||||||
|
# 后置处理器
|
||||||
|
post_response_handler = InstanceField(model_type=PostResponseHandler)
|
||||||
|
# 补全问题
|
||||||
|
padding_problem_text = serializers.CharField(required=False)
|
||||||
|
|
||||||
|
def is_valid(self, *, raise_exception=False):
|
||||||
|
super().is_valid(raise_exception=True)
|
||||||
|
message_list: List = self.initial_data.get('message_list')
|
||||||
|
for message in message_list:
|
||||||
|
if not isinstance(message, BaseMessage):
|
||||||
|
raise Exception("message 类型错误")
|
||||||
|
|
||||||
|
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
|
||||||
|
return self.InstanceSerializer
|
||||||
|
|
||||||
|
def _run(self, manage: PiplineManage):
|
||||||
|
chat_result = self.execute(**self.context['step_args'], manage=manage)
|
||||||
|
manage.context['chat_result'] = chat_result
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def execute(self, message_list: List[BaseMessage],
|
||||||
|
chat_id, problem_text,
|
||||||
|
post_response_handler: PostResponseHandler,
|
||||||
|
chat_model: BaseChatModel = None,
|
||||||
|
paragraph_list=None,
|
||||||
|
manage: PiplineManage = None,
|
||||||
|
padding_problem_text: str = None, **kwargs):
|
||||||
|
pass
|
||||||
@ -0,0 +1,111 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: base_chat_step.py
|
||||||
|
@date:2024/1/9 18:25
|
||||||
|
@desc: 对话step Base实现
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from django.http import StreamingHttpResponse
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.schema import BaseMessage
|
||||||
|
from langchain.schema.messages import BaseMessageChunk, HumanMessage
|
||||||
|
|
||||||
|
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||||
|
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
|
||||||
|
from dataset.models import Paragraph
|
||||||
|
|
||||||
|
|
||||||
|
def event_content(response,
|
||||||
|
chat_id,
|
||||||
|
chat_record_id,
|
||||||
|
paragraph_list: List[Paragraph],
|
||||||
|
post_response_handler: PostResponseHandler,
|
||||||
|
manage,
|
||||||
|
step,
|
||||||
|
chat_model,
|
||||||
|
message_list: List[BaseMessage],
|
||||||
|
problem_text: str,
|
||||||
|
padding_problem_text: str = None):
|
||||||
|
all_text = ''
|
||||||
|
try:
|
||||||
|
for chunk in response:
|
||||||
|
all_text += chunk.content
|
||||||
|
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||||
|
'content': chunk.content, 'is_end': False}) + "\n\n"
|
||||||
|
|
||||||
|
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||||
|
'content': '', 'is_end': True}) + "\n\n"
|
||||||
|
# 获取token
|
||||||
|
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||||
|
response_token = chat_model.get_num_tokens(all_text)
|
||||||
|
step.context['message_tokens'] = request_token
|
||||||
|
step.context['answer_tokens'] = response_token
|
||||||
|
current_time = time.time()
|
||||||
|
step.context['answer_text'] = all_text
|
||||||
|
step.context['run_time'] = current_time - step.context['start_time']
|
||||||
|
manage.context['run_time'] = current_time - manage.context['start_time']
|
||||||
|
manage.context['message_tokens'] = manage.context['message_tokens'] + request_token
|
||||||
|
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
|
||||||
|
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||||
|
all_text, manage, step, padding_problem_text)
|
||||||
|
except Exception as e:
|
||||||
|
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
|
||||||
|
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||||
|
'content': '异常' + str(e), 'is_end': True}) + "\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseChatStep(IChatStep):
|
||||||
|
def execute(self, message_list: List[BaseMessage],
|
||||||
|
chat_id,
|
||||||
|
problem_text,
|
||||||
|
post_response_handler: PostResponseHandler,
|
||||||
|
chat_model: BaseChatModel = None,
|
||||||
|
paragraph_list=None,
|
||||||
|
manage: PiplineManage = None,
|
||||||
|
padding_problem_text: str = None,
|
||||||
|
**kwargs):
|
||||||
|
# 调用模型
|
||||||
|
if chat_model is None:
|
||||||
|
chat_result = iter(
|
||||||
|
[BaseMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
|
||||||
|
else:
|
||||||
|
chat_result = chat_model.stream(message_list)
|
||||||
|
|
||||||
|
chat_record_id = uuid.uuid1()
|
||||||
|
r = StreamingHttpResponse(
|
||||||
|
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
|
||||||
|
post_response_handler, manage, self, chat_model, message_list, problem_text,
|
||||||
|
padding_problem_text),
|
||||||
|
content_type='text/event-stream;charset=utf-8')
|
||||||
|
|
||||||
|
r['Cache-Control'] = 'no-cache'
|
||||||
|
return r
|
||||||
|
|
||||||
|
def get_details(self, manage, **kwargs):
|
||||||
|
return {
|
||||||
|
'step_type': 'chat_step',
|
||||||
|
'run_time': self.context['run_time'],
|
||||||
|
'model_id': str(manage.context['model_id']),
|
||||||
|
'message_list': self.reset_message_list(self.context['step_args'].get('message_list'),
|
||||||
|
self.context['answer_text']),
|
||||||
|
'message_tokens': self.context['message_tokens'],
|
||||||
|
'answer_tokens': self.context['answer_tokens'],
|
||||||
|
'cost': 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reset_message_list(message_list: List[BaseMessage], answer_text):
|
||||||
|
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
|
||||||
|
message
|
||||||
|
in
|
||||||
|
message_list]
|
||||||
|
result.append({'role': 'ai', 'content': answer_text})
|
||||||
|
return result
|
||||||
@ -0,0 +1,8 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: __init__.py.py
|
||||||
|
@date:2024/1/9 18:23
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
@ -0,0 +1,68 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: i_generate_human_message_step.py
|
||||||
|
@date:2024/1/9 18:15
|
||||||
|
@desc: 生成对话模板
|
||||||
|
"""
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Type, List
|
||||||
|
|
||||||
|
from langchain.schema import BaseMessage
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||||
|
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||||
|
from application.models import ChatRecord
|
||||||
|
from common.field.common import InstanceField
|
||||||
|
from dataset.models import Paragraph
|
||||||
|
|
||||||
|
|
||||||
|
class IGenerateHumanMessageStep(IBaseChatPipelineStep):
|
||||||
|
class InstanceSerializer(serializers.Serializer):
|
||||||
|
# 问题
|
||||||
|
problem_text = serializers.CharField(required=True)
|
||||||
|
# 段落列表
|
||||||
|
paragraph_list = serializers.ListField(child=InstanceField(model_type=Paragraph, required=True))
|
||||||
|
# 历史对答
|
||||||
|
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True))
|
||||||
|
# 多轮对话数量
|
||||||
|
dialogue_number = serializers.IntegerField(required=True)
|
||||||
|
# 最大携带知识库段落长度
|
||||||
|
max_paragraph_char_number = serializers.IntegerField(required=True)
|
||||||
|
# 模板
|
||||||
|
prompt = serializers.CharField(required=True)
|
||||||
|
# 补齐问题
|
||||||
|
padding_problem_text = serializers.CharField(required=False)
|
||||||
|
|
||||||
|
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
|
||||||
|
return self.InstanceSerializer
|
||||||
|
|
||||||
|
def _run(self, manage: PiplineManage):
|
||||||
|
message_list = self.execute(**self.context['step_args'])
|
||||||
|
manage.context['message_list'] = message_list
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def execute(self,
|
||||||
|
problem_text: str,
|
||||||
|
paragraph_list: List[Paragraph],
|
||||||
|
history_chat_record: List[ChatRecord],
|
||||||
|
dialogue_number: int,
|
||||||
|
max_paragraph_char_number: int,
|
||||||
|
prompt: str,
|
||||||
|
padding_problem_text: str = None,
|
||||||
|
**kwargs) -> List[BaseMessage]:
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param problem_text: 原始问题文本
|
||||||
|
:param paragraph_list: 段落列表
|
||||||
|
:param history_chat_record: 历史对话记录
|
||||||
|
:param dialogue_number: 多轮对话数量
|
||||||
|
:param max_paragraph_char_number: 最大段落长度
|
||||||
|
:param prompt: 模板
|
||||||
|
:param padding_problem_text 用户修改文本
|
||||||
|
:param kwargs: 其他参数
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
pass
|
||||||
@ -0,0 +1,57 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: base_generate_human_message_step.py.py
|
||||||
|
@date:2024/1/10 17:50
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.schema import BaseMessage, HumanMessage
|
||||||
|
|
||||||
|
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
|
||||||
|
IGenerateHumanMessageStep
|
||||||
|
from application.models import ChatRecord
|
||||||
|
from common.util.split_model import flat_map
|
||||||
|
from dataset.models import Paragraph
|
||||||
|
|
||||||
|
|
||||||
|
class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
|
||||||
|
|
||||||
|
def execute(self, problem_text: str,
|
||||||
|
paragraph_list: List[Paragraph],
|
||||||
|
history_chat_record: List[ChatRecord],
|
||||||
|
dialogue_number: int,
|
||||||
|
max_paragraph_char_number: int,
|
||||||
|
prompt: str,
|
||||||
|
padding_problem_text: str = None,
|
||||||
|
**kwargs) -> List[BaseMessage]:
|
||||||
|
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
||||||
|
start_index = len(history_chat_record) - dialogue_number
|
||||||
|
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||||
|
for index in
|
||||||
|
range(start_index if start_index > 0 else 0, len(history_chat_record))]
|
||||||
|
return [*flat_map(history_message),
|
||||||
|
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list)]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_human_message(prompt: str,
|
||||||
|
problem: str,
|
||||||
|
max_paragraph_char_number: int,
|
||||||
|
paragraph_list: List[Paragraph]):
|
||||||
|
if paragraph_list is None or len(paragraph_list) == 0:
|
||||||
|
return HumanMessage(content=problem)
|
||||||
|
temp_data = ""
|
||||||
|
data_list = []
|
||||||
|
for p in paragraph_list:
|
||||||
|
content = f"{p.title}:{p.content}"
|
||||||
|
temp_data += content
|
||||||
|
if len(temp_data) > max_paragraph_char_number:
|
||||||
|
row_data = content[0:max_paragraph_char_number - len(temp_data)]
|
||||||
|
data_list.append(f"<data>{row_data}</data>")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
data_list.append(f"<data>{content}</data>")
|
||||||
|
data = "\n".join(data_list)
|
||||||
|
return HumanMessage(content=prompt.format(**{'data': data, 'question': problem}))
|
||||||
@ -0,0 +1,8 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: __init__.py.py
|
||||||
|
@date:2024/1/9 18:23
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
@ -0,0 +1,49 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: i_reset_problem_step.py
|
||||||
|
@date:2024/1/9 18:12
|
||||||
|
@desc: 重写处理问题
|
||||||
|
"""
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Type, List
|
||||||
|
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||||
|
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||||
|
from application.chat_pipeline.step.chat_step.i_chat_step import ModelField
|
||||||
|
from application.models import ChatRecord
|
||||||
|
from common.field.common import InstanceField
|
||||||
|
|
||||||
|
|
||||||
|
class IResetProblemStep(IBaseChatPipelineStep):
|
||||||
|
class InstanceSerializer(serializers.Serializer):
|
||||||
|
# 问题文本
|
||||||
|
problem_text = serializers.CharField(required=True)
|
||||||
|
# 历史对答
|
||||||
|
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True))
|
||||||
|
# 大语言模型
|
||||||
|
chat_model = ModelField()
|
||||||
|
|
||||||
|
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
|
||||||
|
return self.InstanceSerializer
|
||||||
|
|
||||||
|
def _run(self, manage: PiplineManage):
|
||||||
|
padding_problem = self.execute(**self.context.get('step_args'))
|
||||||
|
# 用户输入问题
|
||||||
|
source_problem_text = self.context.get('step_args').get('problem_text')
|
||||||
|
self.context['problem_text'] = source_problem_text
|
||||||
|
self.context['padding_problem_text'] = padding_problem
|
||||||
|
manage.context['problem_text'] = source_problem_text
|
||||||
|
manage.context['padding_problem_text'] = padding_problem
|
||||||
|
# 累加tokens
|
||||||
|
manage.context['message_tokens'] = manage.context['message_tokens'] + self.context.get('message_tokens')
|
||||||
|
manage.context['answer_tokens'] = manage.context['answer_tokens'] + self.context.get('answer_tokens')
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
|
||||||
|
**kwargs):
|
||||||
|
pass
|
||||||
@ -0,0 +1,47 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: base_reset_problem_step.py
|
||||||
|
@date:2024/1/10 14:35
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.schema import HumanMessage
|
||||||
|
|
||||||
|
from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep
|
||||||
|
from application.models import ChatRecord
|
||||||
|
from common.util.split_model import flat_map
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
'()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在<data></data>标签中')
|
||||||
|
|
||||||
|
|
||||||
|
class BaseResetProblemStep(IResetProblemStep):
|
||||||
|
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
|
||||||
|
**kwargs) -> str:
|
||||||
|
start_index = len(history_chat_record) - 3
|
||||||
|
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||||
|
for index in
|
||||||
|
range(start_index if start_index > 0 else 0, len(history_chat_record))]
|
||||||
|
message_list = [*flat_map(history_message),
|
||||||
|
HumanMessage(content=prompt.format(**{'question': problem_text}))]
|
||||||
|
response = chat_model(message_list)
|
||||||
|
padding_problem = response.content[response.content.index('<data>') + 6:response.content.index('</data>')]
|
||||||
|
self.context['message_tokens'] = chat_model.get_num_tokens_from_messages(message_list)
|
||||||
|
self.context['answer_tokens'] = chat_model.get_num_tokens(padding_problem)
|
||||||
|
return padding_problem
|
||||||
|
|
||||||
|
def get_details(self, manage, **kwargs):
|
||||||
|
return {
|
||||||
|
'step_type': 'problem_padding',
|
||||||
|
'run_time': self.context['run_time'],
|
||||||
|
'model_id': str(manage.context['model_id']) if 'model_id' in manage.context else None,
|
||||||
|
'message_tokens': self.context['message_tokens'],
|
||||||
|
'answer_tokens': self.context['answer_tokens'],
|
||||||
|
'cost': 0,
|
||||||
|
'padding_problem_text': self.context.get('padding_problem_text'),
|
||||||
|
'problem_text': self.context.get("step_args").get('problem_text'),
|
||||||
|
}
|
||||||
@ -0,0 +1,8 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: __init__.py.py
|
||||||
|
@date:2024/1/9 18:24
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
@ -0,0 +1,58 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: i_search_dataset_step.py
|
||||||
|
@date:2024/1/9 18:10
|
||||||
|
@desc: 检索知识库
|
||||||
|
"""
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import List, Type
|
||||||
|
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
|
||||||
|
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||||
|
from dataset.models import Paragraph
|
||||||
|
|
||||||
|
|
||||||
|
class ISearchDatasetStep(IBaseChatPipelineStep):
|
||||||
|
class InstanceSerializer(serializers.Serializer):
|
||||||
|
# 原始问题文本
|
||||||
|
problem_text = serializers.CharField(required=True)
|
||||||
|
# 系统补全问题文本
|
||||||
|
padding_problem_text = serializers.CharField(required=False)
|
||||||
|
# 需要查询的数据集id列表
|
||||||
|
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
|
||||||
|
# 需要排除的文档id
|
||||||
|
exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
|
||||||
|
# 需要排除向量id
|
||||||
|
exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
|
||||||
|
# 需要查询的条数
|
||||||
|
top_n = serializers.IntegerField(required=True)
|
||||||
|
# 相似度 0-1之间
|
||||||
|
similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
|
||||||
|
|
||||||
|
def get_step_serializer(self, manage: PiplineManage) -> Type[InstanceSerializer]:
|
||||||
|
return self.InstanceSerializer
|
||||||
|
|
||||||
|
def _run(self, manage: PiplineManage):
|
||||||
|
paragraph_list = self.execute(**self.context['step_args'])
|
||||||
|
manage.context['paragraph_list'] = paragraph_list
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||||
|
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||||
|
**kwargs) -> List[Paragraph]:
|
||||||
|
"""
|
||||||
|
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
|
||||||
|
:param similarity: 相关性
|
||||||
|
:param top_n: 查询多少条
|
||||||
|
:param problem_text: 用户问题
|
||||||
|
:param dataset_id_list: 需要查询的数据集id列表
|
||||||
|
:param exclude_document_id_list: 需要排除的文档id
|
||||||
|
:param exclude_paragraph_id_list: 需要排除段落id
|
||||||
|
:param padding_problem_text 补全问题
|
||||||
|
:return: 段落列表
|
||||||
|
"""
|
||||||
|
pass
|
||||||
@ -0,0 +1,58 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: base_search_dataset_step.py
|
||||||
|
@date:2024/1/10 10:33
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from django.db.models import QuerySet
|
||||||
|
|
||||||
|
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
|
||||||
|
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||||
|
from dataset.models import Paragraph
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSearchDatasetStep(ISearchDatasetStep):
|
||||||
|
|
||||||
|
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||||
|
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||||
|
**kwargs) -> List[Paragraph]:
|
||||||
|
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
||||||
|
embedding_model = EmbeddingModel.get_embedding_model()
|
||||||
|
embedding_value = embedding_model.embed_query(exec_problem_text)
|
||||||
|
vector = VectorStore.get_embedding_vector()
|
||||||
|
embedding_list = vector.query(embedding_value, dataset_id_list, exclude_document_id_list,
|
||||||
|
exclude_paragraph_id_list, True, top_n, similarity)
|
||||||
|
if embedding_list is None:
|
||||||
|
return []
|
||||||
|
return self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def list_paragraph(paragraph_id_list: List, vector):
|
||||||
|
if paragraph_id_list is None or len(paragraph_id_list) == 0:
|
||||||
|
return []
|
||||||
|
paragraph_list = QuerySet(Paragraph).filter(id__in=paragraph_id_list)
|
||||||
|
# 如果向量库中存在脏数据 直接删除
|
||||||
|
if len(paragraph_list) != len(paragraph_id_list):
|
||||||
|
exist_paragraph_list = [str(row.id) for row in paragraph_list]
|
||||||
|
for paragraph_id in paragraph_id_list:
|
||||||
|
if not exist_paragraph_list.__contains__(paragraph_id):
|
||||||
|
vector.delete_by_paragraph_id(paragraph_id)
|
||||||
|
return paragraph_list
|
||||||
|
|
||||||
|
def get_details(self, manage, **kwargs):
|
||||||
|
step_args = self.context['step_args']
|
||||||
|
|
||||||
|
return {
|
||||||
|
'step_type': 'search_step',
|
||||||
|
'run_time': self.context['run_time'],
|
||||||
|
'problem_text': step_args.get(
|
||||||
|
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
|
||||||
|
'model_name': EmbeddingModel.get_embedding_model().model_name,
|
||||||
|
'message_tokens': 0,
|
||||||
|
'answer_tokens': 0,
|
||||||
|
'cost': 0
|
||||||
|
}
|
||||||
@ -0,0 +1,55 @@
|
|||||||
|
# Generated by Django 4.1.10 on 2024-01-12 18:46
|
||||||
|
|
||||||
|
import django.contrib.postgres.fields
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('application', '0002_alter_chatrecord_dataset'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='chatrecord',
|
||||||
|
name='dataset',
|
||||||
|
),
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='chatrecord',
|
||||||
|
name='paragraph',
|
||||||
|
),
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='chatrecord',
|
||||||
|
name='source_id',
|
||||||
|
),
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='chatrecord',
|
||||||
|
name='source_type',
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='chatrecord',
|
||||||
|
name='const',
|
||||||
|
field=models.IntegerField(default=0, verbose_name='总费用'),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='chatrecord',
|
||||||
|
name='details',
|
||||||
|
field=models.JSONField(default=list, verbose_name='对话详情'),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='chatrecord',
|
||||||
|
name='paragraph_id_list',
|
||||||
|
field=django.contrib.postgres.fields.ArrayField(base_field=models.UUIDField(blank=True), default=list, size=None, verbose_name='引用段落id列表'),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='chatrecord',
|
||||||
|
name='run_time',
|
||||||
|
field=models.FloatField(default=0, verbose_name='运行时长'),
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name='chatrecord',
|
||||||
|
name='answer_text',
|
||||||
|
field=models.CharField(max_length=4096, verbose_name='答案'),
|
||||||
|
),
|
||||||
|
]
|
||||||
@ -0,0 +1,38 @@
|
|||||||
|
# Generated by Django 4.1.10 on 2024-01-15 16:07
|
||||||
|
|
||||||
|
import application.models.application
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('application', '0003_remove_chatrecord_dataset_and_more'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='application',
|
||||||
|
name='example',
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='application',
|
||||||
|
name='dataset_setting',
|
||||||
|
field=models.JSONField(default=application.models.application.get_dataset_setting_dict, verbose_name='数据集参数设置'),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='application',
|
||||||
|
name='model_setting',
|
||||||
|
field=models.JSONField(default=application.models.application.get_model_setting_dict, verbose_name='模型参数相关设置'),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='application',
|
||||||
|
name='problem_optimization',
|
||||||
|
field=models.BooleanField(default=False, verbose_name='问题优化'),
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name='chatrecord',
|
||||||
|
name='details',
|
||||||
|
field=models.JSONField(default={}, verbose_name='对话详情'),
|
||||||
|
),
|
||||||
|
]
|
||||||
18
apps/application/migrations/0005_alter_chatrecord_details.py
Normal file
18
apps/application/migrations/0005_alter_chatrecord_details.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# Generated by Django 4.1.10 on 2024-01-16 11:22
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('application', '0004_remove_application_example_and_more'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name='chatrecord',
|
||||||
|
name='details',
|
||||||
|
field=models.JSONField(default=dict, verbose_name='对话详情'),
|
||||||
|
),
|
||||||
|
]
|
||||||
@ -10,23 +10,47 @@ import uuid
|
|||||||
|
|
||||||
from django.contrib.postgres.fields import ArrayField
|
from django.contrib.postgres.fields import ArrayField
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
from langchain.schema import HumanMessage, AIMessage
|
||||||
|
|
||||||
from common.mixins.app_model_mixin import AppModelMixin
|
from common.mixins.app_model_mixin import AppModelMixin
|
||||||
from dataset.models.data_set import DataSet, Paragraph
|
from dataset.models.data_set import DataSet
|
||||||
from embedding.models import SourceType
|
|
||||||
from setting.models.model_management import Model
|
from setting.models.model_management import Model
|
||||||
from users.models import User
|
from users.models import User
|
||||||
|
|
||||||
|
|
||||||
|
def get_dataset_setting_dict():
|
||||||
|
return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_setting_dict():
|
||||||
|
return {'prompt': Application.get_default_model_prompt()}
|
||||||
|
|
||||||
|
|
||||||
class Application(AppModelMixin):
|
class Application(AppModelMixin):
|
||||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
|
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
|
||||||
name = models.CharField(max_length=128, verbose_name="应用名称")
|
name = models.CharField(max_length=128, verbose_name="应用名称")
|
||||||
desc = models.CharField(max_length=128, verbose_name="引用描述", default="")
|
desc = models.CharField(max_length=128, verbose_name="引用描述", default="")
|
||||||
prologue = models.CharField(max_length=1024, verbose_name="开场白", default="")
|
prologue = models.CharField(max_length=1024, verbose_name="开场白", default="")
|
||||||
example = ArrayField(verbose_name="示例列表", base_field=models.CharField(max_length=256, blank=True), default=list)
|
|
||||||
dialogue_number = models.IntegerField(default=0, verbose_name="会话数量")
|
dialogue_number = models.IntegerField(default=0, verbose_name="会话数量")
|
||||||
user = models.ForeignKey(User, on_delete=models.DO_NOTHING)
|
user = models.ForeignKey(User, on_delete=models.DO_NOTHING)
|
||||||
model = models.ForeignKey(Model, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
|
model = models.ForeignKey(Model, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
|
||||||
|
dataset_setting = models.JSONField(verbose_name="数据集参数设置", default=get_dataset_setting_dict)
|
||||||
|
model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict)
|
||||||
|
problem_optimization = models.BooleanField(verbose_name="问题优化", default=False)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_default_model_prompt():
|
||||||
|
return ('已知信息:'
|
||||||
|
'\n{data}'
|
||||||
|
'\n回答要求:'
|
||||||
|
'\n- 如果你不知道答案或者没有从获取答案,请回答“没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作”。'
|
||||||
|
'\n- 避免提及你是从<data></data>中获得的知识。'
|
||||||
|
'\n- 请保持答案与<data></data>中描述的一致。'
|
||||||
|
'\n- 请使用markdown 语法优化答案的格式。'
|
||||||
|
'\n- <data></data>中的图片链接、链接地址和脚本语言请完整返回。'
|
||||||
|
'\n- 请使用与问题相同的语言来回答。'
|
||||||
|
'\n问题:'
|
||||||
|
'\n{question}')
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
db_table = "application"
|
db_table = "application"
|
||||||
@ -65,20 +89,28 @@ class ChatRecord(AppModelMixin):
|
|||||||
chat = models.ForeignKey(Chat, on_delete=models.CASCADE)
|
chat = models.ForeignKey(Chat, on_delete=models.CASCADE)
|
||||||
vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices,
|
vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices,
|
||||||
default=VoteChoices.UN_VOTE)
|
default=VoteChoices.UN_VOTE)
|
||||||
dataset = models.ForeignKey(DataSet, on_delete=models.SET_NULL, verbose_name="数据集", blank=True, null=True)
|
paragraph_id_list = ArrayField(verbose_name="引用段落id列表",
|
||||||
paragraph = models.ForeignKey(Paragraph, on_delete=models.SET_NULL, verbose_name="段落id", blank=True, null=True)
|
base_field=models.UUIDField(max_length=128, blank=True)
|
||||||
source_id = models.UUIDField(max_length=128, verbose_name="资源id 段落/问题 id ", null=True)
|
, default=list)
|
||||||
source_type = models.CharField(verbose_name='资源类型', max_length=2, choices=SourceType.choices,
|
problem_text = models.CharField(max_length=1024, verbose_name="问题")
|
||||||
default=SourceType.PROBLEM, blank=True, null=True)
|
answer_text = models.CharField(max_length=4096, verbose_name="答案")
|
||||||
message_tokens = models.IntegerField(verbose_name="请求token数量", default=0)
|
message_tokens = models.IntegerField(verbose_name="请求token数量", default=0)
|
||||||
answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0)
|
answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0)
|
||||||
problem_text = models.CharField(max_length=1024, verbose_name="问题")
|
const = models.IntegerField(verbose_name="总费用", default=0)
|
||||||
answer_text = models.CharField(max_length=1024, verbose_name="答案")
|
details = models.JSONField(verbose_name="对话详情", default=dict)
|
||||||
improve_paragraph_id_list = ArrayField(verbose_name="改进标注列表",
|
improve_paragraph_id_list = ArrayField(verbose_name="改进标注列表",
|
||||||
base_field=models.UUIDField(max_length=128, blank=True)
|
base_field=models.UUIDField(max_length=128, blank=True)
|
||||||
, default=list)
|
, default=list)
|
||||||
|
run_time = models.FloatField(verbose_name="运行时长", default=0)
|
||||||
index = models.IntegerField(verbose_name="对话下标")
|
index = models.IntegerField(verbose_name="对话下标")
|
||||||
|
|
||||||
|
def get_human_message(self):
|
||||||
|
if 'problem_padding' in self.details:
|
||||||
|
return HumanMessage(content=self.details.get('problem_padding').get('padding_problem_text'))
|
||||||
|
return HumanMessage(content=self.problem_text)
|
||||||
|
|
||||||
|
def get_ai_message(self):
|
||||||
|
return AIMessage(content=self.answer_text)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
db_table = "application_chat_record"
|
db_table = "application_chat_record"
|
||||||
|
|||||||
@ -63,15 +63,35 @@ class ApplicationSerializerModel(serializers.ModelSerializer):
|
|||||||
fields = "__all__"
|
fields = "__all__"
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetSettingSerializer(serializers.Serializer):
|
||||||
|
top_n = serializers.FloatField(required=True)
|
||||||
|
similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
|
||||||
|
max_paragraph_char_number = serializers.IntegerField(required=True, max_value=10000)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSettingSerializer(serializers.Serializer):
|
||||||
|
prompt = serializers.CharField(required=True, max_length=4096)
|
||||||
|
|
||||||
|
|
||||||
class ApplicationSerializer(serializers.Serializer):
|
class ApplicationSerializer(serializers.Serializer):
|
||||||
name = serializers.CharField(required=True)
|
name = serializers.CharField(required=True)
|
||||||
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True)
|
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True)
|
||||||
model_id = serializers.CharField(required=True)
|
model_id = serializers.CharField(required=True)
|
||||||
multiple_rounds_dialogue = serializers.BooleanField(required=True)
|
multiple_rounds_dialogue = serializers.BooleanField(required=True)
|
||||||
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True)
|
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True)
|
||||||
example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True), allow_null=True)
|
|
||||||
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
|
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
|
||||||
allow_null=True)
|
allow_null=True)
|
||||||
|
# 数据集相关设置
|
||||||
|
dataset_setting = DatasetSettingSerializer(required=True)
|
||||||
|
# 模型相关设置
|
||||||
|
model_setting = ModelSettingSerializer(required=True)
|
||||||
|
# 问题补全
|
||||||
|
problem_optimization = serializers.BooleanField(required=True)
|
||||||
|
|
||||||
|
def is_valid(self, *, user_id=None, raise_exception=False):
|
||||||
|
super().is_valid(raise_exception=True)
|
||||||
|
ModelDatasetAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'),
|
||||||
|
'dataset_id_list': self.data.get('dataset_id_list')}).is_valid()
|
||||||
|
|
||||||
class AccessTokenSerializer(serializers.Serializer):
|
class AccessTokenSerializer(serializers.Serializer):
|
||||||
application_id = serializers.UUIDField(required=True)
|
application_id = serializers.UUIDField(required=True)
|
||||||
@ -135,13 +155,13 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
model_id = serializers.CharField(required=False)
|
model_id = serializers.CharField(required=False)
|
||||||
multiple_rounds_dialogue = serializers.BooleanField(required=False)
|
multiple_rounds_dialogue = serializers.BooleanField(required=False)
|
||||||
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True)
|
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True)
|
||||||
example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True))
|
|
||||||
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
|
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
|
||||||
|
# 数据集相关设置
|
||||||
def is_valid(self, *, user_id=None, raise_exception=False):
|
dataset_setting = serializers.JSONField(required=False, allow_null=True)
|
||||||
super().is_valid(raise_exception=True)
|
# 模型相关设置
|
||||||
ModelDatasetAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'),
|
model_setting = serializers.JSONField(required=False, allow_null=True)
|
||||||
'dataset_id_list': self.data.get('dataset_id_list')}).is_valid()
|
# 问题补全
|
||||||
|
problem_optimization = serializers.BooleanField(required=False, allow_null=True)
|
||||||
|
|
||||||
class Create(serializers.Serializer):
|
class Create(serializers.Serializer):
|
||||||
user_id = serializers.UUIDField(required=True)
|
user_id = serializers.UUIDField(required=True)
|
||||||
@ -168,9 +188,12 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def to_application_model(user_id: str, application: Dict):
|
def to_application_model(user_id: str, application: Dict):
|
||||||
return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'),
|
return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'),
|
||||||
prologue=application.get('prologue'), example=application.get('example'),
|
prologue=application.get('prologue'),
|
||||||
dialogue_number=3 if application.get('multiple_rounds_dialogue') else 0,
|
dialogue_number=3 if application.get('multiple_rounds_dialogue') else 0,
|
||||||
user_id=user_id, model_id=application.get('model_id'),
|
user_id=user_id, model_id=application.get('model_id'),
|
||||||
|
dataset_setting=application.get('dataset_setting'),
|
||||||
|
model_setting=application.get('model_setting'),
|
||||||
|
problem_optimization=application.get('problem_optimization')
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -267,7 +290,7 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
class ApplicationModel(serializers.ModelSerializer):
|
class ApplicationModel(serializers.ModelSerializer):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Application
|
model = Application
|
||||||
fields = ['id', 'name', 'desc', 'prologue', 'example', 'dialogue_number']
|
fields = ['id', 'name', 'desc', 'prologue', 'dialogue_number']
|
||||||
|
|
||||||
class Operate(serializers.Serializer):
|
class Operate(serializers.Serializer):
|
||||||
application_id = serializers.UUIDField(required=True)
|
application_id = serializers.UUIDField(required=True)
|
||||||
@ -317,8 +340,9 @@ class ApplicationSerializer(serializers.Serializer):
|
|||||||
|
|
||||||
model = QuerySet(Model).get(id=instance.get('model_id') if 'model_id' in instance else application.model_id)
|
model = QuerySet(Model).get(id=instance.get('model_id') if 'model_id' in instance else application.model_id)
|
||||||
|
|
||||||
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'example', 'status',
|
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
|
||||||
'api_key_is_active']
|
'dataset_setting', 'model_setting', 'problem_optimization'
|
||||||
|
'api_key_is_active']
|
||||||
for update_key in update_keys:
|
for update_key in update_keys:
|
||||||
if update_key in instance and instance.get(update_key) is not None:
|
if update_key in instance and instance.get(update_key) is not None:
|
||||||
if update_key == 'multiple_rounds_dialogue':
|
if update_key == 'multiple_rounds_dialogue':
|
||||||
|
|||||||
@ -7,194 +7,181 @@
|
|||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
import json
|
import json
|
||||||
import uuid
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from django.db.models import QuerySet
|
|
||||||
from django.http import StreamingHttpResponse
|
|
||||||
from langchain.chat_models.base import BaseChatModel
|
|
||||||
from langchain.schema import HumanMessage
|
|
||||||
from rest_framework import serializers, status
|
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from common import event
|
from django.db.models import QuerySet
|
||||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from common.response import result
|
from rest_framework import serializers
|
||||||
from dataset.models import Paragraph
|
|
||||||
from embedding.models import SourceType
|
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||||
from setting.models.model_management import Model
|
from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler
|
||||||
|
from application.chat_pipeline.step.chat_step.impl.base_chat_step import BaseChatStep
|
||||||
|
from application.chat_pipeline.step.generate_human_message_step.impl.base_generate_human_message_step import \
|
||||||
|
BaseGenerateHumanMessageStep
|
||||||
|
from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
|
||||||
|
from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
|
||||||
|
from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping
|
||||||
|
from common.exception.app_exception import AppApiException
|
||||||
|
from common.util.rsa_util import decrypt
|
||||||
|
from common.util.split_model import flat_map
|
||||||
|
from dataset.models import Paragraph, Document
|
||||||
|
from setting.models import Model
|
||||||
|
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||||
|
|
||||||
chat_cache = cache
|
chat_cache = cache
|
||||||
|
|
||||||
|
|
||||||
class MessageManagement:
|
|
||||||
@staticmethod
|
|
||||||
def get_message(title: str, content: str, message: str):
|
|
||||||
if content is None:
|
|
||||||
return HumanMessage(content=message)
|
|
||||||
return HumanMessage(content=(
|
|
||||||
f'已知信息:{title}:{content} '
|
|
||||||
'根据上述已知信息,请简洁和专业的来回答用户的问题。已知信息中的图片、链接地址和脚本语言请直接返回。如果无法从已知信息中得到答案,请说 “没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作” 或 “根据已知信息无法回答该问题,建议联系官方技术支持人员”,不允许在答案中添加编造成分,答案请使用中文。'
|
|
||||||
f'问题是:{message}'))
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage:
|
|
||||||
def __init__(self, id: str, problem: str, title: str, paragraph: str, embedding_id: str, dataset_id: str,
|
|
||||||
document_id: str,
|
|
||||||
paragraph_id,
|
|
||||||
source_type: SourceType,
|
|
||||||
source_id: str,
|
|
||||||
answer: str,
|
|
||||||
message_tokens: int,
|
|
||||||
answer_token: int,
|
|
||||||
chat_model=None,
|
|
||||||
chat_message=None):
|
|
||||||
self.id = id
|
|
||||||
self.problem = problem
|
|
||||||
self.title = title
|
|
||||||
self.paragraph = paragraph
|
|
||||||
self.embedding_id = embedding_id
|
|
||||||
self.dataset_id = dataset_id
|
|
||||||
self.document_id = document_id
|
|
||||||
self.paragraph_id = paragraph_id
|
|
||||||
self.source_type = source_type
|
|
||||||
self.source_id = source_id
|
|
||||||
self.answer = answer
|
|
||||||
self.message_tokens = message_tokens
|
|
||||||
self.answer_token = answer_token
|
|
||||||
self.chat_model = chat_model
|
|
||||||
self.chat_message = chat_message
|
|
||||||
|
|
||||||
def get_chat_message(self):
|
|
||||||
return MessageManagement.get_message(self.problem, self.paragraph, self.problem)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatInfo:
|
class ChatInfo:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
chat_id: str,
|
chat_id: str,
|
||||||
model: Model,
|
|
||||||
chat_model: BaseChatModel,
|
chat_model: BaseChatModel,
|
||||||
application_id: str | None,
|
|
||||||
dataset_id_list: List[str],
|
dataset_id_list: List[str],
|
||||||
exclude_document_id_list: list[str],
|
exclude_document_id_list: list[str],
|
||||||
dialogue_number: int):
|
application: Application):
|
||||||
|
"""
|
||||||
|
:param chat_id: 对话id
|
||||||
|
:param chat_model: 对话模型
|
||||||
|
:param dataset_id_list: 数据集列表
|
||||||
|
:param exclude_document_id_list: 排除的文档
|
||||||
|
:param application: 应用信息
|
||||||
|
"""
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
self.application_id = application_id
|
self.application = application
|
||||||
self.model = model
|
|
||||||
self.chat_model = chat_model
|
self.chat_model = chat_model
|
||||||
self.dataset_id_list = dataset_id_list
|
self.dataset_id_list = dataset_id_list
|
||||||
self.exclude_document_id_list = exclude_document_id_list
|
self.exclude_document_id_list = exclude_document_id_list
|
||||||
self.dialogue_number = dialogue_number
|
self.chat_record_list: List[ChatRecord] = []
|
||||||
self.chat_message_list: List[ChatMessage] = []
|
|
||||||
|
|
||||||
def append_chat_message(self, chat_message: ChatMessage):
|
def to_base_pipeline_manage_params(self):
|
||||||
self.chat_message_list.append(chat_message)
|
dataset_setting = self.application.dataset_setting
|
||||||
if self.application_id is not None:
|
model_setting = self.application.model_setting
|
||||||
|
return {
|
||||||
|
'dataset_id_list': self.dataset_id_list,
|
||||||
|
'exclude_document_id_list': self.exclude_document_id_list,
|
||||||
|
'exclude_paragraph_id_list': [],
|
||||||
|
'top_n': dataset_setting.get('top_n') if 'top_n' in dataset_setting else 3,
|
||||||
|
'similarity': dataset_setting.get('similarity') if 'similarity' in dataset_setting else 0.6,
|
||||||
|
'max_paragraph_char_number': dataset_setting.get(
|
||||||
|
'max_paragraph_char_number') if 'max_paragraph_char_number' in dataset_setting else 5000,
|
||||||
|
'history_chat_record': self.chat_record_list,
|
||||||
|
'chat_id': self.chat_id,
|
||||||
|
'dialogue_number': self.application.dialogue_number,
|
||||||
|
'prompt': model_setting.get(
|
||||||
|
'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(),
|
||||||
|
'chat_model': self.chat_model,
|
||||||
|
'model_id': self.application.model.id if self.application.model is not None else None,
|
||||||
|
'problem_optimization': self.application.problem_optimization
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
|
||||||
|
exclude_paragraph_id_list):
|
||||||
|
params = self.to_base_pipeline_manage_params()
|
||||||
|
return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
|
||||||
|
'exclude_paragraph_id_list': exclude_paragraph_id_list}
|
||||||
|
|
||||||
|
def append_chat_record(self, chat_record: ChatRecord):
|
||||||
|
# 存入缓存中
|
||||||
|
self.chat_record_list.append(chat_record)
|
||||||
|
if self.application.id is not None:
|
||||||
# 插入数据库
|
# 插入数据库
|
||||||
event.ListenerChatMessage.record_chat_message_signal.send(
|
if not QuerySet(Chat).filter(id=self.chat_id).exists():
|
||||||
event.RecordChatMessageArgs(len(self.chat_message_list) - 1, self.chat_id, self.application_id,
|
Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text).save()
|
||||||
chat_message)
|
# 插入会话记录
|
||||||
)
|
chat_record.save()
|
||||||
# 异步更新token
|
|
||||||
event.ListenerChatMessage.update_chat_message_token_signal.send(chat_message)
|
|
||||||
|
|
||||||
def get_context_message(self):
|
|
||||||
start_index = len(self.chat_message_list) - self.dialogue_number
|
def get_post_handler(chat_info: ChatInfo):
|
||||||
return [self.chat_message_list[index].get_chat_message() for index in
|
class PostHandler(PostResponseHandler):
|
||||||
range(start_index if start_index > 0 else 0, len(self.chat_message_list))]
|
|
||||||
|
def handler(self,
|
||||||
|
chat_id: UUID,
|
||||||
|
chat_record_id,
|
||||||
|
paragraph_list: List[Paragraph],
|
||||||
|
problem_text: str,
|
||||||
|
answer_text,
|
||||||
|
manage: PiplineManage,
|
||||||
|
step: BaseChatStep,
|
||||||
|
padding_problem_text: str = None,
|
||||||
|
**kwargs):
|
||||||
|
chat_record = ChatRecord(id=chat_record_id,
|
||||||
|
chat_id=chat_id,
|
||||||
|
paragraph_id_list=[str(p.id) for p in paragraph_list],
|
||||||
|
problem_text=problem_text,
|
||||||
|
answer_text=answer_text,
|
||||||
|
details=manage.get_details(),
|
||||||
|
message_tokens=manage.context['message_tokens'],
|
||||||
|
answer_tokens=manage.context['answer_tokens'],
|
||||||
|
run_time=manage.context['run_time'],
|
||||||
|
index=len(chat_info.chat_record_list) + 1)
|
||||||
|
chat_info.append_chat_record(chat_record)
|
||||||
|
# 重新设置缓存
|
||||||
|
chat_cache.set(chat_id,
|
||||||
|
chat_info, timeout=60 * 30)
|
||||||
|
|
||||||
|
return PostHandler()
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageSerializer(serializers.Serializer):
|
class ChatMessageSerializer(serializers.Serializer):
|
||||||
chat_id = serializers.UUIDField(required=True)
|
chat_id = serializers.UUIDField(required=True)
|
||||||
|
|
||||||
def chat(self, message):
|
def chat(self, message, re_chat: bool):
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
chat_id = self.data.get('chat_id')
|
chat_id = self.data.get('chat_id')
|
||||||
chat_info: ChatInfo = chat_cache.get(chat_id)
|
chat_info: ChatInfo = chat_cache.get(chat_id)
|
||||||
if chat_info is None:
|
if chat_info is None:
|
||||||
return result.Result(response_status=status.HTTP_404_NOT_FOUND, code=404, message="会话过期")
|
chat_info = self.re_open_chat(chat_id)
|
||||||
|
chat_cache.set(chat_id,
|
||||||
|
chat_info, timeout=60 * 30)
|
||||||
|
|
||||||
chat_model = chat_info.chat_model
|
pipline_manage_builder = PiplineManage.builder()
|
||||||
vector = VectorStore.get_embedding_vector()
|
# 如果开启了问题优化,则添加上问题优化步骤
|
||||||
# 向量库检索
|
if chat_info.application.problem_optimization:
|
||||||
_value = vector.search(message, chat_info.dataset_id_list, chat_info.exclude_document_id_list,
|
pipline_manage_builder.append_step(BaseResetProblemStep)
|
||||||
[chat_message.embedding_id for chat_message in
|
# 构建流水线管理器
|
||||||
(list(filter(lambda row: row.problem == message, chat_info.chat_message_list)))],
|
pipline_message = (pipline_manage_builder.append_step(BaseSearchDatasetStep)
|
||||||
True,
|
.append_step(BaseGenerateHumanMessageStep)
|
||||||
EmbeddingModel.get_embedding_model())
|
.append_step(BaseChatStep)
|
||||||
# 查询段落id详情
|
.build())
|
||||||
paragraph = None
|
exclude_paragraph_id_list = []
|
||||||
if _value is not None:
|
# 相同问题是否需要排除已经查询到的段落
|
||||||
paragraph = QuerySet(Paragraph).get(id=_value.get('paragraph_id'))
|
if re_chat:
|
||||||
if paragraph is None:
|
paragraph_id_list = flat_map([row.paragraph_id_list for row in
|
||||||
vector.delete_by_paragraph_id(_value.get('paragraph_id'))
|
filter(lambda chat_record: chat_record == message,
|
||||||
|
chat_info.chat_record_list)])
|
||||||
|
exclude_paragraph_id_list = list(set(paragraph_id_list))
|
||||||
|
# 构建运行参数
|
||||||
|
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list)
|
||||||
|
# 运行流水线作业
|
||||||
|
pipline_message.run(params)
|
||||||
|
return pipline_message.context['chat_result']
|
||||||
|
|
||||||
title, content = (None, None) if paragraph is None else (paragraph.title, paragraph.content)
|
@staticmethod
|
||||||
_id = str(uuid.uuid1())
|
def re_open_chat(chat_id: str):
|
||||||
|
chat = QuerySet(Chat).filter(id=chat_id).first()
|
||||||
|
if chat is None:
|
||||||
|
raise AppApiException(500, "会话不存在")
|
||||||
|
application = QuerySet(Application).filter(id=chat.application_id).first()
|
||||||
|
if application is None:
|
||||||
|
raise AppApiException(500, "应用不存在")
|
||||||
|
model = QuerySet(Model).filter(id=application.model_id).first()
|
||||||
|
chat_model = None
|
||||||
|
if model is not None:
|
||||||
|
# 对话模型
|
||||||
|
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||||
|
json.loads(
|
||||||
|
decrypt(model.credential)),
|
||||||
|
streaming=True)
|
||||||
|
# 数据集id列表
|
||||||
|
dataset_id_list = [str(row.dataset_id) for row in
|
||||||
|
QuerySet(ApplicationDatasetMapping).filter(
|
||||||
|
application_id=application.id)]
|
||||||
|
|
||||||
embedding_id, dataset_id, document_id, paragraph_id, source_type, source_id = (_value.get(
|
# 需要排除的文档
|
||||||
'id'), _value.get(
|
exclude_document_id_list = [str(document.id) for document in
|
||||||
'dataset_id'), _value.get(
|
QuerySet(Document).filter(
|
||||||
'document_id'), _value.get(
|
dataset_id__in=dataset_id_list,
|
||||||
'paragraph_id'), _value.get(
|
is_active=False)]
|
||||||
'source_type'), _value.get(
|
return ChatInfo(chat_id, chat_model, dataset_id_list, exclude_document_id_list, application)
|
||||||
'source_id')) if _value is not None else (None, None, None, None, None, None)
|
|
||||||
|
|
||||||
if chat_model is None:
|
|
||||||
def event_block_content(c: str):
|
|
||||||
yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None,
|
|
||||||
'is_end': True,
|
|
||||||
'content': c if c is not None else '抱歉,根据已知信息无法回答这个问题,请重新描述您的问题或提供更多信息~'}) + "\n\n"
|
|
||||||
chat_info.append_chat_message(
|
|
||||||
ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id,
|
|
||||||
paragraph_id,
|
|
||||||
source_type,
|
|
||||||
source_id,
|
|
||||||
c if c is not None else '抱歉,根据已知信息无法回答这个问题,请重新描述您的问题或提供更多信息~',
|
|
||||||
0,
|
|
||||||
0))
|
|
||||||
# 重新设置缓存
|
|
||||||
chat_cache.set(chat_id,
|
|
||||||
chat_info, timeout=60 * 30)
|
|
||||||
|
|
||||||
r = StreamingHttpResponse(streaming_content=event_block_content(content),
|
|
||||||
content_type='text/event-stream;charset=utf-8')
|
|
||||||
|
|
||||||
r['Cache-Control'] = 'no-cache'
|
|
||||||
return r
|
|
||||||
# 获取上下文
|
|
||||||
history_message = chat_info.get_context_message()
|
|
||||||
|
|
||||||
# 构建会话请求问题
|
|
||||||
chat_message = [*history_message, MessageManagement.get_message(title, content, message)]
|
|
||||||
# 对话
|
|
||||||
result_data = chat_model.stream(chat_message)
|
|
||||||
|
|
||||||
def event_content(response):
|
|
||||||
all_text = ''
|
|
||||||
try:
|
|
||||||
for chunk in response:
|
|
||||||
all_text += chunk.content
|
|
||||||
yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None,
|
|
||||||
'content': chunk.content, 'is_end': False}) + "\n\n"
|
|
||||||
|
|
||||||
yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None,
|
|
||||||
'content': '', 'is_end': True}) + "\n\n"
|
|
||||||
chat_info.append_chat_message(
|
|
||||||
ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id,
|
|
||||||
paragraph_id,
|
|
||||||
source_type,
|
|
||||||
source_id, all_text,
|
|
||||||
0,
|
|
||||||
0,
|
|
||||||
chat_message=chat_message, chat_model=chat_model))
|
|
||||||
# 重新设置缓存
|
|
||||||
chat_cache.set(chat_id,
|
|
||||||
chat_info, timeout=60 * 30)
|
|
||||||
except Exception as e:
|
|
||||||
yield e
|
|
||||||
|
|
||||||
r = StreamingHttpResponse(streaming_content=event_content(result_data),
|
|
||||||
content_type='text/event-stream;charset=utf-8')
|
|
||||||
|
|
||||||
r['Cache-Control'] = 'no-cache'
|
|
||||||
return r
|
|
||||||
|
|||||||
@ -10,7 +10,8 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict
|
from functools import reduce
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from django.db import transaction
|
from django.db import transaction
|
||||||
@ -18,7 +19,8 @@ from django.db.models import QuerySet
|
|||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
|
||||||
from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord
|
from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord
|
||||||
from application.serializers.application_serializers import ModelDatasetAssociation
|
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
|
||||||
|
ModelSettingSerializer
|
||||||
from application.serializers.chat_message_serializers import ChatInfo
|
from application.serializers.chat_message_serializers import ChatInfo
|
||||||
from common.db.search import native_search, native_page_search, page_search
|
from common.db.search import native_search, native_page_search, page_search
|
||||||
from common.event import ListenerManagement
|
from common.event import ListenerManagement
|
||||||
@ -26,8 +28,8 @@ from common.exception.app_exception import AppApiException
|
|||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from common.util.lock import try_lock, un_lock
|
from common.util.lock import try_lock, un_lock
|
||||||
from common.util.rsa_util import decrypt
|
from common.util.rsa_util import decrypt
|
||||||
|
from common.util.split_model import flat_map
|
||||||
from dataset.models import Document, Problem, Paragraph
|
from dataset.models import Document, Problem, Paragraph
|
||||||
from embedding.models import SourceType, Embedding
|
|
||||||
from setting.models import Model
|
from setting.models import Model
|
||||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||||
from smartdoc.conf import PROJECT_DIR
|
from smartdoc.conf import PROJECT_DIR
|
||||||
@ -106,12 +108,12 @@ class ChatSerializers(serializers.Serializer):
|
|||||||
|
|
||||||
chat_id = str(uuid.uuid1())
|
chat_id = str(uuid.uuid1())
|
||||||
chat_cache.set(chat_id,
|
chat_cache.set(chat_id,
|
||||||
ChatInfo(chat_id, model, chat_model, application_id, dataset_id_list,
|
ChatInfo(chat_id, chat_model, dataset_id_list,
|
||||||
[str(document.id) for document in
|
[str(document.id) for document in
|
||||||
QuerySet(Document).filter(
|
QuerySet(Document).filter(
|
||||||
dataset_id__in=dataset_id_list,
|
dataset_id__in=dataset_id_list,
|
||||||
is_active=False)],
|
is_active=False)],
|
||||||
application.dialogue_number), timeout=60 * 30)
|
application), timeout=60 * 30)
|
||||||
return chat_id
|
return chat_id
|
||||||
|
|
||||||
class OpenTempChat(serializers.Serializer):
|
class OpenTempChat(serializers.Serializer):
|
||||||
@ -122,6 +124,12 @@ class ChatSerializers(serializers.Serializer):
|
|||||||
multiple_rounds_dialogue = serializers.BooleanField(required=True)
|
multiple_rounds_dialogue = serializers.BooleanField(required=True)
|
||||||
|
|
||||||
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
|
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
|
||||||
|
# 数据集相关设置
|
||||||
|
dataset_setting = DatasetSettingSerializer(required=True)
|
||||||
|
# 模型相关设置
|
||||||
|
model_setting = ModelSettingSerializer(required=True)
|
||||||
|
# 问题补全
|
||||||
|
problem_optimization = serializers.BooleanField(required=True)
|
||||||
|
|
||||||
def is_valid(self, *, raise_exception=False):
|
def is_valid(self, *, raise_exception=False):
|
||||||
super().is_valid(raise_exception=True)
|
super().is_valid(raise_exception=True)
|
||||||
@ -140,42 +148,62 @@ class ChatSerializers(serializers.Serializer):
|
|||||||
json.loads(
|
json.loads(
|
||||||
decrypt(model.credential)),
|
decrypt(model.credential)),
|
||||||
streaming=True)
|
streaming=True)
|
||||||
|
application = Application(id=None, dialogue_number=3, model=model,
|
||||||
|
dataset_setting=self.data.get('dataset_setting'),
|
||||||
|
model_setting=self.data.get('model_setting'),
|
||||||
|
problem_optimization=self.data.get('problem_optimization'))
|
||||||
chat_cache.set(chat_id,
|
chat_cache.set(chat_id,
|
||||||
ChatInfo(chat_id, model, chat_model, None, dataset_id_list,
|
ChatInfo(chat_id, chat_model, dataset_id_list,
|
||||||
[str(document.id) for document in
|
[str(document.id) for document in
|
||||||
QuerySet(Document).filter(
|
QuerySet(Document).filter(
|
||||||
dataset_id__in=dataset_id_list,
|
dataset_id__in=dataset_id_list,
|
||||||
is_active=False)],
|
is_active=False)],
|
||||||
3 if self.data.get('multiple_rounds_dialogue') else 1), timeout=60 * 30)
|
application), timeout=60 * 30)
|
||||||
return chat_id
|
return chat_id
|
||||||
|
|
||||||
|
|
||||||
def vote_exec(source_type: SourceType, source_id: str, field: str, post_handler):
|
|
||||||
if source_type == SourceType.PROBLEM:
|
|
||||||
problem = QuerySet(Problem).get(id=source_id)
|
|
||||||
if problem is not None:
|
|
||||||
problem.__setattr__(field, post_handler(problem))
|
|
||||||
problem.save()
|
|
||||||
embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type)
|
|
||||||
embedding.__setattr__(field, problem.__getattribute__(field))
|
|
||||||
embedding.save()
|
|
||||||
if source_type == SourceType.PARAGRAPH:
|
|
||||||
paragraph = QuerySet(Paragraph).get(id=source_id)
|
|
||||||
if paragraph is not None:
|
|
||||||
paragraph.__setattr__(field, post_handler(paragraph))
|
|
||||||
paragraph.save()
|
|
||||||
embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type)
|
|
||||||
embedding.__setattr__(field, paragraph.__getattribute__(field))
|
|
||||||
embedding.save()
|
|
||||||
|
|
||||||
|
|
||||||
class ChatRecordSerializerModel(serializers.ModelSerializer):
|
class ChatRecordSerializerModel(serializers.ModelSerializer):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = ChatRecord
|
model = ChatRecord
|
||||||
fields = "__all__"
|
fields = ['id', 'chat_id', 'vote_status', 'problem_text', 'answer_text',
|
||||||
|
'message_tokens', 'answer_tokens', 'const', 'improve_paragraph_id_list', 'run_time', 'index']
|
||||||
|
|
||||||
|
|
||||||
class ChatRecordSerializer(serializers.Serializer):
|
class ChatRecordSerializer(serializers.Serializer):
|
||||||
|
class Operate(serializers.Serializer):
|
||||||
|
chat_id = serializers.UUIDField(required=True)
|
||||||
|
|
||||||
|
chat_record_id = serializers.UUIDField(required=True)
|
||||||
|
|
||||||
|
def one(self, with_valid=True):
|
||||||
|
if with_valid:
|
||||||
|
self.is_valid(raise_exception=True)
|
||||||
|
chat_record_id = self.data.get('chat_record_id')
|
||||||
|
chat_id = self.data.get('chat_id')
|
||||||
|
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
|
||||||
|
dataset_list = []
|
||||||
|
paragraph_list = []
|
||||||
|
if len(chat_record.paragraph_id_list) > 0:
|
||||||
|
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=chat_record.paragraph_id_list),
|
||||||
|
get_file_content(
|
||||||
|
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
|
||||||
|
'list_dataset_paragraph_by_paragraph_id.sql')),
|
||||||
|
with_table_name=True)
|
||||||
|
dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
|
||||||
|
[{row.get(
|
||||||
|
'dataset_id'): row.get(
|
||||||
|
"dataset_name")} for
|
||||||
|
row in
|
||||||
|
paragraph_list],
|
||||||
|
{}).items()]
|
||||||
|
|
||||||
|
return {
|
||||||
|
**ChatRecordSerializerModel(chat_record).data,
|
||||||
|
'padding_problem_text': chat_record.details.get(
|
||||||
|
'padding_problem_text') if 'problem_padding' in chat_record.details else None,
|
||||||
|
'dataset_list': dataset_list,
|
||||||
|
'paragraph_list': paragraph_list}
|
||||||
|
|
||||||
class Query(serializers.Serializer):
|
class Query(serializers.Serializer):
|
||||||
application_id = serializers.UUIDField(required=True)
|
application_id = serializers.UUIDField(required=True)
|
||||||
chat_id = serializers.UUIDField(required=True)
|
chat_id = serializers.UUIDField(required=True)
|
||||||
@ -183,15 +211,57 @@ class ChatRecordSerializer(serializers.Serializer):
|
|||||||
def list(self, with_valid=True):
|
def list(self, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
|
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))
|
||||||
return [ChatRecordSerializerModel(chat_record).data for chat_record in
|
return [ChatRecordSerializerModel(chat_record).data for chat_record in
|
||||||
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))]
|
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))]
|
||||||
|
|
||||||
|
def reset_chat_record_list(self, chat_record_list: List[ChatRecord]):
|
||||||
|
paragraph_id_list = flat_map([chat_record.paragraph_id_list for chat_record in chat_record_list])
|
||||||
|
# 去重
|
||||||
|
paragraph_id_list = list(set(paragraph_id_list))
|
||||||
|
paragraph_list = self.search_paragraph(paragraph_id_list)
|
||||||
|
return [self.reset_chat_record(chat_record, paragraph_list) for chat_record in chat_record_list]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def search_paragraph(paragraph_id_list: List[str]):
|
||||||
|
paragraph_list = []
|
||||||
|
if len(paragraph_id_list) > 0:
|
||||||
|
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
|
||||||
|
get_file_content(
|
||||||
|
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
|
||||||
|
'list_dataset_paragraph_by_paragraph_id.sql')),
|
||||||
|
with_table_name=True)
|
||||||
|
return paragraph_list
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reset_chat_record(chat_record, all_paragraph_list):
|
||||||
|
paragraph_list = list(
|
||||||
|
filter(lambda paragraph: chat_record.paragraph_id_list.__contains__(str(paragraph.get('id'))),
|
||||||
|
all_paragraph_list))
|
||||||
|
dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
|
||||||
|
[{row.get(
|
||||||
|
'dataset_id'): row.get(
|
||||||
|
"dataset_name")} for
|
||||||
|
row in
|
||||||
|
paragraph_list],
|
||||||
|
{}).items()]
|
||||||
|
return {
|
||||||
|
**ChatRecordSerializerModel(chat_record).data,
|
||||||
|
'padding_problem_text': chat_record.details.get('problem_padding').get(
|
||||||
|
'padding_problem_text') if 'problem_padding' in chat_record.details else None,
|
||||||
|
'dataset_list': dataset_list,
|
||||||
|
'paragraph_list': paragraph_list
|
||||||
|
}
|
||||||
|
|
||||||
def page(self, current_page: int, page_size: int, with_valid=True):
|
def page(self, current_page: int, page_size: int, with_valid=True):
|
||||||
if with_valid:
|
if with_valid:
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
return page_search(current_page, page_size,
|
page = page_search(current_page, page_size,
|
||||||
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by("index"),
|
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by("index"),
|
||||||
post_records_handler=lambda chat_record: ChatRecordSerializerModel(chat_record).data)
|
post_records_handler=lambda chat_record: chat_record)
|
||||||
|
records = page.get('records')
|
||||||
|
page['records'] = self.reset_chat_record_list(records)
|
||||||
|
return page
|
||||||
|
|
||||||
class Vote(serializers.Serializer):
|
class Vote(serializers.Serializer):
|
||||||
chat_id = serializers.UUIDField(required=True)
|
chat_id = serializers.UUIDField(required=True)
|
||||||
@ -216,38 +286,20 @@ class ChatRecordSerializer(serializers.Serializer):
|
|||||||
if vote_status == VoteChoices.STAR:
|
if vote_status == VoteChoices.STAR:
|
||||||
# 点赞
|
# 点赞
|
||||||
chat_record_details_model.vote_status = VoteChoices.STAR
|
chat_record_details_model.vote_status = VoteChoices.STAR
|
||||||
# 点赞数量 +1
|
|
||||||
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
|
|
||||||
'star_num',
|
|
||||||
lambda r: r.star_num + 1)
|
|
||||||
|
|
||||||
if vote_status == VoteChoices.TRAMPLE:
|
if vote_status == VoteChoices.TRAMPLE:
|
||||||
# 点踩
|
# 点踩
|
||||||
chat_record_details_model.vote_status = VoteChoices.TRAMPLE
|
chat_record_details_model.vote_status = VoteChoices.TRAMPLE
|
||||||
# 点踩数量+1
|
|
||||||
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
|
|
||||||
'trample_num',
|
|
||||||
lambda r: r.trample_num + 1)
|
|
||||||
chat_record_details_model.save()
|
chat_record_details_model.save()
|
||||||
else:
|
else:
|
||||||
if vote_status == VoteChoices.UN_VOTE:
|
if vote_status == VoteChoices.UN_VOTE:
|
||||||
# 取消点赞
|
# 取消点赞
|
||||||
chat_record_details_model.vote_status = VoteChoices.UN_VOTE
|
chat_record_details_model.vote_status = VoteChoices.UN_VOTE
|
||||||
chat_record_details_model.save()
|
chat_record_details_model.save()
|
||||||
if chat_record_details_model.vote_status == VoteChoices.STAR:
|
|
||||||
# 点赞数量 -1
|
|
||||||
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
|
|
||||||
'star_num', lambda r: r.star_num - 1)
|
|
||||||
if chat_record_details_model.vote_status == VoteChoices.TRAMPLE:
|
|
||||||
# 点踩数量 -1
|
|
||||||
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
|
|
||||||
'trample_num', lambda r: r.trample_num - 1)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise AppApiException(500, "已经投票过,请先取消后再进行投票")
|
raise AppApiException(500, "已经投票过,请先取消后再进行投票")
|
||||||
finally:
|
finally:
|
||||||
un_lock(self.data.get('chat_record_id'))
|
un_lock(self.data.get('chat_record_id'))
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
class ImproveSerializer(serializers.Serializer):
|
class ImproveSerializer(serializers.Serializer):
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
SELECT * FROM ( SELECT * FROM application ${application_custom_sql} UNION
|
SELECT *,to_json(dataset_setting) as dataset_setting,to_json(model_setting) as model_setting FROM ( SELECT * FROM application ${application_custom_sql} UNION
|
||||||
SELECT
|
SELECT
|
||||||
*
|
*
|
||||||
FROM
|
FROM
|
||||||
|
|||||||
@ -0,0 +1,6 @@
|
|||||||
|
SELECT
|
||||||
|
paragraph.*,
|
||||||
|
dataset."name" AS "dataset_name"
|
||||||
|
FROM
|
||||||
|
paragraph paragraph
|
||||||
|
LEFT JOIN dataset dataset ON dataset."id" = paragraph.dataset_id
|
||||||
@ -58,6 +58,7 @@ class ApplicationApi(ApiMixin):
|
|||||||
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description='创建时间'),
|
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description='创建时间'),
|
||||||
|
|
||||||
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description='修改时间'),
|
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description='修改时间'),
|
||||||
|
|
||||||
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||||
items=openapi.Schema(type=openapi.TYPE_STRING),
|
items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||||
title="关联知识库Id列表",
|
title="关联知识库Id列表",
|
||||||
@ -133,7 +134,7 @@ class ApplicationApi(ApiMixin):
|
|||||||
def get_request_body_api():
|
def get_request_body_api():
|
||||||
return openapi.Schema(
|
return openapi.Schema(
|
||||||
type=openapi.TYPE_OBJECT,
|
type=openapi.TYPE_OBJECT,
|
||||||
required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue'],
|
required=[],
|
||||||
properties={
|
properties={
|
||||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"),
|
'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"),
|
||||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"),
|
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"),
|
||||||
@ -141,11 +142,52 @@ class ApplicationApi(ApiMixin):
|
|||||||
"multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话",
|
"multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话",
|
||||||
description="是否开启多轮对话"),
|
description="是否开启多轮对话"),
|
||||||
'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"),
|
'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"),
|
||||||
'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
|
|
||||||
title="示例列表", description="示例列表"),
|
|
||||||
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||||
items=openapi.Schema(type=openapi.TYPE_STRING),
|
items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||||
title="关联知识库Id列表", description="关联知识库Id列表"),
|
title="关联知识库Id列表", description="关联知识库Id列表"),
|
||||||
|
'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(),
|
||||||
|
'model_setting': ApplicationApi.ModelSetting.get_request_body_api(),
|
||||||
|
'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化",
|
||||||
|
description="是否开启问题优化", default=True)
|
||||||
|
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
class DatasetSetting(ApiMixin):
|
||||||
|
@staticmethod
|
||||||
|
def get_request_body_api():
|
||||||
|
return openapi.Schema(
|
||||||
|
type=openapi.TYPE_OBJECT,
|
||||||
|
required=[''],
|
||||||
|
properties={
|
||||||
|
'top_n': openapi.Schema(type=openapi.TYPE_NUMBER, title="引用分段数", description="引用分段数",
|
||||||
|
default=5),
|
||||||
|
'similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title='相似度', description="相似度",
|
||||||
|
default=0.6),
|
||||||
|
'max_paragraph_char_number': openapi.Schema(type=openapi.TYPE_NUMBER, title='最多引用字符数',
|
||||||
|
description="最多引用字符数", default=3000),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
class ModelSetting(ApiMixin):
|
||||||
|
@staticmethod
|
||||||
|
def get_request_body_api():
|
||||||
|
return openapi.Schema(
|
||||||
|
type=openapi.TYPE_OBJECT,
|
||||||
|
required=['prompt'],
|
||||||
|
properties={
|
||||||
|
'prompt': openapi.Schema(type=openapi.TYPE_STRING, title="提示词", description="提示词",
|
||||||
|
default=('已知信息:'
|
||||||
|
'\n{data}'
|
||||||
|
'\n回答要求:'
|
||||||
|
'\n- 如果你不知道答案或者没有从获取答案,请回答“没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作”。'
|
||||||
|
'\n- 避免提及你是从<data></data>中获得的知识。'
|
||||||
|
'\n- 请保持答案与<data></data>中描述的一致。'
|
||||||
|
'\n- 请使用markdown 语法优化答案的格式。'
|
||||||
|
'\n- <data></data>中的图片链接、链接地址和脚本语言请完整返回。'
|
||||||
|
'\n- 请使用与问题相同的语言来回答。'
|
||||||
|
'\n问题:'
|
||||||
|
'\n{question}')),
|
||||||
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -155,7 +197,8 @@ class ApplicationApi(ApiMixin):
|
|||||||
def get_request_body_api():
|
def get_request_body_api():
|
||||||
return openapi.Schema(
|
return openapi.Schema(
|
||||||
type=openapi.TYPE_OBJECT,
|
type=openapi.TYPE_OBJECT,
|
||||||
required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue'],
|
required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'dataset_setting', 'model_setting',
|
||||||
|
'problem_optimization'],
|
||||||
properties={
|
properties={
|
||||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"),
|
'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"),
|
||||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"),
|
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"),
|
||||||
@ -163,11 +206,13 @@ class ApplicationApi(ApiMixin):
|
|||||||
"multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话",
|
"multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话",
|
||||||
description="是否开启多轮对话"),
|
description="是否开启多轮对话"),
|
||||||
'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"),
|
'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"),
|
||||||
'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
|
|
||||||
title="示例列表", description="示例列表"),
|
|
||||||
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||||
items=openapi.Schema(type=openapi.TYPE_STRING),
|
items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||||
title="关联知识库Id列表", description="关联知识库Id列表")
|
title="关联知识库Id列表", description="关联知识库Id列表"),
|
||||||
|
'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(),
|
||||||
|
'model_setting': ApplicationApi.ModelSetting.get_request_body_api(),
|
||||||
|
'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化",
|
||||||
|
description="是否开启问题优化", default=True)
|
||||||
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
"""
|
"""
|
||||||
from drf_yasg import openapi
|
from drf_yasg import openapi
|
||||||
|
|
||||||
|
from application.swagger_api.application_api import ApplicationApi
|
||||||
from common.mixins.api_mixin import ApiMixin
|
from common.mixins.api_mixin import ApiMixin
|
||||||
|
|
||||||
|
|
||||||
@ -19,6 +20,7 @@ class ChatApi(ApiMixin):
|
|||||||
required=['message'],
|
required=['message'],
|
||||||
properties={
|
properties={
|
||||||
'message': openapi.Schema(type=openapi.TYPE_STRING, title="问题", description="问题"),
|
'message': openapi.Schema(type=openapi.TYPE_STRING, title="问题", description="问题"),
|
||||||
|
're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default="重新生成")
|
||||||
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -73,14 +75,19 @@ class ChatApi(ApiMixin):
|
|||||||
def get_request_body_api():
|
def get_request_body_api():
|
||||||
return openapi.Schema(
|
return openapi.Schema(
|
||||||
type=openapi.TYPE_OBJECT,
|
type=openapi.TYPE_OBJECT,
|
||||||
required=['model_id', 'multiple_rounds_dialogue'],
|
required=['model_id', 'multiple_rounds_dialogue', 'dataset_setting', 'model_setting',
|
||||||
|
'problem_optimization'],
|
||||||
properties={
|
properties={
|
||||||
'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"),
|
'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"),
|
||||||
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||||
items=openapi.Schema(type=openapi.TYPE_STRING),
|
items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||||
title="关联知识库Id列表", description="关联知识库Id列表"),
|
title="关联知识库Id列表", description="关联知识库Id列表"),
|
||||||
'multiple_rounds_dialogue': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮会话",
|
'multiple_rounds_dialogue': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮会话",
|
||||||
description="是否开启多轮会话")
|
description="是否开启多轮会话"),
|
||||||
|
'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(),
|
||||||
|
'model_setting': ApplicationApi.ModelSetting.get_request_body_api(),
|
||||||
|
'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化",
|
||||||
|
description="是否开启问题优化", default=True)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,8 @@ urlpatterns = [
|
|||||||
path('application/<str:application_id>/chat/<chat_id>/chat_record/', views.ChatView.ChatRecord.as_view()),
|
path('application/<str:application_id>/chat/<chat_id>/chat_record/', views.ChatView.ChatRecord.as_view()),
|
||||||
path('application/<str:application_id>/chat/<chat_id>/chat_record/<int:current_page>/<int:page_size>',
|
path('application/<str:application_id>/chat/<chat_id>/chat_record/<int:current_page>/<int:page_size>',
|
||||||
views.ChatView.ChatRecord.Page.as_view()),
|
views.ChatView.ChatRecord.Page.as_view()),
|
||||||
|
path('application/<str:application_id>/chat/<chat_id>/chat_record/<chat_record_id>',
|
||||||
|
views.ChatView.ChatRecord.Operate.as_view()),
|
||||||
path('application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/vote',
|
path('application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/vote',
|
||||||
views.ChatView.ChatRecord.Vote.as_view(),
|
views.ChatView.ChatRecord.Vote.as_view(),
|
||||||
name=''),
|
name=''),
|
||||||
|
|||||||
@ -71,7 +71,8 @@ class ChatView(APIView):
|
|||||||
dynamic_tag=keywords.get('application_id'))])
|
dynamic_tag=keywords.get('application_id'))])
|
||||||
)
|
)
|
||||||
def post(self, request: Request, chat_id: str):
|
def post(self, request: Request, chat_id: str):
|
||||||
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'))
|
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'), request.data.get(
|
||||||
|
're_chat') if 're_chat' in request.data else False)
|
||||||
|
|
||||||
@action(methods=['GET'], detail=False)
|
@action(methods=['GET'], detail=False)
|
||||||
@swagger_auto_schema(operation_summary="获取对话列表",
|
@swagger_auto_schema(operation_summary="获取对话列表",
|
||||||
@ -134,6 +135,27 @@ class ChatView(APIView):
|
|||||||
class ChatRecord(APIView):
|
class ChatRecord(APIView):
|
||||||
authentication_classes = [TokenAuth]
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
class Operate(APIView):
|
||||||
|
authentication_classes = [TokenAuth]
|
||||||
|
|
||||||
|
@action(methods=['GET'], detail=False)
|
||||||
|
@swagger_auto_schema(operation_summary="获取对话记录详情",
|
||||||
|
operation_id="获取对话记录详情",
|
||||||
|
manual_parameters=ChatRecordApi.get_request_params_api(),
|
||||||
|
responses=result.get_api_array_response(ChatRecordApi.get_response_body_api()),
|
||||||
|
tags=["应用/对话日志"]
|
||||||
|
)
|
||||||
|
@has_permissions(
|
||||||
|
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
|
||||||
|
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||||
|
dynamic_tag=keywords.get('application_id'))])
|
||||||
|
)
|
||||||
|
def get(self, request: Request, application_id: str, chat_id: str, chat_record_id: str):
|
||||||
|
return result.success(ChatRecordSerializer.Operate(
|
||||||
|
data={'application_id': application_id,
|
||||||
|
'chat_id': chat_id,
|
||||||
|
'chat_record_id': chat_record_id}).one())
|
||||||
|
|
||||||
@action(methods=['GET'], detail=False)
|
@action(methods=['GET'], detail=False)
|
||||||
@swagger_auto_schema(operation_summary="获取对话记录列表",
|
@swagger_auto_schema(operation_summary="获取对话记录列表",
|
||||||
operation_id="获取对话记录列表",
|
operation_id="获取对话记录列表",
|
||||||
|
|||||||
@ -83,7 +83,7 @@ def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, s
|
|||||||
field_replace_dict = get_field_replace_dict(queryset)
|
field_replace_dict = get_field_replace_dict(queryset)
|
||||||
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
|
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
|
||||||
field_replace_dict=field_replace_dict)
|
field_replace_dict=field_replace_dict)
|
||||||
sql, params = app_sql_compiler.get_query_str(with_table_name)
|
sql, params = app_sql_compiler.get_query_str(with_table_name=with_table_name)
|
||||||
return sql, params
|
return sql, params
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -7,10 +7,8 @@
|
|||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
from .listener_manage import *
|
from .listener_manage import *
|
||||||
from .listener_chat_message import *
|
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
listener_manage.ListenerManagement().run()
|
listener_manage.ListenerManagement().run()
|
||||||
listener_chat_message.ListenerChatMessage().run()
|
|
||||||
QuerySet(Document).filter(status=Status.embedding).update(**{'status': Status.error})
|
QuerySet(Document).filter(status=Status.embedding).update(**{'status': Status.error})
|
||||||
|
|||||||
@ -1,67 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
"""
|
|
||||||
@project: maxkb
|
|
||||||
@Author:虎
|
|
||||||
@file: listener_manage.py
|
|
||||||
@date:2023/10/20 14:01
|
|
||||||
@desc:
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from blinker import signal
|
|
||||||
from django.db.models import QuerySet
|
|
||||||
|
|
||||||
from application.models import ChatRecord, Chat
|
|
||||||
from application.serializers.chat_message_serializers import ChatMessage
|
|
||||||
from common.event.common import poxy
|
|
||||||
|
|
||||||
|
|
||||||
class RecordChatMessageArgs:
|
|
||||||
def __init__(self, index: int, chat_id: str, application_id: str, chat_message: ChatMessage):
|
|
||||||
self.index = index
|
|
||||||
self.chat_id = chat_id
|
|
||||||
self.application_id = application_id
|
|
||||||
self.chat_message = chat_message
|
|
||||||
|
|
||||||
|
|
||||||
class ListenerChatMessage:
|
|
||||||
record_chat_message_signal = signal("record_chat_message")
|
|
||||||
update_chat_message_token_signal = signal("update_chat_message_token")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def record_chat_message(args: RecordChatMessageArgs):
|
|
||||||
if not QuerySet(Chat).filter(id=args.chat_id).exists():
|
|
||||||
Chat(id=args.chat_id, application_id=args.application_id, abstract=args.chat_message.problem).save()
|
|
||||||
# 插入会话记录
|
|
||||||
try:
|
|
||||||
chat_record = ChatRecord(
|
|
||||||
id=args.chat_message.id,
|
|
||||||
chat_id=args.chat_id,
|
|
||||||
dataset_id=args.chat_message.dataset_id,
|
|
||||||
paragraph_id=args.chat_message.paragraph_id,
|
|
||||||
source_id=args.chat_message.source_id,
|
|
||||||
source_type=args.chat_message.source_type,
|
|
||||||
problem_text=args.chat_message.problem,
|
|
||||||
answer_text=args.chat_message.answer,
|
|
||||||
index=args.index,
|
|
||||||
message_tokens=args.chat_message.message_tokens,
|
|
||||||
answer_tokens=args.chat_message.answer_token)
|
|
||||||
chat_record.save()
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@poxy
|
|
||||||
def update_token(chat_message: ChatMessage):
|
|
||||||
if chat_message.chat_model is not None:
|
|
||||||
logging.getLogger("max_kb").info("开始更新token")
|
|
||||||
message_token = chat_message.chat_model.get_num_tokens_from_messages(chat_message.chat_message)
|
|
||||||
answer_token = chat_message.chat_model.get_num_tokens(chat_message.answer)
|
|
||||||
# 修改token数量
|
|
||||||
QuerySet(ChatRecord).filter(id=chat_message.id).update(
|
|
||||||
**{'message_tokens': message_token, 'answer_tokens': answer_token})
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
# 记录会话
|
|
||||||
ListenerChatMessage.record_chat_message_signal.connect(self.record_chat_message)
|
|
||||||
ListenerChatMessage.update_chat_message_token_signal.connect(self.update_token)
|
|
||||||
34
apps/common/field/common.py
Normal file
34
apps/common/field/common.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
"""
|
||||||
|
@project: maxkb
|
||||||
|
@Author:虎
|
||||||
|
@file: common.py
|
||||||
|
@date:2024/1/11 18:44
|
||||||
|
@desc:
|
||||||
|
"""
|
||||||
|
from rest_framework import serializers
|
||||||
|
|
||||||
|
|
||||||
|
class InstanceField(serializers.Field):
|
||||||
|
def __init__(self, model_type, **kwargs):
|
||||||
|
self.model_type = model_type
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def to_internal_value(self, data):
|
||||||
|
if not isinstance(data, self.model_type):
|
||||||
|
self.fail('message类型错误', value=data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def to_representation(self, value):
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionField(serializers.Field):
|
||||||
|
|
||||||
|
def to_internal_value(self, data):
|
||||||
|
if not callable(data):
|
||||||
|
self.fail('不是一個函數', value=data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def to_representation(self, value):
|
||||||
|
return value
|
||||||
@ -5,9 +5,7 @@ SELECT
|
|||||||
problem.dataset_id AS dataset_id,
|
problem.dataset_id AS dataset_id,
|
||||||
0 AS source_type,
|
0 AS source_type,
|
||||||
problem."content" AS "text",
|
problem."content" AS "text",
|
||||||
paragraph.is_active AS is_active,
|
paragraph.is_active AS is_active
|
||||||
problem.star_num as star_num,
|
|
||||||
problem.trample_num as trample_num
|
|
||||||
FROM
|
FROM
|
||||||
problem problem
|
problem problem
|
||||||
LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id
|
LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id
|
||||||
@ -23,9 +21,7 @@ SELECT
|
|||||||
concat_ws('
|
concat_ws('
|
||||||
',concat_ws('
|
',concat_ws('
|
||||||
',paragraph.title,paragraph."content"),paragraph.title) AS "text",
|
',paragraph.title,paragraph."content"),paragraph.title) AS "text",
|
||||||
paragraph.is_active AS is_active,
|
paragraph.is_active AS is_active
|
||||||
paragraph.star_num as star_num,
|
|
||||||
paragraph.trample_num as trample_num
|
|
||||||
FROM
|
FROM
|
||||||
paragraph paragraph
|
paragraph paragraph
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,37 @@
|
|||||||
|
# Generated by Django 4.1.10 on 2024-01-16 11:22
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('dataset', '0003_alter_paragraph_content'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='paragraph',
|
||||||
|
name='hit_num',
|
||||||
|
),
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='paragraph',
|
||||||
|
name='star_num',
|
||||||
|
),
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='paragraph',
|
||||||
|
name='trample_num',
|
||||||
|
),
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='problem',
|
||||||
|
name='hit_num',
|
||||||
|
),
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='problem',
|
||||||
|
name='star_num',
|
||||||
|
),
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='problem',
|
||||||
|
name='trample_num',
|
||||||
|
),
|
||||||
|
]
|
||||||
@ -74,9 +74,6 @@ class Paragraph(AppModelMixin):
|
|||||||
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
|
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
|
||||||
content = models.CharField(max_length=4096, verbose_name="段落内容")
|
content = models.CharField(max_length=4096, verbose_name="段落内容")
|
||||||
title = models.CharField(max_length=256, verbose_name="标题", default="")
|
title = models.CharField(max_length=256, verbose_name="标题", default="")
|
||||||
hit_num = models.IntegerField(verbose_name="命中数量", default=0)
|
|
||||||
star_num = models.IntegerField(verbose_name="点赞数", default=0)
|
|
||||||
trample_num = models.IntegerField(verbose_name="点踩数", default=0)
|
|
||||||
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
|
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
|
||||||
default=Status.embedding)
|
default=Status.embedding)
|
||||||
is_active = models.BooleanField(default=True)
|
is_active = models.BooleanField(default=True)
|
||||||
@ -94,9 +91,6 @@ class Problem(AppModelMixin):
|
|||||||
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False)
|
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||||
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False)
|
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False)
|
||||||
content = models.CharField(max_length=256, verbose_name="问题内容")
|
content = models.CharField(max_length=256, verbose_name="问题内容")
|
||||||
hit_num = models.IntegerField(verbose_name="命中数量", default=0)
|
|
||||||
star_num = models.IntegerField(verbose_name="点赞数", default=0)
|
|
||||||
trample_num = models.IntegerField(verbose_name="点踩数", default=0)
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
db_table = "problem"
|
db_table = "problem"
|
||||||
|
|||||||
@ -26,12 +26,11 @@ from application.models import ApplicationDatasetMapping
|
|||||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||||
from common.db.search import get_dynamics_model, native_page_search, native_search
|
from common.db.search import get_dynamics_model, native_page_search, native_search
|
||||||
from common.db.sql_execute import select_list
|
from common.db.sql_execute import select_list
|
||||||
from common.event.listener_manage import ListenerManagement, SyncWebDatasetArgs
|
|
||||||
from common.exception.app_exception import AppApiException
|
from common.exception.app_exception import AppApiException
|
||||||
from common.mixins.api_mixin import ApiMixin
|
from common.mixins.api_mixin import ApiMixin
|
||||||
from common.util.common import post
|
from common.util.common import post
|
||||||
from common.util.file_util import get_file_content
|
from common.util.file_util import get_file_content
|
||||||
from common.util.fork import ChildLink, Fork, ForkManage
|
from common.util.fork import ChildLink, Fork
|
||||||
from common.util.split_model import get_split_model
|
from common.util.split_model import get_split_model
|
||||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
|
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
|
||||||
from dataset.serializers.common_serializers import list_paragraph
|
from dataset.serializers.common_serializers import list_paragraph
|
||||||
@ -286,7 +285,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
properties={
|
properties={
|
||||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
|
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
|
||||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
|
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
|
||||||
'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url", description="web站点url"),
|
'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url",
|
||||||
|
description="web站点url"),
|
||||||
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器")
|
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器")
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -369,7 +369,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||||||
dataset_id = uuid.uuid1()
|
dataset_id = uuid.uuid1()
|
||||||
dataset = DataSet(
|
dataset = DataSet(
|
||||||
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
|
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
|
||||||
'type': Type.web, 'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector')}})
|
'type': Type.web,
|
||||||
|
'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector')}})
|
||||||
dataset.save()
|
dataset.save()
|
||||||
ListenerManagement.sync_web_dataset_signal.send(
|
ListenerManagement.sync_web_dataset_signal.send(
|
||||||
SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'),
|
SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'),
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from dataset.serializers.problem_serializers import ProblemInstanceSerializer, P
|
|||||||
class ParagraphSerializer(serializers.ModelSerializer):
|
class ParagraphSerializer(serializers.ModelSerializer):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Paragraph
|
model = Paragraph
|
||||||
fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'document_id', 'title',
|
fields = ['id', 'content', 'is_active', 'document_id', 'title',
|
||||||
'create_time', 'update_time']
|
'create_time', 'update_time']
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -24,7 +24,7 @@ from embedding.vector.pg_vector import PGVector
|
|||||||
class ProblemSerializer(serializers.ModelSerializer):
|
class ProblemSerializer(serializers.ModelSerializer):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = Problem
|
model = Problem
|
||||||
fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id', 'document_id',
|
fields = ['id', 'content', 'dataset_id', 'document_id',
|
||||||
'create_time', 'update_time']
|
'create_time', 'update_time']
|
||||||
|
|
||||||
|
|
||||||
@ -77,8 +77,8 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
|
|||||||
'document_id': self.data.get('document_id'),
|
'document_id': self.data.get('document_id'),
|
||||||
'paragraph_id': self.data.get('paragraph_id'),
|
'paragraph_id': self.data.get('paragraph_id'),
|
||||||
'dataset_id': self.data.get('dataset_id'),
|
'dataset_id': self.data.get('dataset_id'),
|
||||||
'star_num': 0,
|
|
||||||
'trample_num': 0})
|
})
|
||||||
|
|
||||||
return ProblemSerializers.Operate(
|
return ProblemSerializers.Operate(
|
||||||
data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'),
|
data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'),
|
||||||
|
|||||||
@ -0,0 +1,37 @@
|
|||||||
|
# Generated by Django 4.1.10 on 2024-01-16 11:22
|
||||||
|
|
||||||
|
import django.contrib.postgres.fields
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('embedding', '0001_initial'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='embedding',
|
||||||
|
name='star_num',
|
||||||
|
),
|
||||||
|
migrations.RemoveField(
|
||||||
|
model_name='embedding',
|
||||||
|
name='trample_num',
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='embedding',
|
||||||
|
name='keywords',
|
||||||
|
field=django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=256), default=list, size=None, verbose_name='关键词列表'),
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='embedding',
|
||||||
|
name='meta',
|
||||||
|
field=models.JSONField(default=dict, verbose_name='元数据'),
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name='embedding',
|
||||||
|
name='source_type',
|
||||||
|
field=models.CharField(choices=[('0', '问题'), ('1', '段落'), ('2', '标题')], default='0', max_length=5, verbose_name='资源类型'),
|
||||||
|
),
|
||||||
|
]
|
||||||
@ -6,6 +6,7 @@
|
|||||||
@date:2023/9/21 15:46
|
@date:2023/9/21 15:46
|
||||||
@desc:
|
@desc:
|
||||||
"""
|
"""
|
||||||
|
from django.contrib.postgres.fields import ArrayField
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
|
||||||
from common.field.vector_field import VectorField
|
from common.field.vector_field import VectorField
|
||||||
@ -16,6 +17,7 @@ class SourceType(models.TextChoices):
|
|||||||
"""订单类型"""
|
"""订单类型"""
|
||||||
PROBLEM = 0, '问题'
|
PROBLEM = 0, '问题'
|
||||||
PARAGRAPH = 1, '段落'
|
PARAGRAPH = 1, '段落'
|
||||||
|
TITLE = 2, '标题'
|
||||||
|
|
||||||
|
|
||||||
class Embedding(models.Model):
|
class Embedding(models.Model):
|
||||||
@ -36,10 +38,10 @@ class Embedding(models.Model):
|
|||||||
|
|
||||||
embedding = VectorField(verbose_name="向量")
|
embedding = VectorField(verbose_name="向量")
|
||||||
|
|
||||||
star_num = models.IntegerField(default=0, verbose_name="点赞数量")
|
keywords = ArrayField(verbose_name="关键词列表",
|
||||||
|
base_field=models.CharField(max_length=256), default=list)
|
||||||
|
|
||||||
trample_num = models.IntegerField(default=0,
|
meta = models.JSONField(verbose_name="元数据", default=dict)
|
||||||
verbose_name="点踩数量")
|
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
db_table = "embedding"
|
db_table = "embedding"
|
||||||
|
|||||||
@ -1,15 +1,17 @@
|
|||||||
SELECT * FROM (SELECT
|
SELECT
|
||||||
*,
|
paragraph_id,
|
||||||
( 1 - ( embedding.embedding <=> %s ) ) AS similarity,
|
comprehensive_score,
|
||||||
CASE
|
comprehensive_score as similarity
|
||||||
|
|
||||||
WHEN embedding.star_num - embedding.trample_num = 0 THEN
|
|
||||||
0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) )
|
|
||||||
END AS score
|
|
||||||
FROM
|
FROM
|
||||||
embedding,
|
(
|
||||||
( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs
|
SELECT DISTINCT ON
|
||||||
${embedding_query}
|
("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score
|
||||||
) temp
|
FROM
|
||||||
WHERE similarity>0.5
|
( SELECT *, ( 1 - ( embedding.embedding <=> %s ) ) AS similarity FROM embedding ${embedding_query}) TEMP
|
||||||
ORDER BY (similarity + score) DESC LIMIT 1
|
ORDER BY
|
||||||
|
paragraph_id,
|
||||||
|
similarity DESC
|
||||||
|
) DISTINCT_TEMP
|
||||||
|
WHERE comprehensive_score>%s
|
||||||
|
ORDER BY comprehensive_score DESC
|
||||||
|
LIMIT %s
|
||||||
@ -1,34 +1,17 @@
|
|||||||
SELECT
|
SELECT
|
||||||
similarity,
|
|
||||||
paragraph_id,
|
paragraph_id,
|
||||||
comprehensive_score
|
comprehensive_score,
|
||||||
|
comprehensive_score as similarity
|
||||||
FROM
|
FROM
|
||||||
(
|
(
|
||||||
SELECT DISTINCT ON
|
SELECT DISTINCT ON
|
||||||
( "paragraph_id" ) ( similarity + score ),*,
|
("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score
|
||||||
( similarity + score ) AS comprehensive_score
|
|
||||||
FROM
|
FROM
|
||||||
(
|
( SELECT *, ( 1 - ( embedding.embedding <=> %s ) ) AS similarity FROM embedding ${embedding_query} ) TEMP
|
||||||
SELECT
|
|
||||||
*,
|
|
||||||
( 1 - ( embedding.embedding <=> %s ) ) AS similarity,
|
|
||||||
CASE
|
|
||||||
|
|
||||||
WHEN embedding.star_num - embedding.trample_num = 0 THEN
|
|
||||||
0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) )
|
|
||||||
END AS score
|
|
||||||
FROM
|
|
||||||
embedding,
|
|
||||||
( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs
|
|
||||||
${embedding_query}
|
|
||||||
) TEMP
|
|
||||||
WHERE
|
|
||||||
similarity > %s
|
|
||||||
ORDER BY
|
ORDER BY
|
||||||
paragraph_id,
|
paragraph_id,
|
||||||
( similarity + score )
|
similarity DESC
|
||||||
DESC
|
) DISTINCT_TEMP
|
||||||
) ss
|
WHERE comprehensive_score>%s
|
||||||
ORDER BY
|
ORDER BY comprehensive_score DESC
|
||||||
comprehensive_score DESC
|
LIMIT %s
|
||||||
LIMIT %s
|
|
||||||
@ -99,8 +99,6 @@ class BaseVectorStore(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||||
is_active: bool,
|
is_active: bool,
|
||||||
star_num: int,
|
|
||||||
trample_num: int,
|
|
||||||
embedding: HuggingFaceEmbeddings):
|
embedding: HuggingFaceEmbeddings):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -108,11 +106,20 @@ class BaseVectorStore(ABC):
|
|||||||
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
|
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||||
exclude_id_list: list[str],
|
exclude_paragraph_list: list[str],
|
||||||
is_active: bool,
|
is_active: bool,
|
||||||
embedding: HuggingFaceEmbeddings):
|
embedding: HuggingFaceEmbeddings):
|
||||||
|
if dataset_id_list is None or len(dataset_id_list) == 0:
|
||||||
|
return []
|
||||||
|
embedding_query = embedding.embed_query(query_text)
|
||||||
|
result = self.query(embedding_query, dataset_id_list, exclude_document_id_list, exclude_paragraph_list,
|
||||||
|
is_active, 1, 0.65)
|
||||||
|
return result[0]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||||
|
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@ -33,8 +33,6 @@ class PGVector(BaseVectorStore):
|
|||||||
|
|
||||||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||||
is_active: bool,
|
is_active: bool,
|
||||||
star_num: int,
|
|
||||||
trample_num: int,
|
|
||||||
embedding: HuggingFaceEmbeddings):
|
embedding: HuggingFaceEmbeddings):
|
||||||
text_embedding = embedding.embed_query(text)
|
text_embedding = embedding.embed_query(text)
|
||||||
embedding = Embedding(id=uuid.uuid1(),
|
embedding = Embedding(id=uuid.uuid1(),
|
||||||
@ -44,10 +42,7 @@ class PGVector(BaseVectorStore):
|
|||||||
paragraph_id=paragraph_id,
|
paragraph_id=paragraph_id,
|
||||||
source_id=source_id,
|
source_id=source_id,
|
||||||
embedding=text_embedding,
|
embedding=text_embedding,
|
||||||
source_type=source_type,
|
source_type=source_type)
|
||||||
star_num=star_num,
|
|
||||||
trample_num=trample_num
|
|
||||||
)
|
|
||||||
embedding.save()
|
embedding.save()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -61,8 +56,6 @@ class PGVector(BaseVectorStore):
|
|||||||
is_active=text_list[index].get('is_active', True),
|
is_active=text_list[index].get('is_active', True),
|
||||||
source_id=text_list[index].get('source_id'),
|
source_id=text_list[index].get('source_id'),
|
||||||
source_type=text_list[index].get('source_type'),
|
source_type=text_list[index].get('source_type'),
|
||||||
star_num=text_list[index].get('star_num'),
|
|
||||||
trample_num=text_list[index].get('trample_num'),
|
|
||||||
embedding=embeddings[index]) for index in
|
embedding=embeddings[index]) for index in
|
||||||
range(0, len(text_list))]
|
range(0, len(text_list))]
|
||||||
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
|
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
|
||||||
@ -78,29 +71,27 @@ class PGVector(BaseVectorStore):
|
|||||||
'hit_test.sql')),
|
'hit_test.sql')),
|
||||||
with_table_name=True)
|
with_table_name=True)
|
||||||
embedding_model = select_list(exec_sql,
|
embedding_model = select_list(exec_sql,
|
||||||
[json.dumps(embedding_query), *exec_params, *exec_params, similarity, top_number])
|
[json.dumps(embedding_query), *exec_params, similarity, top_number])
|
||||||
return embedding_model
|
return embedding_model
|
||||||
|
|
||||||
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||||
exclude_id_list: list[str],
|
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float):
|
||||||
is_active: bool,
|
|
||||||
embedding: HuggingFaceEmbeddings):
|
|
||||||
exclude_dict = {}
|
exclude_dict = {}
|
||||||
if dataset_id_list is None or len(dataset_id_list) == 0:
|
if dataset_id_list is None or len(dataset_id_list) == 0:
|
||||||
return None
|
return []
|
||||||
query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=is_active)
|
query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=is_active)
|
||||||
embedding_query = embedding.embed_query(query_text)
|
|
||||||
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
|
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
|
||||||
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
|
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
|
||||||
if exclude_id_list is not None and len(exclude_id_list) > 0:
|
if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0:
|
||||||
exclude_dict.__setitem__('id__in', exclude_id_list)
|
exclude_dict.__setitem__('paragraph_id__in', exclude_paragraph_list)
|
||||||
query_set = query_set.exclude(**exclude_dict)
|
query_set = query_set.exclude(**exclude_dict)
|
||||||
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
|
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
|
||||||
select_string=get_file_content(
|
select_string=get_file_content(
|
||||||
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
|
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
|
||||||
'embedding_search.sql')),
|
'embedding_search.sql')),
|
||||||
with_table_name=True)
|
with_table_name=True)
|
||||||
embedding_model = select_one(exec_sql, (json.dumps(embedding_query), *exec_params, *exec_params))
|
embedding_model = select_list(exec_sql,
|
||||||
|
[json.dumps(query_embedding), *exec_params, similarity, top_n])
|
||||||
return embedding_model
|
return embedding_model
|
||||||
|
|
||||||
def update_by_source_id(self, source_id: str, instance: Dict):
|
def update_by_source_id(self, source_id: str, instance: Dict):
|
||||||
|
|||||||
@ -14,13 +14,23 @@ from langchain.chat_models.base import BaseChatModel
|
|||||||
from langchain.load import dumpd
|
from langchain.load import dumpd
|
||||||
from langchain.schema import LLMResult
|
from langchain.schema import LLMResult
|
||||||
from langchain.schema.language_model import LanguageModelInput
|
from langchain.schema.language_model import LanguageModelInput
|
||||||
from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage
|
from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage, get_buffer_string
|
||||||
from langchain.schema.output import ChatGenerationChunk
|
from langchain.schema.output import ChatGenerationChunk
|
||||||
from langchain.schema.runnable import RunnableConfig
|
from langchain.schema.runnable import RunnableConfig
|
||||||
|
from transformers import GPT2TokenizerFast
|
||||||
|
|
||||||
|
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False,
|
||||||
|
force_download=False)
|
||||||
|
|
||||||
|
|
||||||
class QianfanChatModel(QianfanChatEndpoint):
|
class QianfanChatModel(QianfanChatEndpoint):
|
||||||
|
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
return len(tokenizer.encode(text))
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self,
|
self,
|
||||||
input: LanguageModelInput,
|
input: LanguageModelInput,
|
||||||
@ -30,7 +40,7 @@ class QianfanChatModel(QianfanChatEndpoint):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[BaseMessageChunk]:
|
) -> Iterator[BaseMessageChunk]:
|
||||||
if len(input) % 2 == 0:
|
if len(input) % 2 == 0:
|
||||||
input = [HumanMessage(content='占位'), *input]
|
input = [HumanMessage(content='padding'), *input]
|
||||||
input = [
|
input = [
|
||||||
HumanMessage(content=input[index].content) if index % 2 == 0 else AIMessage(content=input[index].content)
|
HumanMessage(content=input[index].content) if index % 2 == 0 else AIMessage(content=input[index].content)
|
||||||
for index in range(0, len(input))]
|
for index in range(0, len(input))]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user