162 lines
6.7 KiB
Python
162 lines
6.7 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
自定义 FilesystemMiddleware - 支持 SKILL.md 文件完整读取
|
||
"""
|
||
|
||
from pathlib import Path
|
||
from typing import Annotated, cast
|
||
import mimetypes
|
||
import warnings
|
||
|
||
from langchain.tools import ToolRuntime
|
||
from langchain_core.messages import ToolMessage
|
||
from langchain_core.tools import BaseTool, StructuredTool
|
||
from typing_extensions import override
|
||
|
||
from deepagents.backends import StateBackend
|
||
from deepagents.backends.composite import CompositeBackend
|
||
from deepagents.backends.protocol import (
|
||
BACKEND_TYPES,
|
||
ReadResult,
|
||
)
|
||
from deepagents.backends.utils import _get_file_type, validate_path
|
||
from langchain_core.messages.content import ContentBlock
|
||
from deepagents.middleware.filesystem import (
|
||
DEFAULT_READ_OFFSET,
|
||
DEFAULT_READ_LIMIT,
|
||
FilesystemMiddleware,
|
||
FilesystemState,
|
||
READ_FILE_TOOL_DESCRIPTION,
|
||
READ_FILE_TRUNCATION_MSG,
|
||
NUM_CHARS_PER_TOKEN,
|
||
ReadFileSchema,
|
||
check_empty_content,
|
||
format_content_with_line_numbers,
|
||
)
|
||
|
||
from langgraph.types import Command
|
||
|
||
|
||
# SKILL.md 文件的行数限制(设置为较大的值以完整读取)
|
||
SKILL_MD_READ_LIMIT = 100000
|
||
|
||
|
||
class CustomFilesystemMiddleware(FilesystemMiddleware):
|
||
"""自定义 FilesystemMiddleware,支持 SKILL.md 文件完整读取。
|
||
|
||
继承自 deepagents.middleware.filesystem.FilesystemMiddleware,
|
||
覆盖 read_file 工具,使 SKILL.md 文件可以完整读取。
|
||
"""
|
||
|
||
@override
|
||
def _create_read_file_tool(self) -> BaseTool:
|
||
"""创建自定义的 read_file 工具,支持 SKILL.md 完整读取。"""
|
||
tool_description = self._custom_tool_descriptions.get("read_file") or READ_FILE_TOOL_DESCRIPTION
|
||
token_limit = self._tool_token_limit_before_evict
|
||
|
||
def _truncate(content: str, file_path: str, limit: int) -> str:
|
||
lines = content.splitlines(keepends=True)
|
||
if len(lines) > limit:
|
||
lines = lines[:limit]
|
||
content = "".join(lines)
|
||
|
||
if token_limit and len(content) >= NUM_CHARS_PER_TOKEN * token_limit:
|
||
truncation_msg = READ_FILE_TRUNCATION_MSG.format(file_path=file_path)
|
||
max_content_length = NUM_CHARS_PER_TOKEN * token_limit - len(truncation_msg)
|
||
content = content[:max_content_length] + truncation_msg
|
||
|
||
return content
|
||
|
||
def _handle_read_result(
|
||
read_result: ReadResult | str,
|
||
validated_path: str,
|
||
tool_call_id: str | None,
|
||
offset: int,
|
||
limit: int,
|
||
) -> ToolMessage | str:
|
||
if isinstance(read_result, str):
|
||
warnings.warn(
|
||
"Returning a plain `str` from `backend.read()` is deprecated. ",
|
||
DeprecationWarning,
|
||
stacklevel=2,
|
||
)
|
||
return _truncate(read_result, validated_path, limit)
|
||
|
||
if read_result.error:
|
||
return f"Error: {read_result.error}"
|
||
|
||
if read_result.file_data is None:
|
||
return f"Error: no data returned for '{validated_path}'"
|
||
|
||
file_type = _get_file_type(validated_path)
|
||
content = read_result.file_data["content"]
|
||
|
||
if file_type != "text":
|
||
mime_type = mimetypes.guess_type("file" + Path(validated_path).suffix)[0] or "application/octet-stream"
|
||
return ToolMessage(
|
||
content_blocks=cast("list[ContentBlock]", [{"type": file_type, "base64": content, "mime_type": mime_type}]),
|
||
name="read_file",
|
||
tool_call_id=tool_call_id,
|
||
additional_kwargs={"read_file_path": validated_path, "read_file_media_type": mime_type},
|
||
)
|
||
|
||
empty_msg = check_empty_content(content)
|
||
if empty_msg:
|
||
return empty_msg
|
||
|
||
content = format_content_with_line_numbers(content, start_line=offset + 1)
|
||
return _truncate(content, validated_path, limit)
|
||
|
||
def sync_read_file(
|
||
file_path: Annotated[str, "Absolute path to the file to read. Must be absolute, not relative."],
|
||
runtime: ToolRuntime[None, FilesystemState],
|
||
offset: Annotated[int, "Line number to start reading from (0-indexed). Use for pagination of large files."] = DEFAULT_READ_OFFSET,
|
||
limit: Annotated[int, "Maximum number of lines to read. Use for pagination of large files."] = DEFAULT_READ_LIMIT,
|
||
) -> ToolMessage | str:
|
||
"""Synchronous wrapper for read_file tool with SKILL.md special handling."""
|
||
resolved_backend = self._get_backend(runtime)
|
||
try:
|
||
validated_path = validate_path(file_path)
|
||
except ValueError as e:
|
||
return f"Error: {e}"
|
||
|
||
# 如果是 SKILL.md 文件,使用大限制读取完整内容
|
||
if validated_path.endswith("SKILL.md") or validated_path.endswith("/SKILL.md"):
|
||
limit = SKILL_MD_READ_LIMIT
|
||
|
||
read_result = resolved_backend.read(validated_path, offset=offset, limit=limit)
|
||
return _handle_read_result(read_result, validated_path, runtime.tool_call_id, offset, limit)
|
||
|
||
async def async_read_file(
|
||
file_path: Annotated[str, "Absolute path to the file to read. Must be absolute, not relative."],
|
||
runtime: ToolRuntime[None, FilesystemState],
|
||
offset: Annotated[int, "Line number to start reading from (0-indexed). Use for pagination of large files."] = DEFAULT_READ_OFFSET,
|
||
limit: Annotated[int, "Maximum number of lines to read. Use for pagination of large files."] = DEFAULT_READ_LIMIT,
|
||
) -> ToolMessage | str:
|
||
"""Asynchronous wrapper for read_file tool with SKILL.md special handling."""
|
||
resolved_backend = self._get_backend(runtime)
|
||
try:
|
||
validated_path = validate_path(file_path)
|
||
except ValueError as e:
|
||
return f"Error: {e}"
|
||
|
||
# 如果是 SKILL.md 文件,使用大限制读取完整内容
|
||
if validated_path.endswith("SKILL.md") or validated_path.endswith("/SKILL.md"):
|
||
limit = SKILL_MD_READ_LIMIT
|
||
|
||
read_result = await resolved_backend.aread(validated_path, offset=offset, limit=limit)
|
||
return _handle_read_result(read_result, validated_path, runtime.tool_call_id, offset, limit)
|
||
|
||
return StructuredTool.from_function(
|
||
name="read_file",
|
||
description=tool_description,
|
||
func=sync_read_file,
|
||
coroutine=async_read_file,
|
||
infer_schema=False,
|
||
args_schema=ReadFileSchema,
|
||
)
|
||
|
||
def _get_read_file_description(self) -> str:
|
||
"""获取 read_file 工具的描述,添加 SKILL.md 完整读取的说明。"""
|
||
return READ_FILE_TOOL_DESCRIPTION
|