删除agent manager

This commit is contained in:
朱潮 2025-12-17 20:27:06 +08:00
parent 23bc62a2b8
commit b78b178c03
11 changed files with 204 additions and 754 deletions

View File

@ -72,17 +72,18 @@ class AgentConfig:
"""从v1请求创建配置""" """从v1请求创建配置"""
# 延迟导入避免循环依赖 # 延迟导入避免循环依赖
from .logging_handler import LoggingCallbackHandler from .logging_handler import LoggingCallbackHandler
from utils.fastapi_utils import get_preamble_text
if messages is None: if messages is None:
messages = [] messages = []
return cls( preamble_text, system_prompt = get_preamble_text(request.language, request.system_prompt)
config = cls(
bot_id=request.bot_id, bot_id=request.bot_id,
api_key=api_key, api_key=api_key,
model_name=request.model, model_name=request.model,
model_server=request.model_server, model_server=request.model_server,
language=request.language, language=request.language,
system_prompt=request.system_prompt, system_prompt=system_prompt,
mcp_settings=request.mcp_settings, mcp_settings=request.mcp_settings,
robot_type=request.robot_type, robot_type=request.robot_type,
user_identifier=request.user_identifier, user_identifier=request.user_identifier,
@ -93,25 +94,30 @@ class AgentConfig:
tool_response=request.tool_response, tool_response=request.tool_response,
generate_cfg=generate_cfg, generate_cfg=generate_cfg,
logging_handler=LoggingCallbackHandler(), logging_handler=LoggingCallbackHandler(),
messages=messages messages=messages,
preamble_text=preamble_text,
) )
config.safe_print()
return config
@classmethod @classmethod
def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None): def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None):
"""从v2请求创建配置""" """从v2请求创建配置"""
# 延迟导入避免循环依赖 # 延迟导入避免循环依赖
from .logging_handler import LoggingCallbackHandler from .logging_handler import LoggingCallbackHandler
from utils.fastapi_utils import get_preamble_text
if messages is None: if messages is None:
messages = [] messages = []
language = request.language or bot_config.get("language", "zh")
return cls( preamble_text, system_prompt = get_preamble_text(language, bot_config.get("system_prompt"))
config = cls(
bot_id=request.bot_id, bot_id=request.bot_id,
api_key=bot_config.get("api_key"), api_key=bot_config.get("api_key"),
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"), model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
model_server=bot_config.get("model_server", ""), model_server=bot_config.get("model_server", ""),
language=request.language or bot_config.get("language", "zh"), language=language,
system_prompt=bot_config.get("system_prompt"), system_prompt=system_prompt,
mcp_settings=bot_config.get("mcp_settings", []), mcp_settings=bot_config.get("mcp_settings", []),
robot_type=bot_config.get("robot_type", "general_agent"), robot_type=bot_config.get("robot_type", "general_agent"),
user_identifier=request.user_identifier, user_identifier=request.user_identifier,
@ -122,8 +128,11 @@ class AgentConfig:
tool_response=request.tool_response, tool_response=request.tool_response,
generate_cfg={}, # v2接口不传递额外的generate_cfg generate_cfg={}, # v2接口不传递额外的generate_cfg
logging_handler=LoggingCallbackHandler(), logging_handler=LoggingCallbackHandler(),
messages=messages messages=messages,
preamble_text=preamble_text,
) )
config.safe_print()
return config
def invoke_config(self): def invoke_config(self):
"""返回Langchain需要的配置字典""" """返回Langchain需要的配置字典"""

126
agent/checkpoint_utils.py Normal file
View File

@ -0,0 +1,126 @@
"""用于处理 LangGraph checkpoint 相关的工具函数"""
import logging
from typing import List, Dict, Any, Optional
from langgraph.checkpoint.memory import MemorySaver
logger = logging.getLogger('app')
async def check_checkpoint_history(checkpointer: MemorySaver, thread_id: str) -> bool:
"""
检查指定的 thread_id checkpointer 中是否已有历史记录
Args:
checkpointer: MemorySaver 实例
thread_id: 线程ID通常是 session_id
Returns:
bool: True 表示有历史记录False 表示没有
"""
if not checkpointer or not thread_id:
logger.debug(f"No checkpointer or thread_id: checkpointer={bool(checkpointer)}, thread_id={thread_id}")
return False
try:
# 获取配置
config = {"configurable": {"thread_id": thread_id}}
# 调试信息:检查 checkpointer 类型
logger.debug(f"Checkpointer type: {type(checkpointer)}")
logger.debug(f"Checkpointer dir: {[attr for attr in dir(checkpointer) if not attr.startswith('_')]}")
# 先尝试获取最新的 checkpoint
try:
latest_checkpoint = await checkpointer.aget_tuple(config)
logger.debug(f"aget_tuple result: {latest_checkpoint}")
if latest_checkpoint is not None:
logger.info(f"Found latest checkpoint for thread_id: {thread_id}")
# 解构 checkpoint tuple
checkpoint_config, checkpoint, metadata = latest_checkpoint
logger.debug(f"Checkpoint metadata: {metadata}")
return True
except Exception as e:
logger.warning(f"aget_tuple failed: {e}")
# 如果没有最新的,再列出所有
logger.debug(f"No latest checkpoint for thread_id: {thread_id}, checking all checkpoints...")
try:
checkpoints = []
async for c in checkpointer.alist(config):
checkpoints.append(c)
logger.debug(f"Found checkpoint: {c}")
# 如果有至少一个 checkpoint说明有历史记录
has_history = len(checkpoints) > 0
if has_history:
logger.info(f"Found {len(checkpoints)} checkpoints in total for thread_id: {thread_id}")
else:
logger.info(f"No existing history for thread_id: {thread_id}")
return has_history
except Exception as e:
logger.warning(f"alist failed: {e}")
return False
except Exception as e:
import traceback
logger.error(f"Error checking checkpoint history for thread_id {thread_id}: {e}")
logger.error(f"Full traceback: {traceback.format_exc()}")
# 出错时保守处理,返回 False
return False
def prepare_messages_for_agent(
messages: List[Dict[str, Any]],
has_history: bool
) -> List[Dict[str, Any]]:
"""
根据是否有历史记录来准备要发送给 agent 的消息
Args:
messages: 完整的消息列表
has_history: 是否已有历史记录
Returns:
List[Dict]: 要发送给 agent 的消息列表
"""
if not messages:
return []
# 如果有历史记录,只发送最后一条用户消息
if has_history:
# 找到最后一条用户消息
for msg in reversed(messages):
if msg.get('role') == 'user':
logger.info(f"Has history, sending only last user message: {msg.get('content', '')[:50]}...")
return [msg]
# 如果没有用户消息(理论上不应该发生),返回空列表
logger.warning("No user message found in messages")
return messages
# 如果没有历史记录,发送所有消息
logger.info(f"No history, sending all {len(messages)} messages")
return messages
def update_agent_config_for_checkpoint(
config_messages: List[Dict[str, Any]],
has_history: bool
) -> List[Dict[str, Any]]:
"""
更新 AgentConfig 中的 messages根据是否有历史记录决定发送哪些消息
这个函数可以在调用 agent 之前使用避免重复处理消息历史
Args:
config_messages: AgentConfig 中的原始消息列表
has_history: 是否已有历史记录
Returns:
List[Dict]: 更新后的消息列表
"""
return prepare_messages_for_agent(config_messages, has_history)

View File

@ -15,7 +15,7 @@ from .guideline_middleware import GuidelineMiddleware
from .tool_output_length_middleware import ToolOutputLengthMiddleware from .tool_output_length_middleware import ToolOutputLengthMiddleware
from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH
from agent.agent_config import AgentConfig from agent.agent_config import AgentConfig
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
# Utility functions # Utility functions
def read_system_prompt(): def read_system_prompt():
@ -57,7 +57,6 @@ async def get_tools_from_mcp(mcp):
# 发生异常时返回空列表,避免上层调用报错 # 发生异常时返回空列表,避免上层调用报错
return [] return []
async def init_agent(config: AgentConfig): async def init_agent(config: AgentConfig):
""" """
初始化 Agent支持持久化内存和对话摘要 初始化 Agent支持持久化内存和对话摘要
@ -66,9 +65,20 @@ async def init_agent(config: AgentConfig):
config: AgentConfig 对象包含所有初始化参数 config: AgentConfig 对象包含所有初始化参数
mcp: MCP配置如果为None则使用配置中的mcp_settings mcp: MCP配置如果为None则使用配置中的mcp_settings
""" """
final_system_prompt = await load_system_prompt_async(
config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier
)
final_mcp_settings = await load_mcp_settings_async(
config.project_dir, config.mcp_settings, config.bot_id, config.robot_type
)
# 如果没有提供mcp使用config中的mcp_settings # 如果没有提供mcp使用config中的mcp_settings
mcp_settings = config.mcp_settings if config.mcp_settings else read_mcp_settings() mcp_settings = final_mcp_settings if final_mcp_settings else read_mcp_settings()
system_prompt = config.system_prompt if config.system_prompt else read_system_prompt() system_prompt = final_system_prompt if final_system_prompt else read_system_prompt()
config.system_prompt = mcp_settings
config.mcp_settings = system_prompt
mcp_tools = await get_tools_from_mcp(mcp_settings) mcp_tools = await get_tools_from_mcp(mcp_settings)
# 检测或使用指定的提供商 # 检测或使用指定的提供商
@ -124,4 +134,8 @@ async def init_agent(config: AgentConfig):
checkpointer=checkpointer # 传入 checkpointer 以启用持久化 checkpointer=checkpointer # 传入 checkpointer 以启用持久化
) )
# 将 checkpointer 作为属性附加到 agent 上,方便访问
if checkpointer:
agent._checkpointer = checkpointer
return agent return agent

View File

@ -1,323 +0,0 @@
# Copyright 2023
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""分片助手管理器 - 减少锁竞争的高并发agent缓存系统"""
import hashlib
import time
import json
import asyncio
from typing import Dict, List, Optional
import threading
import logging
logger = logging.getLogger('app')
from agent.deep_assistant import init_agent
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
from agent.agent_config import AgentConfig
class ShardedAgentManager:
"""分片助手管理器
使用分片技术减少锁竞争支持高并发访问
"""
def __init__(self, max_cached_agents: int = 20, shard_count: int = 16):
self.max_cached_agents = max_cached_agents
self.shard_count = shard_count
# 创建分片
self.shards = []
for i in range(shard_count):
shard = {
'agents': {}, # {cache_key: assistant_instance}
'unique_ids': {}, # {cache_key: unique_id}
'access_times': {}, # LRU 访问时间管理
'creation_times': {}, # 创建时间记录
'lock': asyncio.Lock(), # 每个分片独立锁
'creation_locks': {}, # 防止并发创建相同agent的锁
}
self.shards.append(shard)
# 用于统计的全局锁(读写分离)
self._stats_lock = threading.RLock()
self._global_stats = {
'total_requests': 0,
'cache_hits': 0,
'cache_misses': 0,
'agent_creations': 0
}
def _get_shard_index(self, cache_key: str) -> int:
"""根据缓存键获取分片索引"""
hash_value = int(hashlib.md5(cache_key.encode('utf-8')).hexdigest(), 16)
return hash_value % self.shard_count
def _get_cache_key(self, bot_id: str, model_name: str = None, api_key: str = None,
model_server: str = None, generate_cfg: Dict = None,
system_prompt: str = None, mcp_settings: List[Dict] = None,
enable_thinking: bool = False) -> str:
"""获取包含所有相关参数的哈希值作为缓存键"""
cache_data = {
'bot_id': bot_id,
'model_name': model_name or '',
'api_key': api_key or '',
'model_server': model_server or '',
'generate_cfg': json.dumps(generate_cfg or {}, sort_keys=True),
'system_prompt': system_prompt or '',
'mcp_settings': json.dumps(mcp_settings or [], sort_keys=True),
'enable_thinking': enable_thinking
}
cache_str = json.dumps(cache_data, sort_keys=True)
return hashlib.md5(cache_str.encode('utf-8')).hexdigest()[:16]
def _update_access_time(self, shard: dict, cache_key: str):
"""更新访问时间LRU 管理)"""
shard['access_times'][cache_key] = time.time()
def _cleanup_old_agents(self, shard: dict):
"""清理分片中的旧助手实例,基于 LRU 策略"""
# 计算每个分片的最大容量
shard_max_capacity = max(1, self.max_cached_agents // self.shard_count)
if len(shard['agents']) <= shard_max_capacity:
return
# 按 LRU 顺序排序,删除最久未访问的实例
sorted_keys = sorted(shard['access_times'].keys(),
key=lambda k: shard['access_times'][k])
keys_to_remove = sorted_keys[:-shard_max_capacity]
removed_count = 0
for cache_key in keys_to_remove:
try:
del shard['agents'][cache_key]
del shard['unique_ids'][cache_key]
del shard['access_times'][cache_key]
del shard['creation_times'][cache_key]
shard['creation_locks'].pop(cache_key, None)
removed_count += 1
logger.info(f"分片清理过期的助手实例缓存: {cache_key}")
except KeyError:
continue
if removed_count > 0:
logger.info(f"分片已清理 {removed_count} 个过期的助手实例缓存")
async def get_or_create_agent(self, config: AgentConfig):
"""获取或创建文件预加载的助手实例
Args:
config: AgentConfig 对象包含所有初始化参数
"""
# 更新请求统计
with self._stats_lock:
self._global_stats['total_requests'] += 1
# 异步加载配置文件(带缓存)
final_system_prompt = await load_system_prompt_async(
config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier
)
final_mcp_settings = await load_mcp_settings_async(
config.project_dir, config.mcp_settings, config.bot_id, config.robot_type
)
config.system_prompt = final_system_prompt
config.mcp_settings = final_mcp_settings
cache_key = self._get_cache_key(config.bot_id, config.model_name, config.api_key, config.model_server,
config.generate_cfg, final_system_prompt, final_mcp_settings,
config.enable_thinking)
# 获取分片
shard_index = self._get_shard_index(cache_key)
shard = self.shards[shard_index]
# 使用分片级异步锁防止并发创建相同的agent
async with shard['lock']:
# 检查是否已存在该助手实例
if cache_key in shard['agents']:
self._update_access_time(shard, cache_key)
agent = shard['agents'][cache_key]
# 更新缓存命中统计
with self._stats_lock:
self._global_stats['cache_hits'] += 1
logger.info(f"分片复用现有的助手实例缓存: {cache_key} (bot_id: {config.bot_id}, shard: {shard_index})")
return agent
# 更新缓存未命中统计
with self._stats_lock:
self._global_stats['cache_misses'] += 1
# 使用更细粒度的创建锁
creation_lock = shard['creation_locks'].setdefault(cache_key, asyncio.Lock())
# 在分片锁外创建agent减少锁持有时间
async with creation_lock:
# 再次检查是否已存在(获取锁后可能有其他请求已创建)
async with shard['lock']:
if cache_key in shard['agents']:
self._update_access_time(shard, cache_key)
agent = shard['agents'][cache_key]
with self._stats_lock:
self._global_stats['cache_hits'] += 1
return agent
# 清理过期实例
async with shard['lock']:
self._cleanup_old_agents(shard)
# 创建新的助手实例
logger.info(f"分片创建新的助手实例缓存: {cache_key}, bot_id: {config.bot_id}, shard: {shard_index}")
current_time = time.time()
agent = await init_agent(config)
# 缓存实例
async with shard['lock']:
shard['agents'][cache_key] = agent
shard['unique_ids'][cache_key] = config.bot_id
shard['access_times'][cache_key] = current_time
shard['creation_times'][cache_key] = current_time
# 清理创建锁
shard['creation_locks'].pop(cache_key, None)
# 更新创建统计
with self._stats_lock:
self._global_stats['agent_creations'] += 1
logger.info(f"分片助手实例缓存创建完成: {cache_key}, shard: {shard_index}")
return agent
def get_cache_stats(self) -> Dict:
"""获取缓存统计信息"""
current_time = time.time()
total_agents = 0
agents_info = []
for i, shard in enumerate(self.shards):
for cache_key, agent in shard['agents'].items():
total_agents += 1
agents_info.append({
"cache_key": cache_key,
"unique_id": shard['unique_ids'].get(cache_key, "unknown"),
"shard": i,
"created_at": shard['creation_times'].get(cache_key, 0),
"last_accessed": shard['access_times'].get(cache_key, 0),
"age_seconds": int(current_time - shard['creation_times'].get(cache_key, current_time)),
"idle_seconds": int(current_time - shard['access_times'].get(cache_key, current_time))
})
stats = {
"total_cached_agents": total_agents,
"max_cached_agents": self.max_cached_agents,
"shard_count": self.shard_count,
"agents": agents_info
}
# 添加全局统计
with self._stats_lock:
stats.update({
"total_requests": self._global_stats['total_requests'],
"cache_hits": self._global_stats['cache_hits'],
"cache_misses": self._global_stats['cache_misses'],
"cache_hit_rate": (
self._global_stats['cache_hits'] / max(1, self._global_stats['total_requests']) * 100
),
"agent_creations": self._global_stats['agent_creations']
})
return stats
def clear_cache(self) -> int:
"""清空所有缓存"""
cache_count = 0
for shard in self.shards:
cache_count += len(shard['agents'])
shard['agents'].clear()
shard['unique_ids'].clear()
shard['access_times'].clear()
shard['creation_times'].clear()
shard['creation_locks'].clear()
# 重置统计
with self._stats_lock:
self._global_stats = {
'total_requests': 0,
'cache_hits': 0,
'cache_misses': 0,
'agent_creations': 0
}
logger.info(f"分片管理器已清空所有助手实例缓存,共清理 {cache_count} 个实例")
return cache_count
def remove_cache_by_unique_id(self, unique_id: str) -> int:
"""根据 unique_id 移除所有相关的缓存"""
removed_count = 0
for i, shard in enumerate(self.shards):
keys_to_remove = []
# 找到所有匹配的 unique_id 的缓存键
for cache_key, stored_unique_id in shard['unique_ids'].items():
if stored_unique_id == unique_id:
keys_to_remove.append(cache_key)
# 移除找到的缓存
for cache_key in keys_to_remove:
try:
del shard['agents'][cache_key]
del shard['unique_ids'][cache_key]
del shard['access_times'][cache_key]
del shard['creation_times'][cache_key]
shard['creation_locks'].pop(cache_key, None)
removed_count += 1
logger.info(f"分片 {i} 已移除助手实例缓存: {cache_key} (unique_id: {unique_id})")
except KeyError:
continue
if removed_count > 0:
logger.info(f"分片管理器已移除 unique_id={unique_id}{removed_count} 个助手实例缓存")
else:
logger.warning(f"分片管理器未找到 unique_id={unique_id} 的缓存实例")
return removed_count
# 全局分片助手管理器实例
_global_sharded_agent_manager: Optional[ShardedAgentManager] = None
def get_global_sharded_agent_manager() -> ShardedAgentManager:
"""获取全局分片助手管理器实例"""
global _global_sharded_agent_manager
if _global_sharded_agent_manager is None:
_global_sharded_agent_manager = ShardedAgentManager()
return _global_sharded_agent_manager
def init_global_sharded_agent_manager(max_cached_agents: int = 20, shard_count: int = 16):
"""初始化全局分片助手管理器"""
global _global_sharded_agent_manager
_global_sharded_agent_manager = ShardedAgentManager(max_cached_agents, shard_count)
return _global_sharded_agent_manager

View File

@ -10,13 +10,12 @@ services:
- "8001:8001" - "8001:8001"
environment: environment:
# 应用配置 # 应用配置
- PYTHONPATH=/app - BACKEND_HOST=http://api-dev.gbase.ai
- PYTHONUNBUFFERED=1 - MAX_CONTEXT_TOKENS=262144
- AGENT_POOL_SIZE=2 - DEFAULT_THINKING_ENABLE=true
volumes: volumes:
# 挂载项目数据目录 # 挂载项目数据目录
- ./projects:/app/projects - ./projects:/app/projects
- ./public:/app/public
restart: unless-stopped restart: unless-stopped
healthcheck: healthcheck:

View File

@ -19,9 +19,6 @@ from utils.log_util.logger import init_with_fastapi
# Import route modules # Import route modules
from routes import chat, files, projects, system from routes import chat, files, projects, system
# Import the system manager from routes.system to access the initialized components
from routes.system import agent_manager, connection_pool, file_cache
app = FastAPI(title="Database Assistant API", version="1.0.0") app = FastAPI(title="Database Assistant API", version="1.0.0")
init_with_fastapi(app) init_with_fastapi(app)

View File

@ -4,32 +4,26 @@ import asyncio
from typing import Union, Optional from typing import Union, Optional
from fastapi import APIRouter, HTTPException, Header from fastapi import APIRouter, HTTPException, Header
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import logging import logging
logger = logging.getLogger('app') logger = logging.getLogger('app')
from utils import ( from utils import (
Message, ChatRequest, ChatResponse Message, ChatRequest, ChatResponse
) )
from agent.sharded_agent_manager import init_global_sharded_agent_manager
from utils.api_models import ChatRequestV2 from utils.api_models import ChatRequestV2
from utils.fastapi_utils import ( from utils.fastapi_utils import (
process_messages, process_messages,
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config, create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
call_preamble_llm, get_preamble_text, call_preamble_llm,
create_stream_chunk create_stream_chunk
) )
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT
from agent.agent_config import AgentConfig from agent.agent_config import AgentConfig
from agent.deep_assistant import init_agent
router = APIRouter() router = APIRouter()
# 初始化全局助手管理器
agent_manager = init_global_sharded_agent_manager(
max_cached_agents=MAX_CACHED_AGENTS,
shard_count=SHARD_COUNT
)
def append_user_last_message(messages: list, content: str) -> bool: def append_user_last_message(messages: list, content: str) -> bool:
@ -97,13 +91,13 @@ def format_messages_to_chat_history(messages: list) -> str:
async def enhanced_generate_stream_response( async def enhanced_generate_stream_response(
agent_manager, agent,
config: AgentConfig config: AgentConfig
): ):
"""增强的渐进式流式响应生成器 - 并发优化版本 """增强的渐进式流式响应生成器 - 并发优化版本
Args: Args:
agent_manager: agent管理器 agent: LangChain agent 对象
config: AgentConfig 对象包含所有参数 config: AgentConfig 对象包含所有参数
""" """
try: try:
@ -137,9 +131,6 @@ async def enhanced_generate_stream_response(
# Agent 任务(准备 + 流式处理) # Agent 任务(准备 + 流式处理)
async def agent_task(): async def agent_task():
try: try:
# 准备 agent
agent = await agent_manager.get_or_create_agent(config)
# 开始流式处理 # 开始流式处理
logger.info(f"Starting agent stream response") logger.info(f"Starting agent stream response")
chunk_id = 0 chunk_id = 0
@ -270,25 +261,43 @@ async def create_agent_and_generate_response(
Args: Args:
config: AgentConfig 对象包含所有参数 config: AgentConfig 对象包含所有参数
""" """
config.safe_print() # 获取或创建 agent需要先创建 agent 才能访问 checkpointer
config.preamble_text, config.system_prompt = get_preamble_text(config.language, config.system_prompt) agent = await init_agent(config)
# 如果有 checkpointer检查是否有历史记录
if config.session_id:
# 检查 checkpointer
checkpointer = None
if hasattr(agent, '_checkpointer'):
checkpointer = agent._checkpointer
if checkpointer:
from agent.checkpoint_utils import check_checkpoint_history, prepare_messages_for_agent
has_history = await check_checkpoint_history(checkpointer, config.session_id)
# 更新 config.messages
config.messages = prepare_messages_for_agent(config.messages, has_history)
logger.info(f"Session {config.session_id}: has_history={has_history}, sending {len(config.messages)} messages")
else:
logger.warning(f"No checkpointer found for session {config.session_id}")
else:
logger.debug(f"No session_id provided, skipping checkpoint check")
# 如果是流式模式,使用增强的流式响应生成器 # 如果是流式模式,使用增强的流式响应生成器
if config.stream: if config.stream:
return StreamingResponse( return StreamingResponse(
enhanced_generate_stream_response( enhanced_generate_stream_response(
agent_manager=agent_manager, agent=agent,
config=config config=config
), ),
media_type="text/event-stream", media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
) )
messages = config.messages # 使用更新后的 messages
# 使用公共函数处理所有逻辑 agent_responses = await agent.ainvoke({"messages": config.messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS)
agent = await agent_manager.get_or_create_agent(config) append_messages = agent_responses["messages"][len(config.messages):]
agent_responses = await agent.ainvoke({"messages": messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS)
append_messages = agent_responses["messages"][len(messages):]
response_text = "" response_text = ""
for msg in append_messages: for msg in append_messages:
if isinstance(msg,AIMessage): if isinstance(msg,AIMessage):
@ -313,9 +322,9 @@ async def create_agent_and_generate_response(
"finish_reason": "stop" "finish_reason": "stop"
}], }],
usage={ usage={
"prompt_tokens": sum(len(msg.get("content", "")) for msg in messages), "prompt_tokens": sum(len(msg.get("content", "")) for msg in config.messages),
"completion_tokens": len(response_text), "completion_tokens": len(response_text),
"total_tokens": sum(len(msg.get("content", "")) for msg in messages) + len(response_text) "total_tokens": sum(len(msg.get("content", "")) for msg in config.messages) + len(response_text)
} }
) )
else: else:
@ -375,9 +384,7 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
# 创建 AgentConfig 对象 # 创建 AgentConfig 对象
config = AgentConfig.from_v1_request(request, api_key, project_dir, generate_cfg, messages) config = AgentConfig.from_v1_request(request, api_key, project_dir, generate_cfg, messages)
# 调用公共的agent创建和响应生成逻辑 # 调用公共的agent创建和响应生成逻辑
return await create_agent_and_generate_response( return await create_agent_and_generate_response(config)
config=config
)
except Exception as e: except Exception as e:
import traceback import traceback
@ -451,9 +458,7 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st
# 创建 AgentConfig 对象 # 创建 AgentConfig 对象
config = AgentConfig.from_v2_request(request, bot_config, project_dir, messages) config = AgentConfig.from_v2_request(request, bot_config, project_dir, messages)
# 调用公共的agent创建和响应生成逻辑 # 调用公共的agent创建和响应生成逻辑
return await create_agent_and_generate_response( return await create_agent_and_generate_response(config)
config=config
)
except HTTPException: except HTTPException:
raise raise

View File

@ -6,25 +6,12 @@ from fastapi import APIRouter, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from utils import ( from utils import (
get_global_connection_pool, init_global_connection_pool,
get_global_file_cache, init_global_file_cache,
setup_system_optimizations setup_system_optimizations
) )
from agent.sharded_agent_manager import init_global_sharded_agent_manager
try:
from utils.system_optimizer import apply_optimization_profile
except ImportError:
def apply_optimization_profile(profile):
return {"profile": profile, "status": "system_optimizer not available"}
from embedding import get_model_manager from embedding import get_model_manager
from pydantic import BaseModel from pydantic import BaseModel
import logging import logging
from utils.settings import (
MAX_CACHED_AGENTS, SHARD_COUNT, MAX_CONNECTIONS_PER_HOST, MAX_CONNECTIONS_TOTAL,
KEEPALIVE_TIMEOUT, CONNECT_TIMEOUT, TOTAL_TIMEOUT, FILE_CACHE_SIZE, FILE_CACHE_TTL,
TOKENIZERS_PARALLELISM
)
logger = logging.getLogger('app') logger = logging.getLogger('app')
@ -49,35 +36,8 @@ class EncodeResponse(BaseModel):
logger.info("正在初始化系统优化...") logger.info("正在初始化系统优化...")
system_optimizer = setup_system_optimizations() system_optimizer = setup_system_optimizations()
# 全局助手管理器配置(使用优化后的配置)
max_cached_agents = MAX_CACHED_AGENTS # 增加缓存大小
shard_count = SHARD_COUNT # 分片数量
# 初始化优化的全局助手管理器
agent_manager = init_global_sharded_agent_manager(
max_cached_agents=max_cached_agents,
shard_count=shard_count
)
# 初始化连接池
connection_pool = init_global_connection_pool(
max_connections_per_host=MAX_CONNECTIONS_PER_HOST,
max_connections_total=MAX_CONNECTIONS_TOTAL,
keepalive_timeout=KEEPALIVE_TIMEOUT,
connect_timeout=CONNECT_TIMEOUT,
total_timeout=TOTAL_TIMEOUT
)
# 初始化文件缓存
file_cache = init_global_file_cache(
cache_size=FILE_CACHE_SIZE,
ttl=FILE_CACHE_TTL
)
logger.info("系统优化初始化完成") logger.info("系统优化初始化完成")
logger.info(f"- 分片Agent管理器: {shard_count} 个分片,最多缓存 {max_cached_agents} 个agent")
logger.info(f"- 连接池: 每主机100连接总计500连接")
logger.info(f"- 文件缓存: 1000个文件TTL 300秒")
@router.get("/api/health") @router.get("/api/health")
@ -90,9 +50,6 @@ async def health_check():
async def get_performance_stats(): async def get_performance_stats():
"""获取系统性能统计信息""" """获取系统性能统计信息"""
try: try:
# 获取agent管理器统计
agent_stats = agent_manager.get_cache_stats()
# 获取连接池统计(简化版) # 获取连接池统计(简化版)
pool_stats = { pool_stats = {
"connection_pool": "active", "connection_pool": "active",
@ -128,7 +85,6 @@ async def get_performance_stats():
"success": True, "success": True,
"timestamp": int(time.time()), "timestamp": int(time.time()),
"performance": { "performance": {
"agent_manager": agent_stats,
"connection_pool": pool_stats, "connection_pool": pool_stats,
"file_cache": file_cache_stats, "file_cache": file_cache_stats,
"system": system_stats "system": system_stats
@ -140,87 +96,6 @@ async def get_performance_stats():
raise HTTPException(status_code=500, detail=f"获取性能统计失败: {str(e)}") raise HTTPException(status_code=500, detail=f"获取性能统计失败: {str(e)}")
@router.post("/api/v1/system/optimize")
async def optimize_system(profile: str = "balanced"):
"""应用系统优化配置"""
try:
# 应用优化配置
config = apply_optimization_profile(profile)
return {
"success": True,
"message": f"已应用 {profile} 优化配置",
"config": config
}
except Exception as e:
logger.error(f"Error applying optimization profile: {str(e)}")
raise HTTPException(status_code=500, detail=f"应用优化配置失败: {str(e)}")
@router.post("/api/v1/system/clear-cache")
async def clear_system_cache(cache_type: Optional[str] = None):
"""清理系统缓存"""
try:
cleared_counts = {}
if cache_type is None or cache_type == "agent":
# 清理agent缓存
agent_count = agent_manager.clear_cache()
cleared_counts["agent_cache"] = agent_count
if cache_type is None or cache_type == "file":
# 清理文件缓存
if hasattr(file_cache, '_cache'):
file_count = len(file_cache._cache)
file_cache._cache.clear()
cleared_counts["file_cache"] = file_count
return {
"success": True,
"message": f"已清理指定类型的缓存",
"cleared_counts": cleared_counts
}
except Exception as e:
logger.error(f"Error clearing cache: {str(e)}")
raise HTTPException(status_code=500, detail=f"清理缓存失败: {str(e)}")
@router.get("/api/v1/system/config")
async def get_system_config():
"""获取当前系统配置"""
try:
return {
"success": True,
"config": {
"max_cached_agents": max_cached_agents,
"shard_count": shard_count,
"tokenizer_parallelism": TOKENIZERS_PARALLELISM,
"max_connections_per_host": str(MAX_CONNECTIONS_PER_HOST),
"max_connections_total": str(MAX_CONNECTIONS_TOTAL),
"file_cache_size": str(FILE_CACHE_SIZE),
"file_cache_ttl": str(FILE_CACHE_TTL)
}
}
except Exception as e:
logger.error(f"Error getting system config: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取系统配置失败: {str(e)}")
@router.post("/system/remove-project-cache")
async def remove_project_cache(dataset_id: str):
"""移除特定项目的缓存"""
try:
removed_count = agent_manager.remove_cache_by_unique_id(dataset_id)
if removed_count > 0:
return {"message": f"项目缓存移除成功: {dataset_id}", "removed_count": removed_count}
else:
return {"message": f"未找到项目缓存: {dataset_id}", "removed_count": 0}
except Exception as e:
raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}")
@router.post("/api/v1/embedding/encode", response_model=EncodeResponse) @router.post("/api/v1/embedding/encode", response_model=EncodeResponse)
async def encode_texts(request: EncodeRequest): async def encode_texts(request: EncodeRequest):

View File

@ -31,47 +31,10 @@ from .project_manager import (
) )
from .connection_pool import (
HTTPConnectionPool,
get_global_connection_pool,
init_global_connection_pool,
OAIWithConnectionPool
)
from .async_file_ops import (
AsyncFileCache,
get_global_file_cache,
init_global_file_cache,
async_read_file,
async_read_json,
async_write_file,
async_write_json,
async_file_exists,
async_get_file_mtime,
ParallelFileReader,
get_global_parallel_reader
)
from .system_optimizer import ( from .system_optimizer import (
SystemOptimizer, setup_system_optimizations
AsyncioOptimizer,
setup_system_optimizations,
create_performance_monitor,
get_optimized_worker_config,
OPTIMIZATION_CONFIGS,
apply_optimization_profile,
get_global_system_optimizer
) )
# Import config cache module
# Note: This has been moved to agent package
# from .config_cache import (
# config_cache,
# ConfigFileCache
# )
from .agent_pool import ( from .agent_pool import (
AgentPool, AgentPool,
get_agent_pool, get_agent_pool,

View File

@ -1,212 +0,0 @@
#!/usr/bin/env python3
"""
HTTP连接池管理器 - 提供高效的HTTP连接复用
"""
import aiohttp
import asyncio
from typing import Dict, Optional, Any
import threading
import time
import weakref
class HTTPConnectionPool:
"""HTTP连接池管理器
提供连接复用 Keep-Alive 连接合理超时设置等功能
"""
def __init__(self,
max_connections_per_host: int = 100,
max_connections_total: int = 500,
keepalive_timeout: int = 30,
connect_timeout: int = 10,
total_timeout: int = 60):
"""
初始化连接池
Args:
max_connections_per_host: 每个主机的最大连接数
max_connections_total: 总连接数限制
keepalive_timeout: Keep-Alive超时时间()
connect_timeout: 连接超时时间()
total_timeout: 总请求超时时间()
"""
self.max_connections_per_host = max_connections_per_host
self.max_connections_total = max_connections_total
self.keepalive_timeout = keepalive_timeout
self.connect_timeout = connect_timeout
self.total_timeout = total_timeout
# 创建连接器配置
self.connector_config = {
'limit': max_connections_total,
'limit_per_host': max_connections_per_host,
'keepalive_timeout': keepalive_timeout,
'enable_cleanup_closed': True, # 自动清理关闭的连接
'force_close': False, # 不强制关闭连接
'use_dns_cache': True, # 使用DNS缓存
'ttl_dns_cache': 300, # DNS缓存TTL
}
# 使用线程本地存储来管理事件循环间的session
self._sessions = weakref.WeakKeyDictionary()
self._lock = threading.RLock()
def _create_session(self) -> aiohttp.ClientSession:
"""创建新的aiohttp会话"""
timeout = aiohttp.ClientTimeout(
total=self.total_timeout,
connect=self.connect_timeout,
sock_connect=self.connect_timeout,
sock_read=self.total_timeout
)
connector = aiohttp.TCPConnector(**self.connector_config)
return aiohttp.ClientSession(
connector=connector,
timeout=timeout,
headers={
'User-Agent': 'QwenAgent/1.0',
'Connection': 'keep-alive',
'Accept-Encoding': 'gzip, deflate, br',
}
)
def get_session(self) -> aiohttp.ClientSession:
"""获取当前事件循环的session"""
loop = asyncio.get_running_loop()
with self._lock:
if loop not in self._sessions:
self._sessions[loop] = self._create_session()
return self._sessions[loop]
async def request(self, method: str, url: str, **kwargs) -> aiohttp.ClientResponse:
"""发送HTTP请求自动处理连接复用"""
session = self.get_session()
return await session.request(method, url, **kwargs)
async def get(self, url: str, **kwargs) -> aiohttp.ClientResponse:
"""发送GET请求"""
return await self.request('GET', url, **kwargs)
async def post(self, url: str, **kwargs) -> aiohttp.ClientResponse:
"""发送POST请求"""
return await self.request('POST', url, **kwargs)
async def close(self):
"""关闭所有session"""
with self._lock:
for loop, session in list(self._sessions.items()):
if not session.closed:
await session.close()
self._sessions.clear()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
# 全局连接池实例
_global_connection_pool: Optional[HTTPConnectionPool] = None
_pool_lock = threading.Lock()
def get_global_connection_pool() -> HTTPConnectionPool:
"""获取全局连接池实例"""
global _global_connection_pool
if _global_connection_pool is None:
with _pool_lock:
if _global_connection_pool is None:
_global_connection_pool = HTTPConnectionPool()
return _global_connection_pool
def init_global_connection_pool(
max_connections_per_host: int = 100,
max_connections_total: int = 500,
keepalive_timeout: int = 30,
connect_timeout: int = 10,
total_timeout: int = 60
) -> HTTPConnectionPool:
"""初始化全局连接池"""
global _global_connection_pool
with _pool_lock:
_global_connection_pool = HTTPConnectionPool(
max_connections_per_host=max_connections_per_host,
max_connections_total=max_connections_total,
keepalive_timeout=keepalive_timeout,
connect_timeout=connect_timeout,
total_timeout=total_timeout
)
return _global_connection_pool
class OAIWithConnectionPool:
"""带有连接池的OpenAI API客户端"""
def __init__(self,
config: Dict[str, Any],
connection_pool: Optional[HTTPConnectionPool] = None):
"""
初始化客户端
Args:
config: OpenAI API配置
connection_pool: 可选的连接池实例
"""
self.config = config
self.pool = connection_pool or get_global_connection_pool()
self.base_url = config.get('model_server', '').rstrip('/')
if not self.base_url:
self.base_url = "https://api.openai.com/v1"
self.api_key = config.get('api_key', '')
self.model = config.get('model', 'gpt-3.5-turbo')
self.generate_cfg = config.get('generate_cfg', {})
async def chat_completions(self, messages: list, stream: bool = False, **kwargs):
"""发送聊天完成请求"""
url = f"{self.base_url}/chat/completions"
headers = {
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json',
}
data = {
'model': self.model,
'messages': messages,
'stream': stream,
**self.generate_cfg,
**kwargs
}
async with self.pool.post(url, json=data, headers=headers) as response:
if response.status == 200:
if stream:
return self._handle_stream_response(response)
else:
return await response.json()
else:
error_text = await response.text()
raise Exception(f"API request failed with status {response.status}: {error_text}")
async def _handle_stream_response(self, response: aiohttp.ClientResponse):
"""处理流式响应"""
async for line in response.content:
line = line.decode('utf-8').strip()
if line.startswith('data: '):
data = line[6:] # 移除 'data: ' 前缀
if data == '[DONE]':
break
try:
import json
yield json.loads(data)
except json.JSONDecodeError:
continue

View File

@ -249,9 +249,6 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
parts = re.split(r'\[(THINK|PREAMBLE|TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"]) parts = re.split(r'\[(THINK|PREAMBLE|TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"])
current_tag = None current_tag = None
assistant_content = ""
function_calls = []
tool_responses = []
tool_id_counter = 0 # 添加唯一的工具调用计数器 tool_id_counter = 0 # 添加唯一的工具调用计数器
tool_id_list = [] tool_id_list = []
for i in range(0, len(parts)): for i in range(0, len(parts)):