qwen_agent/agent/custom_filesystem_middleware.py
2026-04-11 11:40:43 +08:00

162 lines
6.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
自定义 FilesystemMiddleware - 支持 SKILL.md 文件完整读取
"""
from pathlib import Path
from typing import Annotated, cast
import mimetypes
import warnings
from langchain.tools import ToolRuntime
from langchain_core.messages import ToolMessage
from langchain_core.tools import BaseTool, StructuredTool
from typing_extensions import override
from deepagents.backends import StateBackend
from deepagents.backends.composite import CompositeBackend
from deepagents.backends.protocol import (
BACKEND_TYPES,
ReadResult,
)
from deepagents.backends.utils import _get_file_type, validate_path
from langchain_core.messages.content import ContentBlock
from deepagents.middleware.filesystem import (
DEFAULT_READ_OFFSET,
DEFAULT_READ_LIMIT,
FilesystemMiddleware,
FilesystemState,
READ_FILE_TOOL_DESCRIPTION,
READ_FILE_TRUNCATION_MSG,
NUM_CHARS_PER_TOKEN,
ReadFileSchema,
check_empty_content,
format_content_with_line_numbers,
)
from langgraph.types import Command
# SKILL.md 文件的行数限制(设置为较大的值以完整读取)
SKILL_MD_READ_LIMIT = 100000
class CustomFilesystemMiddleware(FilesystemMiddleware):
"""自定义 FilesystemMiddleware支持 SKILL.md 文件完整读取。
继承自 deepagents.middleware.filesystem.FilesystemMiddleware
覆盖 read_file 工具,使 SKILL.md 文件可以完整读取。
"""
@override
def _create_read_file_tool(self) -> BaseTool:
"""创建自定义的 read_file 工具,支持 SKILL.md 完整读取。"""
tool_description = self._custom_tool_descriptions.get("read_file") or READ_FILE_TOOL_DESCRIPTION
token_limit = self._tool_token_limit_before_evict
def _truncate(content: str, file_path: str, limit: int) -> str:
lines = content.splitlines(keepends=True)
if len(lines) > limit:
lines = lines[:limit]
content = "".join(lines)
if token_limit and len(content) >= NUM_CHARS_PER_TOKEN * token_limit:
truncation_msg = READ_FILE_TRUNCATION_MSG.format(file_path=file_path)
max_content_length = NUM_CHARS_PER_TOKEN * token_limit - len(truncation_msg)
content = content[:max_content_length] + truncation_msg
return content
def _handle_read_result(
read_result: ReadResult | str,
validated_path: str,
tool_call_id: str | None,
offset: int,
limit: int,
) -> ToolMessage | str:
if isinstance(read_result, str):
warnings.warn(
"Returning a plain `str` from `backend.read()` is deprecated. ",
DeprecationWarning,
stacklevel=2,
)
return _truncate(read_result, validated_path, limit)
if read_result.error:
return f"Error: {read_result.error}"
if read_result.file_data is None:
return f"Error: no data returned for '{validated_path}'"
file_type = _get_file_type(validated_path)
content = read_result.file_data["content"]
if file_type != "text":
mime_type = mimetypes.guess_type("file" + Path(validated_path).suffix)[0] or "application/octet-stream"
return ToolMessage(
content_blocks=cast("list[ContentBlock]", [{"type": file_type, "base64": content, "mime_type": mime_type}]),
name="read_file",
tool_call_id=tool_call_id,
additional_kwargs={"read_file_path": validated_path, "read_file_media_type": mime_type},
)
empty_msg = check_empty_content(content)
if empty_msg:
return empty_msg
content = format_content_with_line_numbers(content, start_line=offset + 1)
return _truncate(content, validated_path, limit)
def sync_read_file(
file_path: Annotated[str, "Absolute path to the file to read. Must be absolute, not relative."],
runtime: ToolRuntime[None, FilesystemState],
offset: Annotated[int, "Line number to start reading from (0-indexed). Use for pagination of large files."] = DEFAULT_READ_OFFSET,
limit: Annotated[int, "Maximum number of lines to read. Use for pagination of large files."] = DEFAULT_READ_LIMIT,
) -> ToolMessage | str:
"""Synchronous wrapper for read_file tool with SKILL.md special handling."""
resolved_backend = self._get_backend(runtime)
try:
validated_path = validate_path(file_path)
except ValueError as e:
return f"Error: {e}"
# 如果是 SKILL.md 文件,使用大限制读取完整内容
if validated_path.endswith("SKILL.md") or validated_path.endswith("/SKILL.md"):
limit = SKILL_MD_READ_LIMIT
read_result = resolved_backend.read(validated_path, offset=offset, limit=limit)
return _handle_read_result(read_result, validated_path, runtime.tool_call_id, offset, limit)
async def async_read_file(
file_path: Annotated[str, "Absolute path to the file to read. Must be absolute, not relative."],
runtime: ToolRuntime[None, FilesystemState],
offset: Annotated[int, "Line number to start reading from (0-indexed). Use for pagination of large files."] = DEFAULT_READ_OFFSET,
limit: Annotated[int, "Maximum number of lines to read. Use for pagination of large files."] = DEFAULT_READ_LIMIT,
) -> ToolMessage | str:
"""Asynchronous wrapper for read_file tool with SKILL.md special handling."""
resolved_backend = self._get_backend(runtime)
try:
validated_path = validate_path(file_path)
except ValueError as e:
return f"Error: {e}"
# 如果是 SKILL.md 文件,使用大限制读取完整内容
if validated_path.endswith("SKILL.md") or validated_path.endswith("/SKILL.md"):
limit = SKILL_MD_READ_LIMIT
read_result = await resolved_backend.aread(validated_path, offset=offset, limit=limit)
return _handle_read_result(read_result, validated_path, runtime.tool_call_id, offset, limit)
return StructuredTool.from_function(
name="read_file",
description=tool_description,
func=sync_read_file,
coroutine=async_read_file,
infer_schema=False,
args_schema=ReadFileSchema,
)
def _get_read_file_description(self) -> str:
"""获取 read_file 工具的描述,添加 SKILL.md 完整读取的说明。"""
return READ_FILE_TOOL_DESCRIPTION