qwen_agent/agent/agent_memory_cache.py

379 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
基于内存的 Agent 缓存管理模块
使用 cachetools 库实现 TTLCache 和 LRUCache
"""
import hashlib
import json
import logging
import time
import threading
from typing import Any, Optional, Dict, Tuple, List
from collections import OrderedDict
from datetime import datetime, timedelta
import cachetools
from utils.settings import AGENT_CACHE_MAX_SIZE, AGENT_CACHE_TTL, AGENT_CACHE_AUTO_RENEW
logger = logging.getLogger('app')
class AgentMemoryCacheManager:
"""
使用 cachetools 实现的内存缓存管理器
- 基于内存存储,访问速度快
- 支持自动过期时间TTL
- 支持缓存大小限制和 LRU 淘汰策略
- 支持访问时自动延长过期时间
- 线程安全(使用 threading.Lock
"""
def __init__(
self,
max_size: int = 1000, # 默认最多缓存 1000 个 Agent
default_ttl: int = 180, # 默认 3 分钟过期
auto_renew: bool = True # 访问时自动延长过期时间
):
"""
初始化内存缓存管理器
Args:
max_size: 最大缓存项数
default_ttl: 默认过期时间(秒)
auto_renew: 是否在访问时自动延长过期时间
"""
# 使用 TTLCache 实现带过期时间的缓存
self.cache = cachetools.TTLCache(
maxsize=max_size,
ttl=default_ttl,
timer=time.monotonic
)
# 用于存储每个键的过期时间信息(支持自动续期)
self._expire_times: Dict[str, float] = {}
# 用于存储创建时间
self._create_times: Dict[str, float] = {}
# 锁,确保线程安全
self._lock = threading.RLock()
self.default_ttl = default_ttl
self.auto_renew = auto_renew
self.max_size = max_size
# 统计信息
self._hits = 0
self._misses = 0
self._sets = 0
self._evictions = 0
logger.info(f"AgentMemoryCacheManager initialized with max_size: {max_size}, "
f"default_ttl: {default_ttl}s, auto_renew: {auto_renew}")
def get(self, cache_key: str) -> Optional[Any]:
"""
获取缓存的 Agent
Args:
cache_key: 缓存键
Returns:
Agent 对象或 None
"""
with self._lock:
current_time = time.monotonic()
# 首先检查是否过期
if cache_key in self._expire_times:
if current_time > self._expire_times[cache_key]:
# 已过期,清理
self._remove_expired(cache_key)
self._misses += 1
logger.debug(f"Cache miss (expired) for key: {cache_key}")
return None
# 尝试从缓存获取
try:
value = self.cache[cache_key]
# 如果启用自动续期
if self.auto_renew:
self._expire_times[cache_key] = current_time + self.default_ttl
logger.debug(f"Cache hit and renewed for key: {cache_key}")
else:
logger.debug(f"Cache hit for key: {cache_key}")
self._hits += 1
return value
except KeyError:
self._misses += 1
logger.debug(f"Cache miss for key: {cache_key}")
return None
def set(self, cache_key: str, agent: Any, ttl: Optional[int] = None) -> bool:
"""
缓存 Agent 对象
Args:
cache_key: 缓存键
agent: 要缓存的 Agent 对象
ttl: 过期时间(秒),如果为 None 则使用默认值
Returns:
是否成功设置缓存
"""
with self._lock:
try:
if ttl is None:
ttl = self.default_ttl
current_time = time.monotonic()
expire_time = current_time + ttl
# 检查是否需要驱逐项
evicted_key = None
if cache_key not in self.cache and len(self.cache) >= self.max_size:
# cachetools 的 TTLCache 会自动驱逐,但我们要记录
# 先获取可能被驱逐的键
oldest_key = next(iter(self.cache)) if self.cache else None
if oldest_key:
evicted_key = oldest_key
# 设置缓存
self.cache[cache_key] = agent
self._expire_times[cache_key] = expire_time
self._create_times[cache_key] = current_time
# 清理被驱逐的项的元数据
if evicted_key and evicted_key != cache_key:
self._cleanup_metadata(evicted_key)
self._evictions += 1
logger.debug(f"Evicted cache key: {evicted_key}")
self._sets += 1
logger.info(f"Cached agent for key: {cache_key}, ttl: {ttl}s")
return True
except Exception as e:
logger.error(f"Error setting cache for key {cache_key}: {e}")
return False
def delete(self, cache_key: str) -> bool:
"""
删除特定的缓存项
Args:
cache_key: 缓存键
Returns:
是否成功删除
"""
with self._lock:
try:
# 从缓存中删除
deleted = cache_key in self.cache
if deleted:
del self.cache[cache_key]
# 清理元数据
self._cleanup_metadata(cache_key)
if deleted:
logger.info(f"Deleted cache for key: {cache_key}")
else:
logger.warning(f"Cache key not found for deletion: {cache_key}")
return deleted
except Exception as e:
logger.error(f"Error deleting cache for key {cache_key}: {e}")
return False
def _remove_expired(self, cache_key: str):
"""移除过期的缓存项"""
if cache_key in self.cache:
del self.cache[cache_key]
self._cleanup_metadata(cache_key)
def _cleanup_metadata(self, cache_key: str):
"""清理指定键的元数据"""
self._expire_times.pop(cache_key, None)
self._create_times.pop(cache_key, None)
def clear_all(self) -> bool:
"""
清空所有缓存
Returns:
是否成功清空
"""
with self._lock:
try:
count = len(self.cache)
self.cache.clear()
self._expire_times.clear()
self._create_times.clear()
# 重置统计信息
self._hits = 0
self._misses = 0
self._sets = 0
self._evictions = 0
logger.info(f"Cleared all cache entries, total: {count}")
return True
except Exception as e:
logger.error(f"Error clearing all cache: {e}")
return False
def get_stats(self) -> Dict[str, Any]:
"""
获取缓存统计信息
Returns:
包含统计信息的字典
"""
with self._lock:
total_requests = self._hits + self._misses
hit_rate = (self._hits / total_requests * 100) if total_requests > 0 else 0
return {
"type": "memory",
"total_items": len(self.cache),
"max_size": self.max_size,
"default_ttl": self.default_ttl,
"auto_renew": self.auto_renew,
"hits": self._hits,
"misses": self._misses,
"hit_rate_percent": round(hit_rate, 2),
"sets": self._sets,
"evictions": self._evictions,
"memory_usage_mb": round(self._estimate_memory_usage() / 1024 / 1024, 2)
}
def _estimate_memory_usage(self) -> int:
"""估算内存使用量(字节)"""
# 这是一个粗略的估算
import sys
total_size = 0
# 估算缓存项的大小
for key, value in self.cache.items():
total_size += sys.getsizeof(key)
total_size += sys.getsizeof(value)
# 估算元数据的大小
total_size += sys.getsizeof(self._expire_times)
total_size += sys.getsizeof(self._create_times)
return total_size
def cleanup_old_entries(self, max_age_seconds: int = 3600) -> int:
"""
清理超过指定时间的所有缓存项
Args:
max_age_seconds: 最大存在时间(秒)
Returns:
清理的缓存项数量
"""
with self._lock:
current_time = time.monotonic()
keys_to_delete = []
# 查找超过最大时间的项
for cache_key, create_time in self._create_times.items():
age_seconds = current_time - create_time
if age_seconds > max_age_seconds:
keys_to_delete.append(cache_key)
# 删除旧项
deleted_count = 0
for key in keys_to_delete:
if self.delete(key):
deleted_count += 1
logger.info(f"Cleaned up {deleted_count} old cache entries older than {max_age_seconds}s")
return deleted_count
def get_keys(self) -> list:
"""
获取所有缓存键
Returns:
缓存键列表
"""
with self._lock:
return list(self.cache.keys())
def __len__(self) -> int:
"""返回缓存中的项数"""
return len(self.cache)
def get_mcp_tools(self, mcp_settings: dict) -> Optional[List]:
"""
获取缓存的 MCP tools
Args:
mcp_settings: MCP 配置字典
Returns:
缓存的 tools 列表或 None
"""
cache_key = self._get_mcp_cache_key(mcp_settings)
return self.get(cache_key)
def set_mcp_tools(self, mcp_settings: dict, tools: List, ttl: Optional[int] = None) -> bool:
"""
缓存 MCP tools
Args:
mcp_settings: MCP 配置字典
tools: 要缓存的 tools 列表
ttl: 过期时间(秒),如果为 None 则使用默认值
Returns:
是否成功设置缓存
"""
cache_key = self._get_mcp_cache_key(mcp_settings)
return self.set(cache_key, tools, ttl=ttl)
def _get_mcp_cache_key(self, mcp_settings: dict) -> str:
"""
根据 mcp_settings 生成缓存键
Args:
mcp_settings: MCP 配置字典
Returns:
缓存键字符串
"""
# 将 mcp_settings 转换为 JSON 字符串并生成哈希
settings_str = json.dumps(mcp_settings, sort_keys=True)
return f"mcp_tools:{hashlib.md5(settings_str.encode()).hexdigest()}"
# 全局缓存管理器实例
_global_cache_manager: Optional[AgentMemoryCacheManager] = None
def get_memory_cache_manager() -> AgentMemoryCacheManager:
"""
获取全局内存缓存管理器实例(单例模式)
Returns:
AgentMemoryCacheManager 实例
"""
global _global_cache_manager
if _global_cache_manager is None:
# 从 settings 导入配置
_global_cache_manager = AgentMemoryCacheManager(
max_size=AGENT_CACHE_MAX_SIZE,
default_ttl=AGENT_CACHE_TTL,
auto_renew=AGENT_CACHE_AUTO_RENEW
)
return _global_cache_manager