qwen_agent/agent/custom_filesystem_middleware.py
2026-03-02 12:55:39 +08:00

180 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
自定义 FilesystemMiddleware - 支持 SKILL.md 文件完整读取
"""
from pathlib import Path
from typing import Annotated, Literal, cast
from langchain.tools import ToolRuntime
from langchain_core.messages import ToolMessage
from langchain_core.tools import BaseTool, StructuredTool
from typing_extensions import override
from deepagents.backends import StateBackend
from deepagents.backends.composite import CompositeBackend
from deepagents.backends.protocol import (
BACKEND_TYPES,
)
from deepagents.middleware.filesystem import (
DEFAULT_READ_OFFSET,
DEFAULT_READ_LIMIT,
IMAGE_EXTENSIONS,
IMAGE_MEDIA_TYPES,
FilesystemMiddleware,
FilesystemState,
)
from langgraph.types import Command
import base64
from langchain_core.messages.content import create_image_block
# SKILL.md 文件的行数限制(设置为较大的值以完整读取)
SKILL_MD_READ_LIMIT = 100000
class CustomFilesystemMiddleware(FilesystemMiddleware):
"""自定义 FilesystemMiddleware支持 SKILL.md 文件完整读取。
继承自 deepagents.middleware.filesystem.FilesystemMiddleware
覆盖 read_file 工具,使 SKILL.md 文件可以完整读取。
"""
@override
def _create_read_file_tool(self) -> BaseTool:
"""创建自定义的 read_file 工具,支持 SKILL.md 完整读取。"""
# 从父类获取工具描述
tool_description = self._custom_tool_descriptions.get("read_file") or self._get_read_file_description()
token_limit = self._tool_token_limit_before_evict
def sync_read_file(
file_path: Annotated[str, "Absolute path to the file to read. Must be absolute, not relative."],
runtime: ToolRuntime[None, FilesystemState],
offset: Annotated[int, "Line number to start reading from (0-indexed). Use for pagination of large files."] = DEFAULT_READ_OFFSET,
limit: Annotated[int, "Maximum number of lines to read. Use for pagination of large files."] = DEFAULT_READ_LIMIT,
) -> ToolMessage | str:
"""Synchronous wrapper for read_file tool with SKILL.md special handling."""
from deepagents.backends.utils import validate_path
from deepagents.middleware.filesystem import READ_FILE_TRUNCATION_MSG, NUM_CHARS_PER_TOKEN
resolved_backend = self._get_backend(runtime)
try:
validated_path = validate_path(file_path)
except ValueError as e:
return f"Error: {e}"
ext = Path(validated_path).suffix.lower()
# 处理图片文件
if ext in IMAGE_EXTENSIONS:
responses = resolved_backend.download_files([validated_path])
if responses and responses[0].content is not None:
media_type = IMAGE_MEDIA_TYPES.get(ext, "image/png")
image_b64 = base64.standard_b64encode(responses[0].content).decode("utf-8")
return ToolMessage(
content_blocks=[create_image_block(base64=image_b64, mime_type=media_type)],
name="read_file",
tool_call_id=runtime.tool_call_id,
additional_kwargs={
"read_file_path": validated_path,
"read_file_media_type": media_type,
},
)
if responses and responses[0].error:
return f"Error reading image: {responses[0].error}"
return "Error reading image: unknown error"
# 如果是 SKILL.md 文件,使用大限制读取完整内容
if validated_path.endswith("SKILL.md") or validated_path.endswith("/SKILL.md"):
actual_limit = SKILL_MD_READ_LIMIT
else:
actual_limit = limit
result = resolved_backend.read(validated_path, offset=offset, limit=actual_limit)
lines = result.splitlines(keepends=True)
if len(lines) > actual_limit:
lines = lines[:actual_limit]
result = "".join(lines)
# Check if result exceeds token threshold and truncate if necessary
if token_limit and len(result) >= NUM_CHARS_PER_TOKEN * token_limit:
# Calculate truncation message length to ensure final result stays under threshold
truncation_msg = READ_FILE_TRUNCATION_MSG.format(file_path=validated_path)
max_content_length = NUM_CHARS_PER_TOKEN * token_limit - len(truncation_msg)
result = result[:max_content_length]
result += truncation_msg
return result
async def async_read_file(
file_path: Annotated[str, "Absolute path to the file to read. Must be absolute, not relative."],
runtime: ToolRuntime[None, FilesystemState],
offset: Annotated[int, "Line number to start reading from (0-indexed). Use for pagination of large files."] = DEFAULT_READ_OFFSET,
limit: Annotated[int, "Maximum number of lines to read. Use for pagination of large files."] = DEFAULT_READ_LIMIT,
) -> ToolMessage | str:
"""Asynchronous wrapper for read_file tool with SKILL.md special handling."""
from deepagents.backends.utils import validate_path
from deepagents.middleware.filesystem import READ_FILE_TRUNCATION_MSG, NUM_CHARS_PER_TOKEN
resolved_backend = self._get_backend(runtime)
try:
validated_path = validate_path(file_path)
except ValueError as e:
return f"Error: {e}"
ext = Path(validated_path).suffix.lower()
# 处理图片文件
if ext in IMAGE_EXTENSIONS:
responses = await resolved_backend.adownload_files([validated_path])
if responses and responses[0].content is not None:
media_type = IMAGE_MEDIA_TYPES.get(ext, "image/png")
image_b64 = base64.standard_b64encode(responses[0].content).decode("utf-8")
return ToolMessage(
content_blocks=[create_image_block(base64=image_b64, mime_type=media_type)],
name="read_file",
tool_call_id=runtime.tool_call_id,
additional_kwargs={
"read_file_path": validated_path,
"read_file_media_type": media_type,
},
)
if responses and responses[0].error:
return f"Error reading image: {responses[0].error}"
return "Error reading image: unknown error"
# 如果是 SKILL.md 文件,使用大限制读取完整内容
if validated_path.endswith("SKILL.md") or validated_path.endswith("/SKILL.md"):
actual_limit = SKILL_MD_READ_LIMIT
else:
actual_limit = limit
result = await resolved_backend.aread(validated_path, offset=offset, limit=actual_limit)
lines = result.splitlines(keepends=True)
if len(lines) > actual_limit:
lines = lines[:actual_limit]
result = "".join(lines)
# Check if result exceeds token threshold and truncate if necessary
if token_limit and len(result) >= NUM_CHARS_PER_TOKEN * token_limit:
# Calculate truncation message length to ensure final result stays under threshold
truncation_msg = READ_FILE_TRUNCATION_MSG.format(file_path=validated_path)
max_content_length = NUM_CHARS_PER_TOKEN * token_limit - len(truncation_msg)
result = result[:max_content_length]
result += truncation_msg
return result
return StructuredTool.from_function(
name="read_file",
description=tool_description,
func=sync_read_file,
coroutine=async_read_file,
)
def _get_read_file_description(self) -> str:
"""获取 read_file 工具的描述,添加 SKILL.md 完整读取的说明。"""
from deepagents.middleware.filesystem import READ_FILE_TOOL_DESCRIPTION
return READ_FILE_TOOL_DESCRIPTION