From b78b178c03fdfb03b671266d356ce553c8f258b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Wed, 17 Dec 2025 20:27:06 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4agent=20manager?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent/agent_config.py | 33 ++-- agent/checkpoint_utils.py | 126 +++++++++++++ agent/deep_assistant.py | 22 ++- agent/sharded_agent_manager.py | 323 --------------------------------- docker-compose.yml | 7 +- fastapi_app.py | 3 - routes/chat.py | 65 ++++--- routes/system.py | 125 ------------- utils/__init__.py | 39 +--- utils/connection_pool.py | 212 ---------------------- utils/fastapi_utils.py | 3 - 11 files changed, 204 insertions(+), 754 deletions(-) create mode 100644 agent/checkpoint_utils.py delete mode 100644 agent/sharded_agent_manager.py delete mode 100644 utils/connection_pool.py diff --git a/agent/agent_config.py b/agent/agent_config.py index edee4e1..f3704ff 100644 --- a/agent/agent_config.py +++ b/agent/agent_config.py @@ -72,17 +72,18 @@ class AgentConfig: """从v1请求创建配置""" # 延迟导入避免循环依赖 from .logging_handler import LoggingCallbackHandler - + from utils.fastapi_utils import get_preamble_text if messages is None: messages = [] - - return cls( + + preamble_text, system_prompt = get_preamble_text(request.language, request.system_prompt) + config = cls( bot_id=request.bot_id, api_key=api_key, model_name=request.model, model_server=request.model_server, language=request.language, - system_prompt=request.system_prompt, + system_prompt=system_prompt, mcp_settings=request.mcp_settings, robot_type=request.robot_type, user_identifier=request.user_identifier, @@ -93,37 +94,45 @@ class AgentConfig: tool_response=request.tool_response, generate_cfg=generate_cfg, logging_handler=LoggingCallbackHandler(), - messages=messages + messages=messages, + preamble_text=preamble_text, ) + config.safe_print() + return config + @classmethod def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None): """从v2请求创建配置""" # 延迟导入避免循环依赖 from .logging_handler import LoggingCallbackHandler - + from utils.fastapi_utils import get_preamble_text if messages is None: messages = [] - - return cls( + language = request.language or bot_config.get("language", "zh") + preamble_text, system_prompt = get_preamble_text(language, bot_config.get("system_prompt")) + config = cls( bot_id=request.bot_id, api_key=bot_config.get("api_key"), model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"), model_server=bot_config.get("model_server", ""), - language=request.language or bot_config.get("language", "zh"), - system_prompt=bot_config.get("system_prompt"), + language=language, + system_prompt=system_prompt, mcp_settings=bot_config.get("mcp_settings", []), robot_type=bot_config.get("robot_type", "general_agent"), user_identifier=request.user_identifier, session_id=request.session_id, enable_thinking=request.enable_thinking, project_dir=project_dir, - stream=request.stream, + stream=request.stream, tool_response=request.tool_response, generate_cfg={}, # v2接口不传递额外的generate_cfg logging_handler=LoggingCallbackHandler(), - messages=messages + messages=messages, + preamble_text=preamble_text, ) + config.safe_print() + return config def invoke_config(self): """返回Langchain需要的配置字典""" diff --git a/agent/checkpoint_utils.py b/agent/checkpoint_utils.py new file mode 100644 index 0000000..a346070 --- /dev/null +++ b/agent/checkpoint_utils.py @@ -0,0 +1,126 @@ +"""用于处理 LangGraph checkpoint 相关的工具函数""" + +import logging +from typing import List, Dict, Any, Optional +from langgraph.checkpoint.memory import MemorySaver + +logger = logging.getLogger('app') + + +async def check_checkpoint_history(checkpointer: MemorySaver, thread_id: str) -> bool: + """ + 检查指定的 thread_id 在 checkpointer 中是否已有历史记录 + + Args: + checkpointer: MemorySaver 实例 + thread_id: 线程ID(通常是 session_id) + + Returns: + bool: True 表示有历史记录,False 表示没有 + """ + if not checkpointer or not thread_id: + logger.debug(f"No checkpointer or thread_id: checkpointer={bool(checkpointer)}, thread_id={thread_id}") + return False + + try: + # 获取配置 + config = {"configurable": {"thread_id": thread_id}} + + # 调试信息:检查 checkpointer 类型 + logger.debug(f"Checkpointer type: {type(checkpointer)}") + logger.debug(f"Checkpointer dir: {[attr for attr in dir(checkpointer) if not attr.startswith('_')]}") + + # 先尝试获取最新的 checkpoint + try: + latest_checkpoint = await checkpointer.aget_tuple(config) + logger.debug(f"aget_tuple result: {latest_checkpoint}") + + if latest_checkpoint is not None: + logger.info(f"Found latest checkpoint for thread_id: {thread_id}") + # 解构 checkpoint tuple + checkpoint_config, checkpoint, metadata = latest_checkpoint + logger.debug(f"Checkpoint metadata: {metadata}") + return True + except Exception as e: + logger.warning(f"aget_tuple failed: {e}") + + # 如果没有最新的,再列出所有 + logger.debug(f"No latest checkpoint for thread_id: {thread_id}, checking all checkpoints...") + try: + checkpoints = [] + async for c in checkpointer.alist(config): + checkpoints.append(c) + logger.debug(f"Found checkpoint: {c}") + + # 如果有至少一个 checkpoint,说明有历史记录 + has_history = len(checkpoints) > 0 + + if has_history: + logger.info(f"Found {len(checkpoints)} checkpoints in total for thread_id: {thread_id}") + else: + logger.info(f"No existing history for thread_id: {thread_id}") + + return has_history + except Exception as e: + logger.warning(f"alist failed: {e}") + return False + + except Exception as e: + import traceback + logger.error(f"Error checking checkpoint history for thread_id {thread_id}: {e}") + logger.error(f"Full traceback: {traceback.format_exc()}") + # 出错时保守处理,返回 False + return False + + +def prepare_messages_for_agent( + messages: List[Dict[str, Any]], + has_history: bool +) -> List[Dict[str, Any]]: + """ + 根据是否有历史记录来准备要发送给 agent 的消息 + + Args: + messages: 完整的消息列表 + has_history: 是否已有历史记录 + + Returns: + List[Dict]: 要发送给 agent 的消息列表 + """ + if not messages: + return [] + + # 如果有历史记录,只发送最后一条用户消息 + if has_history: + # 找到最后一条用户消息 + for msg in reversed(messages): + if msg.get('role') == 'user': + logger.info(f"Has history, sending only last user message: {msg.get('content', '')[:50]}...") + return [msg] + + # 如果没有用户消息(理论上不应该发生),返回空列表 + logger.warning("No user message found in messages") + return messages + + # 如果没有历史记录,发送所有消息 + logger.info(f"No history, sending all {len(messages)} messages") + return messages + + +def update_agent_config_for_checkpoint( + config_messages: List[Dict[str, Any]], + has_history: bool +) -> List[Dict[str, Any]]: + """ + 更新 AgentConfig 中的 messages,根据是否有历史记录决定发送哪些消息 + + 这个函数可以在调用 agent 之前使用,避免重复处理消息历史 + + Args: + config_messages: AgentConfig 中的原始消息列表 + has_history: 是否已有历史记录 + + Returns: + List[Dict]: 更新后的消息列表 + """ + return prepare_messages_for_agent(config_messages, has_history) \ No newline at end of file diff --git a/agent/deep_assistant.py b/agent/deep_assistant.py index 5cac35e..4373888 100644 --- a/agent/deep_assistant.py +++ b/agent/deep_assistant.py @@ -15,7 +15,7 @@ from .guideline_middleware import GuidelineMiddleware from .tool_output_length_middleware import ToolOutputLengthMiddleware from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH from agent.agent_config import AgentConfig - +from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async # Utility functions def read_system_prompt(): @@ -57,7 +57,6 @@ async def get_tools_from_mcp(mcp): # 发生异常时返回空列表,避免上层调用报错 return [] - async def init_agent(config: AgentConfig): """ 初始化 Agent,支持持久化内存和对话摘要 @@ -66,9 +65,20 @@ async def init_agent(config: AgentConfig): config: AgentConfig 对象,包含所有初始化参数 mcp: MCP配置(如果为None则使用配置中的mcp_settings) """ + final_system_prompt = await load_system_prompt_async( + config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier + ) + final_mcp_settings = await load_mcp_settings_async( + config.project_dir, config.mcp_settings, config.bot_id, config.robot_type + ) + # 如果没有提供mcp,使用config中的mcp_settings - mcp_settings = config.mcp_settings if config.mcp_settings else read_mcp_settings() - system_prompt = config.system_prompt if config.system_prompt else read_system_prompt() + mcp_settings = final_mcp_settings if final_mcp_settings else read_mcp_settings() + system_prompt = final_system_prompt if final_system_prompt else read_system_prompt() + + config.system_prompt = mcp_settings + config.mcp_settings = system_prompt + mcp_tools = await get_tools_from_mcp(mcp_settings) # 检测或使用指定的提供商 @@ -124,4 +134,8 @@ async def init_agent(config: AgentConfig): checkpointer=checkpointer # 传入 checkpointer 以启用持久化 ) + # 将 checkpointer 作为属性附加到 agent 上,方便访问 + if checkpointer: + agent._checkpointer = checkpointer + return agent \ No newline at end of file diff --git a/agent/sharded_agent_manager.py b/agent/sharded_agent_manager.py deleted file mode 100644 index 0ae424f..0000000 --- a/agent/sharded_agent_manager.py +++ /dev/null @@ -1,323 +0,0 @@ -# Copyright 2023 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""分片助手管理器 - 减少锁竞争的高并发agent缓存系统""" - -import hashlib -import time -import json -import asyncio -from typing import Dict, List, Optional -import threading -import logging - -logger = logging.getLogger('app') - -from agent.deep_assistant import init_agent -from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async -from agent.agent_config import AgentConfig - - -class ShardedAgentManager: - """分片助手管理器 - - 使用分片技术减少锁竞争,支持高并发访问 - """ - - def __init__(self, max_cached_agents: int = 20, shard_count: int = 16): - self.max_cached_agents = max_cached_agents - self.shard_count = shard_count - - # 创建分片 - self.shards = [] - for i in range(shard_count): - shard = { - 'agents': {}, # {cache_key: assistant_instance} - 'unique_ids': {}, # {cache_key: unique_id} - 'access_times': {}, # LRU 访问时间管理 - 'creation_times': {}, # 创建时间记录 - 'lock': asyncio.Lock(), # 每个分片独立锁 - 'creation_locks': {}, # 防止并发创建相同agent的锁 - } - self.shards.append(shard) - - # 用于统计的全局锁(读写分离) - self._stats_lock = threading.RLock() - self._global_stats = { - 'total_requests': 0, - 'cache_hits': 0, - 'cache_misses': 0, - 'agent_creations': 0 - } - - def _get_shard_index(self, cache_key: str) -> int: - """根据缓存键获取分片索引""" - hash_value = int(hashlib.md5(cache_key.encode('utf-8')).hexdigest(), 16) - return hash_value % self.shard_count - - def _get_cache_key(self, bot_id: str, model_name: str = None, api_key: str = None, - model_server: str = None, generate_cfg: Dict = None, - system_prompt: str = None, mcp_settings: List[Dict] = None, - enable_thinking: bool = False) -> str: - """获取包含所有相关参数的哈希值作为缓存键""" - cache_data = { - 'bot_id': bot_id, - 'model_name': model_name or '', - 'api_key': api_key or '', - 'model_server': model_server or '', - 'generate_cfg': json.dumps(generate_cfg or {}, sort_keys=True), - 'system_prompt': system_prompt or '', - 'mcp_settings': json.dumps(mcp_settings or [], sort_keys=True), - 'enable_thinking': enable_thinking - } - - cache_str = json.dumps(cache_data, sort_keys=True) - return hashlib.md5(cache_str.encode('utf-8')).hexdigest()[:16] - - def _update_access_time(self, shard: dict, cache_key: str): - """更新访问时间(LRU 管理)""" - shard['access_times'][cache_key] = time.time() - - def _cleanup_old_agents(self, shard: dict): - """清理分片中的旧助手实例,基于 LRU 策略""" - # 计算每个分片的最大容量 - shard_max_capacity = max(1, self.max_cached_agents // self.shard_count) - - if len(shard['agents']) <= shard_max_capacity: - return - - # 按 LRU 顺序排序,删除最久未访问的实例 - sorted_keys = sorted(shard['access_times'].keys(), - key=lambda k: shard['access_times'][k]) - - keys_to_remove = sorted_keys[:-shard_max_capacity] - removed_count = 0 - - for cache_key in keys_to_remove: - try: - del shard['agents'][cache_key] - del shard['unique_ids'][cache_key] - del shard['access_times'][cache_key] - del shard['creation_times'][cache_key] - shard['creation_locks'].pop(cache_key, None) - removed_count += 1 - logger.info(f"分片清理过期的助手实例缓存: {cache_key}") - except KeyError: - continue - - if removed_count > 0: - logger.info(f"分片已清理 {removed_count} 个过期的助手实例缓存") - - async def get_or_create_agent(self, config: AgentConfig): - """获取或创建文件预加载的助手实例 - - Args: - config: AgentConfig 对象,包含所有初始化参数 - """ - - # 更新请求统计 - with self._stats_lock: - self._global_stats['total_requests'] += 1 - - # 异步加载配置文件(带缓存) - final_system_prompt = await load_system_prompt_async( - config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier - ) - final_mcp_settings = await load_mcp_settings_async( - config.project_dir, config.mcp_settings, config.bot_id, config.robot_type - ) - config.system_prompt = final_system_prompt - config.mcp_settings = final_mcp_settings - cache_key = self._get_cache_key(config.bot_id, config.model_name, config.api_key, config.model_server, - config.generate_cfg, final_system_prompt, final_mcp_settings, - config.enable_thinking) - - # 获取分片 - shard_index = self._get_shard_index(cache_key) - shard = self.shards[shard_index] - - # 使用分片级异步锁防止并发创建相同的agent - async with shard['lock']: - # 检查是否已存在该助手实例 - if cache_key in shard['agents']: - self._update_access_time(shard, cache_key) - agent = shard['agents'][cache_key] - # 更新缓存命中统计 - with self._stats_lock: - self._global_stats['cache_hits'] += 1 - - logger.info(f"分片复用现有的助手实例缓存: {cache_key} (bot_id: {config.bot_id}, shard: {shard_index})") - return agent - - # 更新缓存未命中统计 - with self._stats_lock: - self._global_stats['cache_misses'] += 1 - - # 使用更细粒度的创建锁 - creation_lock = shard['creation_locks'].setdefault(cache_key, asyncio.Lock()) - - # 在分片锁外创建agent,减少锁持有时间 - async with creation_lock: - # 再次检查是否已存在(获取锁后可能有其他请求已创建) - async with shard['lock']: - if cache_key in shard['agents']: - self._update_access_time(shard, cache_key) - agent = shard['agents'][cache_key] - - with self._stats_lock: - self._global_stats['cache_hits'] += 1 - - return agent - - # 清理过期实例 - async with shard['lock']: - self._cleanup_old_agents(shard) - - # 创建新的助手实例 - logger.info(f"分片创建新的助手实例缓存: {cache_key}, bot_id: {config.bot_id}, shard: {shard_index}") - current_time = time.time() - - agent = await init_agent(config) - - # 缓存实例 - async with shard['lock']: - shard['agents'][cache_key] = agent - shard['unique_ids'][cache_key] = config.bot_id - shard['access_times'][cache_key] = current_time - shard['creation_times'][cache_key] = current_time - - # 清理创建锁 - shard['creation_locks'].pop(cache_key, None) - - # 更新创建统计 - with self._stats_lock: - self._global_stats['agent_creations'] += 1 - - logger.info(f"分片助手实例缓存创建完成: {cache_key}, shard: {shard_index}") - return agent - - def get_cache_stats(self) -> Dict: - """获取缓存统计信息""" - current_time = time.time() - total_agents = 0 - agents_info = [] - - for i, shard in enumerate(self.shards): - for cache_key, agent in shard['agents'].items(): - total_agents += 1 - agents_info.append({ - "cache_key": cache_key, - "unique_id": shard['unique_ids'].get(cache_key, "unknown"), - "shard": i, - "created_at": shard['creation_times'].get(cache_key, 0), - "last_accessed": shard['access_times'].get(cache_key, 0), - "age_seconds": int(current_time - shard['creation_times'].get(cache_key, current_time)), - "idle_seconds": int(current_time - shard['access_times'].get(cache_key, current_time)) - }) - - stats = { - "total_cached_agents": total_agents, - "max_cached_agents": self.max_cached_agents, - "shard_count": self.shard_count, - "agents": agents_info - } - - # 添加全局统计 - with self._stats_lock: - stats.update({ - "total_requests": self._global_stats['total_requests'], - "cache_hits": self._global_stats['cache_hits'], - "cache_misses": self._global_stats['cache_misses'], - "cache_hit_rate": ( - self._global_stats['cache_hits'] / max(1, self._global_stats['total_requests']) * 100 - ), - "agent_creations": self._global_stats['agent_creations'] - }) - - return stats - - def clear_cache(self) -> int: - """清空所有缓存""" - cache_count = 0 - - for shard in self.shards: - cache_count += len(shard['agents']) - shard['agents'].clear() - shard['unique_ids'].clear() - shard['access_times'].clear() - shard['creation_times'].clear() - shard['creation_locks'].clear() - - # 重置统计 - with self._stats_lock: - self._global_stats = { - 'total_requests': 0, - 'cache_hits': 0, - 'cache_misses': 0, - 'agent_creations': 0 - } - - logger.info(f"分片管理器已清空所有助手实例缓存,共清理 {cache_count} 个实例") - return cache_count - - def remove_cache_by_unique_id(self, unique_id: str) -> int: - """根据 unique_id 移除所有相关的缓存""" - removed_count = 0 - - for i, shard in enumerate(self.shards): - keys_to_remove = [] - - # 找到所有匹配的 unique_id 的缓存键 - for cache_key, stored_unique_id in shard['unique_ids'].items(): - if stored_unique_id == unique_id: - keys_to_remove.append(cache_key) - - # 移除找到的缓存 - for cache_key in keys_to_remove: - try: - del shard['agents'][cache_key] - del shard['unique_ids'][cache_key] - del shard['access_times'][cache_key] - del shard['creation_times'][cache_key] - shard['creation_locks'].pop(cache_key, None) - removed_count += 1 - logger.info(f"分片 {i} 已移除助手实例缓存: {cache_key} (unique_id: {unique_id})") - except KeyError: - continue - - if removed_count > 0: - logger.info(f"分片管理器已移除 unique_id={unique_id} 的 {removed_count} 个助手实例缓存") - else: - logger.warning(f"分片管理器未找到 unique_id={unique_id} 的缓存实例") - - return removed_count - - -# 全局分片助手管理器实例 -_global_sharded_agent_manager: Optional[ShardedAgentManager] = None - - -def get_global_sharded_agent_manager() -> ShardedAgentManager: - """获取全局分片助手管理器实例""" - global _global_sharded_agent_manager - if _global_sharded_agent_manager is None: - _global_sharded_agent_manager = ShardedAgentManager() - return _global_sharded_agent_manager - - -def init_global_sharded_agent_manager(max_cached_agents: int = 20, shard_count: int = 16): - """初始化全局分片助手管理器""" - global _global_sharded_agent_manager - _global_sharded_agent_manager = ShardedAgentManager(max_cached_agents, shard_count) - return _global_sharded_agent_manager diff --git a/docker-compose.yml b/docker-compose.yml index 0ad2818..4f2efde 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -10,13 +10,12 @@ services: - "8001:8001" environment: # 应用配置 - - PYTHONPATH=/app - - PYTHONUNBUFFERED=1 - - AGENT_POOL_SIZE=2 + - BACKEND_HOST=http://api-dev.gbase.ai + - MAX_CONTEXT_TOKENS=262144 + - DEFAULT_THINKING_ENABLE=true volumes: # 挂载项目数据目录 - ./projects:/app/projects - - ./public:/app/public restart: unless-stopped healthcheck: diff --git a/fastapi_app.py b/fastapi_app.py index d1d721a..38bc72d 100644 --- a/fastapi_app.py +++ b/fastapi_app.py @@ -19,9 +19,6 @@ from utils.log_util.logger import init_with_fastapi # Import route modules from routes import chat, files, projects, system -# Import the system manager from routes.system to access the initialized components -from routes.system import agent_manager, connection_pool, file_cache - app = FastAPI(title="Database Assistant API", version="1.0.0") init_with_fastapi(app) diff --git a/routes/chat.py b/routes/chat.py index 709622a..692d71f 100644 --- a/routes/chat.py +++ b/routes/chat.py @@ -4,32 +4,26 @@ import asyncio from typing import Union, Optional from fastapi import APIRouter, HTTPException, Header from fastapi.responses import StreamingResponse -from pydantic import BaseModel import logging logger = logging.getLogger('app') from utils import ( Message, ChatRequest, ChatResponse ) -from agent.sharded_agent_manager import init_global_sharded_agent_manager from utils.api_models import ChatRequestV2 from utils.fastapi_utils import ( process_messages, create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config, - call_preamble_llm, get_preamble_text, + call_preamble_llm, create_stream_chunk ) from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT from agent.agent_config import AgentConfig +from agent.deep_assistant import init_agent router = APIRouter() -# 初始化全局助手管理器 -agent_manager = init_global_sharded_agent_manager( - max_cached_agents=MAX_CACHED_AGENTS, - shard_count=SHARD_COUNT -) def append_user_last_message(messages: list, content: str) -> bool: @@ -97,13 +91,13 @@ def format_messages_to_chat_history(messages: list) -> str: async def enhanced_generate_stream_response( - agent_manager, + agent, config: AgentConfig ): """增强的渐进式流式响应生成器 - 并发优化版本 Args: - agent_manager: agent管理器 + agent: LangChain agent 对象 config: AgentConfig 对象,包含所有参数 """ try: @@ -137,9 +131,6 @@ async def enhanced_generate_stream_response( # Agent 任务(准备 + 流式处理) async def agent_task(): try: - # 准备 agent - agent = await agent_manager.get_or_create_agent(config) - # 开始流式处理 logger.info(f"Starting agent stream response") chunk_id = 0 @@ -270,25 +261,43 @@ async def create_agent_and_generate_response( Args: config: AgentConfig 对象,包含所有参数 """ - config.safe_print() - config.preamble_text, config.system_prompt = get_preamble_text(config.language, config.system_prompt) + # 获取或创建 agent(需要先创建 agent 才能访问 checkpointer) + agent = await init_agent(config) + + # 如果有 checkpointer,检查是否有历史记录 + if config.session_id: + # 检查 checkpointer + checkpointer = None + if hasattr(agent, '_checkpointer'): + checkpointer = agent._checkpointer + + if checkpointer: + from agent.checkpoint_utils import check_checkpoint_history, prepare_messages_for_agent + has_history = await check_checkpoint_history(checkpointer, config.session_id) + + # 更新 config.messages + config.messages = prepare_messages_for_agent(config.messages, has_history) + + logger.info(f"Session {config.session_id}: has_history={has_history}, sending {len(config.messages)} messages") + else: + logger.warning(f"No checkpointer found for session {config.session_id}") + else: + logger.debug(f"No session_id provided, skipping checkpoint check") # 如果是流式模式,使用增强的流式响应生成器 if config.stream: return StreamingResponse( enhanced_generate_stream_response( - agent_manager=agent_manager, + agent=agent, config=config ), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} ) - - messages = config.messages - # 使用公共函数处理所有逻辑 - agent = await agent_manager.get_or_create_agent(config) - agent_responses = await agent.ainvoke({"messages": messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS) - append_messages = agent_responses["messages"][len(messages):] + + # 使用更新后的 messages + agent_responses = await agent.ainvoke({"messages": config.messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS) + append_messages = agent_responses["messages"][len(config.messages):] response_text = "" for msg in append_messages: if isinstance(msg,AIMessage): @@ -313,9 +322,9 @@ async def create_agent_and_generate_response( "finish_reason": "stop" }], usage={ - "prompt_tokens": sum(len(msg.get("content", "")) for msg in messages), + "prompt_tokens": sum(len(msg.get("content", "")) for msg in config.messages), "completion_tokens": len(response_text), - "total_tokens": sum(len(msg.get("content", "")) for msg in messages) + len(response_text) + "total_tokens": sum(len(msg.get("content", "")) for msg in config.messages) + len(response_text) } ) else: @@ -375,9 +384,7 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] = # 创建 AgentConfig 对象 config = AgentConfig.from_v1_request(request, api_key, project_dir, generate_cfg, messages) # 调用公共的agent创建和响应生成逻辑 - return await create_agent_and_generate_response( - config=config - ) + return await create_agent_and_generate_response(config) except Exception as e: import traceback @@ -451,9 +458,7 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st # 创建 AgentConfig 对象 config = AgentConfig.from_v2_request(request, bot_config, project_dir, messages) # 调用公共的agent创建和响应生成逻辑 - return await create_agent_and_generate_response( - config=config - ) + return await create_agent_and_generate_response(config) except HTTPException: raise diff --git a/routes/system.py b/routes/system.py index f6b44a8..15b9ea8 100644 --- a/routes/system.py +++ b/routes/system.py @@ -6,25 +6,12 @@ from fastapi import APIRouter, HTTPException from pydantic import BaseModel from utils import ( - get_global_connection_pool, init_global_connection_pool, - get_global_file_cache, init_global_file_cache, setup_system_optimizations ) -from agent.sharded_agent_manager import init_global_sharded_agent_manager -try: - from utils.system_optimizer import apply_optimization_profile -except ImportError: - def apply_optimization_profile(profile): - return {"profile": profile, "status": "system_optimizer not available"} from embedding import get_model_manager from pydantic import BaseModel import logging -from utils.settings import ( - MAX_CACHED_AGENTS, SHARD_COUNT, MAX_CONNECTIONS_PER_HOST, MAX_CONNECTIONS_TOTAL, - KEEPALIVE_TIMEOUT, CONNECT_TIMEOUT, TOTAL_TIMEOUT, FILE_CACHE_SIZE, FILE_CACHE_TTL, - TOKENIZERS_PARALLELISM -) logger = logging.getLogger('app') @@ -49,35 +36,8 @@ class EncodeResponse(BaseModel): logger.info("正在初始化系统优化...") system_optimizer = setup_system_optimizations() -# 全局助手管理器配置(使用优化后的配置) -max_cached_agents = MAX_CACHED_AGENTS # 增加缓存大小 -shard_count = SHARD_COUNT # 分片数量 - -# 初始化优化的全局助手管理器 -agent_manager = init_global_sharded_agent_manager( - max_cached_agents=max_cached_agents, - shard_count=shard_count -) - -# 初始化连接池 -connection_pool = init_global_connection_pool( - max_connections_per_host=MAX_CONNECTIONS_PER_HOST, - max_connections_total=MAX_CONNECTIONS_TOTAL, - keepalive_timeout=KEEPALIVE_TIMEOUT, - connect_timeout=CONNECT_TIMEOUT, - total_timeout=TOTAL_TIMEOUT -) - -# 初始化文件缓存 -file_cache = init_global_file_cache( - cache_size=FILE_CACHE_SIZE, - ttl=FILE_CACHE_TTL -) logger.info("系统优化初始化完成") -logger.info(f"- 分片Agent管理器: {shard_count} 个分片,最多缓存 {max_cached_agents} 个agent") -logger.info(f"- 连接池: 每主机100连接,总计500连接") -logger.info(f"- 文件缓存: 1000个文件,TTL 300秒") @router.get("/api/health") @@ -90,9 +50,6 @@ async def health_check(): async def get_performance_stats(): """获取系统性能统计信息""" try: - # 获取agent管理器统计 - agent_stats = agent_manager.get_cache_stats() - # 获取连接池统计(简化版) pool_stats = { "connection_pool": "active", @@ -128,7 +85,6 @@ async def get_performance_stats(): "success": True, "timestamp": int(time.time()), "performance": { - "agent_manager": agent_stats, "connection_pool": pool_stats, "file_cache": file_cache_stats, "system": system_stats @@ -140,87 +96,6 @@ async def get_performance_stats(): raise HTTPException(status_code=500, detail=f"获取性能统计失败: {str(e)}") -@router.post("/api/v1/system/optimize") -async def optimize_system(profile: str = "balanced"): - """应用系统优化配置""" - try: - # 应用优化配置 - config = apply_optimization_profile(profile) - - return { - "success": True, - "message": f"已应用 {profile} 优化配置", - "config": config - } - - except Exception as e: - logger.error(f"Error applying optimization profile: {str(e)}") - raise HTTPException(status_code=500, detail=f"应用优化配置失败: {str(e)}") - - -@router.post("/api/v1/system/clear-cache") -async def clear_system_cache(cache_type: Optional[str] = None): - """清理系统缓存""" - try: - cleared_counts = {} - - if cache_type is None or cache_type == "agent": - # 清理agent缓存 - agent_count = agent_manager.clear_cache() - cleared_counts["agent_cache"] = agent_count - - if cache_type is None or cache_type == "file": - # 清理文件缓存 - if hasattr(file_cache, '_cache'): - file_count = len(file_cache._cache) - file_cache._cache.clear() - cleared_counts["file_cache"] = file_count - - return { - "success": True, - "message": f"已清理指定类型的缓存", - "cleared_counts": cleared_counts - } - - except Exception as e: - logger.error(f"Error clearing cache: {str(e)}") - raise HTTPException(status_code=500, detail=f"清理缓存失败: {str(e)}") - - -@router.get("/api/v1/system/config") -async def get_system_config(): - """获取当前系统配置""" - try: - return { - "success": True, - "config": { - "max_cached_agents": max_cached_agents, - "shard_count": shard_count, - "tokenizer_parallelism": TOKENIZERS_PARALLELISM, - "max_connections_per_host": str(MAX_CONNECTIONS_PER_HOST), - "max_connections_total": str(MAX_CONNECTIONS_TOTAL), - "file_cache_size": str(FILE_CACHE_SIZE), - "file_cache_ttl": str(FILE_CACHE_TTL) - } - } - - except Exception as e: - logger.error(f"Error getting system config: {str(e)}") - raise HTTPException(status_code=500, detail=f"获取系统配置失败: {str(e)}") - - -@router.post("/system/remove-project-cache") -async def remove_project_cache(dataset_id: str): - """移除特定项目的缓存""" - try: - removed_count = agent_manager.remove_cache_by_unique_id(dataset_id) - if removed_count > 0: - return {"message": f"项目缓存移除成功: {dataset_id}", "removed_count": removed_count} - else: - return {"message": f"未找到项目缓存: {dataset_id}", "removed_count": 0} - except Exception as e: - raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}") - @router.post("/api/v1/embedding/encode", response_model=EncodeResponse) async def encode_texts(request: EncodeRequest): diff --git a/utils/__init__.py b/utils/__init__.py index 7c1037f..c0ae1aa 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -31,47 +31,10 @@ from .project_manager import ( ) - -from .connection_pool import ( - HTTPConnectionPool, - get_global_connection_pool, - init_global_connection_pool, - OAIWithConnectionPool -) - -from .async_file_ops import ( - AsyncFileCache, - get_global_file_cache, - init_global_file_cache, - async_read_file, - async_read_json, - async_write_file, - async_write_json, - async_file_exists, - async_get_file_mtime, - ParallelFileReader, - get_global_parallel_reader -) - - from .system_optimizer import ( - SystemOptimizer, - AsyncioOptimizer, - setup_system_optimizations, - create_performance_monitor, - get_optimized_worker_config, - OPTIMIZATION_CONFIGS, - apply_optimization_profile, - get_global_system_optimizer + setup_system_optimizations ) -# Import config cache module -# Note: This has been moved to agent package -# from .config_cache import ( -# config_cache, -# ConfigFileCache -# ) - from .agent_pool import ( AgentPool, get_agent_pool, diff --git a/utils/connection_pool.py b/utils/connection_pool.py deleted file mode 100644 index 66c90a9..0000000 --- a/utils/connection_pool.py +++ /dev/null @@ -1,212 +0,0 @@ -#!/usr/bin/env python3 -""" -HTTP连接池管理器 - 提供高效的HTTP连接复用 -""" - -import aiohttp -import asyncio -from typing import Dict, Optional, Any -import threading -import time -import weakref - - -class HTTPConnectionPool: - """HTTP连接池管理器 - - 提供连接复用、 Keep-Alive 连接、合理超时设置等功能 - """ - - def __init__(self, - max_connections_per_host: int = 100, - max_connections_total: int = 500, - keepalive_timeout: int = 30, - connect_timeout: int = 10, - total_timeout: int = 60): - """ - 初始化连接池 - - Args: - max_connections_per_host: 每个主机的最大连接数 - max_connections_total: 总连接数限制 - keepalive_timeout: Keep-Alive超时时间(秒) - connect_timeout: 连接超时时间(秒) - total_timeout: 总请求超时时间(秒) - """ - self.max_connections_per_host = max_connections_per_host - self.max_connections_total = max_connections_total - self.keepalive_timeout = keepalive_timeout - self.connect_timeout = connect_timeout - self.total_timeout = total_timeout - - # 创建连接器配置 - self.connector_config = { - 'limit': max_connections_total, - 'limit_per_host': max_connections_per_host, - 'keepalive_timeout': keepalive_timeout, - 'enable_cleanup_closed': True, # 自动清理关闭的连接 - 'force_close': False, # 不强制关闭连接 - 'use_dns_cache': True, # 使用DNS缓存 - 'ttl_dns_cache': 300, # DNS缓存TTL - } - - # 使用线程本地存储来管理事件循环间的session - self._sessions = weakref.WeakKeyDictionary() - self._lock = threading.RLock() - - def _create_session(self) -> aiohttp.ClientSession: - """创建新的aiohttp会话""" - timeout = aiohttp.ClientTimeout( - total=self.total_timeout, - connect=self.connect_timeout, - sock_connect=self.connect_timeout, - sock_read=self.total_timeout - ) - - connector = aiohttp.TCPConnector(**self.connector_config) - - return aiohttp.ClientSession( - connector=connector, - timeout=timeout, - headers={ - 'User-Agent': 'QwenAgent/1.0', - 'Connection': 'keep-alive', - 'Accept-Encoding': 'gzip, deflate, br', - } - ) - - def get_session(self) -> aiohttp.ClientSession: - """获取当前事件循环的session""" - loop = asyncio.get_running_loop() - - with self._lock: - if loop not in self._sessions: - self._sessions[loop] = self._create_session() - return self._sessions[loop] - - async def request(self, method: str, url: str, **kwargs) -> aiohttp.ClientResponse: - """发送HTTP请求,自动处理连接复用""" - session = self.get_session() - return await session.request(method, url, **kwargs) - - async def get(self, url: str, **kwargs) -> aiohttp.ClientResponse: - """发送GET请求""" - return await self.request('GET', url, **kwargs) - - async def post(self, url: str, **kwargs) -> aiohttp.ClientResponse: - """发送POST请求""" - return await self.request('POST', url, **kwargs) - - async def close(self): - """关闭所有session""" - with self._lock: - for loop, session in list(self._sessions.items()): - if not session.closed: - await session.close() - self._sessions.clear() - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() - - -# 全局连接池实例 -_global_connection_pool: Optional[HTTPConnectionPool] = None -_pool_lock = threading.Lock() - - -def get_global_connection_pool() -> HTTPConnectionPool: - """获取全局连接池实例""" - global _global_connection_pool - if _global_connection_pool is None: - with _pool_lock: - if _global_connection_pool is None: - _global_connection_pool = HTTPConnectionPool() - return _global_connection_pool - - -def init_global_connection_pool( - max_connections_per_host: int = 100, - max_connections_total: int = 500, - keepalive_timeout: int = 30, - connect_timeout: int = 10, - total_timeout: int = 60 -) -> HTTPConnectionPool: - """初始化全局连接池""" - global _global_connection_pool - with _pool_lock: - _global_connection_pool = HTTPConnectionPool( - max_connections_per_host=max_connections_per_host, - max_connections_total=max_connections_total, - keepalive_timeout=keepalive_timeout, - connect_timeout=connect_timeout, - total_timeout=total_timeout - ) - return _global_connection_pool - - -class OAIWithConnectionPool: - """带有连接池的OpenAI API客户端""" - - def __init__(self, - config: Dict[str, Any], - connection_pool: Optional[HTTPConnectionPool] = None): - """ - 初始化客户端 - - Args: - config: OpenAI API配置 - connection_pool: 可选的连接池实例 - """ - self.config = config - self.pool = connection_pool or get_global_connection_pool() - self.base_url = config.get('model_server', '').rstrip('/') - if not self.base_url: - self.base_url = "https://api.openai.com/v1" - - self.api_key = config.get('api_key', '') - self.model = config.get('model', 'gpt-3.5-turbo') - self.generate_cfg = config.get('generate_cfg', {}) - - async def chat_completions(self, messages: list, stream: bool = False, **kwargs): - """发送聊天完成请求""" - url = f"{self.base_url}/chat/completions" - - headers = { - 'Authorization': f'Bearer {self.api_key}', - 'Content-Type': 'application/json', - } - - data = { - 'model': self.model, - 'messages': messages, - 'stream': stream, - **self.generate_cfg, - **kwargs - } - - async with self.pool.post(url, json=data, headers=headers) as response: - if response.status == 200: - if stream: - return self._handle_stream_response(response) - else: - return await response.json() - else: - error_text = await response.text() - raise Exception(f"API request failed with status {response.status}: {error_text}") - - async def _handle_stream_response(self, response: aiohttp.ClientResponse): - """处理流式响应""" - async for line in response.content: - line = line.decode('utf-8').strip() - if line.startswith('data: '): - data = line[6:] # 移除 'data: ' 前缀 - if data == '[DONE]': - break - try: - import json - yield json.loads(data) - except json.JSONDecodeError: - continue \ No newline at end of file diff --git a/utils/fastapi_utils.py b/utils/fastapi_utils.py index 282ad51..6145f03 100644 --- a/utils/fastapi_utils.py +++ b/utils/fastapi_utils.py @@ -249,9 +249,6 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li parts = re.split(r'\[(THINK|PREAMBLE|TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"]) current_tag = None - assistant_content = "" - function_calls = [] - tool_responses = [] tool_id_counter = 0 # 添加唯一的工具调用计数器 tool_id_list = [] for i in range(0, len(parts)):