feat: 对话支持全部返回
This commit is contained in:
parent
7eb18fbf30
commit
b7a406db56
@ -116,10 +116,10 @@ class IBaseChatPipelineStep:
|
|||||||
:return: 执行结果
|
:return: 执行结果
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
self.context['start_time'] = start_time
|
||||||
# 校验参数,
|
# 校验参数,
|
||||||
self.valid_args(manage)
|
self.valid_args(manage)
|
||||||
self._run(manage)
|
self._run(manage)
|
||||||
self.context['start_time'] = start_time
|
|
||||||
self.context['run_time'] = time.time() - start_time
|
self.context['run_time'] = time.time() - start_time
|
||||||
|
|
||||||
def _run(self, manage):
|
def _run(self, manage):
|
||||||
|
|||||||
@ -63,6 +63,8 @@ class IChatStep(IBaseChatPipelineStep):
|
|||||||
post_response_handler = InstanceField(model_type=PostResponseHandler)
|
post_response_handler = InstanceField(model_type=PostResponseHandler)
|
||||||
# 补全问题
|
# 补全问题
|
||||||
padding_problem_text = serializers.CharField(required=False)
|
padding_problem_text = serializers.CharField(required=False)
|
||||||
|
# 是否使用流的形式输出
|
||||||
|
stream = serializers.BooleanField(required=False)
|
||||||
|
|
||||||
def is_valid(self, *, raise_exception=False):
|
def is_valid(self, *, raise_exception=False):
|
||||||
super().is_valid(raise_exception=True)
|
super().is_valid(raise_exception=True)
|
||||||
@ -85,5 +87,5 @@ class IChatStep(IBaseChatPipelineStep):
|
|||||||
chat_model: BaseChatModel = None,
|
chat_model: BaseChatModel = None,
|
||||||
paragraph_list=None,
|
paragraph_list=None,
|
||||||
manage: PiplineManage = None,
|
manage: PiplineManage = None,
|
||||||
padding_problem_text: str = None, **kwargs):
|
padding_problem_text: str = None, stream: bool = True, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -16,11 +16,12 @@ from typing import List
|
|||||||
from django.http import StreamingHttpResponse
|
from django.http import StreamingHttpResponse
|
||||||
from langchain.chat_models.base import BaseChatModel
|
from langchain.chat_models.base import BaseChatModel
|
||||||
from langchain.schema import BaseMessage
|
from langchain.schema import BaseMessage
|
||||||
from langchain.schema.messages import BaseMessageChunk, HumanMessage
|
from langchain.schema.messages import BaseMessageChunk, HumanMessage, AIMessage
|
||||||
|
|
||||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||||
from application.chat_pipeline.pipeline_manage import PiplineManage
|
from application.chat_pipeline.pipeline_manage import PiplineManage
|
||||||
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
|
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
|
||||||
|
from common.response import result
|
||||||
|
|
||||||
|
|
||||||
def event_content(response,
|
def event_content(response,
|
||||||
@ -71,23 +72,16 @@ class BaseChatStep(IChatStep):
|
|||||||
paragraph_list=None,
|
paragraph_list=None,
|
||||||
manage: PiplineManage = None,
|
manage: PiplineManage = None,
|
||||||
padding_problem_text: str = None,
|
padding_problem_text: str = None,
|
||||||
|
stream: bool = True,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
# 调用模型
|
if stream:
|
||||||
if chat_model is None:
|
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||||
chat_result = iter(
|
paragraph_list,
|
||||||
[BaseMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
|
manage, padding_problem_text)
|
||||||
else:
|
else:
|
||||||
chat_result = chat_model.stream(message_list)
|
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||||
|
paragraph_list,
|
||||||
chat_record_id = uuid.uuid1()
|
manage, padding_problem_text)
|
||||||
r = StreamingHttpResponse(
|
|
||||||
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
|
|
||||||
post_response_handler, manage, self, chat_model, message_list, problem_text,
|
|
||||||
padding_problem_text),
|
|
||||||
content_type='text/event-stream;charset=utf-8')
|
|
||||||
|
|
||||||
r['Cache-Control'] = 'no-cache'
|
|
||||||
return r
|
|
||||||
|
|
||||||
def get_details(self, manage, **kwargs):
|
def get_details(self, manage, **kwargs):
|
||||||
return {
|
return {
|
||||||
@ -109,3 +103,58 @@ class BaseChatStep(IChatStep):
|
|||||||
message_list]
|
message_list]
|
||||||
result.append({'role': 'ai', 'content': answer_text})
|
result.append({'role': 'ai', 'content': answer_text})
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def execute_stream(self, message_list: List[BaseMessage],
|
||||||
|
chat_id,
|
||||||
|
problem_text,
|
||||||
|
post_response_handler: PostResponseHandler,
|
||||||
|
chat_model: BaseChatModel = None,
|
||||||
|
paragraph_list=None,
|
||||||
|
manage: PiplineManage = None,
|
||||||
|
padding_problem_text: str = None):
|
||||||
|
# 调用模型
|
||||||
|
if chat_model is None:
|
||||||
|
chat_result = iter(
|
||||||
|
[BaseMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
|
||||||
|
else:
|
||||||
|
chat_result = chat_model.stream(message_list)
|
||||||
|
|
||||||
|
chat_record_id = uuid.uuid1()
|
||||||
|
r = StreamingHttpResponse(
|
||||||
|
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
|
||||||
|
post_response_handler, manage, self, chat_model, message_list, problem_text,
|
||||||
|
padding_problem_text),
|
||||||
|
content_type='text/event-stream;charset=utf-8')
|
||||||
|
|
||||||
|
r['Cache-Control'] = 'no-cache'
|
||||||
|
return r
|
||||||
|
|
||||||
|
def execute_block(self, message_list: List[BaseMessage],
|
||||||
|
chat_id,
|
||||||
|
problem_text,
|
||||||
|
post_response_handler: PostResponseHandler,
|
||||||
|
chat_model: BaseChatModel = None,
|
||||||
|
paragraph_list=None,
|
||||||
|
manage: PiplineManage = None,
|
||||||
|
padding_problem_text: str = None):
|
||||||
|
# 调用模型
|
||||||
|
if chat_model is None:
|
||||||
|
chat_result = AIMessage(
|
||||||
|
content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list]))
|
||||||
|
else:
|
||||||
|
chat_result = chat_model(message_list)
|
||||||
|
chat_record_id = uuid.uuid1()
|
||||||
|
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||||
|
response_token = chat_model.get_num_tokens(chat_result.content)
|
||||||
|
self.context['message_tokens'] = request_token
|
||||||
|
self.context['answer_tokens'] = response_token
|
||||||
|
current_time = time.time()
|
||||||
|
self.context['answer_text'] = chat_result.content
|
||||||
|
self.context['run_time'] = current_time - self.context['start_time']
|
||||||
|
manage.context['run_time'] = current_time - manage.context['start_time']
|
||||||
|
manage.context['message_tokens'] = manage.context['message_tokens'] + request_token
|
||||||
|
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
|
||||||
|
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||||
|
chat_result.content, manage, self, padding_problem_text)
|
||||||
|
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
|
||||||
|
'content': chat_result.content, 'is_end': True})
|
||||||
|
|||||||
@ -72,15 +72,16 @@ class ChatInfo:
|
|||||||
'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(),
|
'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(),
|
||||||
'chat_model': self.chat_model,
|
'chat_model': self.chat_model,
|
||||||
'model_id': self.application.model.id if self.application.model is not None else None,
|
'model_id': self.application.model.id if self.application.model is not None else None,
|
||||||
'problem_optimization': self.application.problem_optimization
|
'problem_optimization': self.application.problem_optimization,
|
||||||
|
'stream': True
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
|
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
|
||||||
exclude_paragraph_id_list):
|
exclude_paragraph_id_list, stream=True):
|
||||||
params = self.to_base_pipeline_manage_params()
|
params = self.to_base_pipeline_manage_params()
|
||||||
return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
|
return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
|
||||||
'exclude_paragraph_id_list': exclude_paragraph_id_list}
|
'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream}
|
||||||
|
|
||||||
def append_chat_record(self, chat_record: ChatRecord):
|
def append_chat_record(self, chat_record: ChatRecord):
|
||||||
# 存入缓存中
|
# 存入缓存中
|
||||||
@ -126,7 +127,7 @@ def get_post_handler(chat_info: ChatInfo):
|
|||||||
class ChatMessageSerializer(serializers.Serializer):
|
class ChatMessageSerializer(serializers.Serializer):
|
||||||
chat_id = serializers.UUIDField(required=True)
|
chat_id = serializers.UUIDField(required=True)
|
||||||
|
|
||||||
def chat(self, message, re_chat: bool):
|
def chat(self, message, re_chat: bool, stream: bool):
|
||||||
self.is_valid(raise_exception=True)
|
self.is_valid(raise_exception=True)
|
||||||
chat_id = self.data.get('chat_id')
|
chat_id = self.data.get('chat_id')
|
||||||
chat_info: ChatInfo = chat_cache.get(chat_id)
|
chat_info: ChatInfo = chat_cache.get(chat_id)
|
||||||
@ -152,7 +153,8 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||||||
chat_info.chat_record_list)])
|
chat_info.chat_record_list)])
|
||||||
exclude_paragraph_id_list = list(set(paragraph_id_list))
|
exclude_paragraph_id_list = list(set(paragraph_id_list))
|
||||||
# 构建运行参数
|
# 构建运行参数
|
||||||
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list)
|
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list,
|
||||||
|
stream)
|
||||||
# 运行流水线作业
|
# 运行流水线作业
|
||||||
pipline_message.run(params)
|
pipline_message.run(params)
|
||||||
return pipline_message.context['chat_result']
|
return pipline_message.context['chat_result']
|
||||||
|
|||||||
@ -20,7 +20,8 @@ class ChatApi(ApiMixin):
|
|||||||
required=['message'],
|
required=['message'],
|
||||||
properties={
|
properties={
|
||||||
'message': openapi.Schema(type=openapi.TYPE_STRING, title="问题", description="问题"),
|
'message': openapi.Schema(type=openapi.TYPE_STRING, title="问题", description="问题"),
|
||||||
're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default="重新生成")
|
're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default=False),
|
||||||
|
'stream': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default=True)
|
||||||
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@ -72,7 +72,8 @@ class ChatView(APIView):
|
|||||||
)
|
)
|
||||||
def post(self, request: Request, chat_id: str):
|
def post(self, request: Request, chat_id: str):
|
||||||
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'), request.data.get(
|
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'), request.data.get(
|
||||||
're_chat') if 're_chat' in request.data else False)
|
're_chat') if 're_chat' in request.data else False, request.data.get(
|
||||||
|
'stream') if 'stream' in request.data else True)
|
||||||
|
|
||||||
@action(methods=['GET'], detail=False)
|
@action(methods=['GET'], detail=False)
|
||||||
@swagger_auto_schema(operation_summary="获取对话列表",
|
@swagger_auto_schema(operation_summary="获取对话列表",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user