qwen_agent/utils/async_file_ops.py
2025-11-27 21:50:03 +08:00

300 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
异步文件操作工具 - 提供高效的异步文件读写功能
"""
import os
import json
import asyncio
import aiofiles
import aiofiles.os
import logging
from typing import Dict, List, Optional, Any
from pathlib import Path
import weakref
import threading
import time
from concurrent.futures import ThreadPoolExecutor
# Configure logger
logger = logging.getLogger('app')
class AsyncFileCache:
"""异步文件缓存管理器"""
def __init__(self, cache_size: int = 1000, ttl: int = 300):
"""
初始化文件缓存
Args:
cache_size: 缓存文件数量限制
ttl: 缓存TTL
"""
self.cache_size = cache_size
self.ttl = ttl
self._cache = {} # {file_path: (content, timestamp)}
self._lock = asyncio.Lock()
self._executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="async_file_io")
async def read_file(self, file_path: str, encoding: str = 'utf-8') -> str:
"""异步读取文件内容,带缓存"""
abs_path = os.path.abspath(file_path)
async with self._lock:
# 检查缓存
if abs_path in self._cache:
content, timestamp = self._cache[abs_path]
if time.time() - timestamp < self.ttl:
return content
# 使用线程池异步读取文件
loop = asyncio.get_running_loop()
try:
# 检查文件是否存在
exists = await loop.run_in_executor(
self._executor, os.path.exists, abs_path
)
if not exists:
return ""
# 读取文件内容
content = await loop.run_in_executor(
self._executor, self._read_text_file, abs_path, encoding
)
# 更新缓存LRU策略
if len(self._cache) >= self.cache_size:
# 删除最旧的缓存项
oldest_key = min(self._cache.keys(),
key=lambda k: self._cache[k][1])
del self._cache[oldest_key]
self._cache[abs_path] = (content, time.time())
return content
except Exception as e:
logger.error(f"Error reading file {abs_path}: {e}")
return ""
def _read_text_file(self, file_path: str, encoding: str) -> str:
"""在线程池中同步读取文本文件"""
try:
with open(file_path, 'r', encoding=encoding) as f:
return f.read()
except Exception:
return ""
async def read_json(self, file_path: str) -> Dict[str, Any]:
"""异步读取JSON文件"""
content = await self.read_file(file_path)
if not content.strip():
return {}
try:
return json.loads(content)
except json.JSONDecodeError as e:
logger.error(f"Error parsing JSON from {file_path}: {e}")
return {}
async def write_file(self, file_path: str, content: str, encoding: str = 'utf-8'):
"""异步写入文件内容"""
abs_path = os.path.abspath(file_path)
# 确保目录存在
dir_path = os.path.dirname(abs_path)
if dir_path:
await aiofiles.os.makedirs(dir_path, exist_ok=True)
# 使用aiofiles异步写入
async with aiofiles.open(file_path, 'w', encoding=encoding) as f:
await f.write(content)
# 更新缓存
async with self._lock:
self._cache[abs_path] = (content, time.time())
async def write_json(self, file_path: str, data: Dict[str, Any], indent: int = 2):
"""异步写入JSON文件"""
content = json.dumps(data, ensure_ascii=False, indent=indent)
await self.write_file(file_path, content)
async def exists(self, file_path: str) -> bool:
"""异步检查文件是否存在"""
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
self._executor, os.path.exists, file_path
)
async def getmtime(self, file_path: str) -> float:
"""异步获取文件修改时间"""
loop = asyncio.get_running_loop()
try:
return await loop.run_in_executor(
self._executor, os.path.getmtime, file_path
)
except OSError:
return 0.0
def invalidate_cache(self, file_path: Optional[str] = None):
"""使缓存失效"""
if file_path:
abs_path = os.path.abspath(file_path)
asyncio.create_task(self._invalidate_single(abs_path))
else:
asyncio.create_task(self._clear_all_cache())
async def _invalidate_single(self, file_path: str):
"""使单个文件缓存失效"""
async with self._lock:
self._cache.pop(file_path, None)
async def _clear_all_cache(self):
"""清空所有缓存"""
async with self._lock:
self._cache.clear()
# 全局文件缓存实例
_global_file_cache: Optional[AsyncFileCache] = None
_cache_lock = threading.Lock()
def get_global_file_cache() -> AsyncFileCache:
"""获取全局文件缓存实例"""
global _global_file_cache
if _global_file_cache is None:
with _cache_lock:
if _global_file_cache is None:
_global_file_cache = AsyncFileCache()
return _global_file_cache
def init_global_file_cache(cache_size: int = 1000, ttl: int = 300) -> AsyncFileCache:
"""初始化全局文件缓存"""
global _global_file_cache
with _cache_lock:
_global_file_cache = AsyncFileCache(cache_size, ttl)
return _global_file_cache
async def async_read_file(file_path: str, encoding: str = 'utf-8') -> str:
"""便捷函数:异步读取文件"""
cache = get_global_file_cache()
return await cache.read_file(file_path, encoding)
async def async_read_json(file_path: str) -> Dict[str, Any]:
"""便捷函数异步读取JSON文件"""
cache = get_global_file_cache()
return await cache.read_json(file_path)
async def async_write_file(file_path: str, content: str, encoding: str = 'utf-8'):
"""便捷函数:异步写入文件"""
cache = get_global_file_cache()
await cache.write_file(file_path, content, encoding)
async def async_write_json(file_path: str, data: Dict[str, Any], indent: int = 2):
"""便捷函数异步写入JSON文件"""
cache = get_global_file_cache()
await cache.write_json(file_path, data, indent)
async def async_file_exists(file_path: str) -> bool:
"""便捷函数:异步检查文件是否存在"""
cache = get_global_file_cache()
return await cache.exists(file_path)
async def async_get_file_mtime(file_path: str) -> float:
"""便捷函数:异步获取文件修改时间"""
cache = get_global_file_cache()
return await cache.getmtime(file_path)
class ParallelFileReader:
"""并行文件读取器"""
def __init__(self, max_workers: int = 8):
"""
初始化并行读取器
Args:
max_workers: 最大工作线程数
"""
self.max_workers = max_workers
self._executor = ThreadPoolExecutor(max_workers=max_workers,
thread_name_prefix="parallel_file_reader")
async def read_multiple_files(self, file_paths: List[str],
encoding: str = 'utf-8') -> Dict[str, str]:
"""并行读取多个文件"""
loop = asyncio.get_running_loop()
# 创建并行任务
tasks = []
for file_path in file_paths:
task = loop.run_in_executor(
self._executor, self._read_text_file_sync, file_path, encoding
)
tasks.append((file_path, task))
# 等待所有任务完成
results = {}
for file_path, task in tasks:
try:
content = await task
results[file_path] = content
except Exception as e:
logger.error(f"Error reading {file_path}: {e}")
results[file_path] = ""
return results
def _read_text_file_sync(self, file_path: str, encoding: str) -> str:
"""在线程池中同步读取文本文件"""
try:
if not os.path.exists(file_path):
return ""
with open(file_path, 'r', encoding=encoding) as f:
return f.read()
except Exception:
return ""
async def read_multiple_json(self, file_paths: List[str]) -> Dict[str, Dict[str, Any]]:
"""并行读取多个JSON文件"""
contents = await self.read_multiple_files(file_paths)
results = {}
for file_path, content in contents.items():
if content.strip():
try:
results[file_path] = json.loads(content)
except json.JSONDecodeError as e:
logger.error(f"Error parsing JSON from {file_path}: {e}")
results[file_path] = {}
else:
results[file_path] = {}
return results
# 全局并行读取器实例
_global_parallel_reader: Optional[ParallelFileReader] = None
_reader_lock = threading.Lock()
def get_global_parallel_reader() -> ParallelFileReader:
"""获取全局并行读取器实例"""
global _global_parallel_reader
if _global_parallel_reader is None:
with _reader_lock:
if _global_parallel_reader is None:
_global_parallel_reader = ParallelFileReader()
return _global_parallel_reader
# 导入time模块
import time