catalog-agent/file_loaded_agent_manager.py
2025-10-14 08:59:19 +08:00

249 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2023
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""文件预加载助手管理器 - 管理基于ZIP URL的助手实例缓存"""
import hashlib
import time
from typing import Dict, List, Optional
from qwen_agent.agents import Assistant
from qwen_agent.log import logger
from gbase_agent import init_agent_service_with_files, update_agent_llm
class FileLoadedAgentManager:
"""文件预加载助手管理器
基于 ZIP URL 缓存助手实例,避免重复创建和文件解析
"""
def __init__(self, max_cached_agents: int = 20):
self.agents: Dict[str, Assistant] = {} # {zip_url_hash: assistant_instance}
self.zip_urls: Dict[str, str] = {} # {zip_url_hash: original_zip_url}
self.access_times: Dict[str, float] = {} # LRU 访问时间管理
self.creation_times: Dict[str, float] = {} # 创建时间记录
self.file_counts: Dict[str, int] = {} # 缓存的文件数量
self.max_cached_agents = max_cached_agents
def _get_zip_url_hash(self, zip_url: str) -> str:
"""获取 ZIP URL 的哈希值作为缓存键"""
return hashlib.md5(zip_url.encode('utf-8')).hexdigest()[:16]
def _update_access_time(self, cache_key: str):
"""更新访问时间LRU 管理)"""
self.access_times[cache_key] = time.time()
def _cleanup_old_agents(self):
"""清理旧的助手实例,基于 LRU 策略"""
if len(self.agents) <= self.max_cached_agents:
return
# 按 LRU 顺序排序,删除最久未访问的实例
sorted_keys = sorted(self.access_times.keys(), key=lambda k: self.access_times[k])
keys_to_remove = sorted_keys[:-self.max_cached_agents]
removed_count = 0
for cache_key in keys_to_remove:
try:
del self.agents[cache_key]
del self.zip_urls[cache_key]
del self.access_times[cache_key]
del self.creation_times[cache_key]
del self.file_counts[cache_key]
removed_count += 1
logger.info(f"清理过期的助手实例缓存: {cache_key}")
except KeyError:
continue
if removed_count > 0:
logger.info(f"已清理 {removed_count} 个过期的助手实例缓存")
async def get_or_create_agent(self,
zip_url: str,
files: List[str],
project_dir: str,
model_name: str = "qwen3-next",
api_key: Optional[str] = None,
model_server: Optional[str] = None,
generate_cfg: Optional[Dict] = None) -> Assistant:
"""获取或创建文件预加载的助手实例
Args:
zip_url: ZIP 文件的 URL
files: 需要预加载的文件路径列表
project_dir: 项目目录路径用于读取system_prompt.md和mcp_settings.json
model_name: 模型名称
api_key: API 密钥
model_server: 模型服务器地址
generate_cfg: 生成配置
Returns:
Assistant: 配置好的助手实例
"""
import os
import json
# 从项目目录读取system_prompt.md和mcp_settings.json
system_prompt_template = ""
system_prompt_path = os.path.join(project_dir, "system_prompt.md")
if os.path.exists(system_prompt_path):
with open(system_prompt_path, "r", encoding="utf-8") as f:
system_prompt_template = f.read().strip()
readme = ""
readme_path = os.path.join(project_dir, "README.md")
if os.path.exists(readme_path):
with open(readme_path, "r", encoding="utf-8") as f:
readme = f.read().strip()
dataset_dir = os.path.join(project_dir, "dataset")
system_prompt = system_prompt_template.replace("{dataset_dir}", str(dataset_dir)).replace("{readme}", str(readme))
mcp_settings = {}
mcp_settings_path = os.path.join(project_dir, "mcp_settings.json")
if os.path.exists(mcp_settings_path):
with open(mcp_settings_path, "r", encoding="utf-8") as f:
mcp_settings = json.load(f)
cache_key = self._get_zip_url_hash(zip_url)
# 检查是否已存在该助手实例
if cache_key in self.agents:
self._update_access_time(cache_key)
agent = self.agents[cache_key]
# 动态更新 LLM 配置(如果参数有变化)
update_agent_llm(agent, model_name, api_key, model_server, generate_cfg)
# 如果从项目目录读取到了system_prompt更新agent的系统消息
if system_prompt:
agent.system_message = system_prompt
logger.info(f"复用现有的助手实例缓存: {cache_key} (文件数: {len(files)})")
return agent
# 清理过期实例
self._cleanup_old_agents()
# 创建新的助手实例,预加载文件
logger.info(f"创建新的助手实例缓存: {cache_key}, 预加载文件数: {len(files)}")
current_time = time.time()
agent = init_agent_service_with_files(
files=files,
model_name=model_name,
api_key=api_key,
model_server=model_server,
generate_cfg=generate_cfg,
system_prompt=system_prompt,
mcp=mcp_settings
)
# 缓存实例
self.agents[cache_key] = agent
self.zip_urls[cache_key] = zip_url
self.access_times[cache_key] = current_time
self.creation_times[cache_key] = current_time
self.file_counts[cache_key] = len(files)
logger.info(f"助手实例缓存创建完成: {cache_key}")
return agent
def get_cache_stats(self) -> Dict:
"""获取缓存统计信息"""
current_time = time.time()
stats = {
"total_cached_agents": len(self.agents),
"max_cached_agents": self.max_cached_agents,
"agents": {}
}
for cache_key, agent in self.agents.items():
stats["agents"][cache_key] = {
"zip_url": self.zip_urls.get(cache_key, "unknown"),
"file_count": self.file_counts.get(cache_key, 0),
"created_at": self.creation_times.get(cache_key, 0),
"last_accessed": self.access_times.get(cache_key, 0),
"age_seconds": int(current_time - self.creation_times.get(cache_key, current_time)),
"idle_seconds": int(current_time - self.access_times.get(cache_key, current_time))
}
return stats
def clear_cache(self) -> int:
"""清空所有缓存
Returns:
int: 清理的实例数量
"""
cache_count = len(self.agents)
self.agents.clear()
self.zip_urls.clear()
self.access_times.clear()
self.creation_times.clear()
self.file_counts.clear()
logger.info(f"已清空所有助手实例缓存,共清理 {cache_count} 个实例")
return cache_count
def remove_cache_by_url(self, zip_url: str) -> bool:
"""根据 ZIP URL 移除特定的缓存
Args:
zip_url: ZIP 文件 URL
Returns:
bool: 是否成功移除
"""
cache_key = self._get_zip_url_hash(zip_url)
if cache_key in self.agents:
del self.agents[cache_key]
del self.zip_urls[cache_key]
del self.access_times[cache_key]
del self.creation_times[cache_key]
del self.file_counts[cache_key]
logger.info(f"已移除特定 ZIP URL 的助手实例缓存: {zip_url}")
return True
return False
def list_cached_zip_urls(self) -> List[str]:
"""列出所有缓存的 ZIP URL"""
return list(self.zip_urls.values())
# 全局文件预加载助手管理器实例
_global_agent_manager: Optional[FileLoadedAgentManager] = None
def get_global_agent_manager() -> FileLoadedAgentManager:
"""获取全局文件预加载助手管理器实例"""
global _global_agent_manager
if _global_agent_manager is None:
_global_agent_manager = FileLoadedAgentManager()
return _global_agent_manager
def init_global_agent_manager(max_cached_agents: int = 20):
"""初始化全局文件预加载助手管理器"""
global _global_agent_manager
_global_agent_manager = FileLoadedAgentManager(max_cached_agents)
return _global_agent_manager