feat: 优化对话逻辑

This commit is contained in:
shaohuzhang1 2024-01-16 16:46:54 +08:00
parent 7349f00c54
commit 3f87335c80
46 changed files with 1393 additions and 396 deletions

View File

@ -0,0 +1,55 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file I_base_chat_pipeline.py
@date2024/1/9 17:25
@desc:
"""
import time
from abc import abstractmethod
from typing import Type
from rest_framework import serializers
class IBaseChatPipelineStep:
def __init__(self):
# 当前步骤上下文,用于存储当前步骤信息
self.context = {}
@abstractmethod
def get_step_serializer(self, manage) -> Type[serializers.Serializer]:
pass
def valid_args(self, manage):
step_serializer_clazz = self.get_step_serializer(manage)
step_serializer = step_serializer_clazz(data=manage.context)
step_serializer.is_valid(raise_exception=True)
self.context['step_args'] = step_serializer.data
def run(self, manage):
"""
:param manage: 步骤管理器
:return: 执行结果
"""
start_time = time.time()
# 校验参数,
self.valid_args(manage)
self._run(manage)
self.context['start_time'] = start_time
self.context['run_time'] = time.time() - start_time
def _run(self, manage):
pass
def execute(self, **kwargs):
pass
def get_details(self, manage, **kwargs):
"""
运行详情
:return: 步骤详情
"""
return None

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 17:23
@desc:
"""

View File

@ -0,0 +1,45 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file pipeline_manage.py
@date2024/1/9 17:40
@desc:
"""
import time
from functools import reduce
from typing import List, Type, Dict
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
class PiplineManage:
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]]):
# 步骤执行器
self.step_list = [step() for step in step_list]
# 上下文
self.context = {'message_tokens': 0, 'answer_tokens': 0}
def run(self, context: Dict = None):
self.context['start_time'] = time.time()
if context is not None:
for key, value in context.items():
self.context[key] = value
for step in self.step_list:
step.run(self)
def get_details(self):
return reduce(lambda x, y: {**x, **y}, [{item.get('step_type'): item} for item in
filter(lambda r: r is not None,
[row.get_details(self) for row in self.step_list])], {})
class builder:
def __init__(self):
self.step_list: List[Type[IBaseChatPipelineStep]] = []
def append_step(self, step: Type[IBaseChatPipelineStep]):
self.step_list.append(step)
return self
def build(self):
return PiplineManage(step_list=self.step_list)

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 18:23
@desc:
"""

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 18:23
@desc:
"""

View File

@ -0,0 +1,88 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_chat_step.py
@date2024/1/9 18:17
@desc: 对话
"""
from abc import abstractmethod
from typing import Type, List
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
from application.chat_pipeline.pipeline_manage import PiplineManage
from common.field.common import InstanceField
from dataset.models import Paragraph
class ModelField(serializers.Field):
def to_internal_value(self, data):
if not isinstance(data, BaseChatModel):
self.fail('模型类型错误', value=data)
return data
def to_representation(self, value):
return value
class MessageField(serializers.Field):
def to_internal_value(self, data):
if not isinstance(data, BaseMessage):
self.fail('message类型错误', value=data)
return data
def to_representation(self, value):
return value
class PostResponseHandler:
@abstractmethod
def handler(self, chat_id, chat_record_id, paragraph_list: List[Paragraph], problem_text: str, answer_text,
manage, step, padding_problem_text: str = None, **kwargs):
pass
class IChatStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 对话列表
message_list = serializers.ListField(required=True, child=MessageField(required=True))
# 大语言模型
chat_model = ModelField()
# 段落列表
paragraph_list = serializers.ListField()
# 对话id
chat_id = serializers.UUIDField(required=True)
# 用户问题
problem_text = serializers.CharField(required=True)
# 后置处理器
post_response_handler = InstanceField(model_type=PostResponseHandler)
# 补全问题
padding_problem_text = serializers.CharField(required=False)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
message_list: List = self.initial_data.get('message_list')
for message in message_list:
if not isinstance(message, BaseMessage):
raise Exception("message 类型错误")
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer
def _run(self, manage: PiplineManage):
chat_result = self.execute(**self.context['step_args'], manage=manage)
manage.context['chat_result'] = chat_result
@abstractmethod
def execute(self, message_list: List[BaseMessage],
chat_id, problem_text,
post_response_handler: PostResponseHandler,
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PiplineManage = None,
padding_problem_text: str = None, **kwargs):
pass

View File

@ -0,0 +1,111 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_chat_step.py
@date2024/1/9 18:25
@desc: 对话step Base实现
"""
import json
import logging
import time
import traceback
import uuid
from typing import List
from django.http import StreamingHttpResponse
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
from langchain.schema.messages import BaseMessageChunk, HumanMessage
from application.chat_pipeline.pipeline_manage import PiplineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from dataset.models import Paragraph
def event_content(response,
chat_id,
chat_record_id,
paragraph_list: List[Paragraph],
post_response_handler: PostResponseHandler,
manage,
step,
chat_model,
message_list: List[BaseMessage],
problem_text: str,
padding_problem_text: str = None):
all_text = ''
try:
for chunk in response:
all_text += chunk.content
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': chunk.content, 'is_end': False}) + "\n\n"
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': '', 'is_end': True}) + "\n\n"
# 获取token
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(all_text)
step.context['message_tokens'] = request_token
step.context['answer_tokens'] = response_token
current_time = time.time()
step.context['answer_text'] = all_text
step.context['run_time'] = current_time - step.context['start_time']
manage.context['run_time'] = current_time - manage.context['start_time']
manage.context['message_tokens'] = manage.context['message_tokens'] + request_token
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, step, padding_problem_text)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': '异常' + str(e), 'is_end': True}) + "\n\n"
class BaseChatStep(IChatStep):
def execute(self, message_list: List[BaseMessage],
chat_id,
problem_text,
post_response_handler: PostResponseHandler,
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PiplineManage = None,
padding_problem_text: str = None,
**kwargs):
# 调用模型
if chat_model is None:
chat_result = iter(
[BaseMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
else:
chat_result = chat_model.stream(message_list)
chat_record_id = uuid.uuid1()
r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
post_response_handler, manage, self, chat_model, message_list, problem_text,
padding_problem_text),
content_type='text/event-stream;charset=utf-8')
r['Cache-Control'] = 'no-cache'
return r
def get_details(self, manage, **kwargs):
return {
'step_type': 'chat_step',
'run_time': self.context['run_time'],
'model_id': str(manage.context['model_id']),
'message_list': self.reset_message_list(self.context['step_args'].get('message_list'),
self.context['answer_text']),
'message_tokens': self.context['message_tokens'],
'answer_tokens': self.context['answer_tokens'],
'cost': 0,
}
@staticmethod
def reset_message_list(message_list: List[BaseMessage], answer_text):
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
message
in
message_list]
result.append({'role': 'ai', 'content': answer_text})
return result

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 18:23
@desc:
"""

View File

@ -0,0 +1,68 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_generate_human_message_step.py
@date2024/1/9 18:15
@desc: 生成对话模板
"""
from abc import abstractmethod
from typing import Type, List
from langchain.schema import BaseMessage
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
from application.chat_pipeline.pipeline_manage import PiplineManage
from application.models import ChatRecord
from common.field.common import InstanceField
from dataset.models import Paragraph
class IGenerateHumanMessageStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 问题
problem_text = serializers.CharField(required=True)
# 段落列表
paragraph_list = serializers.ListField(child=InstanceField(model_type=Paragraph, required=True))
# 历史对答
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True))
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True)
# 最大携带知识库段落长度
max_paragraph_char_number = serializers.IntegerField(required=True)
# 模板
prompt = serializers.CharField(required=True)
# 补齐问题
padding_problem_text = serializers.CharField(required=False)
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer
def _run(self, manage: PiplineManage):
message_list = self.execute(**self.context['step_args'])
manage.context['message_list'] = message_list
@abstractmethod
def execute(self,
problem_text: str,
paragraph_list: List[Paragraph],
history_chat_record: List[ChatRecord],
dialogue_number: int,
max_paragraph_char_number: int,
prompt: str,
padding_problem_text: str = None,
**kwargs) -> List[BaseMessage]:
"""
:param problem_text: 原始问题文本
:param paragraph_list: 段落列表
:param history_chat_record: 历史对话记录
:param dialogue_number: 多轮对话数量
:param max_paragraph_char_number: 最大段落长度
:param prompt: 模板
:param padding_problem_text 用户修改文本
:param kwargs: 其他参数
:return:
"""
pass

View File

@ -0,0 +1,57 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_generate_human_message_step.py.py
@date2024/1/10 17:50
@desc:
"""
from typing import List
from langchain.schema import BaseMessage, HumanMessage
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
IGenerateHumanMessageStep
from application.models import ChatRecord
from common.util.split_model import flat_map
from dataset.models import Paragraph
class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
def execute(self, problem_text: str,
paragraph_list: List[Paragraph],
history_chat_record: List[ChatRecord],
dialogue_number: int,
max_paragraph_char_number: int,
prompt: str,
padding_problem_text: str = None,
**kwargs) -> List[BaseMessage]:
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
start_index = len(history_chat_record) - dialogue_number
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))]
return [*flat_map(history_message),
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list)]
@staticmethod
def to_human_message(prompt: str,
problem: str,
max_paragraph_char_number: int,
paragraph_list: List[Paragraph]):
if paragraph_list is None or len(paragraph_list) == 0:
return HumanMessage(content=problem)
temp_data = ""
data_list = []
for p in paragraph_list:
content = f"{p.title}:{p.content}"
temp_data += content
if len(temp_data) > max_paragraph_char_number:
row_data = content[0:max_paragraph_char_number - len(temp_data)]
data_list.append(f"<data>{row_data}</data>")
break
else:
data_list.append(f"<data>{content}</data>")
data = "\n".join(data_list)
return HumanMessage(content=prompt.format(**{'data': data, 'question': problem}))

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 18:23
@desc:
"""

View File

@ -0,0 +1,49 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_reset_problem_step.py
@date2024/1/9 18:12
@desc: 重写处理问题
"""
from abc import abstractmethod
from typing import Type, List
from langchain.chat_models.base import BaseChatModel
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
from application.chat_pipeline.pipeline_manage import PiplineManage
from application.chat_pipeline.step.chat_step.i_chat_step import ModelField
from application.models import ChatRecord
from common.field.common import InstanceField
class IResetProblemStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 问题文本
problem_text = serializers.CharField(required=True)
# 历史对答
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True))
# 大语言模型
chat_model = ModelField()
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer
def _run(self, manage: PiplineManage):
padding_problem = self.execute(**self.context.get('step_args'))
# 用户输入问题
source_problem_text = self.context.get('step_args').get('problem_text')
self.context['problem_text'] = source_problem_text
self.context['padding_problem_text'] = padding_problem
manage.context['problem_text'] = source_problem_text
manage.context['padding_problem_text'] = padding_problem
# 累加tokens
manage.context['message_tokens'] = manage.context['message_tokens'] + self.context.get('message_tokens')
manage.context['answer_tokens'] = manage.context['answer_tokens'] + self.context.get('answer_tokens')
@abstractmethod
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
**kwargs):
pass

View File

@ -0,0 +1,47 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_reset_problem_step.py
@date2024/1/10 14:35
@desc:
"""
from typing import List
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage
from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep
from application.models import ChatRecord
from common.util.split_model import flat_map
prompt = (
'()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在<data></data>标签中')
class BaseResetProblemStep(IResetProblemStep):
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
**kwargs) -> str:
start_index = len(history_chat_record) - 3
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))]
message_list = [*flat_map(history_message),
HumanMessage(content=prompt.format(**{'question': problem_text}))]
response = chat_model(message_list)
padding_problem = response.content[response.content.index('<data>') + 6:response.content.index('</data>')]
self.context['message_tokens'] = chat_model.get_num_tokens_from_messages(message_list)
self.context['answer_tokens'] = chat_model.get_num_tokens(padding_problem)
return padding_problem
def get_details(self, manage, **kwargs):
return {
'step_type': 'problem_padding',
'run_time': self.context['run_time'],
'model_id': str(manage.context['model_id']) if 'model_id' in manage.context else None,
'message_tokens': self.context['message_tokens'],
'answer_tokens': self.context['answer_tokens'],
'cost': 0,
'padding_problem_text': self.context.get('padding_problem_text'),
'problem_text': self.context.get("step_args").get('problem_text'),
}

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 18:24
@desc:
"""

View File

@ -0,0 +1,58 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_search_dataset_step.py
@date2024/1/9 18:10
@desc: 检索知识库
"""
from abc import abstractmethod
from typing import List, Type
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
from application.chat_pipeline.pipeline_manage import PiplineManage
from dataset.models import Paragraph
class ISearchDatasetStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 原始问题文本
problem_text = serializers.CharField(required=True)
# 系统补全问题文本
padding_problem_text = serializers.CharField(required=False)
# 需要查询的数据集id列表
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
# 需要排除的文档id
exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
# 需要排除向量id
exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
# 需要查询的条数
top_n = serializers.IntegerField(required=True)
# 相似度 0-1之间
similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
def get_step_serializer(self, manage: PiplineManage) -> Type[InstanceSerializer]:
return self.InstanceSerializer
def _run(self, manage: PiplineManage):
paragraph_list = self.execute(**self.context['step_args'])
manage.context['paragraph_list'] = paragraph_list
@abstractmethod
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
**kwargs) -> List[Paragraph]:
"""
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
:param similarity: 相关性
:param top_n: 查询多少条
:param problem_text: 用户问题
:param dataset_id_list: 需要查询的数据集id列表
:param exclude_document_id_list: 需要排除的文档id
:param exclude_paragraph_id_list: 需要排除段落id
:param padding_problem_text 补全问题
:return: 段落列表
"""
pass

View File

@ -0,0 +1,58 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_search_dataset_step.py
@date2024/1/10 10:33
@desc:
"""
from typing import List
from django.db.models import QuerySet
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
from common.config.embedding_config import VectorStore, EmbeddingModel
from dataset.models import Paragraph
class BaseSearchDatasetStep(ISearchDatasetStep):
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
**kwargs) -> List[Paragraph]:
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
embedding_model = EmbeddingModel.get_embedding_model()
embedding_value = embedding_model.embed_query(exec_problem_text)
vector = VectorStore.get_embedding_vector()
embedding_list = vector.query(embedding_value, dataset_id_list, exclude_document_id_list,
exclude_paragraph_id_list, True, top_n, similarity)
if embedding_list is None:
return []
return self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector)
@staticmethod
def list_paragraph(paragraph_id_list: List, vector):
if paragraph_id_list is None or len(paragraph_id_list) == 0:
return []
paragraph_list = QuerySet(Paragraph).filter(id__in=paragraph_id_list)
# 如果向量库中存在脏数据 直接删除
if len(paragraph_list) != len(paragraph_id_list):
exist_paragraph_list = [str(row.id) for row in paragraph_list]
for paragraph_id in paragraph_id_list:
if not exist_paragraph_list.__contains__(paragraph_id):
vector.delete_by_paragraph_id(paragraph_id)
return paragraph_list
def get_details(self, manage, **kwargs):
step_args = self.context['step_args']
return {
'step_type': 'search_step',
'run_time': self.context['run_time'],
'problem_text': step_args.get(
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
'model_name': EmbeddingModel.get_embedding_model().model_name,
'message_tokens': 0,
'answer_tokens': 0,
'cost': 0
}

View File

@ -0,0 +1,55 @@
# Generated by Django 4.1.10 on 2024-01-12 18:46
import django.contrib.postgres.fields
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0002_alter_chatrecord_dataset'),
]
operations = [
migrations.RemoveField(
model_name='chatrecord',
name='dataset',
),
migrations.RemoveField(
model_name='chatrecord',
name='paragraph',
),
migrations.RemoveField(
model_name='chatrecord',
name='source_id',
),
migrations.RemoveField(
model_name='chatrecord',
name='source_type',
),
migrations.AddField(
model_name='chatrecord',
name='const',
field=models.IntegerField(default=0, verbose_name='总费用'),
),
migrations.AddField(
model_name='chatrecord',
name='details',
field=models.JSONField(default=list, verbose_name='对话详情'),
),
migrations.AddField(
model_name='chatrecord',
name='paragraph_id_list',
field=django.contrib.postgres.fields.ArrayField(base_field=models.UUIDField(blank=True), default=list, size=None, verbose_name='引用段落id列表'),
),
migrations.AddField(
model_name='chatrecord',
name='run_time',
field=models.FloatField(default=0, verbose_name='运行时长'),
),
migrations.AlterField(
model_name='chatrecord',
name='answer_text',
field=models.CharField(max_length=4096, verbose_name='答案'),
),
]

View File

@ -0,0 +1,38 @@
# Generated by Django 4.1.10 on 2024-01-15 16:07
import application.models.application
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0003_remove_chatrecord_dataset_and_more'),
]
operations = [
migrations.RemoveField(
model_name='application',
name='example',
),
migrations.AddField(
model_name='application',
name='dataset_setting',
field=models.JSONField(default=application.models.application.get_dataset_setting_dict, verbose_name='数据集参数设置'),
),
migrations.AddField(
model_name='application',
name='model_setting',
field=models.JSONField(default=application.models.application.get_model_setting_dict, verbose_name='模型参数相关设置'),
),
migrations.AddField(
model_name='application',
name='problem_optimization',
field=models.BooleanField(default=False, verbose_name='问题优化'),
),
migrations.AlterField(
model_name='chatrecord',
name='details',
field=models.JSONField(default={}, verbose_name='对话详情'),
),
]

View File

@ -0,0 +1,18 @@
# Generated by Django 4.1.10 on 2024-01-16 11:22
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0004_remove_application_example_and_more'),
]
operations = [
migrations.AlterField(
model_name='chatrecord',
name='details',
field=models.JSONField(default=dict, verbose_name='对话详情'),
),
]

View File

@ -10,23 +10,47 @@ import uuid
from django.contrib.postgres.fields import ArrayField from django.contrib.postgres.fields import ArrayField
from django.db import models from django.db import models
from langchain.schema import HumanMessage, AIMessage
from common.mixins.app_model_mixin import AppModelMixin from common.mixins.app_model_mixin import AppModelMixin
from dataset.models.data_set import DataSet, Paragraph from dataset.models.data_set import DataSet
from embedding.models import SourceType
from setting.models.model_management import Model from setting.models.model_management import Model
from users.models import User from users.models import User
def get_dataset_setting_dict():
return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000}
def get_model_setting_dict():
return {'prompt': Application.get_default_model_prompt()}
class Application(AppModelMixin): class Application(AppModelMixin):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
name = models.CharField(max_length=128, verbose_name="应用名称") name = models.CharField(max_length=128, verbose_name="应用名称")
desc = models.CharField(max_length=128, verbose_name="引用描述", default="") desc = models.CharField(max_length=128, verbose_name="引用描述", default="")
prologue = models.CharField(max_length=1024, verbose_name="开场白", default="") prologue = models.CharField(max_length=1024, verbose_name="开场白", default="")
example = ArrayField(verbose_name="示例列表", base_field=models.CharField(max_length=256, blank=True), default=list)
dialogue_number = models.IntegerField(default=0, verbose_name="会话数量") dialogue_number = models.IntegerField(default=0, verbose_name="会话数量")
user = models.ForeignKey(User, on_delete=models.DO_NOTHING) user = models.ForeignKey(User, on_delete=models.DO_NOTHING)
model = models.ForeignKey(Model, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True) model = models.ForeignKey(Model, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
dataset_setting = models.JSONField(verbose_name="数据集参数设置", default=get_dataset_setting_dict)
model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict)
problem_optimization = models.BooleanField(verbose_name="问题优化", default=False)
@staticmethod
def get_default_model_prompt():
return ('已知信息:'
'\n{data}'
'\n回答要求:'
'\n- 如果你不知道答案或者没有从获取答案,请回答“没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作”。'
'\n- 避免提及你是从<data></data>中获得的知识。'
'\n- 请保持答案与<data></data>中描述的一致。'
'\n- 请使用markdown 语法优化答案的格式。'
'\n- <data></data>中的图片链接、链接地址和脚本语言请完整返回。'
'\n- 请使用与问题相同的语言来回答。'
'\n问题:'
'\n{question}')
class Meta: class Meta:
db_table = "application" db_table = "application"
@ -65,20 +89,28 @@ class ChatRecord(AppModelMixin):
chat = models.ForeignKey(Chat, on_delete=models.CASCADE) chat = models.ForeignKey(Chat, on_delete=models.CASCADE)
vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices, vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices,
default=VoteChoices.UN_VOTE) default=VoteChoices.UN_VOTE)
dataset = models.ForeignKey(DataSet, on_delete=models.SET_NULL, verbose_name="数据集", blank=True, null=True) paragraph_id_list = ArrayField(verbose_name="引用段落id列表",
paragraph = models.ForeignKey(Paragraph, on_delete=models.SET_NULL, verbose_name="段落id", blank=True, null=True) base_field=models.UUIDField(max_length=128, blank=True)
source_id = models.UUIDField(max_length=128, verbose_name="资源id 段落/问题 id ", null=True) , default=list)
source_type = models.CharField(verbose_name='资源类型', max_length=2, choices=SourceType.choices, problem_text = models.CharField(max_length=1024, verbose_name="问题")
default=SourceType.PROBLEM, blank=True, null=True) answer_text = models.CharField(max_length=4096, verbose_name="答案")
message_tokens = models.IntegerField(verbose_name="请求token数量", default=0) message_tokens = models.IntegerField(verbose_name="请求token数量", default=0)
answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0) answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0)
problem_text = models.CharField(max_length=1024, verbose_name="问题") const = models.IntegerField(verbose_name="总费用", default=0)
answer_text = models.CharField(max_length=1024, verbose_name="答案") details = models.JSONField(verbose_name="对话详情", default=dict)
improve_paragraph_id_list = ArrayField(verbose_name="改进标注列表", improve_paragraph_id_list = ArrayField(verbose_name="改进标注列表",
base_field=models.UUIDField(max_length=128, blank=True) base_field=models.UUIDField(max_length=128, blank=True)
, default=list) , default=list)
run_time = models.FloatField(verbose_name="运行时长", default=0)
index = models.IntegerField(verbose_name="对话下标") index = models.IntegerField(verbose_name="对话下标")
def get_human_message(self):
if 'problem_padding' in self.details:
return HumanMessage(content=self.details.get('problem_padding').get('padding_problem_text'))
return HumanMessage(content=self.problem_text)
def get_ai_message(self):
return AIMessage(content=self.answer_text)
class Meta: class Meta:
db_table = "application_chat_record" db_table = "application_chat_record"

View File

@ -63,15 +63,35 @@ class ApplicationSerializerModel(serializers.ModelSerializer):
fields = "__all__" fields = "__all__"
class DatasetSettingSerializer(serializers.Serializer):
top_n = serializers.FloatField(required=True)
similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
max_paragraph_char_number = serializers.IntegerField(required=True, max_value=10000)
class ModelSettingSerializer(serializers.Serializer):
prompt = serializers.CharField(required=True, max_length=4096)
class ApplicationSerializer(serializers.Serializer): class ApplicationSerializer(serializers.Serializer):
name = serializers.CharField(required=True) name = serializers.CharField(required=True)
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True) desc = serializers.CharField(required=False, allow_null=True, allow_blank=True)
model_id = serializers.CharField(required=True) model_id = serializers.CharField(required=True)
multiple_rounds_dialogue = serializers.BooleanField(required=True) multiple_rounds_dialogue = serializers.BooleanField(required=True)
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True) prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True)
example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True), allow_null=True)
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True), dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
allow_null=True) allow_null=True)
# 数据集相关设置
dataset_setting = DatasetSettingSerializer(required=True)
# 模型相关设置
model_setting = ModelSettingSerializer(required=True)
# 问题补全
problem_optimization = serializers.BooleanField(required=True)
def is_valid(self, *, user_id=None, raise_exception=False):
super().is_valid(raise_exception=True)
ModelDatasetAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'),
'dataset_id_list': self.data.get('dataset_id_list')}).is_valid()
class AccessTokenSerializer(serializers.Serializer): class AccessTokenSerializer(serializers.Serializer):
application_id = serializers.UUIDField(required=True) application_id = serializers.UUIDField(required=True)
@ -135,13 +155,13 @@ class ApplicationSerializer(serializers.Serializer):
model_id = serializers.CharField(required=False) model_id = serializers.CharField(required=False)
multiple_rounds_dialogue = serializers.BooleanField(required=False) multiple_rounds_dialogue = serializers.BooleanField(required=False)
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True) prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True)
example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
# 数据集相关设置
def is_valid(self, *, user_id=None, raise_exception=False): dataset_setting = serializers.JSONField(required=False, allow_null=True)
super().is_valid(raise_exception=True) # 模型相关设置
ModelDatasetAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'), model_setting = serializers.JSONField(required=False, allow_null=True)
'dataset_id_list': self.data.get('dataset_id_list')}).is_valid() # 问题补全
problem_optimization = serializers.BooleanField(required=False, allow_null=True)
class Create(serializers.Serializer): class Create(serializers.Serializer):
user_id = serializers.UUIDField(required=True) user_id = serializers.UUIDField(required=True)
@ -168,9 +188,12 @@ class ApplicationSerializer(serializers.Serializer):
@staticmethod @staticmethod
def to_application_model(user_id: str, application: Dict): def to_application_model(user_id: str, application: Dict):
return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'), return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'),
prologue=application.get('prologue'), example=application.get('example'), prologue=application.get('prologue'),
dialogue_number=3 if application.get('multiple_rounds_dialogue') else 0, dialogue_number=3 if application.get('multiple_rounds_dialogue') else 0,
user_id=user_id, model_id=application.get('model_id'), user_id=user_id, model_id=application.get('model_id'),
dataset_setting=application.get('dataset_setting'),
model_setting=application.get('model_setting'),
problem_optimization=application.get('problem_optimization')
) )
@staticmethod @staticmethod
@ -267,7 +290,7 @@ class ApplicationSerializer(serializers.Serializer):
class ApplicationModel(serializers.ModelSerializer): class ApplicationModel(serializers.ModelSerializer):
class Meta: class Meta:
model = Application model = Application
fields = ['id', 'name', 'desc', 'prologue', 'example', 'dialogue_number'] fields = ['id', 'name', 'desc', 'prologue', 'dialogue_number']
class Operate(serializers.Serializer): class Operate(serializers.Serializer):
application_id = serializers.UUIDField(required=True) application_id = serializers.UUIDField(required=True)
@ -317,8 +340,9 @@ class ApplicationSerializer(serializers.Serializer):
model = QuerySet(Model).get(id=instance.get('model_id') if 'model_id' in instance else application.model_id) model = QuerySet(Model).get(id=instance.get('model_id') if 'model_id' in instance else application.model_id)
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'example', 'status', update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
'api_key_is_active'] 'dataset_setting', 'model_setting', 'problem_optimization'
'api_key_is_active']
for update_key in update_keys: for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None: if update_key in instance and instance.get(update_key) is not None:
if update_key == 'multiple_rounds_dialogue': if update_key == 'multiple_rounds_dialogue':

View File

@ -7,194 +7,181 @@
@desc: @desc:
""" """
import json import json
import uuid
from typing import List from typing import List
from uuid import UUID
from django.db.models import QuerySet
from django.http import StreamingHttpResponse
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage
from rest_framework import serializers, status
from django.core.cache import cache from django.core.cache import cache
from common import event from django.db.models import QuerySet
from common.config.embedding_config import VectorStore, EmbeddingModel from langchain.chat_models.base import BaseChatModel
from common.response import result from rest_framework import serializers
from dataset.models import Paragraph
from embedding.models import SourceType from application.chat_pipeline.pipeline_manage import PiplineManage
from setting.models.model_management import Model from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler
from application.chat_pipeline.step.chat_step.impl.base_chat_step import BaseChatStep
from application.chat_pipeline.step.generate_human_message_step.impl.base_generate_human_message_step import \
BaseGenerateHumanMessageStep
from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping
from common.exception.app_exception import AppApiException
from common.util.rsa_util import decrypt
from common.util.split_model import flat_map
from dataset.models import Paragraph, Document
from setting.models import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
chat_cache = cache chat_cache = cache
class MessageManagement:
@staticmethod
def get_message(title: str, content: str, message: str):
if content is None:
return HumanMessage(content=message)
return HumanMessage(content=(
f'已知信息:{title}:{content} '
'根据上述已知信息,请简洁和专业的来回答用户的问题。已知信息中的图片、链接地址和脚本语言请直接返回。如果无法从已知信息中得到答案,请说 “没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作” 或 “根据已知信息无法回答该问题,建议联系官方技术支持人员”,不允许在答案中添加编造成分,答案请使用中文。'
f'问题是:{message}'))
class ChatMessage:
def __init__(self, id: str, problem: str, title: str, paragraph: str, embedding_id: str, dataset_id: str,
document_id: str,
paragraph_id,
source_type: SourceType,
source_id: str,
answer: str,
message_tokens: int,
answer_token: int,
chat_model=None,
chat_message=None):
self.id = id
self.problem = problem
self.title = title
self.paragraph = paragraph
self.embedding_id = embedding_id
self.dataset_id = dataset_id
self.document_id = document_id
self.paragraph_id = paragraph_id
self.source_type = source_type
self.source_id = source_id
self.answer = answer
self.message_tokens = message_tokens
self.answer_token = answer_token
self.chat_model = chat_model
self.chat_message = chat_message
def get_chat_message(self):
return MessageManagement.get_message(self.problem, self.paragraph, self.problem)
class ChatInfo: class ChatInfo:
def __init__(self, def __init__(self,
chat_id: str, chat_id: str,
model: Model,
chat_model: BaseChatModel, chat_model: BaseChatModel,
application_id: str | None,
dataset_id_list: List[str], dataset_id_list: List[str],
exclude_document_id_list: list[str], exclude_document_id_list: list[str],
dialogue_number: int): application: Application):
"""
:param chat_id: 对话id
:param chat_model: 对话模型
:param dataset_id_list: 数据集列表
:param exclude_document_id_list: 排除的文档
:param application: 应用信息
"""
self.chat_id = chat_id self.chat_id = chat_id
self.application_id = application_id self.application = application
self.model = model
self.chat_model = chat_model self.chat_model = chat_model
self.dataset_id_list = dataset_id_list self.dataset_id_list = dataset_id_list
self.exclude_document_id_list = exclude_document_id_list self.exclude_document_id_list = exclude_document_id_list
self.dialogue_number = dialogue_number self.chat_record_list: List[ChatRecord] = []
self.chat_message_list: List[ChatMessage] = []
def append_chat_message(self, chat_message: ChatMessage): def to_base_pipeline_manage_params(self):
self.chat_message_list.append(chat_message) dataset_setting = self.application.dataset_setting
if self.application_id is not None: model_setting = self.application.model_setting
return {
'dataset_id_list': self.dataset_id_list,
'exclude_document_id_list': self.exclude_document_id_list,
'exclude_paragraph_id_list': [],
'top_n': dataset_setting.get('top_n') if 'top_n' in dataset_setting else 3,
'similarity': dataset_setting.get('similarity') if 'similarity' in dataset_setting else 0.6,
'max_paragraph_char_number': dataset_setting.get(
'max_paragraph_char_number') if 'max_paragraph_char_number' in dataset_setting else 5000,
'history_chat_record': self.chat_record_list,
'chat_id': self.chat_id,
'dialogue_number': self.application.dialogue_number,
'prompt': model_setting.get(
'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(),
'chat_model': self.chat_model,
'model_id': self.application.model.id if self.application.model is not None else None,
'problem_optimization': self.application.problem_optimization
}
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
exclude_paragraph_id_list):
params = self.to_base_pipeline_manage_params()
return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
'exclude_paragraph_id_list': exclude_paragraph_id_list}
def append_chat_record(self, chat_record: ChatRecord):
# 存入缓存中
self.chat_record_list.append(chat_record)
if self.application.id is not None:
# 插入数据库 # 插入数据库
event.ListenerChatMessage.record_chat_message_signal.send( if not QuerySet(Chat).filter(id=self.chat_id).exists():
event.RecordChatMessageArgs(len(self.chat_message_list) - 1, self.chat_id, self.application_id, Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text).save()
chat_message) # 插入会话记录
) chat_record.save()
# 异步更新token
event.ListenerChatMessage.update_chat_message_token_signal.send(chat_message)
def get_context_message(self):
start_index = len(self.chat_message_list) - self.dialogue_number def get_post_handler(chat_info: ChatInfo):
return [self.chat_message_list[index].get_chat_message() for index in class PostHandler(PostResponseHandler):
range(start_index if start_index > 0 else 0, len(self.chat_message_list))]
def handler(self,
chat_id: UUID,
chat_record_id,
paragraph_list: List[Paragraph],
problem_text: str,
answer_text,
manage: PiplineManage,
step: BaseChatStep,
padding_problem_text: str = None,
**kwargs):
chat_record = ChatRecord(id=chat_record_id,
chat_id=chat_id,
paragraph_id_list=[str(p.id) for p in paragraph_list],
problem_text=problem_text,
answer_text=answer_text,
details=manage.get_details(),
message_tokens=manage.context['message_tokens'],
answer_tokens=manage.context['answer_tokens'],
run_time=manage.context['run_time'],
index=len(chat_info.chat_record_list) + 1)
chat_info.append_chat_record(chat_record)
# 重新设置缓存
chat_cache.set(chat_id,
chat_info, timeout=60 * 30)
return PostHandler()
class ChatMessageSerializer(serializers.Serializer): class ChatMessageSerializer(serializers.Serializer):
chat_id = serializers.UUIDField(required=True) chat_id = serializers.UUIDField(required=True)
def chat(self, message): def chat(self, message, re_chat: bool):
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
chat_id = self.data.get('chat_id') chat_id = self.data.get('chat_id')
chat_info: ChatInfo = chat_cache.get(chat_id) chat_info: ChatInfo = chat_cache.get(chat_id)
if chat_info is None: if chat_info is None:
return result.Result(response_status=status.HTTP_404_NOT_FOUND, code=404, message="会话过期") chat_info = self.re_open_chat(chat_id)
chat_cache.set(chat_id,
chat_info, timeout=60 * 30)
chat_model = chat_info.chat_model pipline_manage_builder = PiplineManage.builder()
vector = VectorStore.get_embedding_vector() # 如果开启了问题优化,则添加上问题优化步骤
# 向量库检索 if chat_info.application.problem_optimization:
_value = vector.search(message, chat_info.dataset_id_list, chat_info.exclude_document_id_list, pipline_manage_builder.append_step(BaseResetProblemStep)
[chat_message.embedding_id for chat_message in # 构建流水线管理器
(list(filter(lambda row: row.problem == message, chat_info.chat_message_list)))], pipline_message = (pipline_manage_builder.append_step(BaseSearchDatasetStep)
True, .append_step(BaseGenerateHumanMessageStep)
EmbeddingModel.get_embedding_model()) .append_step(BaseChatStep)
# 查询段落id详情 .build())
paragraph = None exclude_paragraph_id_list = []
if _value is not None: # 相同问题是否需要排除已经查询到的段落
paragraph = QuerySet(Paragraph).get(id=_value.get('paragraph_id')) if re_chat:
if paragraph is None: paragraph_id_list = flat_map([row.paragraph_id_list for row in
vector.delete_by_paragraph_id(_value.get('paragraph_id')) filter(lambda chat_record: chat_record == message,
chat_info.chat_record_list)])
exclude_paragraph_id_list = list(set(paragraph_id_list))
# 构建运行参数
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list)
# 运行流水线作业
pipline_message.run(params)
return pipline_message.context['chat_result']
title, content = (None, None) if paragraph is None else (paragraph.title, paragraph.content) @staticmethod
_id = str(uuid.uuid1()) def re_open_chat(chat_id: str):
chat = QuerySet(Chat).filter(id=chat_id).first()
if chat is None:
raise AppApiException(500, "会话不存在")
application = QuerySet(Application).filter(id=chat.application_id).first()
if application is None:
raise AppApiException(500, "应用不存在")
model = QuerySet(Model).filter(id=application.model_id).first()
chat_model = None
if model is not None:
# 对话模型
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
decrypt(model.credential)),
streaming=True)
# 数据集id列表
dataset_id_list = [str(row.dataset_id) for row in
QuerySet(ApplicationDatasetMapping).filter(
application_id=application.id)]
embedding_id, dataset_id, document_id, paragraph_id, source_type, source_id = (_value.get( # 需要排除的文档
'id'), _value.get( exclude_document_id_list = [str(document.id) for document in
'dataset_id'), _value.get( QuerySet(Document).filter(
'document_id'), _value.get( dataset_id__in=dataset_id_list,
'paragraph_id'), _value.get( is_active=False)]
'source_type'), _value.get( return ChatInfo(chat_id, chat_model, dataset_id_list, exclude_document_id_list, application)
'source_id')) if _value is not None else (None, None, None, None, None, None)
if chat_model is None:
def event_block_content(c: str):
yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None,
'is_end': True,
'content': c if c is not None else '抱歉,根据已知信息无法回答这个问题,请重新描述您的问题或提供更多信息~'}) + "\n\n"
chat_info.append_chat_message(
ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id,
paragraph_id,
source_type,
source_id,
c if c is not None else '抱歉,根据已知信息无法回答这个问题,请重新描述您的问题或提供更多信息~',
0,
0))
# 重新设置缓存
chat_cache.set(chat_id,
chat_info, timeout=60 * 30)
r = StreamingHttpResponse(streaming_content=event_block_content(content),
content_type='text/event-stream;charset=utf-8')
r['Cache-Control'] = 'no-cache'
return r
# 获取上下文
history_message = chat_info.get_context_message()
# 构建会话请求问题
chat_message = [*history_message, MessageManagement.get_message(title, content, message)]
# 对话
result_data = chat_model.stream(chat_message)
def event_content(response):
all_text = ''
try:
for chunk in response:
all_text += chunk.content
yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None,
'content': chunk.content, 'is_end': False}) + "\n\n"
yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None,
'content': '', 'is_end': True}) + "\n\n"
chat_info.append_chat_message(
ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id,
paragraph_id,
source_type,
source_id, all_text,
0,
0,
chat_message=chat_message, chat_model=chat_model))
# 重新设置缓存
chat_cache.set(chat_id,
chat_info, timeout=60 * 30)
except Exception as e:
yield e
r = StreamingHttpResponse(streaming_content=event_content(result_data),
content_type='text/event-stream;charset=utf-8')
r['Cache-Control'] = 'no-cache'
return r

View File

@ -10,7 +10,8 @@ import datetime
import json import json
import os import os
import uuid import uuid
from typing import Dict from functools import reduce
from typing import Dict, List
from django.core.cache import cache from django.core.cache import cache
from django.db import transaction from django.db import transaction
@ -18,7 +19,8 @@ from django.db.models import QuerySet
from rest_framework import serializers from rest_framework import serializers
from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord
from application.serializers.application_serializers import ModelDatasetAssociation from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
ModelSettingSerializer
from application.serializers.chat_message_serializers import ChatInfo from application.serializers.chat_message_serializers import ChatInfo
from common.db.search import native_search, native_page_search, page_search from common.db.search import native_search, native_page_search, page_search
from common.event import ListenerManagement from common.event import ListenerManagement
@ -26,8 +28,8 @@ from common.exception.app_exception import AppApiException
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock from common.util.lock import try_lock, un_lock
from common.util.rsa_util import decrypt from common.util.rsa_util import decrypt
from common.util.split_model import flat_map
from dataset.models import Document, Problem, Paragraph from dataset.models import Document, Problem, Paragraph
from embedding.models import SourceType, Embedding
from setting.models import Model from setting.models import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from smartdoc.conf import PROJECT_DIR from smartdoc.conf import PROJECT_DIR
@ -106,12 +108,12 @@ class ChatSerializers(serializers.Serializer):
chat_id = str(uuid.uuid1()) chat_id = str(uuid.uuid1())
chat_cache.set(chat_id, chat_cache.set(chat_id,
ChatInfo(chat_id, model, chat_model, application_id, dataset_id_list, ChatInfo(chat_id, chat_model, dataset_id_list,
[str(document.id) for document in [str(document.id) for document in
QuerySet(Document).filter( QuerySet(Document).filter(
dataset_id__in=dataset_id_list, dataset_id__in=dataset_id_list,
is_active=False)], is_active=False)],
application.dialogue_number), timeout=60 * 30) application), timeout=60 * 30)
return chat_id return chat_id
class OpenTempChat(serializers.Serializer): class OpenTempChat(serializers.Serializer):
@ -122,6 +124,12 @@ class ChatSerializers(serializers.Serializer):
multiple_rounds_dialogue = serializers.BooleanField(required=True) multiple_rounds_dialogue = serializers.BooleanField(required=True)
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True)) dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
# 数据集相关设置
dataset_setting = DatasetSettingSerializer(required=True)
# 模型相关设置
model_setting = ModelSettingSerializer(required=True)
# 问题补全
problem_optimization = serializers.BooleanField(required=True)
def is_valid(self, *, raise_exception=False): def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True) super().is_valid(raise_exception=True)
@ -140,42 +148,62 @@ class ChatSerializers(serializers.Serializer):
json.loads( json.loads(
decrypt(model.credential)), decrypt(model.credential)),
streaming=True) streaming=True)
application = Application(id=None, dialogue_number=3, model=model,
dataset_setting=self.data.get('dataset_setting'),
model_setting=self.data.get('model_setting'),
problem_optimization=self.data.get('problem_optimization'))
chat_cache.set(chat_id, chat_cache.set(chat_id,
ChatInfo(chat_id, model, chat_model, None, dataset_id_list, ChatInfo(chat_id, chat_model, dataset_id_list,
[str(document.id) for document in [str(document.id) for document in
QuerySet(Document).filter( QuerySet(Document).filter(
dataset_id__in=dataset_id_list, dataset_id__in=dataset_id_list,
is_active=False)], is_active=False)],
3 if self.data.get('multiple_rounds_dialogue') else 1), timeout=60 * 30) application), timeout=60 * 30)
return chat_id return chat_id
def vote_exec(source_type: SourceType, source_id: str, field: str, post_handler):
if source_type == SourceType.PROBLEM:
problem = QuerySet(Problem).get(id=source_id)
if problem is not None:
problem.__setattr__(field, post_handler(problem))
problem.save()
embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type)
embedding.__setattr__(field, problem.__getattribute__(field))
embedding.save()
if source_type == SourceType.PARAGRAPH:
paragraph = QuerySet(Paragraph).get(id=source_id)
if paragraph is not None:
paragraph.__setattr__(field, post_handler(paragraph))
paragraph.save()
embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type)
embedding.__setattr__(field, paragraph.__getattribute__(field))
embedding.save()
class ChatRecordSerializerModel(serializers.ModelSerializer): class ChatRecordSerializerModel(serializers.ModelSerializer):
class Meta: class Meta:
model = ChatRecord model = ChatRecord
fields = "__all__" fields = ['id', 'chat_id', 'vote_status', 'problem_text', 'answer_text',
'message_tokens', 'answer_tokens', 'const', 'improve_paragraph_id_list', 'run_time', 'index']
class ChatRecordSerializer(serializers.Serializer): class ChatRecordSerializer(serializers.Serializer):
class Operate(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
chat_record_id = serializers.UUIDField(required=True)
def one(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
chat_record_id = self.data.get('chat_record_id')
chat_id = self.data.get('chat_id')
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
dataset_list = []
paragraph_list = []
if len(chat_record.paragraph_id_list) > 0:
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=chat_record.paragraph_id_list),
get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
'list_dataset_paragraph_by_paragraph_id.sql')),
with_table_name=True)
dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
[{row.get(
'dataset_id'): row.get(
"dataset_name")} for
row in
paragraph_list],
{}).items()]
return {
**ChatRecordSerializerModel(chat_record).data,
'padding_problem_text': chat_record.details.get(
'padding_problem_text') if 'problem_padding' in chat_record.details else None,
'dataset_list': dataset_list,
'paragraph_list': paragraph_list}
class Query(serializers.Serializer): class Query(serializers.Serializer):
application_id = serializers.UUIDField(required=True) application_id = serializers.UUIDField(required=True)
chat_id = serializers.UUIDField(required=True) chat_id = serializers.UUIDField(required=True)
@ -183,15 +211,57 @@ class ChatRecordSerializer(serializers.Serializer):
def list(self, with_valid=True): def list(self, with_valid=True):
if with_valid: if with_valid:
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))
return [ChatRecordSerializerModel(chat_record).data for chat_record in return [ChatRecordSerializerModel(chat_record).data for chat_record in
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))] QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))]
def reset_chat_record_list(self, chat_record_list: List[ChatRecord]):
paragraph_id_list = flat_map([chat_record.paragraph_id_list for chat_record in chat_record_list])
# 去重
paragraph_id_list = list(set(paragraph_id_list))
paragraph_list = self.search_paragraph(paragraph_id_list)
return [self.reset_chat_record(chat_record, paragraph_list) for chat_record in chat_record_list]
@staticmethod
def search_paragraph(paragraph_id_list: List[str]):
paragraph_list = []
if len(paragraph_id_list) > 0:
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
'list_dataset_paragraph_by_paragraph_id.sql')),
with_table_name=True)
return paragraph_list
@staticmethod
def reset_chat_record(chat_record, all_paragraph_list):
paragraph_list = list(
filter(lambda paragraph: chat_record.paragraph_id_list.__contains__(str(paragraph.get('id'))),
all_paragraph_list))
dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
[{row.get(
'dataset_id'): row.get(
"dataset_name")} for
row in
paragraph_list],
{}).items()]
return {
**ChatRecordSerializerModel(chat_record).data,
'padding_problem_text': chat_record.details.get('problem_padding').get(
'padding_problem_text') if 'problem_padding' in chat_record.details else None,
'dataset_list': dataset_list,
'paragraph_list': paragraph_list
}
def page(self, current_page: int, page_size: int, with_valid=True): def page(self, current_page: int, page_size: int, with_valid=True):
if with_valid: if with_valid:
self.is_valid(raise_exception=True) self.is_valid(raise_exception=True)
return page_search(current_page, page_size, page = page_search(current_page, page_size,
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by("index"), QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by("index"),
post_records_handler=lambda chat_record: ChatRecordSerializerModel(chat_record).data) post_records_handler=lambda chat_record: chat_record)
records = page.get('records')
page['records'] = self.reset_chat_record_list(records)
return page
class Vote(serializers.Serializer): class Vote(serializers.Serializer):
chat_id = serializers.UUIDField(required=True) chat_id = serializers.UUIDField(required=True)
@ -216,38 +286,20 @@ class ChatRecordSerializer(serializers.Serializer):
if vote_status == VoteChoices.STAR: if vote_status == VoteChoices.STAR:
# 点赞 # 点赞
chat_record_details_model.vote_status = VoteChoices.STAR chat_record_details_model.vote_status = VoteChoices.STAR
# 点赞数量 +1
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
'star_num',
lambda r: r.star_num + 1)
if vote_status == VoteChoices.TRAMPLE: if vote_status == VoteChoices.TRAMPLE:
# 点踩 # 点踩
chat_record_details_model.vote_status = VoteChoices.TRAMPLE chat_record_details_model.vote_status = VoteChoices.TRAMPLE
# 点踩数量+1
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
'trample_num',
lambda r: r.trample_num + 1)
chat_record_details_model.save() chat_record_details_model.save()
else: else:
if vote_status == VoteChoices.UN_VOTE: if vote_status == VoteChoices.UN_VOTE:
# 取消点赞 # 取消点赞
chat_record_details_model.vote_status = VoteChoices.UN_VOTE chat_record_details_model.vote_status = VoteChoices.UN_VOTE
chat_record_details_model.save() chat_record_details_model.save()
if chat_record_details_model.vote_status == VoteChoices.STAR:
# 点赞数量 -1
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
'star_num', lambda r: r.star_num - 1)
if chat_record_details_model.vote_status == VoteChoices.TRAMPLE:
# 点踩数量 -1
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
'trample_num', lambda r: r.trample_num - 1)
else: else:
raise AppApiException(500, "已经投票过,请先取消后再进行投票") raise AppApiException(500, "已经投票过,请先取消后再进行投票")
finally: finally:
un_lock(self.data.get('chat_record_id')) un_lock(self.data.get('chat_record_id'))
return True return True
class ImproveSerializer(serializers.Serializer): class ImproveSerializer(serializers.Serializer):

View File

@ -1,4 +1,4 @@
SELECT * FROM ( SELECT * FROM application ${application_custom_sql} UNION SELECT *,to_json(dataset_setting) as dataset_setting,to_json(model_setting) as model_setting FROM ( SELECT * FROM application ${application_custom_sql} UNION
SELECT SELECT
* *
FROM FROM

View File

@ -0,0 +1,6 @@
SELECT
paragraph.*,
dataset."name" AS "dataset_name"
FROM
paragraph paragraph
LEFT JOIN dataset dataset ON dataset."id" = paragraph.dataset_id

View File

@ -58,6 +58,7 @@ class ApplicationApi(ApiMixin):
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description='创建时间'), 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description='创建时间'),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description='修改时间'), 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description='修改时间'),
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING), items=openapi.Schema(type=openapi.TYPE_STRING),
title="关联知识库Id列表", title="关联知识库Id列表",
@ -133,7 +134,7 @@ class ApplicationApi(ApiMixin):
def get_request_body_api(): def get_request_body_api():
return openapi.Schema( return openapi.Schema(
type=openapi.TYPE_OBJECT, type=openapi.TYPE_OBJECT,
required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue'], required=[],
properties={ properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"), 'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"),
@ -141,11 +142,52 @@ class ApplicationApi(ApiMixin):
"multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话", "multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话",
description="是否开启多轮对话"), description="是否开启多轮对话"),
'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"), 'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"),
'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
title="示例列表", description="示例列表"),
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING), items=openapi.Schema(type=openapi.TYPE_STRING),
title="关联知识库Id列表", description="关联知识库Id列表"), title="关联知识库Id列表", description="关联知识库Id列表"),
'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(),
'model_setting': ApplicationApi.ModelSetting.get_request_body_api(),
'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化",
description="是否开启问题优化", default=True)
}
)
class DatasetSetting(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=[''],
properties={
'top_n': openapi.Schema(type=openapi.TYPE_NUMBER, title="引用分段数", description="引用分段数",
default=5),
'similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title='相似度', description="相似度",
default=0.6),
'max_paragraph_char_number': openapi.Schema(type=openapi.TYPE_NUMBER, title='最多引用字符数',
description="最多引用字符数", default=3000),
}
)
class ModelSetting(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['prompt'],
properties={
'prompt': openapi.Schema(type=openapi.TYPE_STRING, title="提示词", description="提示词",
default=('已知信息:'
'\n{data}'
'\n回答要求:'
'\n- 如果你不知道答案或者没有从获取答案,请回答“没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作”。'
'\n- 避免提及你是从<data></data>中获得的知识。'
'\n- 请保持答案与<data></data>中描述的一致。'
'\n- 请使用markdown 语法优化答案的格式。'
'\n- <data></data>中的图片链接、链接地址和脚本语言请完整返回。'
'\n- 请使用与问题相同的语言来回答。'
'\n问题:'
'\n{question}')),
} }
) )
@ -155,7 +197,8 @@ class ApplicationApi(ApiMixin):
def get_request_body_api(): def get_request_body_api():
return openapi.Schema( return openapi.Schema(
type=openapi.TYPE_OBJECT, type=openapi.TYPE_OBJECT,
required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue'], required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'dataset_setting', 'model_setting',
'problem_optimization'],
properties={ properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"), 'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"),
@ -163,11 +206,13 @@ class ApplicationApi(ApiMixin):
"multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话", "multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话",
description="是否开启多轮对话"), description="是否开启多轮对话"),
'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"), 'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"),
'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
title="示例列表", description="示例列表"),
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING), items=openapi.Schema(type=openapi.TYPE_STRING),
title="关联知识库Id列表", description="关联知识库Id列表") title="关联知识库Id列表", description="关联知识库Id列表"),
'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(),
'model_setting': ApplicationApi.ModelSetting.get_request_body_api(),
'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化",
description="是否开启问题优化", default=True)
} }
) )

View File

@ -8,6 +8,7 @@
""" """
from drf_yasg import openapi from drf_yasg import openapi
from application.swagger_api.application_api import ApplicationApi
from common.mixins.api_mixin import ApiMixin from common.mixins.api_mixin import ApiMixin
@ -19,6 +20,7 @@ class ChatApi(ApiMixin):
required=['message'], required=['message'],
properties={ properties={
'message': openapi.Schema(type=openapi.TYPE_STRING, title="问题", description="问题"), 'message': openapi.Schema(type=openapi.TYPE_STRING, title="问题", description="问题"),
're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default="重新生成")
} }
) )
@ -73,14 +75,19 @@ class ChatApi(ApiMixin):
def get_request_body_api(): def get_request_body_api():
return openapi.Schema( return openapi.Schema(
type=openapi.TYPE_OBJECT, type=openapi.TYPE_OBJECT,
required=['model_id', 'multiple_rounds_dialogue'], required=['model_id', 'multiple_rounds_dialogue', 'dataset_setting', 'model_setting',
'problem_optimization'],
properties={ properties={
'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"), 'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"),
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING), items=openapi.Schema(type=openapi.TYPE_STRING),
title="关联知识库Id列表", description="关联知识库Id列表"), title="关联知识库Id列表", description="关联知识库Id列表"),
'multiple_rounds_dialogue': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮会话", 'multiple_rounds_dialogue': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮会话",
description="是否开启多轮会话") description="是否开启多轮会话"),
'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(),
'model_setting': ApplicationApi.ModelSetting.get_request_body_api(),
'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化",
description="是否开启问题优化", default=True)
} }
) )

View File

@ -25,6 +25,8 @@ urlpatterns = [
path('application/<str:application_id>/chat/<chat_id>/chat_record/', views.ChatView.ChatRecord.as_view()), path('application/<str:application_id>/chat/<chat_id>/chat_record/', views.ChatView.ChatRecord.as_view()),
path('application/<str:application_id>/chat/<chat_id>/chat_record/<int:current_page>/<int:page_size>', path('application/<str:application_id>/chat/<chat_id>/chat_record/<int:current_page>/<int:page_size>',
views.ChatView.ChatRecord.Page.as_view()), views.ChatView.ChatRecord.Page.as_view()),
path('application/<str:application_id>/chat/<chat_id>/chat_record/<chat_record_id>',
views.ChatView.ChatRecord.Operate.as_view()),
path('application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/vote', path('application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/vote',
views.ChatView.ChatRecord.Vote.as_view(), views.ChatView.ChatRecord.Vote.as_view(),
name=''), name=''),

View File

@ -71,7 +71,8 @@ class ChatView(APIView):
dynamic_tag=keywords.get('application_id'))]) dynamic_tag=keywords.get('application_id'))])
) )
def post(self, request: Request, chat_id: str): def post(self, request: Request, chat_id: str):
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message')) return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'), request.data.get(
're_chat') if 're_chat' in request.data else False)
@action(methods=['GET'], detail=False) @action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取对话列表", @swagger_auto_schema(operation_summary="获取对话列表",
@ -134,6 +135,27 @@ class ChatView(APIView):
class ChatRecord(APIView): class ChatRecord(APIView):
authentication_classes = [TokenAuth] authentication_classes = [TokenAuth]
class Operate(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取对话记录详情",
operation_id="获取对话记录详情",
manual_parameters=ChatRecordApi.get_request_params_api(),
responses=result.get_api_array_response(ChatRecordApi.get_response_body_api()),
tags=["应用/对话日志"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)
def get(self, request: Request, application_id: str, chat_id: str, chat_record_id: str):
return result.success(ChatRecordSerializer.Operate(
data={'application_id': application_id,
'chat_id': chat_id,
'chat_record_id': chat_record_id}).one())
@action(methods=['GET'], detail=False) @action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取对话记录列表", @swagger_auto_schema(operation_summary="获取对话记录列表",
operation_id="获取对话记录列表", operation_id="获取对话记录列表",

View File

@ -83,7 +83,7 @@ def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, s
field_replace_dict = get_field_replace_dict(queryset) field_replace_dict = get_field_replace_dict(queryset)
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection, app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
field_replace_dict=field_replace_dict) field_replace_dict=field_replace_dict)
sql, params = app_sql_compiler.get_query_str(with_table_name) sql, params = app_sql_compiler.get_query_str(with_table_name=with_table_name)
return sql, params return sql, params

View File

@ -7,10 +7,8 @@
@desc: @desc:
""" """
from .listener_manage import * from .listener_manage import *
from .listener_chat_message import *
def run(): def run():
listener_manage.ListenerManagement().run() listener_manage.ListenerManagement().run()
listener_chat_message.ListenerChatMessage().run()
QuerySet(Document).filter(status=Status.embedding).update(**{'status': Status.error}) QuerySet(Document).filter(status=Status.embedding).update(**{'status': Status.error})

View File

@ -1,67 +0,0 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file listener_manage.py
@date2023/10/20 14:01
@desc:
"""
import logging
from blinker import signal
from django.db.models import QuerySet
from application.models import ChatRecord, Chat
from application.serializers.chat_message_serializers import ChatMessage
from common.event.common import poxy
class RecordChatMessageArgs:
def __init__(self, index: int, chat_id: str, application_id: str, chat_message: ChatMessage):
self.index = index
self.chat_id = chat_id
self.application_id = application_id
self.chat_message = chat_message
class ListenerChatMessage:
record_chat_message_signal = signal("record_chat_message")
update_chat_message_token_signal = signal("update_chat_message_token")
@staticmethod
def record_chat_message(args: RecordChatMessageArgs):
if not QuerySet(Chat).filter(id=args.chat_id).exists():
Chat(id=args.chat_id, application_id=args.application_id, abstract=args.chat_message.problem).save()
# 插入会话记录
try:
chat_record = ChatRecord(
id=args.chat_message.id,
chat_id=args.chat_id,
dataset_id=args.chat_message.dataset_id,
paragraph_id=args.chat_message.paragraph_id,
source_id=args.chat_message.source_id,
source_type=args.chat_message.source_type,
problem_text=args.chat_message.problem,
answer_text=args.chat_message.answer,
index=args.index,
message_tokens=args.chat_message.message_tokens,
answer_tokens=args.chat_message.answer_token)
chat_record.save()
except Exception as e:
print(e)
@staticmethod
@poxy
def update_token(chat_message: ChatMessage):
if chat_message.chat_model is not None:
logging.getLogger("max_kb").info("开始更新token")
message_token = chat_message.chat_model.get_num_tokens_from_messages(chat_message.chat_message)
answer_token = chat_message.chat_model.get_num_tokens(chat_message.answer)
# 修改token数量
QuerySet(ChatRecord).filter(id=chat_message.id).update(
**{'message_tokens': message_token, 'answer_tokens': answer_token})
def run(self):
# 记录会话
ListenerChatMessage.record_chat_message_signal.connect(self.record_chat_message)
ListenerChatMessage.update_chat_message_token_signal.connect(self.update_token)

View File

@ -0,0 +1,34 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file common.py
@date2024/1/11 18:44
@desc:
"""
from rest_framework import serializers
class InstanceField(serializers.Field):
def __init__(self, model_type, **kwargs):
self.model_type = model_type
super().__init__(**kwargs)
def to_internal_value(self, data):
if not isinstance(data, self.model_type):
self.fail('message类型错误', value=data)
return data
def to_representation(self, value):
return value
class FunctionField(serializers.Field):
def to_internal_value(self, data):
if not callable(data):
self.fail('不是一個函數', value=data)
return data
def to_representation(self, value):
return value

View File

@ -5,9 +5,7 @@ SELECT
problem.dataset_id AS dataset_id, problem.dataset_id AS dataset_id,
0 AS source_type, 0 AS source_type,
problem."content" AS "text", problem."content" AS "text",
paragraph.is_active AS is_active, paragraph.is_active AS is_active
problem.star_num as star_num,
problem.trample_num as trample_num
FROM FROM
problem problem problem problem
LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id
@ -23,9 +21,7 @@ SELECT
concat_ws(' concat_ws('
',concat_ws(' ',concat_ws('
',paragraph.title,paragraph."content"),paragraph.title) AS "text", ',paragraph.title,paragraph."content"),paragraph.title) AS "text",
paragraph.is_active AS is_active, paragraph.is_active AS is_active
paragraph.star_num as star_num,
paragraph.trample_num as trample_num
FROM FROM
paragraph paragraph paragraph paragraph

View File

@ -0,0 +1,37 @@
# Generated by Django 4.1.10 on 2024-01-16 11:22
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('dataset', '0003_alter_paragraph_content'),
]
operations = [
migrations.RemoveField(
model_name='paragraph',
name='hit_num',
),
migrations.RemoveField(
model_name='paragraph',
name='star_num',
),
migrations.RemoveField(
model_name='paragraph',
name='trample_num',
),
migrations.RemoveField(
model_name='problem',
name='hit_num',
),
migrations.RemoveField(
model_name='problem',
name='star_num',
),
migrations.RemoveField(
model_name='problem',
name='trample_num',
),
]

View File

@ -74,9 +74,6 @@ class Paragraph(AppModelMixin):
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING) dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
content = models.CharField(max_length=4096, verbose_name="段落内容") content = models.CharField(max_length=4096, verbose_name="段落内容")
title = models.CharField(max_length=256, verbose_name="标题", default="") title = models.CharField(max_length=256, verbose_name="标题", default="")
hit_num = models.IntegerField(verbose_name="命中数量", default=0)
star_num = models.IntegerField(verbose_name="点赞数", default=0)
trample_num = models.IntegerField(verbose_name="点踩数", default=0)
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices, status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
default=Status.embedding) default=Status.embedding)
is_active = models.BooleanField(default=True) is_active = models.BooleanField(default=True)
@ -94,9 +91,6 @@ class Problem(AppModelMixin):
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False) dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False)
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False) paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False)
content = models.CharField(max_length=256, verbose_name="问题内容") content = models.CharField(max_length=256, verbose_name="问题内容")
hit_num = models.IntegerField(verbose_name="命中数量", default=0)
star_num = models.IntegerField(verbose_name="点赞数", default=0)
trample_num = models.IntegerField(verbose_name="点踩数", default=0)
class Meta: class Meta:
db_table = "problem" db_table = "problem"

View File

@ -26,12 +26,11 @@ from application.models import ApplicationDatasetMapping
from common.config.embedding_config import VectorStore, EmbeddingModel from common.config.embedding_config import VectorStore, EmbeddingModel
from common.db.search import get_dynamics_model, native_page_search, native_search from common.db.search import get_dynamics_model, native_page_search, native_search
from common.db.sql_execute import select_list from common.db.sql_execute import select_list
from common.event.listener_manage import ListenerManagement, SyncWebDatasetArgs
from common.exception.app_exception import AppApiException from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin from common.mixins.api_mixin import ApiMixin
from common.util.common import post from common.util.common import post
from common.util.file_util import get_file_content from common.util.file_util import get_file_content
from common.util.fork import ChildLink, Fork, ForkManage from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
from dataset.serializers.common_serializers import list_paragraph from dataset.serializers.common_serializers import list_paragraph
@ -286,7 +285,8 @@ class DataSetSerializers(serializers.ModelSerializer):
properties={ properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"), 'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url", description="web站点url"), 'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url",
description="web站点url"),
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器") 'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器")
} }
) )
@ -369,7 +369,8 @@ class DataSetSerializers(serializers.ModelSerializer):
dataset_id = uuid.uuid1() dataset_id = uuid.uuid1()
dataset = DataSet( dataset = DataSet(
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id, **{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
'type': Type.web, 'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector')}}) 'type': Type.web,
'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector')}})
dataset.save() dataset.save()
ListenerManagement.sync_web_dataset_signal.send( ListenerManagement.sync_web_dataset_signal.send(
SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'), SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'),

View File

@ -28,7 +28,7 @@ from dataset.serializers.problem_serializers import ProblemInstanceSerializer, P
class ParagraphSerializer(serializers.ModelSerializer): class ParagraphSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Paragraph model = Paragraph
fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'document_id', 'title', fields = ['id', 'content', 'is_active', 'document_id', 'title',
'create_time', 'update_time'] 'create_time', 'update_time']

View File

@ -24,7 +24,7 @@ from embedding.vector.pg_vector import PGVector
class ProblemSerializer(serializers.ModelSerializer): class ProblemSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Problem model = Problem
fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id', 'document_id', fields = ['id', 'content', 'dataset_id', 'document_id',
'create_time', 'update_time'] 'create_time', 'update_time']
@ -77,8 +77,8 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
'document_id': self.data.get('document_id'), 'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'), 'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'), 'dataset_id': self.data.get('dataset_id'),
'star_num': 0,
'trample_num': 0}) })
return ProblemSerializers.Operate( return ProblemSerializers.Operate(
data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'), data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'),

View File

@ -0,0 +1,37 @@
# Generated by Django 4.1.10 on 2024-01-16 11:22
import django.contrib.postgres.fields
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('embedding', '0001_initial'),
]
operations = [
migrations.RemoveField(
model_name='embedding',
name='star_num',
),
migrations.RemoveField(
model_name='embedding',
name='trample_num',
),
migrations.AddField(
model_name='embedding',
name='keywords',
field=django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=256), default=list, size=None, verbose_name='关键词列表'),
),
migrations.AddField(
model_name='embedding',
name='meta',
field=models.JSONField(default=dict, verbose_name='元数据'),
),
migrations.AlterField(
model_name='embedding',
name='source_type',
field=models.CharField(choices=[('0', '问题'), ('1', '段落'), ('2', '标题')], default='0', max_length=5, verbose_name='资源类型'),
),
]

View File

@ -6,6 +6,7 @@
@date2023/9/21 15:46 @date2023/9/21 15:46
@desc: @desc:
""" """
from django.contrib.postgres.fields import ArrayField
from django.db import models from django.db import models
from common.field.vector_field import VectorField from common.field.vector_field import VectorField
@ -16,6 +17,7 @@ class SourceType(models.TextChoices):
"""订单类型""" """订单类型"""
PROBLEM = 0, '问题' PROBLEM = 0, '问题'
PARAGRAPH = 1, '段落' PARAGRAPH = 1, '段落'
TITLE = 2, '标题'
class Embedding(models.Model): class Embedding(models.Model):
@ -36,10 +38,10 @@ class Embedding(models.Model):
embedding = VectorField(verbose_name="向量") embedding = VectorField(verbose_name="向量")
star_num = models.IntegerField(default=0, verbose_name="点赞数量") keywords = ArrayField(verbose_name="关键词列表",
base_field=models.CharField(max_length=256), default=list)
trample_num = models.IntegerField(default=0, meta = models.JSONField(verbose_name="元数据", default=dict)
verbose_name="点踩数量")
class Meta: class Meta:
db_table = "embedding" db_table = "embedding"

View File

@ -1,15 +1,17 @@
SELECT * FROM (SELECT SELECT
*, paragraph_id,
( 1 - ( embedding.embedding <=> %s ) ) AS similarity, comprehensive_score,
CASE comprehensive_score as similarity
WHEN embedding.star_num - embedding.trample_num = 0 THEN
0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) )
END AS score
FROM FROM
embedding, (
( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs SELECT DISTINCT ON
${embedding_query} ("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score
) temp FROM
WHERE similarity>0.5 ( SELECT *, ( 1 - ( embedding.embedding <=> %s ) ) AS similarity FROM embedding ${embedding_query}) TEMP
ORDER BY (similarity + score) DESC LIMIT 1 ORDER BY
paragraph_id,
similarity DESC
) DISTINCT_TEMP
WHERE comprehensive_score>%s
ORDER BY comprehensive_score DESC
LIMIT %s

View File

@ -1,34 +1,17 @@
SELECT SELECT
similarity,
paragraph_id, paragraph_id,
comprehensive_score comprehensive_score,
comprehensive_score as similarity
FROM FROM
( (
SELECT DISTINCT ON SELECT DISTINCT ON
( "paragraph_id" ) ( similarity + score ),*, ("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score
( similarity + score ) AS comprehensive_score
FROM FROM
( ( SELECT *, ( 1 - ( embedding.embedding <=> %s ) ) AS similarity FROM embedding ${embedding_query} ) TEMP
SELECT
*,
( 1 - ( embedding.embedding <=> %s ) ) AS similarity,
CASE
WHEN embedding.star_num - embedding.trample_num = 0 THEN
0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) )
END AS score
FROM
embedding,
( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs
${embedding_query}
) TEMP
WHERE
similarity > %s
ORDER BY ORDER BY
paragraph_id, paragraph_id,
( similarity + score ) similarity DESC
DESC ) DISTINCT_TEMP
) ss WHERE comprehensive_score>%s
ORDER BY ORDER BY comprehensive_score DESC
comprehensive_score DESC LIMIT %s
LIMIT %s

View File

@ -99,8 +99,6 @@ class BaseVectorStore(ABC):
@abstractmethod @abstractmethod
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool, is_active: bool,
star_num: int,
trample_num: int,
embedding: HuggingFaceEmbeddings): embedding: HuggingFaceEmbeddings):
pass pass
@ -108,11 +106,20 @@ class BaseVectorStore(ABC):
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings): def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
pass pass
@abstractmethod
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_id_list: list[str], exclude_paragraph_list: list[str],
is_active: bool, is_active: bool,
embedding: HuggingFaceEmbeddings): embedding: HuggingFaceEmbeddings):
if dataset_id_list is None or len(dataset_id_list) == 0:
return []
embedding_query = embedding.embed_query(query_text)
result = self.query(embedding_query, dataset_id_list, exclude_document_id_list, exclude_paragraph_list,
is_active, 1, 0.65)
return result[0]
@abstractmethod
def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float):
pass pass
@abstractmethod @abstractmethod

View File

@ -33,8 +33,6 @@ class PGVector(BaseVectorStore):
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool, is_active: bool,
star_num: int,
trample_num: int,
embedding: HuggingFaceEmbeddings): embedding: HuggingFaceEmbeddings):
text_embedding = embedding.embed_query(text) text_embedding = embedding.embed_query(text)
embedding = Embedding(id=uuid.uuid1(), embedding = Embedding(id=uuid.uuid1(),
@ -44,10 +42,7 @@ class PGVector(BaseVectorStore):
paragraph_id=paragraph_id, paragraph_id=paragraph_id,
source_id=source_id, source_id=source_id,
embedding=text_embedding, embedding=text_embedding,
source_type=source_type, source_type=source_type)
star_num=star_num,
trample_num=trample_num
)
embedding.save() embedding.save()
return True return True
@ -61,8 +56,6 @@ class PGVector(BaseVectorStore):
is_active=text_list[index].get('is_active', True), is_active=text_list[index].get('is_active', True),
source_id=text_list[index].get('source_id'), source_id=text_list[index].get('source_id'),
source_type=text_list[index].get('source_type'), source_type=text_list[index].get('source_type'),
star_num=text_list[index].get('star_num'),
trample_num=text_list[index].get('trample_num'),
embedding=embeddings[index]) for index in embedding=embeddings[index]) for index in
range(0, len(text_list))] range(0, len(text_list))]
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
@ -78,29 +71,27 @@ class PGVector(BaseVectorStore):
'hit_test.sql')), 'hit_test.sql')),
with_table_name=True) with_table_name=True)
embedding_model = select_list(exec_sql, embedding_model = select_list(exec_sql,
[json.dumps(embedding_query), *exec_params, *exec_params, similarity, top_number]) [json.dumps(embedding_query), *exec_params, similarity, top_number])
return embedding_model return embedding_model
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_id_list: list[str], exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float):
is_active: bool,
embedding: HuggingFaceEmbeddings):
exclude_dict = {} exclude_dict = {}
if dataset_id_list is None or len(dataset_id_list) == 0: if dataset_id_list is None or len(dataset_id_list) == 0:
return None return []
query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=is_active) query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=is_active)
embedding_query = embedding.embed_query(query_text)
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0: if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
exclude_dict.__setitem__('document_id__in', exclude_document_id_list) exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
if exclude_id_list is not None and len(exclude_id_list) > 0: if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0:
exclude_dict.__setitem__('id__in', exclude_id_list) exclude_dict.__setitem__('paragraph_id__in', exclude_paragraph_list)
query_set = query_set.exclude(**exclude_dict) query_set = query_set.exclude(**exclude_dict)
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
select_string=get_file_content( select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
'embedding_search.sql')), 'embedding_search.sql')),
with_table_name=True) with_table_name=True)
embedding_model = select_one(exec_sql, (json.dumps(embedding_query), *exec_params, *exec_params)) embedding_model = select_list(exec_sql,
[json.dumps(query_embedding), *exec_params, similarity, top_n])
return embedding_model return embedding_model
def update_by_source_id(self, source_id: str, instance: Dict): def update_by_source_id(self, source_id: str, instance: Dict):

View File

@ -14,13 +14,23 @@ from langchain.chat_models.base import BaseChatModel
from langchain.load import dumpd from langchain.load import dumpd
from langchain.schema import LLMResult from langchain.schema import LLMResult
from langchain.schema.language_model import LanguageModelInput from langchain.schema.language_model import LanguageModelInput
from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage, get_buffer_string
from langchain.schema.output import ChatGenerationChunk from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig from langchain.schema.runnable import RunnableConfig
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False,
force_download=False)
class QianfanChatModel(QianfanChatEndpoint): class QianfanChatModel(QianfanChatEndpoint):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
return len(tokenizer.encode(text))
def stream( def stream(
self, self,
input: LanguageModelInput, input: LanguageModelInput,
@ -30,7 +40,7 @@ class QianfanChatModel(QianfanChatEndpoint):
**kwargs: Any, **kwargs: Any,
) -> Iterator[BaseMessageChunk]: ) -> Iterator[BaseMessageChunk]:
if len(input) % 2 == 0: if len(input) % 2 == 0:
input = [HumanMessage(content='占位'), *input] input = [HumanMessage(content='padding'), *input]
input = [ input = [
HumanMessage(content=input[index].content) if index % 2 == 0 else AIMessage(content=input[index].content) HumanMessage(content=input[index].content) if index % 2 == 0 else AIMessage(content=input[index].content)
for index in range(0, len(input))] for index in range(0, len(input))]