feat: 工作流编排支持并行 #1154 (#1362)

This commit is contained in:
shaohuzhang1 2024-10-13 12:00:13 +08:00 committed by GitHub
parent 277e13513f
commit 2e331dcf56
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 289 additions and 116 deletions

View File

@ -28,7 +28,7 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
if step_variable is not None: if step_variable is not None:
for key in step_variable: for key in step_variable:
node.context[key] = step_variable[key] node.context[key] = step_variable[key]
if workflow.is_result() and 'answer' in step_variable: if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'answer' in step_variable:
answer = step_variable['answer'] answer = step_variable['answer']
yield answer yield answer
workflow.answer += answer workflow.answer += answer
@ -166,6 +166,7 @@ class INode:
def get_write_error_context(self, e): def get_write_error_context(self, e):
self.status = 500 self.status = 500
self.err_message = str(e) self.err_message = str(e)
self.context['run_time'] = time.time() - self.context['start_time']
def write_error_context(answer, status=200): def write_error_context(answer, status=200):
pass pass

View File

@ -31,7 +31,7 @@ def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wo
node.context['history_message'] = node_variable['history_message'] node.context['history_message'] = node_variable['history_message']
node.context['question'] = node_variable['question'] node.context['question'] = node_variable['question']
node.context['run_time'] = time.time() - node.context['start_time'] node.context['run_time'] = time.time() - node.context['start_time']
if workflow.is_result(): if workflow.is_result(node, NodeResult(node_variable, workflow_variable)):
workflow.answer += answer workflow.answer += answer

View File

@ -26,7 +26,7 @@ def write_context(step_variable: Dict, global_variable: Dict, node, workflow):
if step_variable is not None: if step_variable is not None:
for key in step_variable: for key in step_variable:
node.context[key] = step_variable[key] node.context[key] = step_variable[key]
if workflow.is_result() and 'result' in step_variable: if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable:
result = str(step_variable['result']) + '\n' result = str(step_variable['result']) + '\n'
yield result yield result
workflow.answer += result workflow.answer += result

View File

@ -7,7 +7,9 @@
@desc: @desc:
""" """
import json import json
import threading
import traceback import traceback
from concurrent.futures import ThreadPoolExecutor
from functools import reduce from functools import reduce
from typing import List, Dict from typing import List, Dict
@ -26,6 +28,8 @@ from function_lib.models.function import FunctionLib
from setting.models import Model from setting.models import Model
from setting.models_provider import get_model_credential from setting.models_provider import get_model_credential
executor = ThreadPoolExecutor(max_workers=50)
class Edge: class Edge:
def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords): def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords):
@ -95,17 +99,11 @@ class Flow:
if len(edge_list) == 0: if len(edge_list) == 0:
raise AppApiException(500, raise AppApiException(500,
f'{node.properties.get("stepName")} 节点的{branch.get("type")}分支需要连接') f'{node.properties.get("stepName")} 节点的{branch.get("type")}分支需要连接')
elif len(edge_list) > 1:
raise AppApiException(500,
f'{node.properties.get("stepName")} 节点的{branch.get("type")}分支不能连接俩个节点')
else: else:
edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id] edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id]
if len(edge_list) == 0 and not end_nodes.__contains__(node.type): if len(edge_list) == 0 and not end_nodes.__contains__(node.type):
raise AppApiException(500, f'{node.properties.get("stepName")} 节点不能当做结束节点') raise AppApiException(500, f'{node.properties.get("stepName")} 节点不能当做结束节点')
elif len(edge_list) > 1:
raise AppApiException(500,
f'{node.properties.get("stepName")} 节点不能连接俩个节点')
def get_next_nodes(self, node: Node): def get_next_nodes(self, node: Node):
edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id] edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id]
@ -165,6 +163,77 @@ class Flow:
raise AppApiException(500, '基本信息节点只能有一个') raise AppApiException(500, '基本信息节点只能有一个')
class NodeResultFuture:
def __init__(self, r, e, status=200):
self.r = r
self.e = e
self.status = status
def result(self):
if self.status == 200:
return self.r
else:
raise self.e
def await_result(result, timeout=1):
try:
result.result(timeout)
return False
except Exception as e:
return True
class NodeChunkManage:
def __init__(self, work_flow):
self.node_chunk_list = []
self.current_node_chunk = None
self.work_flow = work_flow
def add_node_chunk(self, node_chunk):
self.node_chunk_list.append(node_chunk)
def pop(self):
if self.current_node_chunk is None:
try:
current_node_chunk = self.node_chunk_list.pop(0)
self.current_node_chunk = current_node_chunk
except IndexError as e:
pass
if self.current_node_chunk is not None:
try:
chunk = self.current_node_chunk.chunk_list.pop(0)
return chunk
except IndexError as e:
if self.current_node_chunk.is_end():
self.current_node_chunk = None
if len(self.work_flow.answer) > 0:
chunk = self.work_flow.base_to_response.to_stream_chunk_response(
self.work_flow.params['chat_id'],
self.work_flow.params['chat_record_id'],
'\n\n', False, 0, 0)
self.work_flow.answer += '\n\n'
return chunk
return self.pop()
return None
class NodeChunk:
def __init__(self):
self.status = 0
self.chunk_list = []
def add_chunk(self, chunk):
self.chunk_list.append(chunk)
def end(self):
self.status = 200
def is_end(self):
return self.status == 200
class WorkflowManage: class WorkflowManage:
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler, def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler,
base_to_response: BaseToResponse = SystemToResponse(), form_data=None): base_to_response: BaseToResponse = SystemToResponse(), form_data=None):
@ -173,12 +242,15 @@ class WorkflowManage:
self.form_data = form_data self.form_data = form_data
self.params = params self.params = params
self.flow = flow self.flow = flow
self.lock = threading.Lock()
self.context = {} self.context = {}
self.node_context = [] self.node_context = []
self.node_chunk_manage = NodeChunkManage(self)
self.work_flow_post_handler = work_flow_post_handler self.work_flow_post_handler = work_flow_post_handler
self.current_node = None self.current_node = None
self.current_result = None self.current_result = None
self.answer = "" self.answer = ""
self.status = 0
self.base_to_response = base_to_response self.base_to_response = base_to_response
def run(self): def run(self):
@ -187,16 +259,12 @@ class WorkflowManage:
return self.run_block() return self.run_block()
def run_block(self): def run_block(self):
try: """
while self.has_next_node(self.current_result): 非流式响应
self.current_node = self.get_next_node() @return: 结果
self.current_node.valid_args(self.current_node.node_params, self.current_node.workflow_params) """
self.node_context.append(self.current_node) result = self.run_chain_async(None)
self.current_result = self.current_node.run() result.result()
result = self.current_result.write_context(self.current_node, self)
if result is not None:
list(result)
if not self.has_next_node(self.current_result):
details = self.get_runtime_details() details = self.get_runtime_details()
message_tokens = sum([row.get('message_tokens') for row in details.values() if message_tokens = sum([row.get('message_tokens') for row in details.values() if
'message_tokens' in row and row.get('message_tokens') is not None]) 'message_tokens' in row and row.get('message_tokens') is not None])
@ -207,95 +275,164 @@ class WorkflowManage:
self) self)
return self.base_to_response.to_block_response(self.params['chat_id'], return self.base_to_response.to_block_response(self.params['chat_id'],
self.params['chat_record_id'], self.answer, True self.params['chat_record_id'], self.answer, True
, message_tokens, answer_tokens) , message_tokens, answer_tokens,
except Exception as e: _status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR)
traceback.print_exc()
self.current_node.get_write_error_context(e)
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
self.answer,
self)
return self.base_to_response.to_block_response(self.params['chat_id'], self.params['chat_record_id'],
str(e), True,
0, 0, _status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def run_stream(self): def run_stream(self):
return tools.to_stream_response_simple(self.stream_event())
def stream_event(self):
try:
while self.has_next_node(self.current_result):
self.current_node = self.get_next_node()
self.node_context.append(self.current_node)
self.current_node.valid_args(self.current_node.node_params, self.current_node.workflow_params)
self.current_result = self.current_node.run()
result = self.current_result.write_context(self.current_node, self)
has_next_node = self.has_next_node(self.current_result)
if result is not None:
if self.is_result():
for r in result:
yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
r, False, 0, 0)
if has_next_node:
yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
'\n', False, 0, 0)
self.answer += '\n'
else:
list(result)
if not has_next_node:
details = self.get_runtime_details()
message_tokens = sum([row.get('message_tokens') for row in details.values() if
'message_tokens' in row and row.get('message_tokens') is not None])
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
'answer_tokens' in row and row.get('answer_tokens') is not None])
yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
'', True, message_tokens, answer_tokens)
break
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
self.answer,
self)
except Exception as e:
self.current_node.get_write_error_context(e)
self.answer += str(e)
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
self.answer,
self)
yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'], self.params['chat_record_id'],
str(e), True, 0, 0)
def is_result(self):
""" """
判断是否是返回节点 流式响应
@return: @return:
""" """
return self.current_node.node_params.get('is_result', not self.has_next_node( result = self.run_chain_async(None)
self.current_result)) if self.current_node.node_params is not None else False return tools.to_stream_response_simple(self.await_result(result))
def await_result(self, result):
try:
while await_result(result):
while True:
chunk = self.node_chunk_manage.pop()
if chunk is not None:
yield chunk
else:
break
while True:
chunk = self.node_chunk_manage.pop()
if chunk is None:
break
yield chunk
finally:
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
self.answer,
self)
def run_chain_async(self, current_node):
future = executor.submit(self.run_chain, current_node)
return future
def run_chain(self, current_node):
if current_node is None:
start_node = self.get_start_node()
current_node = get_node(start_node.type)(start_node, self.params, self)
node_result_future = self.run_node_future(current_node)
try:
is_stream = self.params.get('stream', True)
# 处理节点响应
result = self.hand_event_node_result(current_node,
node_result_future) if is_stream else self.hand_node_result(
current_node, node_result_future)
with self.lock:
if current_node.status == 500:
return
node_list = self.get_next_node_list(current_node, result)
# 获取到可执行的子节点
result_list = []
for node in node_list:
result = self.run_chain_async(node)
result_list.append(result)
[r.result() for r in result_list]
if self.status == 0:
self.status = 200
except Exception as e:
traceback.print_exc()
def hand_node_result(self, current_node, node_result_future):
try:
current_result = node_result_future.result()
result = current_result.write_context(current_node, self)
if result is not None:
# 阻塞获取结果
list(result)
# 添加节点
self.node_context.append(current_node)
return current_result
except Exception as e:
# 添加节点
self.node_context.append(current_node)
traceback.print_exc()
self.status = 500
current_node.get_write_error_context(e)
self.answer += str(e)
def hand_event_node_result(self, current_node, node_result_future):
try:
current_result = node_result_future.result()
result = current_result.write_context(current_node, self)
if result is not None:
if self.is_result(current_node, current_result):
node_chunk = NodeChunk()
self.node_chunk_manage.add_node_chunk(node_chunk)
for r in result:
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
r, False, 0, 0)
node_chunk.add_chunk(chunk)
node_chunk.end()
else:
list(result)
# 添加节点
self.node_context.append(current_node)
return current_result
except Exception as e:
# 添加节点
self.node_context.append(current_node)
traceback.print_exc()
self.status = 500
current_node.get_write_error_context(e)
self.answer += str(e)
chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
str(e), False, 0, 0)
node_chunk = NodeChunk()
self.node_chunk_manage.add_node_chunk(node_chunk)
node_chunk.add_chunk(chunk)
node_chunk.end()
def run_node_async(self, node):
future = executor.submit(self.run_node, node)
return future
def run_node_future(self, node):
try:
node.valid_args(node.node_params, node.workflow_params)
result = self.run_node(node)
return NodeResultFuture(result, None, 200)
except Exception as e:
return NodeResultFuture(None, e, 500)
def run_node(self, node):
result = node.run()
result.write_context(node, self)
return result
def is_result(self, current_node, current_node_result):
return current_node.node_params.get('is_result', not self._has_next_node(
current_node, current_node_result)) if current_node.node_params is not None else False
def get_chunk_content(self, chunk, is_end=False): def get_chunk_content(self, chunk, is_end=False):
return 'data: ' + json.dumps( return 'data: ' + json.dumps(
{'chat_id': self.params['chat_id'], 'id': self.params['chat_record_id'], 'operate': True, {'chat_id': self.params['chat_id'], 'id': self.params['chat_record_id'], 'operate': True,
'content': chunk, 'is_end': is_end}, ensure_ascii=False) + "\n\n" 'content': chunk, 'is_end': is_end}, ensure_ascii=False) + "\n\n"
def has_next_node(self, node_result: NodeResult | None): def _has_next_node(self, current_node, node_result: NodeResult | None):
""" """
是否有下一个可运行的节点 是否有下一个可运行的节点
""" """
if self.current_node is None:
if self.get_start_node() is not None:
return True
else:
if node_result is not None and node_result.is_assertion_result(): if node_result is not None and node_result.is_assertion_result():
for edge in self.flow.edges: for edge in self.flow.edges:
if (edge.sourceNodeId == self.current_node.id and if (edge.sourceNodeId == current_node.id and
f"{edge.sourceNodeId}_{node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): f"{edge.sourceNodeId}_{node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
return True return True
else: else:
for edge in self.flow.edges: for edge in self.flow.edges:
if edge.sourceNodeId == self.current_node.id: if edge.sourceNodeId == current_node.id:
return True return True
return False
def has_next_node(self, node_result: NodeResult | None):
"""
是否有下一个可运行的节点
"""
return self._has_next_node(self.get_start_node() if self.current_node is None else self.current_node,
node_result)
def get_runtime_details(self): def get_runtime_details(self):
details_result = {} details_result = {}
@ -325,9 +462,37 @@ class WorkflowManage:
return None return None
def dependent_node_been_executed(self, node_id):
"""
判断依赖节点是否都已执行
@param node_id: 需要判断的节点id
@return:
"""
up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id]
return all([any([node.id == up_node_id for node in self.node_context]) for up_node_id in up_node_id_list])
def get_next_node_list(self, current_node, current_node_result):
"""
获取下一个可执行节点列表
@param current_node: 当前可执行节点
@param current_node_result: 当前可执行节点结果
@return: 可执行节点列表
"""
node_list = []
if current_node_result is not None and current_node_result.is_assertion_result():
for edge in self.flow.edges:
if (edge.sourceNodeId == current_node.id and
f"{edge.sourceNodeId}_{current_node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId):
if self.dependent_node_been_executed(edge.targetNodeId):
node_list.append(self.get_node_cls_by_id(edge.targetNodeId))
else:
for edge in self.flow.edges:
if edge.sourceNodeId == current_node.id and self.dependent_node_been_executed(edge.targetNodeId):
node_list.append(self.get_node_cls_by_id(edge.targetNodeId))
return node_list
def get_reference_field(self, node_id: str, fields: List[str]): def get_reference_field(self, node_id: str, fields: List[str]):
""" """
@param node_id: 节点id @param node_id: 节点id
@param fields: 字段 @param fields: 字段
@return: @return:

View File

@ -73,7 +73,7 @@ class AppNode extends HtmlResize.view {
lh('div', { lh('div', {
style: { zindex: 0 }, style: { zindex: 0 },
onClick: () => { onClick: () => {
if (!isConnect && type == 'right') { if (type == 'right') {
this.props.model.openNodeMenu(anchorData) this.props.model.openNodeMenu(anchorData)
} }
}, },
@ -193,23 +193,34 @@ class AppNodeModel extends HtmlResize.model {
get_width() { get_width() {
return this.properties?.width || 340 return this.properties?.width || 340
} }
setAttributes() { setAttributes() {
this.width = this.get_width() this.width = this.get_width()
const isLoop=(node_id:string,target_node_id:string)=>{
const up_node_list=this.graphModel.getNodeIncomingNode(node_id)
for (const index in up_node_list) {
const item=up_node_list[index]
if(item.id===target_node_id){
return true
}else{
const result= isLoop(item.id,target_node_id)
if(result){
return true
}
}
}
return false
}
const circleOnlyAsTarget = { const circleOnlyAsTarget = {
message: '只允许从右边的锚点连出', message: '只允许从右边的锚点连出',
validate: (sourceNode: any, targetNode: any, sourceAnchor: any) => { validate: (sourceNode: any, targetNode: any, sourceAnchor: any) => {
return sourceAnchor.type === 'right' return sourceAnchor.type === 'right'
} }
} }
this.sourceRules.push({ this.sourceRules.push({
message: '只允许连一个节点', message: '不可循环连线',
validate: (sourceNode: any, targetNode: any, sourceAnchor: any, targetAnchor: any) => { validate: (sourceNode: any, targetNode: any, sourceAnchor: any, targetAnchor: any) => {
return !this.graphModel.edges.some( return !isLoop(sourceNode.id,targetNode.id)
(item) =>
item.sourceAnchorId === sourceAnchor.id || item.targetAnchorId === targetAnchor.id
)
} }
}) })

View File

@ -129,16 +129,12 @@ export class WorkFlowInstance {
const edge_list = this.edges.filter((edge) => edge.sourceAnchorId == source_anchor_id) const edge_list = this.edges.filter((edge) => edge.sourceAnchorId == source_anchor_id)
if (edge_list.length == 0) { if (edge_list.length == 0) {
throw `${node.properties.stepName} 节点的${branch.type}分支需要连接` throw `${node.properties.stepName} 节点的${branch.type}分支需要连接`
} else if (edge_list.length > 1) {
throw `${node.properties.stepName} 节点的${branch.type}分支不能连接俩个节点`
} }
} }
} else { } else {
const edge_list = this.edges.filter((edge) => edge.sourceNodeId == node.id) const edge_list = this.edges.filter((edge) => edge.sourceNodeId == node.id)
if (edge_list.length == 0 && !end_nodes.includes(node.type)) { if (edge_list.length == 0 && !end_nodes.includes(node.type)) {
throw `${node.properties.stepName} 节点不能当做结束节点` throw `${node.properties.stepName} 节点不能当做结束节点`
} else if (edge_list.length > 1) {
throw `${node.properties.stepName} 节点不能连接俩个节点`
} }
} }
if (node.properties.status && node.properties.status !== 200) { if (node.properties.status && node.properties.status !== 200) {