删除agent manager

This commit is contained in:
朱潮 2025-12-17 20:27:06 +08:00
parent 23bc62a2b8
commit b78b178c03
11 changed files with 204 additions and 754 deletions

View File

@ -72,17 +72,18 @@ class AgentConfig:
"""从v1请求创建配置"""
# 延迟导入避免循环依赖
from .logging_handler import LoggingCallbackHandler
from utils.fastapi_utils import get_preamble_text
if messages is None:
messages = []
return cls(
preamble_text, system_prompt = get_preamble_text(request.language, request.system_prompt)
config = cls(
bot_id=request.bot_id,
api_key=api_key,
model_name=request.model,
model_server=request.model_server,
language=request.language,
system_prompt=request.system_prompt,
system_prompt=system_prompt,
mcp_settings=request.mcp_settings,
robot_type=request.robot_type,
user_identifier=request.user_identifier,
@ -93,37 +94,45 @@ class AgentConfig:
tool_response=request.tool_response,
generate_cfg=generate_cfg,
logging_handler=LoggingCallbackHandler(),
messages=messages
messages=messages,
preamble_text=preamble_text,
)
config.safe_print()
return config
@classmethod
def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None):
"""从v2请求创建配置"""
# 延迟导入避免循环依赖
from .logging_handler import LoggingCallbackHandler
from utils.fastapi_utils import get_preamble_text
if messages is None:
messages = []
return cls(
language = request.language or bot_config.get("language", "zh")
preamble_text, system_prompt = get_preamble_text(language, bot_config.get("system_prompt"))
config = cls(
bot_id=request.bot_id,
api_key=bot_config.get("api_key"),
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
model_server=bot_config.get("model_server", ""),
language=request.language or bot_config.get("language", "zh"),
system_prompt=bot_config.get("system_prompt"),
language=language,
system_prompt=system_prompt,
mcp_settings=bot_config.get("mcp_settings", []),
robot_type=bot_config.get("robot_type", "general_agent"),
user_identifier=request.user_identifier,
session_id=request.session_id,
enable_thinking=request.enable_thinking,
project_dir=project_dir,
stream=request.stream,
stream=request.stream,
tool_response=request.tool_response,
generate_cfg={}, # v2接口不传递额外的generate_cfg
logging_handler=LoggingCallbackHandler(),
messages=messages
messages=messages,
preamble_text=preamble_text,
)
config.safe_print()
return config
def invoke_config(self):
"""返回Langchain需要的配置字典"""

126
agent/checkpoint_utils.py Normal file
View File

@ -0,0 +1,126 @@
"""用于处理 LangGraph checkpoint 相关的工具函数"""
import logging
from typing import List, Dict, Any, Optional
from langgraph.checkpoint.memory import MemorySaver
logger = logging.getLogger('app')
async def check_checkpoint_history(checkpointer: MemorySaver, thread_id: str) -> bool:
"""
检查指定的 thread_id checkpointer 中是否已有历史记录
Args:
checkpointer: MemorySaver 实例
thread_id: 线程ID通常是 session_id
Returns:
bool: True 表示有历史记录False 表示没有
"""
if not checkpointer or not thread_id:
logger.debug(f"No checkpointer or thread_id: checkpointer={bool(checkpointer)}, thread_id={thread_id}")
return False
try:
# 获取配置
config = {"configurable": {"thread_id": thread_id}}
# 调试信息:检查 checkpointer 类型
logger.debug(f"Checkpointer type: {type(checkpointer)}")
logger.debug(f"Checkpointer dir: {[attr for attr in dir(checkpointer) if not attr.startswith('_')]}")
# 先尝试获取最新的 checkpoint
try:
latest_checkpoint = await checkpointer.aget_tuple(config)
logger.debug(f"aget_tuple result: {latest_checkpoint}")
if latest_checkpoint is not None:
logger.info(f"Found latest checkpoint for thread_id: {thread_id}")
# 解构 checkpoint tuple
checkpoint_config, checkpoint, metadata = latest_checkpoint
logger.debug(f"Checkpoint metadata: {metadata}")
return True
except Exception as e:
logger.warning(f"aget_tuple failed: {e}")
# 如果没有最新的,再列出所有
logger.debug(f"No latest checkpoint for thread_id: {thread_id}, checking all checkpoints...")
try:
checkpoints = []
async for c in checkpointer.alist(config):
checkpoints.append(c)
logger.debug(f"Found checkpoint: {c}")
# 如果有至少一个 checkpoint说明有历史记录
has_history = len(checkpoints) > 0
if has_history:
logger.info(f"Found {len(checkpoints)} checkpoints in total for thread_id: {thread_id}")
else:
logger.info(f"No existing history for thread_id: {thread_id}")
return has_history
except Exception as e:
logger.warning(f"alist failed: {e}")
return False
except Exception as e:
import traceback
logger.error(f"Error checking checkpoint history for thread_id {thread_id}: {e}")
logger.error(f"Full traceback: {traceback.format_exc()}")
# 出错时保守处理,返回 False
return False
def prepare_messages_for_agent(
messages: List[Dict[str, Any]],
has_history: bool
) -> List[Dict[str, Any]]:
"""
根据是否有历史记录来准备要发送给 agent 的消息
Args:
messages: 完整的消息列表
has_history: 是否已有历史记录
Returns:
List[Dict]: 要发送给 agent 的消息列表
"""
if not messages:
return []
# 如果有历史记录,只发送最后一条用户消息
if has_history:
# 找到最后一条用户消息
for msg in reversed(messages):
if msg.get('role') == 'user':
logger.info(f"Has history, sending only last user message: {msg.get('content', '')[:50]}...")
return [msg]
# 如果没有用户消息(理论上不应该发生),返回空列表
logger.warning("No user message found in messages")
return messages
# 如果没有历史记录,发送所有消息
logger.info(f"No history, sending all {len(messages)} messages")
return messages
def update_agent_config_for_checkpoint(
config_messages: List[Dict[str, Any]],
has_history: bool
) -> List[Dict[str, Any]]:
"""
更新 AgentConfig 中的 messages根据是否有历史记录决定发送哪些消息
这个函数可以在调用 agent 之前使用避免重复处理消息历史
Args:
config_messages: AgentConfig 中的原始消息列表
has_history: 是否已有历史记录
Returns:
List[Dict]: 更新后的消息列表
"""
return prepare_messages_for_agent(config_messages, has_history)

View File

@ -15,7 +15,7 @@ from .guideline_middleware import GuidelineMiddleware
from .tool_output_length_middleware import ToolOutputLengthMiddleware
from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH
from agent.agent_config import AgentConfig
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
# Utility functions
def read_system_prompt():
@ -57,7 +57,6 @@ async def get_tools_from_mcp(mcp):
# 发生异常时返回空列表,避免上层调用报错
return []
async def init_agent(config: AgentConfig):
"""
初始化 Agent支持持久化内存和对话摘要
@ -66,9 +65,20 @@ async def init_agent(config: AgentConfig):
config: AgentConfig 对象包含所有初始化参数
mcp: MCP配置如果为None则使用配置中的mcp_settings
"""
final_system_prompt = await load_system_prompt_async(
config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier
)
final_mcp_settings = await load_mcp_settings_async(
config.project_dir, config.mcp_settings, config.bot_id, config.robot_type
)
# 如果没有提供mcp使用config中的mcp_settings
mcp_settings = config.mcp_settings if config.mcp_settings else read_mcp_settings()
system_prompt = config.system_prompt if config.system_prompt else read_system_prompt()
mcp_settings = final_mcp_settings if final_mcp_settings else read_mcp_settings()
system_prompt = final_system_prompt if final_system_prompt else read_system_prompt()
config.system_prompt = mcp_settings
config.mcp_settings = system_prompt
mcp_tools = await get_tools_from_mcp(mcp_settings)
# 检测或使用指定的提供商
@ -124,4 +134,8 @@ async def init_agent(config: AgentConfig):
checkpointer=checkpointer # 传入 checkpointer 以启用持久化
)
# 将 checkpointer 作为属性附加到 agent 上,方便访问
if checkpointer:
agent._checkpointer = checkpointer
return agent

View File

@ -1,323 +0,0 @@
# Copyright 2023
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""分片助手管理器 - 减少锁竞争的高并发agent缓存系统"""
import hashlib
import time
import json
import asyncio
from typing import Dict, List, Optional
import threading
import logging
logger = logging.getLogger('app')
from agent.deep_assistant import init_agent
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
from agent.agent_config import AgentConfig
class ShardedAgentManager:
"""分片助手管理器
使用分片技术减少锁竞争支持高并发访问
"""
def __init__(self, max_cached_agents: int = 20, shard_count: int = 16):
self.max_cached_agents = max_cached_agents
self.shard_count = shard_count
# 创建分片
self.shards = []
for i in range(shard_count):
shard = {
'agents': {}, # {cache_key: assistant_instance}
'unique_ids': {}, # {cache_key: unique_id}
'access_times': {}, # LRU 访问时间管理
'creation_times': {}, # 创建时间记录
'lock': asyncio.Lock(), # 每个分片独立锁
'creation_locks': {}, # 防止并发创建相同agent的锁
}
self.shards.append(shard)
# 用于统计的全局锁(读写分离)
self._stats_lock = threading.RLock()
self._global_stats = {
'total_requests': 0,
'cache_hits': 0,
'cache_misses': 0,
'agent_creations': 0
}
def _get_shard_index(self, cache_key: str) -> int:
"""根据缓存键获取分片索引"""
hash_value = int(hashlib.md5(cache_key.encode('utf-8')).hexdigest(), 16)
return hash_value % self.shard_count
def _get_cache_key(self, bot_id: str, model_name: str = None, api_key: str = None,
model_server: str = None, generate_cfg: Dict = None,
system_prompt: str = None, mcp_settings: List[Dict] = None,
enable_thinking: bool = False) -> str:
"""获取包含所有相关参数的哈希值作为缓存键"""
cache_data = {
'bot_id': bot_id,
'model_name': model_name or '',
'api_key': api_key or '',
'model_server': model_server or '',
'generate_cfg': json.dumps(generate_cfg or {}, sort_keys=True),
'system_prompt': system_prompt or '',
'mcp_settings': json.dumps(mcp_settings or [], sort_keys=True),
'enable_thinking': enable_thinking
}
cache_str = json.dumps(cache_data, sort_keys=True)
return hashlib.md5(cache_str.encode('utf-8')).hexdigest()[:16]
def _update_access_time(self, shard: dict, cache_key: str):
"""更新访问时间LRU 管理)"""
shard['access_times'][cache_key] = time.time()
def _cleanup_old_agents(self, shard: dict):
"""清理分片中的旧助手实例,基于 LRU 策略"""
# 计算每个分片的最大容量
shard_max_capacity = max(1, self.max_cached_agents // self.shard_count)
if len(shard['agents']) <= shard_max_capacity:
return
# 按 LRU 顺序排序,删除最久未访问的实例
sorted_keys = sorted(shard['access_times'].keys(),
key=lambda k: shard['access_times'][k])
keys_to_remove = sorted_keys[:-shard_max_capacity]
removed_count = 0
for cache_key in keys_to_remove:
try:
del shard['agents'][cache_key]
del shard['unique_ids'][cache_key]
del shard['access_times'][cache_key]
del shard['creation_times'][cache_key]
shard['creation_locks'].pop(cache_key, None)
removed_count += 1
logger.info(f"分片清理过期的助手实例缓存: {cache_key}")
except KeyError:
continue
if removed_count > 0:
logger.info(f"分片已清理 {removed_count} 个过期的助手实例缓存")
async def get_or_create_agent(self, config: AgentConfig):
"""获取或创建文件预加载的助手实例
Args:
config: AgentConfig 对象包含所有初始化参数
"""
# 更新请求统计
with self._stats_lock:
self._global_stats['total_requests'] += 1
# 异步加载配置文件(带缓存)
final_system_prompt = await load_system_prompt_async(
config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier
)
final_mcp_settings = await load_mcp_settings_async(
config.project_dir, config.mcp_settings, config.bot_id, config.robot_type
)
config.system_prompt = final_system_prompt
config.mcp_settings = final_mcp_settings
cache_key = self._get_cache_key(config.bot_id, config.model_name, config.api_key, config.model_server,
config.generate_cfg, final_system_prompt, final_mcp_settings,
config.enable_thinking)
# 获取分片
shard_index = self._get_shard_index(cache_key)
shard = self.shards[shard_index]
# 使用分片级异步锁防止并发创建相同的agent
async with shard['lock']:
# 检查是否已存在该助手实例
if cache_key in shard['agents']:
self._update_access_time(shard, cache_key)
agent = shard['agents'][cache_key]
# 更新缓存命中统计
with self._stats_lock:
self._global_stats['cache_hits'] += 1
logger.info(f"分片复用现有的助手实例缓存: {cache_key} (bot_id: {config.bot_id}, shard: {shard_index})")
return agent
# 更新缓存未命中统计
with self._stats_lock:
self._global_stats['cache_misses'] += 1
# 使用更细粒度的创建锁
creation_lock = shard['creation_locks'].setdefault(cache_key, asyncio.Lock())
# 在分片锁外创建agent减少锁持有时间
async with creation_lock:
# 再次检查是否已存在(获取锁后可能有其他请求已创建)
async with shard['lock']:
if cache_key in shard['agents']:
self._update_access_time(shard, cache_key)
agent = shard['agents'][cache_key]
with self._stats_lock:
self._global_stats['cache_hits'] += 1
return agent
# 清理过期实例
async with shard['lock']:
self._cleanup_old_agents(shard)
# 创建新的助手实例
logger.info(f"分片创建新的助手实例缓存: {cache_key}, bot_id: {config.bot_id}, shard: {shard_index}")
current_time = time.time()
agent = await init_agent(config)
# 缓存实例
async with shard['lock']:
shard['agents'][cache_key] = agent
shard['unique_ids'][cache_key] = config.bot_id
shard['access_times'][cache_key] = current_time
shard['creation_times'][cache_key] = current_time
# 清理创建锁
shard['creation_locks'].pop(cache_key, None)
# 更新创建统计
with self._stats_lock:
self._global_stats['agent_creations'] += 1
logger.info(f"分片助手实例缓存创建完成: {cache_key}, shard: {shard_index}")
return agent
def get_cache_stats(self) -> Dict:
"""获取缓存统计信息"""
current_time = time.time()
total_agents = 0
agents_info = []
for i, shard in enumerate(self.shards):
for cache_key, agent in shard['agents'].items():
total_agents += 1
agents_info.append({
"cache_key": cache_key,
"unique_id": shard['unique_ids'].get(cache_key, "unknown"),
"shard": i,
"created_at": shard['creation_times'].get(cache_key, 0),
"last_accessed": shard['access_times'].get(cache_key, 0),
"age_seconds": int(current_time - shard['creation_times'].get(cache_key, current_time)),
"idle_seconds": int(current_time - shard['access_times'].get(cache_key, current_time))
})
stats = {
"total_cached_agents": total_agents,
"max_cached_agents": self.max_cached_agents,
"shard_count": self.shard_count,
"agents": agents_info
}
# 添加全局统计
with self._stats_lock:
stats.update({
"total_requests": self._global_stats['total_requests'],
"cache_hits": self._global_stats['cache_hits'],
"cache_misses": self._global_stats['cache_misses'],
"cache_hit_rate": (
self._global_stats['cache_hits'] / max(1, self._global_stats['total_requests']) * 100
),
"agent_creations": self._global_stats['agent_creations']
})
return stats
def clear_cache(self) -> int:
"""清空所有缓存"""
cache_count = 0
for shard in self.shards:
cache_count += len(shard['agents'])
shard['agents'].clear()
shard['unique_ids'].clear()
shard['access_times'].clear()
shard['creation_times'].clear()
shard['creation_locks'].clear()
# 重置统计
with self._stats_lock:
self._global_stats = {
'total_requests': 0,
'cache_hits': 0,
'cache_misses': 0,
'agent_creations': 0
}
logger.info(f"分片管理器已清空所有助手实例缓存,共清理 {cache_count} 个实例")
return cache_count
def remove_cache_by_unique_id(self, unique_id: str) -> int:
"""根据 unique_id 移除所有相关的缓存"""
removed_count = 0
for i, shard in enumerate(self.shards):
keys_to_remove = []
# 找到所有匹配的 unique_id 的缓存键
for cache_key, stored_unique_id in shard['unique_ids'].items():
if stored_unique_id == unique_id:
keys_to_remove.append(cache_key)
# 移除找到的缓存
for cache_key in keys_to_remove:
try:
del shard['agents'][cache_key]
del shard['unique_ids'][cache_key]
del shard['access_times'][cache_key]
del shard['creation_times'][cache_key]
shard['creation_locks'].pop(cache_key, None)
removed_count += 1
logger.info(f"分片 {i} 已移除助手实例缓存: {cache_key} (unique_id: {unique_id})")
except KeyError:
continue
if removed_count > 0:
logger.info(f"分片管理器已移除 unique_id={unique_id}{removed_count} 个助手实例缓存")
else:
logger.warning(f"分片管理器未找到 unique_id={unique_id} 的缓存实例")
return removed_count
# 全局分片助手管理器实例
_global_sharded_agent_manager: Optional[ShardedAgentManager] = None
def get_global_sharded_agent_manager() -> ShardedAgentManager:
"""获取全局分片助手管理器实例"""
global _global_sharded_agent_manager
if _global_sharded_agent_manager is None:
_global_sharded_agent_manager = ShardedAgentManager()
return _global_sharded_agent_manager
def init_global_sharded_agent_manager(max_cached_agents: int = 20, shard_count: int = 16):
"""初始化全局分片助手管理器"""
global _global_sharded_agent_manager
_global_sharded_agent_manager = ShardedAgentManager(max_cached_agents, shard_count)
return _global_sharded_agent_manager

View File

@ -10,13 +10,12 @@ services:
- "8001:8001"
environment:
# 应用配置
- PYTHONPATH=/app
- PYTHONUNBUFFERED=1
- AGENT_POOL_SIZE=2
- BACKEND_HOST=http://api-dev.gbase.ai
- MAX_CONTEXT_TOKENS=262144
- DEFAULT_THINKING_ENABLE=true
volumes:
# 挂载项目数据目录
- ./projects:/app/projects
- ./public:/app/public
restart: unless-stopped
healthcheck:

View File

@ -19,9 +19,6 @@ from utils.log_util.logger import init_with_fastapi
# Import route modules
from routes import chat, files, projects, system
# Import the system manager from routes.system to access the initialized components
from routes.system import agent_manager, connection_pool, file_cache
app = FastAPI(title="Database Assistant API", version="1.0.0")
init_with_fastapi(app)

View File

@ -4,32 +4,26 @@ import asyncio
from typing import Union, Optional
from fastapi import APIRouter, HTTPException, Header
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import logging
logger = logging.getLogger('app')
from utils import (
Message, ChatRequest, ChatResponse
)
from agent.sharded_agent_manager import init_global_sharded_agent_manager
from utils.api_models import ChatRequestV2
from utils.fastapi_utils import (
process_messages,
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
call_preamble_llm, get_preamble_text,
call_preamble_llm,
create_stream_chunk
)
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT
from agent.agent_config import AgentConfig
from agent.deep_assistant import init_agent
router = APIRouter()
# 初始化全局助手管理器
agent_manager = init_global_sharded_agent_manager(
max_cached_agents=MAX_CACHED_AGENTS,
shard_count=SHARD_COUNT
)
def append_user_last_message(messages: list, content: str) -> bool:
@ -97,13 +91,13 @@ def format_messages_to_chat_history(messages: list) -> str:
async def enhanced_generate_stream_response(
agent_manager,
agent,
config: AgentConfig
):
"""增强的渐进式流式响应生成器 - 并发优化版本
Args:
agent_manager: agent管理器
agent: LangChain agent 对象
config: AgentConfig 对象包含所有参数
"""
try:
@ -137,9 +131,6 @@ async def enhanced_generate_stream_response(
# Agent 任务(准备 + 流式处理)
async def agent_task():
try:
# 准备 agent
agent = await agent_manager.get_or_create_agent(config)
# 开始流式处理
logger.info(f"Starting agent stream response")
chunk_id = 0
@ -270,25 +261,43 @@ async def create_agent_and_generate_response(
Args:
config: AgentConfig 对象包含所有参数
"""
config.safe_print()
config.preamble_text, config.system_prompt = get_preamble_text(config.language, config.system_prompt)
# 获取或创建 agent需要先创建 agent 才能访问 checkpointer
agent = await init_agent(config)
# 如果有 checkpointer检查是否有历史记录
if config.session_id:
# 检查 checkpointer
checkpointer = None
if hasattr(agent, '_checkpointer'):
checkpointer = agent._checkpointer
if checkpointer:
from agent.checkpoint_utils import check_checkpoint_history, prepare_messages_for_agent
has_history = await check_checkpoint_history(checkpointer, config.session_id)
# 更新 config.messages
config.messages = prepare_messages_for_agent(config.messages, has_history)
logger.info(f"Session {config.session_id}: has_history={has_history}, sending {len(config.messages)} messages")
else:
logger.warning(f"No checkpointer found for session {config.session_id}")
else:
logger.debug(f"No session_id provided, skipping checkpoint check")
# 如果是流式模式,使用增强的流式响应生成器
if config.stream:
return StreamingResponse(
enhanced_generate_stream_response(
agent_manager=agent_manager,
agent=agent,
config=config
),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
)
messages = config.messages
# 使用公共函数处理所有逻辑
agent = await agent_manager.get_or_create_agent(config)
agent_responses = await agent.ainvoke({"messages": messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS)
append_messages = agent_responses["messages"][len(messages):]
# 使用更新后的 messages
agent_responses = await agent.ainvoke({"messages": config.messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS)
append_messages = agent_responses["messages"][len(config.messages):]
response_text = ""
for msg in append_messages:
if isinstance(msg,AIMessage):
@ -313,9 +322,9 @@ async def create_agent_and_generate_response(
"finish_reason": "stop"
}],
usage={
"prompt_tokens": sum(len(msg.get("content", "")) for msg in messages),
"prompt_tokens": sum(len(msg.get("content", "")) for msg in config.messages),
"completion_tokens": len(response_text),
"total_tokens": sum(len(msg.get("content", "")) for msg in messages) + len(response_text)
"total_tokens": sum(len(msg.get("content", "")) for msg in config.messages) + len(response_text)
}
)
else:
@ -375,9 +384,7 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
# 创建 AgentConfig 对象
config = AgentConfig.from_v1_request(request, api_key, project_dir, generate_cfg, messages)
# 调用公共的agent创建和响应生成逻辑
return await create_agent_and_generate_response(
config=config
)
return await create_agent_and_generate_response(config)
except Exception as e:
import traceback
@ -451,9 +458,7 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st
# 创建 AgentConfig 对象
config = AgentConfig.from_v2_request(request, bot_config, project_dir, messages)
# 调用公共的agent创建和响应生成逻辑
return await create_agent_and_generate_response(
config=config
)
return await create_agent_and_generate_response(config)
except HTTPException:
raise

View File

@ -6,25 +6,12 @@ from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from utils import (
get_global_connection_pool, init_global_connection_pool,
get_global_file_cache, init_global_file_cache,
setup_system_optimizations
)
from agent.sharded_agent_manager import init_global_sharded_agent_manager
try:
from utils.system_optimizer import apply_optimization_profile
except ImportError:
def apply_optimization_profile(profile):
return {"profile": profile, "status": "system_optimizer not available"}
from embedding import get_model_manager
from pydantic import BaseModel
import logging
from utils.settings import (
MAX_CACHED_AGENTS, SHARD_COUNT, MAX_CONNECTIONS_PER_HOST, MAX_CONNECTIONS_TOTAL,
KEEPALIVE_TIMEOUT, CONNECT_TIMEOUT, TOTAL_TIMEOUT, FILE_CACHE_SIZE, FILE_CACHE_TTL,
TOKENIZERS_PARALLELISM
)
logger = logging.getLogger('app')
@ -49,35 +36,8 @@ class EncodeResponse(BaseModel):
logger.info("正在初始化系统优化...")
system_optimizer = setup_system_optimizations()
# 全局助手管理器配置(使用优化后的配置)
max_cached_agents = MAX_CACHED_AGENTS # 增加缓存大小
shard_count = SHARD_COUNT # 分片数量
# 初始化优化的全局助手管理器
agent_manager = init_global_sharded_agent_manager(
max_cached_agents=max_cached_agents,
shard_count=shard_count
)
# 初始化连接池
connection_pool = init_global_connection_pool(
max_connections_per_host=MAX_CONNECTIONS_PER_HOST,
max_connections_total=MAX_CONNECTIONS_TOTAL,
keepalive_timeout=KEEPALIVE_TIMEOUT,
connect_timeout=CONNECT_TIMEOUT,
total_timeout=TOTAL_TIMEOUT
)
# 初始化文件缓存
file_cache = init_global_file_cache(
cache_size=FILE_CACHE_SIZE,
ttl=FILE_CACHE_TTL
)
logger.info("系统优化初始化完成")
logger.info(f"- 分片Agent管理器: {shard_count} 个分片,最多缓存 {max_cached_agents} 个agent")
logger.info(f"- 连接池: 每主机100连接总计500连接")
logger.info(f"- 文件缓存: 1000个文件TTL 300秒")
@router.get("/api/health")
@ -90,9 +50,6 @@ async def health_check():
async def get_performance_stats():
"""获取系统性能统计信息"""
try:
# 获取agent管理器统计
agent_stats = agent_manager.get_cache_stats()
# 获取连接池统计(简化版)
pool_stats = {
"connection_pool": "active",
@ -128,7 +85,6 @@ async def get_performance_stats():
"success": True,
"timestamp": int(time.time()),
"performance": {
"agent_manager": agent_stats,
"connection_pool": pool_stats,
"file_cache": file_cache_stats,
"system": system_stats
@ -140,87 +96,6 @@ async def get_performance_stats():
raise HTTPException(status_code=500, detail=f"获取性能统计失败: {str(e)}")
@router.post("/api/v1/system/optimize")
async def optimize_system(profile: str = "balanced"):
"""应用系统优化配置"""
try:
# 应用优化配置
config = apply_optimization_profile(profile)
return {
"success": True,
"message": f"已应用 {profile} 优化配置",
"config": config
}
except Exception as e:
logger.error(f"Error applying optimization profile: {str(e)}")
raise HTTPException(status_code=500, detail=f"应用优化配置失败: {str(e)}")
@router.post("/api/v1/system/clear-cache")
async def clear_system_cache(cache_type: Optional[str] = None):
"""清理系统缓存"""
try:
cleared_counts = {}
if cache_type is None or cache_type == "agent":
# 清理agent缓存
agent_count = agent_manager.clear_cache()
cleared_counts["agent_cache"] = agent_count
if cache_type is None or cache_type == "file":
# 清理文件缓存
if hasattr(file_cache, '_cache'):
file_count = len(file_cache._cache)
file_cache._cache.clear()
cleared_counts["file_cache"] = file_count
return {
"success": True,
"message": f"已清理指定类型的缓存",
"cleared_counts": cleared_counts
}
except Exception as e:
logger.error(f"Error clearing cache: {str(e)}")
raise HTTPException(status_code=500, detail=f"清理缓存失败: {str(e)}")
@router.get("/api/v1/system/config")
async def get_system_config():
"""获取当前系统配置"""
try:
return {
"success": True,
"config": {
"max_cached_agents": max_cached_agents,
"shard_count": shard_count,
"tokenizer_parallelism": TOKENIZERS_PARALLELISM,
"max_connections_per_host": str(MAX_CONNECTIONS_PER_HOST),
"max_connections_total": str(MAX_CONNECTIONS_TOTAL),
"file_cache_size": str(FILE_CACHE_SIZE),
"file_cache_ttl": str(FILE_CACHE_TTL)
}
}
except Exception as e:
logger.error(f"Error getting system config: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取系统配置失败: {str(e)}")
@router.post("/system/remove-project-cache")
async def remove_project_cache(dataset_id: str):
"""移除特定项目的缓存"""
try:
removed_count = agent_manager.remove_cache_by_unique_id(dataset_id)
if removed_count > 0:
return {"message": f"项目缓存移除成功: {dataset_id}", "removed_count": removed_count}
else:
return {"message": f"未找到项目缓存: {dataset_id}", "removed_count": 0}
except Exception as e:
raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}")
@router.post("/api/v1/embedding/encode", response_model=EncodeResponse)
async def encode_texts(request: EncodeRequest):

View File

@ -31,47 +31,10 @@ from .project_manager import (
)
from .connection_pool import (
HTTPConnectionPool,
get_global_connection_pool,
init_global_connection_pool,
OAIWithConnectionPool
)
from .async_file_ops import (
AsyncFileCache,
get_global_file_cache,
init_global_file_cache,
async_read_file,
async_read_json,
async_write_file,
async_write_json,
async_file_exists,
async_get_file_mtime,
ParallelFileReader,
get_global_parallel_reader
)
from .system_optimizer import (
SystemOptimizer,
AsyncioOptimizer,
setup_system_optimizations,
create_performance_monitor,
get_optimized_worker_config,
OPTIMIZATION_CONFIGS,
apply_optimization_profile,
get_global_system_optimizer
setup_system_optimizations
)
# Import config cache module
# Note: This has been moved to agent package
# from .config_cache import (
# config_cache,
# ConfigFileCache
# )
from .agent_pool import (
AgentPool,
get_agent_pool,

View File

@ -1,212 +0,0 @@
#!/usr/bin/env python3
"""
HTTP连接池管理器 - 提供高效的HTTP连接复用
"""
import aiohttp
import asyncio
from typing import Dict, Optional, Any
import threading
import time
import weakref
class HTTPConnectionPool:
"""HTTP连接池管理器
提供连接复用 Keep-Alive 连接合理超时设置等功能
"""
def __init__(self,
max_connections_per_host: int = 100,
max_connections_total: int = 500,
keepalive_timeout: int = 30,
connect_timeout: int = 10,
total_timeout: int = 60):
"""
初始化连接池
Args:
max_connections_per_host: 每个主机的最大连接数
max_connections_total: 总连接数限制
keepalive_timeout: Keep-Alive超时时间()
connect_timeout: 连接超时时间()
total_timeout: 总请求超时时间()
"""
self.max_connections_per_host = max_connections_per_host
self.max_connections_total = max_connections_total
self.keepalive_timeout = keepalive_timeout
self.connect_timeout = connect_timeout
self.total_timeout = total_timeout
# 创建连接器配置
self.connector_config = {
'limit': max_connections_total,
'limit_per_host': max_connections_per_host,
'keepalive_timeout': keepalive_timeout,
'enable_cleanup_closed': True, # 自动清理关闭的连接
'force_close': False, # 不强制关闭连接
'use_dns_cache': True, # 使用DNS缓存
'ttl_dns_cache': 300, # DNS缓存TTL
}
# 使用线程本地存储来管理事件循环间的session
self._sessions = weakref.WeakKeyDictionary()
self._lock = threading.RLock()
def _create_session(self) -> aiohttp.ClientSession:
"""创建新的aiohttp会话"""
timeout = aiohttp.ClientTimeout(
total=self.total_timeout,
connect=self.connect_timeout,
sock_connect=self.connect_timeout,
sock_read=self.total_timeout
)
connector = aiohttp.TCPConnector(**self.connector_config)
return aiohttp.ClientSession(
connector=connector,
timeout=timeout,
headers={
'User-Agent': 'QwenAgent/1.0',
'Connection': 'keep-alive',
'Accept-Encoding': 'gzip, deflate, br',
}
)
def get_session(self) -> aiohttp.ClientSession:
"""获取当前事件循环的session"""
loop = asyncio.get_running_loop()
with self._lock:
if loop not in self._sessions:
self._sessions[loop] = self._create_session()
return self._sessions[loop]
async def request(self, method: str, url: str, **kwargs) -> aiohttp.ClientResponse:
"""发送HTTP请求自动处理连接复用"""
session = self.get_session()
return await session.request(method, url, **kwargs)
async def get(self, url: str, **kwargs) -> aiohttp.ClientResponse:
"""发送GET请求"""
return await self.request('GET', url, **kwargs)
async def post(self, url: str, **kwargs) -> aiohttp.ClientResponse:
"""发送POST请求"""
return await self.request('POST', url, **kwargs)
async def close(self):
"""关闭所有session"""
with self._lock:
for loop, session in list(self._sessions.items()):
if not session.closed:
await session.close()
self._sessions.clear()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
# 全局连接池实例
_global_connection_pool: Optional[HTTPConnectionPool] = None
_pool_lock = threading.Lock()
def get_global_connection_pool() -> HTTPConnectionPool:
"""获取全局连接池实例"""
global _global_connection_pool
if _global_connection_pool is None:
with _pool_lock:
if _global_connection_pool is None:
_global_connection_pool = HTTPConnectionPool()
return _global_connection_pool
def init_global_connection_pool(
max_connections_per_host: int = 100,
max_connections_total: int = 500,
keepalive_timeout: int = 30,
connect_timeout: int = 10,
total_timeout: int = 60
) -> HTTPConnectionPool:
"""初始化全局连接池"""
global _global_connection_pool
with _pool_lock:
_global_connection_pool = HTTPConnectionPool(
max_connections_per_host=max_connections_per_host,
max_connections_total=max_connections_total,
keepalive_timeout=keepalive_timeout,
connect_timeout=connect_timeout,
total_timeout=total_timeout
)
return _global_connection_pool
class OAIWithConnectionPool:
"""带有连接池的OpenAI API客户端"""
def __init__(self,
config: Dict[str, Any],
connection_pool: Optional[HTTPConnectionPool] = None):
"""
初始化客户端
Args:
config: OpenAI API配置
connection_pool: 可选的连接池实例
"""
self.config = config
self.pool = connection_pool or get_global_connection_pool()
self.base_url = config.get('model_server', '').rstrip('/')
if not self.base_url:
self.base_url = "https://api.openai.com/v1"
self.api_key = config.get('api_key', '')
self.model = config.get('model', 'gpt-3.5-turbo')
self.generate_cfg = config.get('generate_cfg', {})
async def chat_completions(self, messages: list, stream: bool = False, **kwargs):
"""发送聊天完成请求"""
url = f"{self.base_url}/chat/completions"
headers = {
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json',
}
data = {
'model': self.model,
'messages': messages,
'stream': stream,
**self.generate_cfg,
**kwargs
}
async with self.pool.post(url, json=data, headers=headers) as response:
if response.status == 200:
if stream:
return self._handle_stream_response(response)
else:
return await response.json()
else:
error_text = await response.text()
raise Exception(f"API request failed with status {response.status}: {error_text}")
async def _handle_stream_response(self, response: aiohttp.ClientResponse):
"""处理流式响应"""
async for line in response.content:
line = line.decode('utf-8').strip()
if line.startswith('data: '):
data = line[6:] # 移除 'data: ' 前缀
if data == '[DONE]':
break
try:
import json
yield json.loads(data)
except json.JSONDecodeError:
continue

View File

@ -249,9 +249,6 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
parts = re.split(r'\[(THINK|PREAMBLE|TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"])
current_tag = None
assistant_content = ""
function_calls = []
tool_responses = []
tool_id_counter = 0 # 添加唯一的工具调用计数器
tool_id_list = []
for i in range(0, len(parts)):