catalog-agent/utils/agent_pool.py
2025-10-17 22:04:10 +08:00

178 lines
5.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
from typing import List, Optional
import logging
logger = logging.getLogger(__name__)
class AgentPool:
"""助手实例池管理器"""
def __init__(self, pool_size: int = 5):
"""
初始化助手实例池
Args:
pool_size: 池中实例的数量默认5个
"""
self.pool_size = pool_size
self.pool: asyncio.Queue = asyncio.Queue(maxsize=pool_size)
self.semaphore = asyncio.Semaphore(pool_size)
self.agents = [] # 保存所有创建的实例引用
async def initialize(self, agent_factory):
"""
初始化实例池,使用工厂函数创建助手实例
Args:
agent_factory: 创建助手实例的工厂函数
"""
logger.info(f"正在初始化助手实例池,大小: {self.pool_size}")
for i in range(self.pool_size):
try:
agent = agent_factory()
await self.pool.put(agent)
self.agents.append(agent)
logger.info(f"助手实例 {i+1}/{self.pool_size} 创建成功")
except Exception as e:
logger.error(f"创建助手实例 {i+1} 失败: {e}")
raise
logger.info("助手实例池初始化完成")
async def get_agent(self, timeout: Optional[float] = 30.0):
"""
获取空闲的助手实例
Args:
timeout: 获取超时时间默认30秒
Returns:
助手实例
Raises:
asyncio.TimeoutError: 获取超时
"""
try:
# 使用信号量控制并发
await asyncio.wait_for(self.semaphore.acquire(), timeout=timeout)
# 从池中获取实例
agent = await asyncio.wait_for(self.pool.get(), timeout=timeout)
logger.debug(f"成功获取助手实例,剩余池大小: {self.pool.qsize()}")
return agent
except asyncio.TimeoutError:
logger.error(f"获取助手实例超时 ({timeout}秒)")
raise
async def release_agent(self, agent):
"""
释放助手实例回池
Args:
agent: 要释放的助手实例
"""
try:
await self.pool.put(agent)
self.semaphore.release()
logger.debug(f"释放助手实例,当前池大小: {self.pool.qsize()}")
except Exception as e:
logger.error(f"释放助手实例失败: {e}")
# 即使释放失败也要释放信号量
self.semaphore.release()
def get_pool_stats(self) -> dict:
"""
获取池状态统计信息
Returns:
包含池状态信息的字典
"""
return {
"pool_size": self.pool_size,
"available_agents": self.pool.qsize(),
"total_agents": len(self.agents),
"in_use_agents": len(self.agents) - self.pool.qsize()
}
async def shutdown(self):
"""关闭实例池,清理资源"""
logger.info("正在关闭助手实例池...")
# 清空队列
while not self.pool.empty():
try:
agent = self.pool.get_nowait()
# 如果有清理方法,调用清理
if hasattr(agent, 'cleanup'):
await agent.cleanup()
except asyncio.QueueEmpty:
break
logger.info("助手实例池已关闭")
# 全局实例池单例
_global_agent_pool: Optional[AgentPool] = None
def get_agent_pool() -> Optional[AgentPool]:
"""获取全局助手实例池"""
return _global_agent_pool
def set_agent_pool(pool: AgentPool):
"""设置全局助手实例池"""
global _global_agent_pool
_global_agent_pool = pool
async def init_global_agent_pool(pool_size: int = 5, agent_factory=None):
"""
初始化全局助手实例池
Args:
pool_size: 池大小
agent_factory: 实例工厂函数
"""
global _global_agent_pool
if _global_agent_pool is not None:
logger.warning("全局助手实例池已存在,跳过初始化")
return
if agent_factory is None:
raise ValueError("必须提供 agent_factory 参数")
_global_agent_pool = AgentPool(pool_size=pool_size)
await _global_agent_pool.initialize(agent_factory)
logger.info("全局助手实例池初始化完成")
async def get_agent_from_pool(timeout: Optional[float] = 30.0):
"""
从全局池获取助手实例
Args:
timeout: 获取超时时间
Returns:
助手实例
"""
if _global_agent_pool is None:
raise RuntimeError("全局助手实例池未初始化")
return await _global_agent_pool.get_agent(timeout)
async def release_agent_to_pool(agent):
"""
释放助手实例到全局池
Args:
agent: 要释放的助手实例
"""
if _global_agent_pool is None:
raise RuntimeError("全局助手实例池未初始化")
await _global_agent_pool.release_agent(agent)