337 lines
14 KiB
Python
337 lines
14 KiB
Python
# Copyright 2023
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
"""分片助手管理器 - 减少锁竞争的高并发agent缓存系统"""
|
||
|
||
import hashlib
|
||
import time
|
||
import json
|
||
import asyncio
|
||
from typing import Dict, List, Optional
|
||
import threading
|
||
import logging
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
from agent.deep_assistant import init_agent
|
||
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
||
|
||
|
||
class ShardedAgentManager:
|
||
"""分片助手管理器
|
||
|
||
使用分片技术减少锁竞争,支持高并发访问
|
||
"""
|
||
|
||
def __init__(self, max_cached_agents: int = 20, shard_count: int = 16):
|
||
self.max_cached_agents = max_cached_agents
|
||
self.shard_count = shard_count
|
||
|
||
# 创建分片
|
||
self.shards = []
|
||
for i in range(shard_count):
|
||
shard = {
|
||
'agents': {}, # {cache_key: assistant_instance}
|
||
'unique_ids': {}, # {cache_key: unique_id}
|
||
'access_times': {}, # LRU 访问时间管理
|
||
'creation_times': {}, # 创建时间记录
|
||
'lock': asyncio.Lock(), # 每个分片独立锁
|
||
'creation_locks': {}, # 防止并发创建相同agent的锁
|
||
}
|
||
self.shards.append(shard)
|
||
|
||
# 用于统计的全局锁(读写分离)
|
||
self._stats_lock = threading.RLock()
|
||
self._global_stats = {
|
||
'total_requests': 0,
|
||
'cache_hits': 0,
|
||
'cache_misses': 0,
|
||
'agent_creations': 0
|
||
}
|
||
|
||
def _get_shard_index(self, cache_key: str) -> int:
|
||
"""根据缓存键获取分片索引"""
|
||
hash_value = int(hashlib.md5(cache_key.encode('utf-8')).hexdigest(), 16)
|
||
return hash_value % self.shard_count
|
||
|
||
def _get_cache_key(self, bot_id: str, model_name: str = None, api_key: str = None,
|
||
model_server: str = None, generate_cfg: Dict = None,
|
||
system_prompt: str = None, mcp_settings: List[Dict] = None) -> str:
|
||
"""获取包含所有相关参数的哈希值作为缓存键"""
|
||
cache_data = {
|
||
'bot_id': bot_id,
|
||
'model_name': model_name or '',
|
||
'api_key': api_key or '',
|
||
'model_server': model_server or '',
|
||
'generate_cfg': json.dumps(generate_cfg or {}, sort_keys=True),
|
||
'system_prompt': system_prompt or '',
|
||
'mcp_settings': json.dumps(mcp_settings or [], sort_keys=True)
|
||
}
|
||
|
||
cache_str = json.dumps(cache_data, sort_keys=True)
|
||
return hashlib.md5(cache_str.encode('utf-8')).hexdigest()[:16]
|
||
|
||
def _update_access_time(self, shard: dict, cache_key: str):
|
||
"""更新访问时间(LRU 管理)"""
|
||
shard['access_times'][cache_key] = time.time()
|
||
|
||
def _cleanup_old_agents(self, shard: dict):
|
||
"""清理分片中的旧助手实例,基于 LRU 策略"""
|
||
# 计算每个分片的最大容量
|
||
shard_max_capacity = max(1, self.max_cached_agents // self.shard_count)
|
||
|
||
if len(shard['agents']) <= shard_max_capacity:
|
||
return
|
||
|
||
# 按 LRU 顺序排序,删除最久未访问的实例
|
||
sorted_keys = sorted(shard['access_times'].keys(),
|
||
key=lambda k: shard['access_times'][k])
|
||
|
||
keys_to_remove = sorted_keys[:-shard_max_capacity]
|
||
removed_count = 0
|
||
|
||
for cache_key in keys_to_remove:
|
||
try:
|
||
del shard['agents'][cache_key]
|
||
del shard['unique_ids'][cache_key]
|
||
del shard['access_times'][cache_key]
|
||
del shard['creation_times'][cache_key]
|
||
shard['creation_locks'].pop(cache_key, None)
|
||
removed_count += 1
|
||
logger.info(f"分片清理过期的助手实例缓存: {cache_key}")
|
||
except KeyError:
|
||
continue
|
||
|
||
if removed_count > 0:
|
||
logger.info(f"分片已清理 {removed_count} 个过期的助手实例缓存")
|
||
|
||
async def get_or_create_agent(self,
|
||
bot_id: str,
|
||
project_dir: Optional[str],
|
||
model_name: str = "qwen3-next",
|
||
api_key: Optional[str] = None,
|
||
model_server: Optional[str] = None,
|
||
generate_cfg: Optional[Dict] = None,
|
||
language: Optional[str] = None,
|
||
system_prompt: Optional[str] = None,
|
||
mcp_settings: Optional[List[Dict]] = None,
|
||
robot_type: Optional[str] = "general_agent",
|
||
user_identifier: Optional[str] = None):
|
||
"""获取或创建文件预加载的助手实例"""
|
||
|
||
# 更新请求统计
|
||
with self._stats_lock:
|
||
self._global_stats['total_requests'] += 1
|
||
|
||
# 异步加载配置文件(带缓存)
|
||
final_system_prompt = await load_system_prompt_async(
|
||
project_dir, language, system_prompt, robot_type, bot_id, user_identifier
|
||
)
|
||
final_mcp_settings = await load_mcp_settings_async(
|
||
project_dir, mcp_settings, bot_id, robot_type
|
||
)
|
||
|
||
cache_key = self._get_cache_key(bot_id, model_name, api_key, model_server,
|
||
generate_cfg, final_system_prompt, final_mcp_settings)
|
||
|
||
# 获取分片
|
||
shard_index = self._get_shard_index(cache_key)
|
||
shard = self.shards[shard_index]
|
||
|
||
# 使用分片级异步锁防止并发创建相同的agent
|
||
async with shard['lock']:
|
||
# 检查是否已存在该助手实例
|
||
if cache_key in shard['agents']:
|
||
self._update_access_time(shard, cache_key)
|
||
agent = shard['agents'][cache_key]
|
||
# 更新缓存命中统计
|
||
with self._stats_lock:
|
||
self._global_stats['cache_hits'] += 1
|
||
|
||
logger.info(f"分片复用现有的助手实例缓存: {cache_key} (bot_id: {bot_id}, shard: {shard_index})")
|
||
return agent
|
||
|
||
# 更新缓存未命中统计
|
||
with self._stats_lock:
|
||
self._global_stats['cache_misses'] += 1
|
||
|
||
# 使用更细粒度的创建锁
|
||
creation_lock = shard['creation_locks'].setdefault(cache_key, asyncio.Lock())
|
||
|
||
# 在分片锁外创建agent,减少锁持有时间
|
||
async with creation_lock:
|
||
# 再次检查是否已存在(获取锁后可能有其他请求已创建)
|
||
async with shard['lock']:
|
||
if cache_key in shard['agents']:
|
||
self._update_access_time(shard, cache_key)
|
||
agent = shard['agents'][cache_key]
|
||
|
||
with self._stats_lock:
|
||
self._global_stats['cache_hits'] += 1
|
||
|
||
return agent
|
||
|
||
# 清理过期实例
|
||
async with shard['lock']:
|
||
self._cleanup_old_agents(shard)
|
||
|
||
# 创建新的助手实例
|
||
logger.info(f"分片创建新的助手实例缓存: {cache_key}, bot_id: {bot_id}, shard: {shard_index}")
|
||
current_time = time.time()
|
||
|
||
agent = await init_agent(
|
||
bot_id=bot_id,
|
||
model_name=model_name,
|
||
api_key=api_key,
|
||
model_server=model_server,
|
||
generate_cfg=generate_cfg,
|
||
system_prompt=final_system_prompt,
|
||
mcp=final_mcp_settings,
|
||
robot_type=robot_type,
|
||
language=language,
|
||
user_identifier=user_identifier,
|
||
)
|
||
|
||
# 缓存实例
|
||
async with shard['lock']:
|
||
shard['agents'][cache_key] = agent
|
||
shard['unique_ids'][cache_key] = bot_id
|
||
shard['access_times'][cache_key] = current_time
|
||
shard['creation_times'][cache_key] = current_time
|
||
|
||
# 清理创建锁
|
||
shard['creation_locks'].pop(cache_key, None)
|
||
|
||
# 更新创建统计
|
||
with self._stats_lock:
|
||
self._global_stats['agent_creations'] += 1
|
||
|
||
logger.info(f"分片助手实例缓存创建完成: {cache_key}, shard: {shard_index}")
|
||
return agent
|
||
|
||
def get_cache_stats(self) -> Dict:
|
||
"""获取缓存统计信息"""
|
||
current_time = time.time()
|
||
total_agents = 0
|
||
agents_info = []
|
||
|
||
for i, shard in enumerate(self.shards):
|
||
for cache_key, agent in shard['agents'].items():
|
||
total_agents += 1
|
||
agents_info.append({
|
||
"cache_key": cache_key,
|
||
"unique_id": shard['unique_ids'].get(cache_key, "unknown"),
|
||
"shard": i,
|
||
"created_at": shard['creation_times'].get(cache_key, 0),
|
||
"last_accessed": shard['access_times'].get(cache_key, 0),
|
||
"age_seconds": int(current_time - shard['creation_times'].get(cache_key, current_time)),
|
||
"idle_seconds": int(current_time - shard['access_times'].get(cache_key, current_time))
|
||
})
|
||
|
||
stats = {
|
||
"total_cached_agents": total_agents,
|
||
"max_cached_agents": self.max_cached_agents,
|
||
"shard_count": self.shard_count,
|
||
"agents": agents_info
|
||
}
|
||
|
||
# 添加全局统计
|
||
with self._stats_lock:
|
||
stats.update({
|
||
"total_requests": self._global_stats['total_requests'],
|
||
"cache_hits": self._global_stats['cache_hits'],
|
||
"cache_misses": self._global_stats['cache_misses'],
|
||
"cache_hit_rate": (
|
||
self._global_stats['cache_hits'] / max(1, self._global_stats['total_requests']) * 100
|
||
),
|
||
"agent_creations": self._global_stats['agent_creations']
|
||
})
|
||
|
||
return stats
|
||
|
||
def clear_cache(self) -> int:
|
||
"""清空所有缓存"""
|
||
cache_count = 0
|
||
|
||
for shard in self.shards:
|
||
cache_count += len(shard['agents'])
|
||
shard['agents'].clear()
|
||
shard['unique_ids'].clear()
|
||
shard['access_times'].clear()
|
||
shard['creation_times'].clear()
|
||
shard['creation_locks'].clear()
|
||
|
||
# 重置统计
|
||
with self._stats_lock:
|
||
self._global_stats = {
|
||
'total_requests': 0,
|
||
'cache_hits': 0,
|
||
'cache_misses': 0,
|
||
'agent_creations': 0
|
||
}
|
||
|
||
logger.info(f"分片管理器已清空所有助手实例缓存,共清理 {cache_count} 个实例")
|
||
return cache_count
|
||
|
||
def remove_cache_by_unique_id(self, unique_id: str) -> int:
|
||
"""根据 unique_id 移除所有相关的缓存"""
|
||
removed_count = 0
|
||
|
||
for i, shard in enumerate(self.shards):
|
||
keys_to_remove = []
|
||
|
||
# 找到所有匹配的 unique_id 的缓存键
|
||
for cache_key, stored_unique_id in shard['unique_ids'].items():
|
||
if stored_unique_id == unique_id:
|
||
keys_to_remove.append(cache_key)
|
||
|
||
# 移除找到的缓存
|
||
for cache_key in keys_to_remove:
|
||
try:
|
||
del shard['agents'][cache_key]
|
||
del shard['unique_ids'][cache_key]
|
||
del shard['access_times'][cache_key]
|
||
del shard['creation_times'][cache_key]
|
||
shard['creation_locks'].pop(cache_key, None)
|
||
removed_count += 1
|
||
logger.info(f"分片 {i} 已移除助手实例缓存: {cache_key} (unique_id: {unique_id})")
|
||
except KeyError:
|
||
continue
|
||
|
||
if removed_count > 0:
|
||
logger.info(f"分片管理器已移除 unique_id={unique_id} 的 {removed_count} 个助手实例缓存")
|
||
else:
|
||
logger.warning(f"分片管理器未找到 unique_id={unique_id} 的缓存实例")
|
||
|
||
return removed_count
|
||
|
||
|
||
# 全局分片助手管理器实例
|
||
_global_sharded_agent_manager: Optional[ShardedAgentManager] = None
|
||
|
||
|
||
def get_global_sharded_agent_manager() -> ShardedAgentManager:
|
||
"""获取全局分片助手管理器实例"""
|
||
global _global_sharded_agent_manager
|
||
if _global_sharded_agent_manager is None:
|
||
_global_sharded_agent_manager = ShardedAgentManager()
|
||
return _global_sharded_agent_manager
|
||
|
||
|
||
def init_global_sharded_agent_manager(max_cached_agents: int = 20, shard_count: int = 16):
|
||
"""初始化全局分片助手管理器"""
|
||
global _global_sharded_agent_manager
|
||
_global_sharded_agent_manager = ShardedAgentManager(max_cached_agents, shard_count)
|
||
return _global_sharded_agent_manager
|