fix: 修复工作流节点输出等问题 (#1716)

This commit is contained in:
shaohuzhang1 2024-11-29 19:26:16 +08:00 committed by GitHub
parent bce2558951
commit b8aa4756c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 313 additions and 73 deletions

View File

@ -9,6 +9,7 @@
import time import time
import uuid import uuid
from abc import abstractmethod from abc import abstractmethod
from hashlib import sha1
from typing import Type, Dict, List from typing import Type, Dict, List
from django.core import cache from django.core import cache
@ -131,6 +132,7 @@ class FlowParamsSerializer(serializers.Serializer):
class INode: class INode:
view_type = 'many_view'
@abstractmethod @abstractmethod
def save_context(self, details, workflow_manage): def save_context(self, details, workflow_manage):
@ -139,7 +141,7 @@ class INode:
def get_answer_text(self): def get_answer_text(self):
return self.answer_text return self.answer_text
def __init__(self, node, workflow_params, workflow_manage, runtime_node_id=None): def __init__(self, node, workflow_params, workflow_manage, up_node_id_list=None):
# 当前步骤上下文,用于存储当前步骤信息 # 当前步骤上下文,用于存储当前步骤信息
self.status = 200 self.status = 200
self.err_message = '' self.err_message = ''
@ -152,10 +154,13 @@ class INode:
self.context = {} self.context = {}
self.answer_text = None self.answer_text = None
self.id = node.id self.id = node.id
if runtime_node_id is None: if up_node_id_list is None:
self.runtime_node_id = str(uuid.uuid1()) up_node_id_list = []
else: self.up_node_id_list = up_node_id_list
self.runtime_node_id = runtime_node_id self.runtime_node_id = sha1(uuid.NAMESPACE_DNS.bytes + bytes(str(uuid.uuid5(uuid.NAMESPACE_DNS,
"".join([*sorted(up_node_id_list),
node.id]))),
"utf-8")).hexdigest()
def valid_args(self, node_params, flow_params): def valid_args(self, node_params, flow_params):
flow_params_serializer_class = self.get_flow_params_serializer_class() flow_params_serializer_class = self.get_flow_params_serializer_class()

View File

@ -21,6 +21,7 @@ class FormNodeParamsSerializer(serializers.Serializer):
class IFormNode(INode): class IFormNode(INode):
type = 'form-node' type = 'form-node'
view_type = 'single_view'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return FormNodeParamsSerializer return FormNodeParamsSerializer

View File

@ -34,6 +34,8 @@ class BaseFormNode(IFormNode):
self.context['form_field_list'] = details.get('form_field_list') self.context['form_field_list'] = details.get('form_field_list')
self.context['run_time'] = details.get('run_time') self.context['run_time'] = details.get('run_time')
self.context['start_time'] = details.get('start_time') self.context['start_time'] = details.get('start_time')
self.context['form_data'] = details.get('form_data')
self.context['is_submit'] = details.get('is_submit')
self.answer_text = details.get('result') self.answer_text = details.get('result')
def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult: def execute(self, form_field_list, form_content_format, **kwargs) -> NodeResult:
@ -77,6 +79,7 @@ class BaseFormNode(IFormNode):
"form_field_list": self.context.get('form_field_list'), "form_field_list": self.context.get('form_field_list'),
'form_data': self.context.get('form_data'), 'form_data': self.context.get('form_data'),
'start_time': self.context.get('start_time'), 'start_time': self.context.get('start_time'),
'is_submit': self.context.get('is_submit'),
'run_time': self.context.get('run_time'), 'run_time': self.context.get('run_time'),
'type': self.node.type, 'type': self.node.type,
'status': self.status, 'status': self.status,

View File

@ -52,7 +52,8 @@ class Node:
self.__setattr__(keyword, kwargs.get(keyword)) self.__setattr__(keyword, kwargs.get(keyword))
end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node', 'image-understand-node'] end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node', 'application-node',
'image-understand-node']
class Flow: class Flow:
@ -229,7 +230,9 @@ class NodeChunk:
def add_chunk(self, chunk): def add_chunk(self, chunk):
self.chunk_list.append(chunk) self.chunk_list.append(chunk)
def end(self): def end(self, chunk=None):
if chunk is not None:
self.add_chunk(chunk)
self.status = 200 self.status = 200
def is_end(self): def is_end(self):
@ -266,6 +269,7 @@ class WorkflowManage:
self.status = 0 self.status = 0
self.base_to_response = base_to_response self.base_to_response = base_to_response
self.chat_record = chat_record self.chat_record = chat_record
self.await_future_map = {}
if start_node_id is not None: if start_node_id is not None:
self.load_node(chat_record, start_node_id, start_node_data) self.load_node(chat_record, start_node_id, start_node_data)
else: else:
@ -286,14 +290,16 @@ class WorkflowManage:
for node_details in sorted(chat_record.details.values(), key=lambda d: d.get('index')): for node_details in sorted(chat_record.details.values(), key=lambda d: d.get('index')):
node_id = node_details.get('node_id') node_id = node_details.get('node_id')
if node_details.get('runtime_node_id') == start_node_id: if node_details.get('runtime_node_id') == start_node_id:
self.start_node = self.get_node_cls_by_id(node_id, node_details.get('runtime_node_id')) self.start_node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list'))
self.start_node.valid_args(self.start_node.node_params, self.start_node.workflow_params) self.start_node.valid_args(self.start_node.node_params, self.start_node.workflow_params)
self.start_node.save_context(node_details, self) self.start_node.save_context(node_details, self)
node_result = NodeResult({**start_node_data, 'form_data': start_node_data, 'is_submit': True}, {}) node_result = NodeResult({**start_node_data, 'form_data': start_node_data, 'is_submit': True}, {})
self.start_node_result_future = NodeResultFuture(node_result, None) self.start_node_result_future = NodeResultFuture(node_result, None)
return self.node_context.append(self.start_node)
continue
node_id = node_details.get('node_id') node_id = node_details.get('node_id')
node = self.get_node_cls_by_id(node_id, node_details.get('runtime_node_id')) node = self.get_node_cls_by_id(node_id, node_details.get('up_node_id_list'))
node.valid_args(node.node_params, node.workflow_params) node.valid_args(node.node_params, node.workflow_params)
node.save_context(node_details, self) node.save_context(node_details, self)
self.node_context.append(node) self.node_context.append(node)
@ -345,17 +351,22 @@ class WorkflowManage:
if chunk is None: if chunk is None:
break break
yield chunk yield chunk
yield self.get_chunk_content('', True)
finally: finally:
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
self.answer, self.answer,
self) self)
yield self.get_chunk_content('', True)
def run_chain_async(self, current_node, node_result_future): def run_chain_async(self, current_node, node_result_future):
future = executor.submit(self.run_chain, current_node, node_result_future) future = executor.submit(self.run_chain, current_node, node_result_future)
return future return future
def set_await_map(self, node_run_list):
sorted_node_run_list = sorted(node_run_list, key=lambda n: n.get('node').node.y)
for index in range(len(sorted_node_run_list)):
self.await_future_map[sorted_node_run_list[index].get('node').runtime_node_id] = [
sorted_node_run_list[i].get('future')
for i in range(index)]
def run_chain(self, current_node, node_result_future=None): def run_chain(self, current_node, node_result_future=None):
if current_node is None: if current_node is None:
start_node = self.get_start_node() start_node = self.get_start_node()
@ -365,6 +376,9 @@ class WorkflowManage:
try: try:
is_stream = self.params.get('stream', True) is_stream = self.params.get('stream', True)
# 处理节点响应 # 处理节点响应
await_future_list = self.await_future_map.get(current_node.runtime_node_id, None)
if await_future_list is not None:
[f.result() for f in await_future_list]
result = self.hand_event_node_result(current_node, result = self.hand_event_node_result(current_node,
node_result_future) if is_stream else self.hand_node_result( node_result_future) if is_stream else self.hand_node_result(
current_node, node_result_future) current_node, node_result_future)
@ -373,11 +387,9 @@ class WorkflowManage:
return return
node_list = self.get_next_node_list(current_node, result) node_list = self.get_next_node_list(current_node, result)
# 获取到可执行的子节点 # 获取到可执行的子节点
result_list = [] result_list = [{'node': node, 'future': self.run_chain_async(node, None)} for node in node_list]
for node in node_list: self.set_await_map(result_list)
result = self.run_chain_async(node, None) [r.get('future').result() for r in result_list]
result_list.append(result)
[r.result() for r in result_list]
if self.status == 0: if self.status == 0:
self.status = 200 self.status = 200
except Exception as e: except Exception as e:
@ -401,6 +413,14 @@ class WorkflowManage:
current_node.get_write_error_context(e) current_node.get_write_error_context(e)
self.answer += str(e) self.answer += str(e)
def append_node(self, current_node):
for index in range(len(self.node_context)):
n = self.node_context[index]
if current_node.id == n.node.id and current_node.runtime_node_id == n.runtime_node_id:
self.node_context[index] = current_node
return
self.node_context.append(current_node)
def hand_event_node_result(self, current_node, node_result_future): def hand_event_node_result(self, current_node, node_result_future):
node_chunk = NodeChunk() node_chunk = NodeChunk()
try: try:
@ -412,22 +432,35 @@ class WorkflowManage:
for r in result: for r in result:
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'], self.params['chat_record_id'],
r, False, 0, 0) current_node.id,
current_node.up_node_id_list,
r, False, 0, 0,
{'node_type': current_node.type,
'view_type': current_node.view_type})
node_chunk.add_chunk(chunk) node_chunk.add_chunk(chunk)
node_chunk.end() chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
current_node.id,
current_node.up_node_id_list,
'', False, 0, 0, {'node_is_end': True,
'node_type': current_node.type,
'view_type': current_node.view_type})
node_chunk.end(chunk)
else: else:
list(result) list(result)
# 添加节点 # 添加节点
self.node_context.append(current_node) self.append_node(current_node)
return current_result return current_result
except Exception as e: except Exception as e:
# 添加节点 # 添加节点
self.node_context.append(current_node) self.append_node(current_node)
traceback.print_exc() traceback.print_exc()
self.answer += str(e) self.answer += str(e)
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'], self.params['chat_record_id'],
str(e), False, 0, 0) current_node.id,
current_node.up_node_id_list,
str(e), False, 0, 0, {'node_is_end': True})
if not self.node_chunk_manage.contains(node_chunk): if not self.node_chunk_manage.contains(node_chunk):
self.node_chunk_manage.add_node_chunk(node_chunk) self.node_chunk_manage.add_node_chunk(node_chunk)
node_chunk.add_chunk(chunk) node_chunk.add_chunk(chunk)
@ -492,32 +525,36 @@ class WorkflowManage:
continue continue
details = node.get_details(index) details = node.get_details(index)
details['node_id'] = node.id details['node_id'] = node.id
details['up_node_id_list'] = node.up_node_id_list
details['runtime_node_id'] = node.runtime_node_id details['runtime_node_id'] = node.runtime_node_id
details_result[node.runtime_node_id] = details details_result[node.runtime_node_id] = details
return details_result return details_result
def get_answer_text_list(self): def get_answer_text_list(self):
answer_text_list = [] result = []
next_node_id_list = []
if self.start_node is not None:
next_node_id_list = [edge.targetNodeId for edge in self.flow.edges if
edge.sourceNodeId == self.start_node.id]
for index in range(len(self.node_context)): for index in range(len(self.node_context)):
node = self.node_context[index] node = self.node_context[index]
up_node = None
if index > 0:
up_node = self.node_context[index - 1]
answer_text = node.get_answer_text() answer_text = node.get_answer_text()
if answer_text is not None: if answer_text is not None:
if self.chat_record is not None and self.chat_record.details is not None: if up_node is None or node.view_type == 'single_view' or (
details = self.chat_record.details.get(node.runtime_node_id) node.view_type == 'many_view' and up_node.view_type == 'single_view'):
if details is not None and self.start_node.runtime_node_id != node.runtime_node_id: result.append(node.get_answer_text())
continue elif self.chat_record is not None and next_node_id_list.__contains__(
answer_text_list.append( node.id) and up_node is not None and not next_node_id_list.__contains__(
{'content': answer_text, 'type': 'form' if node.type == 'form-node' else 'md'}) up_node.id):
result = [] result.append(node.get_answer_text())
for index in range(len(answer_text_list)): else:
answer = answer_text_list[index] content = result[len(result) - 1]
if index == 0: answer_text = node.get_answer_text()
result.append(answer.get('content')) result[len(result) - 1] += answer_text if len(
continue content) == 0 else ('\n\n' + answer_text)
if answer.get('type') != answer_text_list[index - 1].get('type'):
result.append(answer.get('content'))
else:
result[-1] += answer.get('content')
return result return result
def get_next_node(self): def get_next_node(self):
@ -540,6 +577,15 @@ class WorkflowManage:
return None return None
@staticmethod
def dependent_node(up_node_id, node):
if node.id == up_node_id:
if node.type == 'form-node':
if node.context.get('form_data', None) is not None:
return True
return False
return True
def dependent_node_been_executed(self, node_id): def dependent_node_been_executed(self, node_id):
""" """
判断依赖节点是否都已执行 判断依赖节点是否都已执行
@ -547,7 +593,12 @@ class WorkflowManage:
@return: @return:
""" """
up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id] up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id]
return all([any([node.id == up_node_id for node in self.node_context]) for up_node_id in up_node_id_list]) return all([any([self.dependent_node(up_node_id, node) for node in self.node_context]) for up_node_id in
up_node_id_list])
def get_up_node_id_list(self, node_id):
up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id]
return up_node_id_list
def get_next_node_list(self, current_node, current_node_result): def get_next_node_list(self, current_node, current_node_result):
""" """
@ -556,6 +607,7 @@ class WorkflowManage:
@param current_node_result: 当前可执行节点结果 @param current_node_result: 当前可执行节点结果
@return: 可执行节点列表 @return: 可执行节点列表
""" """
if current_node.type == 'form-node' and 'form_data' not in current_node_result.node_variable: if current_node.type == 'form-node' and 'form_data' not in current_node_result.node_variable:
return [] return []
node_list = [] node_list = []
@ -564,11 +616,13 @@ class WorkflowManage:
if (edge.sourceNodeId == current_node.id and if (edge.sourceNodeId == current_node.id and
f"{edge.sourceNodeId}_{current_node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): f"{edge.sourceNodeId}_{current_node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
if self.dependent_node_been_executed(edge.targetNodeId): if self.dependent_node_been_executed(edge.targetNodeId):
node_list.append(self.get_node_cls_by_id(edge.targetNodeId)) node_list.append(
self.get_node_cls_by_id(edge.targetNodeId, self.get_up_node_id_list(edge.targetNodeId)))
else: else:
for edge in self.flow.edges: for edge in self.flow.edges:
if edge.sourceNodeId == current_node.id and self.dependent_node_been_executed(edge.targetNodeId): if edge.sourceNodeId == current_node.id and self.dependent_node_been_executed(edge.targetNodeId):
node_list.append(self.get_node_cls_by_id(edge.targetNodeId)) node_list.append(
self.get_node_cls_by_id(edge.targetNodeId, self.get_up_node_id_list(edge.targetNodeId)))
return node_list return node_list
def get_reference_field(self, node_id: str, fields: List[str]): def get_reference_field(self, node_id: str, fields: List[str]):
@ -629,11 +683,11 @@ class WorkflowManage:
base_node_list = [node for node in self.flow.nodes if node.type == 'base-node'] base_node_list = [node for node in self.flow.nodes if node.type == 'base-node']
return base_node_list[0] return base_node_list[0]
def get_node_cls_by_id(self, node_id, runtime_node_id=None): def get_node_cls_by_id(self, node_id, up_node_id_list=None):
for node in self.flow.nodes: for node in self.flow.nodes:
if node.id == node_id: if node.id == node_id:
node_instance = get_node(node.type)(node, node_instance = get_node(node.type)(node,
self.params, self, runtime_node_id) self.params, self, up_node_id_list)
return node_instance return node_instance
return None return None

View File

@ -224,8 +224,13 @@ class ChatMessageSerializer(serializers.Serializer):
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答")) re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答"))
chat_record_id = serializers.UUIDField(required=False, allow_null=True, chat_record_id = serializers.UUIDField(required=False, allow_null=True,
error_messages=ErrMessage.uuid("对话记录id")) error_messages=ErrMessage.uuid("对话记录id"))
node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("节点id"))
runtime_node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, runtime_node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("节点id")) error_messages=ErrMessage.char("运行时节点id"))
node_data = serializers.DictField(required=False, error_messages=ErrMessage.char("节点参数")) node_data = serializers.DictField(required=False, error_messages=ErrMessage.char("节点参数"))
application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id")) application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id"))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
@ -339,7 +344,8 @@ class ChatMessageSerializer(serializers.Serializer):
'client_id': client_id, 'client_id': client_id,
'client_type': client_type, 'client_type': client_type,
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type), 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
base_to_response, form_data, image_list, document_list, self.data.get('runtime_node_id'), base_to_response, form_data, image_list, document_list,
self.data.get('runtime_node_id'),
self.data.get('node_data'), chat_record) self.data.get('node_data'), chat_record)
r = work_flow_manage.run() r = work_flow_manage.run()
return r return r

View File

@ -135,6 +135,7 @@ class ChatView(APIView):
'document_list': request.data.get( 'document_list': request.data.get(
'document_list') if 'document_list' in request.data else [], 'document_list') if 'document_list' in request.data else [],
'client_type': request.auth.client_type, 'client_type': request.auth.client_type,
'node_id': request.data.get('node_id', None),
'runtime_node_id': request.data.get('runtime_node_id', None), 'runtime_node_id': request.data.get('runtime_node_id', None),
'node_data': request.data.get('node_data', {}), 'node_data': request.data.get('node_data', {}),
'chat_record_id': request.data.get('chat_record_id')} 'chat_record_id': request.data.get('chat_record_id')}

View File

@ -14,12 +14,15 @@ from rest_framework import status
class BaseToResponse(ABC): class BaseToResponse(ABC):
@abstractmethod @abstractmethod
def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens, def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens,
prompt_tokens, other_params: dict = None,
_status=status.HTTP_200_OK): _status=status.HTTP_200_OK):
pass pass
@abstractmethod @abstractmethod
def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens): def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end,
completion_tokens,
prompt_tokens, other_params: dict = None):
pass pass
@staticmethod @staticmethod

View File

@ -20,6 +20,7 @@ from common.handle.base_to_response import BaseToResponse
class OpenaiToResponse(BaseToResponse): class OpenaiToResponse(BaseToResponse):
def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens, def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens,
other_params: dict = None,
_status=status.HTTP_200_OK): _status=status.HTTP_200_OK):
data = ChatCompletion(id=chat_record_id, choices=[ data = ChatCompletion(id=chat_record_id, choices=[
BlockChoice(finish_reason='stop', index=0, chat_id=chat_id, BlockChoice(finish_reason='stop', index=0, chat_id=chat_id,
@ -31,7 +32,8 @@ class OpenaiToResponse(BaseToResponse):
).dict() ).dict()
return JsonResponse(data=data, status=_status) return JsonResponse(data=data, status=_status)
def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens): def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, completion_tokens,
prompt_tokens, other_params: dict = None):
chunk = ChatCompletionChunk(id=chat_record_id, model='', object='chat.completion.chunk', chunk = ChatCompletionChunk(id=chat_record_id, model='', object='chat.completion.chunk',
created=datetime.datetime.now().second, choices=[ created=datetime.datetime.now().second, choices=[
Choice(delta=ChoiceDelta(content=content, chat_id=chat_id), finish_reason='stop' if is_end else None, Choice(delta=ChoiceDelta(content=content, chat_id=chat_id), finish_reason='stop' if is_end else None,

View File

@ -15,12 +15,23 @@ from common.response import result
class SystemToResponse(BaseToResponse): class SystemToResponse(BaseToResponse):
def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens, def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens,
prompt_tokens, other_params: dict = None,
_status=status.HTTP_200_OK): _status=status.HTTP_200_OK):
if other_params is None:
other_params = {}
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': content, 'is_end': is_end}, response_status=_status, code=_status) 'content': content, 'is_end': is_end, **other_params}, response_status=_status,
code=_status)
def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens): def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, completion_tokens,
prompt_tokens, other_params: dict = None):
if other_params is None:
other_params = {}
chunk = json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, chunk = json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': content, 'is_end': is_end}) 'content': content, 'node_id': node_id, 'up_node_id_list': up_node_id_list, 'is_end': is_end,
'usage': {'completion_tokens': completion_tokens,
'prompt_tokens': prompt_tokens,
'total_tokens': completion_tokens + prompt_tokens},
**other_params})
return super().format_stream_chunk(chunk) return super().format_stream_chunk(chunk)

View File

@ -17,6 +17,8 @@ def updateDocumentStatus(apps, schema_editor):
ParagraphModel = apps.get_model('dataset', 'Paragraph') ParagraphModel = apps.get_model('dataset', 'Paragraph')
DocumentModel = apps.get_model('dataset', 'Document') DocumentModel = apps.get_model('dataset', 'Document')
success_list = QuerySet(DocumentModel).filter(status='2') success_list = QuerySet(DocumentModel).filter(status='2')
if len(success_list) == 0:
return
ListenerManagement.update_status(QuerySet(ParagraphModel).filter(document_id__in=[d.id for d in success_list]), ListenerManagement.update_status(QuerySet(ParagraphModel).filter(document_id__in=[d.id for d in success_list]),
TaskType.EMBEDDING, State.SUCCESS) TaskType.EMBEDDING, State.SUCCESS)
ListenerManagement.get_aggregation_document_status_by_query_set(QuerySet(DocumentModel))() ListenerManagement.get_aggregation_document_status_by_query_set(QuerySet(DocumentModel))()

View File

@ -22,6 +22,17 @@ interface ApplicationFormType {
tts_model_enable?: boolean tts_model_enable?: boolean
tts_type?: string tts_type?: string
} }
interface Chunk {
chat_id: string
id: string
content: string
node_id: string
up_node_id: string
is_end: boolean
node_is_end: boolean
node_type: string
view_type: string
}
interface chatType { interface chatType {
id: string id: string
problem_text: string problem_text: string
@ -47,6 +58,21 @@ interface chatType {
} }
} }
interface Node {
buffer: Array<string>
node_id: string
up_node_id: string
node_type: string
view_type: string
index: number
is_end: boolean
}
interface WriteNodeInfo {
current_node: any
answer_text_list_index: number
current_up_node?: any
divider_content?: Array<string>
}
export class ChatRecordManage { export class ChatRecordManage {
id?: any id?: any
ms: number ms: number
@ -55,6 +81,8 @@ export class ChatRecordManage {
write_ed?: boolean write_ed?: boolean
is_stop?: boolean is_stop?: boolean
loading?: Ref<boolean> loading?: Ref<boolean>
node_list: Array<any>
write_node_info?: WriteNodeInfo
constructor(chat: chatType, ms?: number, loading?: Ref<boolean>) { constructor(chat: chatType, ms?: number, loading?: Ref<boolean>) {
this.ms = ms ? ms : 10 this.ms = ms ? ms : 10
this.chat = chat this.chat = chat
@ -62,12 +90,82 @@ export class ChatRecordManage {
this.is_stop = false this.is_stop = false
this.is_close = false this.is_close = false
this.write_ed = false this.write_ed = false
this.node_list = []
} }
append_answer(chunk_answer: String) { append_answer(chunk_answer: string, index?: number) {
this.chat.answer_text_list[this.chat.answer_text_list.length - 1] = this.chat.answer_text_list[index != undefined ? index : this.chat.answer_text_list.length - 1] =
this.chat.answer_text_list[this.chat.answer_text_list.length - 1] + chunk_answer this.chat.answer_text_list[
index !== undefined ? index : this.chat.answer_text_list.length - 1
]
? this.chat.answer_text_list[
index !== undefined ? index : this.chat.answer_text_list.length - 1
] + chunk_answer
: chunk_answer
this.chat.answer_text = this.chat.answer_text + chunk_answer this.chat.answer_text = this.chat.answer_text + chunk_answer
} }
get_run_node() {
if (
this.write_node_info &&
(this.write_node_info.current_node.buffer.length > 0 ||
!this.write_node_info.current_node.is_end)
) {
return this.write_node_info
}
const run_node = this.node_list.filter((item) => item.buffer.length > 0 || !item.is_end).at(0)
if (run_node) {
const index = this.node_list.indexOf(run_node)
let current_up_node = undefined
if (index > 0) {
current_up_node = this.node_list[index - 1]
}
let answer_text_list_index = 0
if (
current_up_node == undefined ||
run_node.view_type == 'single_view' ||
(run_node.view_type == 'many_view' && current_up_node.view_type == 'single_view')
) {
const none_index = this.chat.answer_text_list.indexOf('')
if (none_index > -1) {
answer_text_list_index = none_index
} else {
answer_text_list_index = this.chat.answer_text_list.length
}
} else {
const none_index = this.chat.answer_text_list.indexOf('')
if (none_index > -1) {
answer_text_list_index = none_index
} else {
answer_text_list_index = this.chat.answer_text_list.length - 1
}
}
this.write_node_info = {
current_node: run_node,
divider_content: ['\n\n'],
current_up_node: current_up_node,
answer_text_list_index: answer_text_list_index
}
return this.write_node_info
}
return undefined
}
closeInterval() {
this.chat.write_ed = true
this.write_ed = true
if (this.loading) {
this.loading.value = false
}
if (this.id) {
clearInterval(this.id)
}
const last_index = this.chat.answer_text_list.lastIndexOf('')
if (last_index > 0) {
this.chat.answer_text_list.splice(last_index, 1)
}
}
write() { write() {
this.chat.is_stop = false this.chat.is_stop = false
this.is_stop = false this.is_stop = false
@ -78,22 +176,45 @@ export class ChatRecordManage {
this.loading.value = true this.loading.value = true
} }
this.id = setInterval(() => { this.id = setInterval(() => {
if (this.chat.buffer.length > 20) { const node_info = this.get_run_node()
this.append_answer(this.chat.buffer.splice(0, this.chat.buffer.length - 20).join('')) if (node_info == undefined) {
if (this.is_close) {
this.closeInterval()
}
return
}
const { current_node, answer_text_list_index, divider_content } = node_info
if (current_node.buffer.length > 20) {
const context = current_node.is_end
? current_node.buffer.splice(0)
: current_node.buffer.splice(
0,
current_node.is_end ? undefined : current_node.buffer.length - 20
)
this.append_answer(
(divider_content ? divider_content.splice(0).join('') : '') + context.join(''),
answer_text_list_index
)
} else if (this.is_close) { } else if (this.is_close) {
this.append_answer(this.chat.buffer.splice(0).join('')) while (true) {
this.chat.write_ed = true const node_info = this.get_run_node()
this.write_ed = true if (node_info == undefined) {
if (this.loading) { break
this.loading.value = false }
} this.append_answer(
if (this.id) { (node_info.divider_content ? node_info.divider_content.splice(0).join('') : '') +
clearInterval(this.id) node_info.current_node.buffer.splice(0).join(''),
node_info.answer_text_list_index
)
} }
this.closeInterval()
} else { } else {
const s = this.chat.buffer.shift() const s = current_node.buffer.shift()
if (s !== undefined) { if (s !== undefined) {
this.append_answer(s) this.append_answer(
(divider_content ? divider_content.splice(0).join('') : '') + s,
answer_text_list_index
)
} }
} }
}, this.ms) }, this.ms)
@ -113,6 +234,28 @@ export class ChatRecordManage {
this.is_close = false this.is_close = false
this.is_stop = false this.is_stop = false
} }
appendChunk(chunk: Chunk) {
let n = this.node_list.find(
(item) => item.node_id == chunk.node_id && item.up_node_id === chunk.up_node_id
)
if (n) {
n.buffer.push(...chunk.content)
} else {
n = {
buffer: [...chunk.content],
node_id: chunk.node_id,
up_node_id: chunk.up_node_id,
node_type: chunk.node_type,
index: this.node_list.length,
view_type: chunk.view_type,
is_end: false
}
this.node_list.push(n)
}
if (chunk.node_is_end) {
n['is_end'] = true
}
}
append(answer_text_block: string) { append(answer_text_block: string) {
for (let index = 0; index < answer_text_block.length; index++) { for (let index = 0; index < answer_text_block.length; index++) {
this.chat.buffer.push(answer_text_block[index]) this.chat.buffer.push(answer_text_block[index])
@ -126,6 +269,12 @@ export class ChatManagement {
static addChatRecord(chat: chatType, ms: number, loading?: Ref<boolean>) { static addChatRecord(chat: chatType, ms: number, loading?: Ref<boolean>) {
this.chatMessageContainer[chat.id] = new ChatRecordManage(chat, ms, loading) this.chatMessageContainer[chat.id] = new ChatRecordManage(chat, ms, loading)
} }
static appendChunk(chatRecordId: string, chunk: Chunk) {
const chatRecord = this.chatMessageContainer[chatRecordId]
if (chatRecord) {
chatRecord.appendChunk(chunk)
}
}
static append(chatRecordId: string, content: string) { static append(chatRecordId: string, content: string) {
const chatRecord = this.chatMessageContainer[chatRecordId] const chatRecord = this.chatMessageContainer[chatRecordId]
if (chatRecord) { if (chatRecord) {
@ -144,6 +293,7 @@ export class ChatManagement {
*/ */
static write(chatRecordId: string) { static write(chatRecordId: string) {
const chatRecord = this.chatMessageContainer[chatRecordId] const chatRecord = this.chatMessageContainer[chatRecordId]
console.log('chatRecord', chatRecordId, this.chatMessageContainer, chatRecord)
if (chatRecord) { if (chatRecord) {
chatRecord.write() chatRecord.write()
} }

View File

@ -223,10 +223,8 @@ const getWrite = (chat: any, reader: any, stream: boolean) => {
const chunk = JSON?.parse(split[index].replace('data:', '')) const chunk = JSON?.parse(split[index].replace('data:', ''))
chat.chat_id = chunk.chat_id chat.chat_id = chunk.chat_id
chat.record_id = chunk.id chat.record_id = chunk.id
const content = chunk?.content ChatManagement.appendChunk(chat.id, chunk)
if (content) {
ChatManagement.append(chat.id, content)
}
if (chunk.is_end) { if (chunk.is_end) {
// //
return Promise.resolve() return Promise.resolve()
@ -275,6 +273,7 @@ const errorWrite = (chat: any, message?: string) => {
function chatMessage(chat?: any, problem?: string, re_chat?: boolean, other_params_data?: any) { function chatMessage(chat?: any, problem?: string, re_chat?: boolean, other_params_data?: any) {
loading.value = true loading.value = true
console.log(chat)
if (!chat) { if (!chat) {
chat = reactive({ chat = reactive({
id: randomId(), id: randomId(),
@ -306,6 +305,10 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean, other_para
scrollDiv.value.setScrollTop(getMaxHeight()) scrollDiv.value.setScrollTop(getMaxHeight())
}) })
} }
if (chat.run_time) {
ChatManagement.addChatRecord(chat, 50, loading)
ChatManagement.write(chat.id)
}
if (!chartOpenId.value) { if (!chartOpenId.value) {
getChartOpenId(chat).catch(() => { getChartOpenId(chat).catch(() => {
errorWrite(chat) errorWrite(chat)

View File

@ -52,7 +52,6 @@ const is_submit = computed(() => {
const _form_data = ref<any>({}) const _form_data = ref<any>({})
const form_data = computed({ const form_data = computed({
get: () => { get: () => {
console.log(form_setting_data.value)
if (form_setting_data.value.is_submit) { if (form_setting_data.value.is_submit) {
return form_setting_data.value.form_data return form_setting_data.value.form_data
} else { } else {