feat: add MCP tool ID and source fields to chat node for enhanced configuration
This commit is contained in:
parent
1875368ea8
commit
e9c8c9581f
@ -31,9 +31,14 @@ class ChatNodeSerializer(serializers.Serializer):
|
|||||||
label='Model settings')
|
label='Model settings')
|
||||||
dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
dialogue_type = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||||
label=_("Context Type"))
|
label=_("Context Type"))
|
||||||
mcp_enable = serializers.BooleanField(required=False,
|
mcp_enable = serializers.BooleanField(required=False, label=_("Whether to enable MCP"))
|
||||||
label=_("Whether to enable MCP"))
|
|
||||||
mcp_servers = serializers.JSONField(required=False, label=_("MCP Server"))
|
mcp_servers = serializers.JSONField(required=False, label=_("MCP Server"))
|
||||||
|
mcp_tool_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("MCP Tool ID"))
|
||||||
|
mcp_source = serializers.CharField(required=False, allow_blank=True, allow_null=True, label=_("MCP Source"))
|
||||||
|
|
||||||
|
tool_enable = serializers.BooleanField(required=False, default=False, label=_("Whether to enable tools"))
|
||||||
|
tool_ids = serializers.ListField(child=serializers.UUIDField(), required=False, allow_empty=True,
|
||||||
|
label=_("Tool IDs"), )
|
||||||
|
|
||||||
|
|
||||||
class IChatNode(INode):
|
class IChatNode(INode):
|
||||||
@ -52,5 +57,9 @@ class IChatNode(INode):
|
|||||||
model_setting=None,
|
model_setting=None,
|
||||||
mcp_enable=False,
|
mcp_enable=False,
|
||||||
mcp_servers=None,
|
mcp_servers=None,
|
||||||
|
mcp_tool_id=None,
|
||||||
|
mcp_source=None,
|
||||||
|
tool_enable=False,
|
||||||
|
tool_ids=None,
|
||||||
**kwargs) -> NodeResult:
|
**kwargs) -> NodeResult:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -8,7 +8,7 @@
|
|||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
@ -23,9 +23,11 @@ from langgraph.prebuilt import create_react_agent
|
|||||||
from application.flow.i_step_node import NodeResult, INode
|
from application.flow.i_step_node import NodeResult, INode
|
||||||
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
|
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
|
||||||
from application.flow.tools import Reasoning
|
from application.flow.tools import Reasoning
|
||||||
|
from common.utils.logger import maxkb_logger
|
||||||
|
from common.utils.tool_code import ToolExecutor
|
||||||
from models_provider.models import Model
|
from models_provider.models import Model
|
||||||
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
|
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
|
||||||
from common.utils.logger import maxkb_logger
|
from tools.models import Tool
|
||||||
|
|
||||||
tool_message_template = """
|
tool_message_template = """
|
||||||
<details>
|
<details>
|
||||||
@ -211,6 +213,10 @@ class BaseChatNode(IChatNode):
|
|||||||
model_setting=None,
|
model_setting=None,
|
||||||
mcp_enable=False,
|
mcp_enable=False,
|
||||||
mcp_servers=None,
|
mcp_servers=None,
|
||||||
|
mcp_tool_id=None,
|
||||||
|
mcp_source=None,
|
||||||
|
tool_enable=False,
|
||||||
|
tool_ids=None,
|
||||||
**kwargs) -> NodeResult:
|
**kwargs) -> NodeResult:
|
||||||
if dialogue_type is None:
|
if dialogue_type is None:
|
||||||
dialogue_type = 'WORKFLOW'
|
dialogue_type = 'WORKFLOW'
|
||||||
@ -234,12 +240,13 @@ class BaseChatNode(IChatNode):
|
|||||||
message_list = self.generate_message_list(system, prompt, history_message)
|
message_list = self.generate_message_list(system, prompt, history_message)
|
||||||
self.context['message_list'] = message_list
|
self.context['message_list'] = message_list
|
||||||
|
|
||||||
if mcp_enable and mcp_servers is not None and '"stdio"' not in mcp_servers:
|
# 处理 MCP 请求
|
||||||
r = mcp_response_generator(chat_model, message_list, mcp_servers)
|
mcp_result = self._handle_mcp_request(
|
||||||
return NodeResult(
|
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_id, tool_ids, chat_model, message_list,
|
||||||
{'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
history_message, question
|
||||||
'history_message': history_message, 'question': question.content}, {},
|
)
|
||||||
_write_context=write_context_stream)
|
if mcp_result:
|
||||||
|
return mcp_result
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
r = chat_model.stream(message_list)
|
r = chat_model.stream(message_list)
|
||||||
@ -252,6 +259,48 @@ class BaseChatNode(IChatNode):
|
|||||||
'history_message': history_message, 'question': question.content}, {},
|
'history_message': history_message, 'question': question.content}, {},
|
||||||
_write_context=write_context)
|
_write_context=write_context)
|
||||||
|
|
||||||
|
def _handle_mcp_request(self, mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_id, tool_ids,
|
||||||
|
chat_model, message_list, history_message, question):
|
||||||
|
if not mcp_enable and not tool_enable:
|
||||||
|
return None
|
||||||
|
|
||||||
|
mcp_servers_config = {}
|
||||||
|
|
||||||
|
if mcp_enable:
|
||||||
|
if mcp_source == 'custom' and mcp_servers is not None and '"stdio"' not in mcp_servers:
|
||||||
|
mcp_servers_config = json.loads(mcp_servers)
|
||||||
|
elif mcp_tool_id:
|
||||||
|
mcp_tool = QuerySet(Tool).filter(id=mcp_tool_id).first()
|
||||||
|
if mcp_tool:
|
||||||
|
mcp_servers_config = json.loads(mcp_tool.code)
|
||||||
|
|
||||||
|
if tool_enable:
|
||||||
|
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
|
||||||
|
self.context['tool_ids'] = tool_ids
|
||||||
|
for tool_id in tool_ids:
|
||||||
|
tool = QuerySet(Tool).filter(id=tool_id).first()
|
||||||
|
executor = ToolExecutor()
|
||||||
|
code = executor.generate_mcp_server_code(tool.code)
|
||||||
|
code_path = f'{executor.sandbox_path}/execute/{tool_id}.py'
|
||||||
|
with open(code_path, 'w') as f:
|
||||||
|
f.write(code)
|
||||||
|
|
||||||
|
tool_config = {
|
||||||
|
'command': 'python',
|
||||||
|
'args': [code_path],
|
||||||
|
'transport': 'stdio',
|
||||||
|
}
|
||||||
|
mcp_servers_config[str(tool.id)] = tool_config
|
||||||
|
|
||||||
|
if len(mcp_servers_config) > 0:
|
||||||
|
r = mcp_response_generator(chat_model, message_list, json.dumps(mcp_servers_config))
|
||||||
|
return NodeResult(
|
||||||
|
{'result': r, 'chat_model': chat_model, 'message_list': message_list,
|
||||||
|
'history_message': history_message, 'question': question.content}, {},
|
||||||
|
_write_context=write_context_stream)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id):
|
def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id):
|
||||||
start_index = len(history_chat_record) - dialogue_number
|
start_index = len(history_chat_record) - dialogue_number
|
||||||
@ -284,6 +333,14 @@ class BaseChatNode(IChatNode):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def get_details(self, index: int, **kwargs):
|
def get_details(self, index: int, **kwargs):
|
||||||
|
# 删除临时生成的MCP代码文件
|
||||||
|
if self.context.get('tool_ids'):
|
||||||
|
executor = ToolExecutor()
|
||||||
|
# 清理工具代码文件,延时删除,避免文件被占用
|
||||||
|
for tool_id in self.context.get('tool_ids'):
|
||||||
|
code_path = f'{executor.sandbox_path}/execute/{tool_id}.py'
|
||||||
|
if os.path.exists(code_path):
|
||||||
|
os.remove(code_path)
|
||||||
return {
|
return {
|
||||||
'name': self.node.properties.get('stepName'),
|
'name': self.node.properties.get('stepName'),
|
||||||
"index": index,
|
"index": index,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user