qwen_agent/agent/sharded_agent_manager.py
2025-12-16 16:06:47 +08:00

324 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2023
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""分片助手管理器 - 减少锁竞争的高并发agent缓存系统"""
import hashlib
import time
import json
import asyncio
from typing import Dict, List, Optional
import threading
import logging
logger = logging.getLogger('app')
from agent.deep_assistant import init_agent
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
from utils.agent_config import AgentConfig
class ShardedAgentManager:
"""分片助手管理器
使用分片技术减少锁竞争,支持高并发访问
"""
def __init__(self, max_cached_agents: int = 20, shard_count: int = 16):
self.max_cached_agents = max_cached_agents
self.shard_count = shard_count
# 创建分片
self.shards = []
for i in range(shard_count):
shard = {
'agents': {}, # {cache_key: assistant_instance}
'unique_ids': {}, # {cache_key: unique_id}
'access_times': {}, # LRU 访问时间管理
'creation_times': {}, # 创建时间记录
'lock': asyncio.Lock(), # 每个分片独立锁
'creation_locks': {}, # 防止并发创建相同agent的锁
}
self.shards.append(shard)
# 用于统计的全局锁(读写分离)
self._stats_lock = threading.RLock()
self._global_stats = {
'total_requests': 0,
'cache_hits': 0,
'cache_misses': 0,
'agent_creations': 0
}
def _get_shard_index(self, cache_key: str) -> int:
"""根据缓存键获取分片索引"""
hash_value = int(hashlib.md5(cache_key.encode('utf-8')).hexdigest(), 16)
return hash_value % self.shard_count
def _get_cache_key(self, bot_id: str, model_name: str = None, api_key: str = None,
model_server: str = None, generate_cfg: Dict = None,
system_prompt: str = None, mcp_settings: List[Dict] = None,
enable_thinking: bool = False) -> str:
"""获取包含所有相关参数的哈希值作为缓存键"""
cache_data = {
'bot_id': bot_id,
'model_name': model_name or '',
'api_key': api_key or '',
'model_server': model_server or '',
'generate_cfg': json.dumps(generate_cfg or {}, sort_keys=True),
'system_prompt': system_prompt or '',
'mcp_settings': json.dumps(mcp_settings or [], sort_keys=True),
'enable_thinking': enable_thinking
}
cache_str = json.dumps(cache_data, sort_keys=True)
return hashlib.md5(cache_str.encode('utf-8')).hexdigest()[:16]
def _update_access_time(self, shard: dict, cache_key: str):
"""更新访问时间LRU 管理)"""
shard['access_times'][cache_key] = time.time()
def _cleanup_old_agents(self, shard: dict):
"""清理分片中的旧助手实例,基于 LRU 策略"""
# 计算每个分片的最大容量
shard_max_capacity = max(1, self.max_cached_agents // self.shard_count)
if len(shard['agents']) <= shard_max_capacity:
return
# 按 LRU 顺序排序,删除最久未访问的实例
sorted_keys = sorted(shard['access_times'].keys(),
key=lambda k: shard['access_times'][k])
keys_to_remove = sorted_keys[:-shard_max_capacity]
removed_count = 0
for cache_key in keys_to_remove:
try:
del shard['agents'][cache_key]
del shard['unique_ids'][cache_key]
del shard['access_times'][cache_key]
del shard['creation_times'][cache_key]
shard['creation_locks'].pop(cache_key, None)
removed_count += 1
logger.info(f"分片清理过期的助手实例缓存: {cache_key}")
except KeyError:
continue
if removed_count > 0:
logger.info(f"分片已清理 {removed_count} 个过期的助手实例缓存")
async def get_or_create_agent(self, config: AgentConfig):
"""获取或创建文件预加载的助手实例
Args:
config: AgentConfig 对象,包含所有初始化参数
"""
# 更新请求统计
with self._stats_lock:
self._global_stats['total_requests'] += 1
# 异步加载配置文件(带缓存)
final_system_prompt = await load_system_prompt_async(
config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier
)
final_mcp_settings = await load_mcp_settings_async(
config.project_dir, config.mcp_settings, config.bot_id, config.robot_type
)
config.system_prompt = final_system_prompt
config.mcp_settings = final_mcp_settings
cache_key = self._get_cache_key(config.bot_id, config.model_name, config.api_key, config.model_server,
config.generate_cfg, final_system_prompt, final_mcp_settings,
config.enable_thinking)
# 获取分片
shard_index = self._get_shard_index(cache_key)
shard = self.shards[shard_index]
# 使用分片级异步锁防止并发创建相同的agent
async with shard['lock']:
# 检查是否已存在该助手实例
if cache_key in shard['agents']:
self._update_access_time(shard, cache_key)
agent = shard['agents'][cache_key]
# 更新缓存命中统计
with self._stats_lock:
self._global_stats['cache_hits'] += 1
logger.info(f"分片复用现有的助手实例缓存: {cache_key} (bot_id: {config.bot_id}, shard: {shard_index})")
return agent
# 更新缓存未命中统计
with self._stats_lock:
self._global_stats['cache_misses'] += 1
# 使用更细粒度的创建锁
creation_lock = shard['creation_locks'].setdefault(cache_key, asyncio.Lock())
# 在分片锁外创建agent减少锁持有时间
async with creation_lock:
# 再次检查是否已存在(获取锁后可能有其他请求已创建)
async with shard['lock']:
if cache_key in shard['agents']:
self._update_access_time(shard, cache_key)
agent = shard['agents'][cache_key]
with self._stats_lock:
self._global_stats['cache_hits'] += 1
return agent
# 清理过期实例
async with shard['lock']:
self._cleanup_old_agents(shard)
# 创建新的助手实例
logger.info(f"分片创建新的助手实例缓存: {cache_key}, bot_id: {config.bot_id}, shard: {shard_index}")
current_time = time.time()
agent = await init_agent(config)
# 缓存实例
async with shard['lock']:
shard['agents'][cache_key] = agent
shard['unique_ids'][cache_key] = config.bot_id
shard['access_times'][cache_key] = current_time
shard['creation_times'][cache_key] = current_time
# 清理创建锁
shard['creation_locks'].pop(cache_key, None)
# 更新创建统计
with self._stats_lock:
self._global_stats['agent_creations'] += 1
logger.info(f"分片助手实例缓存创建完成: {cache_key}, shard: {shard_index}")
return agent
def get_cache_stats(self) -> Dict:
"""获取缓存统计信息"""
current_time = time.time()
total_agents = 0
agents_info = []
for i, shard in enumerate(self.shards):
for cache_key, agent in shard['agents'].items():
total_agents += 1
agents_info.append({
"cache_key": cache_key,
"unique_id": shard['unique_ids'].get(cache_key, "unknown"),
"shard": i,
"created_at": shard['creation_times'].get(cache_key, 0),
"last_accessed": shard['access_times'].get(cache_key, 0),
"age_seconds": int(current_time - shard['creation_times'].get(cache_key, current_time)),
"idle_seconds": int(current_time - shard['access_times'].get(cache_key, current_time))
})
stats = {
"total_cached_agents": total_agents,
"max_cached_agents": self.max_cached_agents,
"shard_count": self.shard_count,
"agents": agents_info
}
# 添加全局统计
with self._stats_lock:
stats.update({
"total_requests": self._global_stats['total_requests'],
"cache_hits": self._global_stats['cache_hits'],
"cache_misses": self._global_stats['cache_misses'],
"cache_hit_rate": (
self._global_stats['cache_hits'] / max(1, self._global_stats['total_requests']) * 100
),
"agent_creations": self._global_stats['agent_creations']
})
return stats
def clear_cache(self) -> int:
"""清空所有缓存"""
cache_count = 0
for shard in self.shards:
cache_count += len(shard['agents'])
shard['agents'].clear()
shard['unique_ids'].clear()
shard['access_times'].clear()
shard['creation_times'].clear()
shard['creation_locks'].clear()
# 重置统计
with self._stats_lock:
self._global_stats = {
'total_requests': 0,
'cache_hits': 0,
'cache_misses': 0,
'agent_creations': 0
}
logger.info(f"分片管理器已清空所有助手实例缓存,共清理 {cache_count} 个实例")
return cache_count
def remove_cache_by_unique_id(self, unique_id: str) -> int:
"""根据 unique_id 移除所有相关的缓存"""
removed_count = 0
for i, shard in enumerate(self.shards):
keys_to_remove = []
# 找到所有匹配的 unique_id 的缓存键
for cache_key, stored_unique_id in shard['unique_ids'].items():
if stored_unique_id == unique_id:
keys_to_remove.append(cache_key)
# 移除找到的缓存
for cache_key in keys_to_remove:
try:
del shard['agents'][cache_key]
del shard['unique_ids'][cache_key]
del shard['access_times'][cache_key]
del shard['creation_times'][cache_key]
shard['creation_locks'].pop(cache_key, None)
removed_count += 1
logger.info(f"分片 {i} 已移除助手实例缓存: {cache_key} (unique_id: {unique_id})")
except KeyError:
continue
if removed_count > 0:
logger.info(f"分片管理器已移除 unique_id={unique_id}{removed_count} 个助手实例缓存")
else:
logger.warning(f"分片管理器未找到 unique_id={unique_id} 的缓存实例")
return removed_count
# 全局分片助手管理器实例
_global_sharded_agent_manager: Optional[ShardedAgentManager] = None
def get_global_sharded_agent_manager() -> ShardedAgentManager:
"""获取全局分片助手管理器实例"""
global _global_sharded_agent_manager
if _global_sharded_agent_manager is None:
_global_sharded_agent_manager = ShardedAgentManager()
return _global_sharded_agent_manager
def init_global_sharded_agent_manager(max_cached_agents: int = 20, shard_count: int = 16):
"""初始化全局分片助手管理器"""
global _global_sharded_agent_manager
_global_sharded_agent_manager = ShardedAgentManager(max_cached_agents, shard_count)
return _global_sharded_agent_manager