99 lines
3.2 KiB
Python
99 lines
3.2 KiB
Python
import logging
|
|
from functools import wraps
|
|
from typing import Any
|
|
|
|
try:
|
|
from mcp import ClientSession, types
|
|
except ImportError:
|
|
from mcp.client.session import ClientSession
|
|
from mcp import types
|
|
|
|
from utils.log_util.context import g
|
|
|
|
logger = logging.getLogger("app")
|
|
|
|
_PATCHED_ATTR = "_catalog_trace_meta_patched"
|
|
_TRACE_META_TOOL_NAMES = {"rag_retrieve", "table_rag_retrieve"}
|
|
|
|
|
|
def _get_trace_id() -> str:
|
|
try:
|
|
trace_id = getattr(g, "trace_id", "")
|
|
except (LookupError, KeyError):
|
|
return ""
|
|
return str(trace_id) if trace_id else ""
|
|
|
|
|
|
def _get_tool_name(args: tuple[Any, ...], kwargs: dict[str, Any]) -> str:
|
|
name = args[0] if args else kwargs.get("name")
|
|
return str(name) if name else ""
|
|
|
|
|
|
def patch_mcp_client_session_trace_meta() -> None:
|
|
"""Attach catalog trace id to MCP tools/call params._meta."""
|
|
if getattr(ClientSession.call_tool, _PATCHED_ATTR, False):
|
|
return
|
|
|
|
original_call_tool = ClientSession.call_tool
|
|
|
|
@wraps(original_call_tool)
|
|
async def call_tool_with_trace_meta(self: ClientSession, *args: Any, **kwargs: Any) -> Any:
|
|
tool_name = _get_tool_name(args, kwargs)
|
|
trace_id = _get_trace_id() if tool_name in _TRACE_META_TOOL_NAMES else ""
|
|
if trace_id:
|
|
meta = kwargs.get("meta")
|
|
if isinstance(meta, dict):
|
|
meta = {**meta, "trace_id": meta.get("trace_id") or trace_id}
|
|
else:
|
|
meta = {"trace_id": trace_id}
|
|
kwargs["meta"] = meta
|
|
|
|
try:
|
|
return await original_call_tool(self, *args, **kwargs)
|
|
except TypeError as exc:
|
|
if trace_id and "meta" in kwargs and "unexpected keyword argument" in str(exc):
|
|
return await _call_tool_with_meta_compat(self, *args, **kwargs)
|
|
raise
|
|
|
|
setattr(call_tool_with_trace_meta, _PATCHED_ATTR, True)
|
|
ClientSession.call_tool = call_tool_with_trace_meta
|
|
|
|
|
|
async def _call_tool_with_meta_compat(self: ClientSession, *args: Any, **kwargs: Any) -> Any:
|
|
"""Call tools/call with _meta for MCP SDK versions before call_tool(meta=...)."""
|
|
name = _get_tool_name(args, kwargs)
|
|
if not name:
|
|
raise TypeError("call_tool() missing required argument: 'name'")
|
|
|
|
arguments = args[1] if len(args) > 1 else kwargs.get("arguments", kwargs.get("args"))
|
|
read_timeout_seconds = (
|
|
args[2] if len(args) > 2 else kwargs.get("read_timeout_seconds")
|
|
)
|
|
progress_callback = (
|
|
args[3] if len(args) > 3 else kwargs.get("progress_callback")
|
|
)
|
|
meta = kwargs.get("meta")
|
|
|
|
request_meta = meta if isinstance(meta, dict) else None
|
|
result = await self.send_request(
|
|
types.ClientRequest(
|
|
types.CallToolRequest(
|
|
method="tools/call",
|
|
params=types.CallToolRequestParams(
|
|
name=name,
|
|
arguments=arguments,
|
|
_meta=request_meta,
|
|
),
|
|
)
|
|
),
|
|
types.CallToolResult,
|
|
request_read_timeout_seconds=read_timeout_seconds,
|
|
progress_callback=progress_callback,
|
|
)
|
|
|
|
validate_tool_result = getattr(self, "_validate_tool_result", None)
|
|
if validate_tool_result and not result.isError:
|
|
await validate_tool_result(name, result)
|
|
|
|
return result
|