删除agent manager
This commit is contained in:
parent
23bc62a2b8
commit
b78b178c03
@ -72,17 +72,18 @@ class AgentConfig:
|
||||
"""从v1请求创建配置"""
|
||||
# 延迟导入避免循环依赖
|
||||
from .logging_handler import LoggingCallbackHandler
|
||||
|
||||
from utils.fastapi_utils import get_preamble_text
|
||||
if messages is None:
|
||||
messages = []
|
||||
|
||||
return cls(
|
||||
|
||||
preamble_text, system_prompt = get_preamble_text(request.language, request.system_prompt)
|
||||
config = cls(
|
||||
bot_id=request.bot_id,
|
||||
api_key=api_key,
|
||||
model_name=request.model,
|
||||
model_server=request.model_server,
|
||||
language=request.language,
|
||||
system_prompt=request.system_prompt,
|
||||
system_prompt=system_prompt,
|
||||
mcp_settings=request.mcp_settings,
|
||||
robot_type=request.robot_type,
|
||||
user_identifier=request.user_identifier,
|
||||
@ -93,37 +94,45 @@ class AgentConfig:
|
||||
tool_response=request.tool_response,
|
||||
generate_cfg=generate_cfg,
|
||||
logging_handler=LoggingCallbackHandler(),
|
||||
messages=messages
|
||||
messages=messages,
|
||||
preamble_text=preamble_text,
|
||||
)
|
||||
config.safe_print()
|
||||
return config
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_v2_request(cls, request, bot_config: Dict, project_dir: Optional[str] = None, messages: Optional[List] = None):
|
||||
"""从v2请求创建配置"""
|
||||
# 延迟导入避免循环依赖
|
||||
from .logging_handler import LoggingCallbackHandler
|
||||
|
||||
from utils.fastapi_utils import get_preamble_text
|
||||
if messages is None:
|
||||
messages = []
|
||||
|
||||
return cls(
|
||||
language = request.language or bot_config.get("language", "zh")
|
||||
preamble_text, system_prompt = get_preamble_text(language, bot_config.get("system_prompt"))
|
||||
config = cls(
|
||||
bot_id=request.bot_id,
|
||||
api_key=bot_config.get("api_key"),
|
||||
model_name=bot_config.get("model", "qwen/qwen3-next-80b-a3b-instruct"),
|
||||
model_server=bot_config.get("model_server", ""),
|
||||
language=request.language or bot_config.get("language", "zh"),
|
||||
system_prompt=bot_config.get("system_prompt"),
|
||||
language=language,
|
||||
system_prompt=system_prompt,
|
||||
mcp_settings=bot_config.get("mcp_settings", []),
|
||||
robot_type=bot_config.get("robot_type", "general_agent"),
|
||||
user_identifier=request.user_identifier,
|
||||
session_id=request.session_id,
|
||||
enable_thinking=request.enable_thinking,
|
||||
project_dir=project_dir,
|
||||
stream=request.stream,
|
||||
stream=request.stream,
|
||||
tool_response=request.tool_response,
|
||||
generate_cfg={}, # v2接口不传递额外的generate_cfg
|
||||
logging_handler=LoggingCallbackHandler(),
|
||||
messages=messages
|
||||
messages=messages,
|
||||
preamble_text=preamble_text,
|
||||
)
|
||||
config.safe_print()
|
||||
return config
|
||||
|
||||
def invoke_config(self):
|
||||
"""返回Langchain需要的配置字典"""
|
||||
|
||||
126
agent/checkpoint_utils.py
Normal file
126
agent/checkpoint_utils.py
Normal file
@ -0,0 +1,126 @@
|
||||
"""用于处理 LangGraph checkpoint 相关的工具函数"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
|
||||
async def check_checkpoint_history(checkpointer: MemorySaver, thread_id: str) -> bool:
|
||||
"""
|
||||
检查指定的 thread_id 在 checkpointer 中是否已有历史记录
|
||||
|
||||
Args:
|
||||
checkpointer: MemorySaver 实例
|
||||
thread_id: 线程ID(通常是 session_id)
|
||||
|
||||
Returns:
|
||||
bool: True 表示有历史记录,False 表示没有
|
||||
"""
|
||||
if not checkpointer or not thread_id:
|
||||
logger.debug(f"No checkpointer or thread_id: checkpointer={bool(checkpointer)}, thread_id={thread_id}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 获取配置
|
||||
config = {"configurable": {"thread_id": thread_id}}
|
||||
|
||||
# 调试信息:检查 checkpointer 类型
|
||||
logger.debug(f"Checkpointer type: {type(checkpointer)}")
|
||||
logger.debug(f"Checkpointer dir: {[attr for attr in dir(checkpointer) if not attr.startswith('_')]}")
|
||||
|
||||
# 先尝试获取最新的 checkpoint
|
||||
try:
|
||||
latest_checkpoint = await checkpointer.aget_tuple(config)
|
||||
logger.debug(f"aget_tuple result: {latest_checkpoint}")
|
||||
|
||||
if latest_checkpoint is not None:
|
||||
logger.info(f"Found latest checkpoint for thread_id: {thread_id}")
|
||||
# 解构 checkpoint tuple
|
||||
checkpoint_config, checkpoint, metadata = latest_checkpoint
|
||||
logger.debug(f"Checkpoint metadata: {metadata}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"aget_tuple failed: {e}")
|
||||
|
||||
# 如果没有最新的,再列出所有
|
||||
logger.debug(f"No latest checkpoint for thread_id: {thread_id}, checking all checkpoints...")
|
||||
try:
|
||||
checkpoints = []
|
||||
async for c in checkpointer.alist(config):
|
||||
checkpoints.append(c)
|
||||
logger.debug(f"Found checkpoint: {c}")
|
||||
|
||||
# 如果有至少一个 checkpoint,说明有历史记录
|
||||
has_history = len(checkpoints) > 0
|
||||
|
||||
if has_history:
|
||||
logger.info(f"Found {len(checkpoints)} checkpoints in total for thread_id: {thread_id}")
|
||||
else:
|
||||
logger.info(f"No existing history for thread_id: {thread_id}")
|
||||
|
||||
return has_history
|
||||
except Exception as e:
|
||||
logger.warning(f"alist failed: {e}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Error checking checkpoint history for thread_id {thread_id}: {e}")
|
||||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||
# 出错时保守处理,返回 False
|
||||
return False
|
||||
|
||||
|
||||
def prepare_messages_for_agent(
|
||||
messages: List[Dict[str, Any]],
|
||||
has_history: bool
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
根据是否有历史记录来准备要发送给 agent 的消息
|
||||
|
||||
Args:
|
||||
messages: 完整的消息列表
|
||||
has_history: 是否已有历史记录
|
||||
|
||||
Returns:
|
||||
List[Dict]: 要发送给 agent 的消息列表
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# 如果有历史记录,只发送最后一条用户消息
|
||||
if has_history:
|
||||
# 找到最后一条用户消息
|
||||
for msg in reversed(messages):
|
||||
if msg.get('role') == 'user':
|
||||
logger.info(f"Has history, sending only last user message: {msg.get('content', '')[:50]}...")
|
||||
return [msg]
|
||||
|
||||
# 如果没有用户消息(理论上不应该发生),返回空列表
|
||||
logger.warning("No user message found in messages")
|
||||
return messages
|
||||
|
||||
# 如果没有历史记录,发送所有消息
|
||||
logger.info(f"No history, sending all {len(messages)} messages")
|
||||
return messages
|
||||
|
||||
|
||||
def update_agent_config_for_checkpoint(
|
||||
config_messages: List[Dict[str, Any]],
|
||||
has_history: bool
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
更新 AgentConfig 中的 messages,根据是否有历史记录决定发送哪些消息
|
||||
|
||||
这个函数可以在调用 agent 之前使用,避免重复处理消息历史
|
||||
|
||||
Args:
|
||||
config_messages: AgentConfig 中的原始消息列表
|
||||
has_history: 是否已有历史记录
|
||||
|
||||
Returns:
|
||||
List[Dict]: 更新后的消息列表
|
||||
"""
|
||||
return prepare_messages_for_agent(config_messages, has_history)
|
||||
@ -15,7 +15,7 @@ from .guideline_middleware import GuidelineMiddleware
|
||||
from .tool_output_length_middleware import ToolOutputLengthMiddleware
|
||||
from utils.settings import SUMMARIZATION_MAX_TOKENS, TOOL_OUTPUT_MAX_LENGTH
|
||||
from agent.agent_config import AgentConfig
|
||||
|
||||
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
||||
|
||||
# Utility functions
|
||||
def read_system_prompt():
|
||||
@ -57,7 +57,6 @@ async def get_tools_from_mcp(mcp):
|
||||
# 发生异常时返回空列表,避免上层调用报错
|
||||
return []
|
||||
|
||||
|
||||
async def init_agent(config: AgentConfig):
|
||||
"""
|
||||
初始化 Agent,支持持久化内存和对话摘要
|
||||
@ -66,9 +65,20 @@ async def init_agent(config: AgentConfig):
|
||||
config: AgentConfig 对象,包含所有初始化参数
|
||||
mcp: MCP配置(如果为None则使用配置中的mcp_settings)
|
||||
"""
|
||||
final_system_prompt = await load_system_prompt_async(
|
||||
config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier
|
||||
)
|
||||
final_mcp_settings = await load_mcp_settings_async(
|
||||
config.project_dir, config.mcp_settings, config.bot_id, config.robot_type
|
||||
)
|
||||
|
||||
# 如果没有提供mcp,使用config中的mcp_settings
|
||||
mcp_settings = config.mcp_settings if config.mcp_settings else read_mcp_settings()
|
||||
system_prompt = config.system_prompt if config.system_prompt else read_system_prompt()
|
||||
mcp_settings = final_mcp_settings if final_mcp_settings else read_mcp_settings()
|
||||
system_prompt = final_system_prompt if final_system_prompt else read_system_prompt()
|
||||
|
||||
config.system_prompt = mcp_settings
|
||||
config.mcp_settings = system_prompt
|
||||
|
||||
mcp_tools = await get_tools_from_mcp(mcp_settings)
|
||||
|
||||
# 检测或使用指定的提供商
|
||||
@ -124,4 +134,8 @@ async def init_agent(config: AgentConfig):
|
||||
checkpointer=checkpointer # 传入 checkpointer 以启用持久化
|
||||
)
|
||||
|
||||
# 将 checkpointer 作为属性附加到 agent 上,方便访问
|
||||
if checkpointer:
|
||||
agent._checkpointer = checkpointer
|
||||
|
||||
return agent
|
||||
@ -1,323 +0,0 @@
|
||||
# Copyright 2023
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""分片助手管理器 - 减少锁竞争的高并发agent缓存系统"""
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional
|
||||
import threading
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
from agent.deep_assistant import init_agent
|
||||
from agent.prompt_loader import load_system_prompt_async, load_mcp_settings_async
|
||||
from agent.agent_config import AgentConfig
|
||||
|
||||
|
||||
class ShardedAgentManager:
|
||||
"""分片助手管理器
|
||||
|
||||
使用分片技术减少锁竞争,支持高并发访问
|
||||
"""
|
||||
|
||||
def __init__(self, max_cached_agents: int = 20, shard_count: int = 16):
|
||||
self.max_cached_agents = max_cached_agents
|
||||
self.shard_count = shard_count
|
||||
|
||||
# 创建分片
|
||||
self.shards = []
|
||||
for i in range(shard_count):
|
||||
shard = {
|
||||
'agents': {}, # {cache_key: assistant_instance}
|
||||
'unique_ids': {}, # {cache_key: unique_id}
|
||||
'access_times': {}, # LRU 访问时间管理
|
||||
'creation_times': {}, # 创建时间记录
|
||||
'lock': asyncio.Lock(), # 每个分片独立锁
|
||||
'creation_locks': {}, # 防止并发创建相同agent的锁
|
||||
}
|
||||
self.shards.append(shard)
|
||||
|
||||
# 用于统计的全局锁(读写分离)
|
||||
self._stats_lock = threading.RLock()
|
||||
self._global_stats = {
|
||||
'total_requests': 0,
|
||||
'cache_hits': 0,
|
||||
'cache_misses': 0,
|
||||
'agent_creations': 0
|
||||
}
|
||||
|
||||
def _get_shard_index(self, cache_key: str) -> int:
|
||||
"""根据缓存键获取分片索引"""
|
||||
hash_value = int(hashlib.md5(cache_key.encode('utf-8')).hexdigest(), 16)
|
||||
return hash_value % self.shard_count
|
||||
|
||||
def _get_cache_key(self, bot_id: str, model_name: str = None, api_key: str = None,
|
||||
model_server: str = None, generate_cfg: Dict = None,
|
||||
system_prompt: str = None, mcp_settings: List[Dict] = None,
|
||||
enable_thinking: bool = False) -> str:
|
||||
"""获取包含所有相关参数的哈希值作为缓存键"""
|
||||
cache_data = {
|
||||
'bot_id': bot_id,
|
||||
'model_name': model_name or '',
|
||||
'api_key': api_key or '',
|
||||
'model_server': model_server or '',
|
||||
'generate_cfg': json.dumps(generate_cfg or {}, sort_keys=True),
|
||||
'system_prompt': system_prompt or '',
|
||||
'mcp_settings': json.dumps(mcp_settings or [], sort_keys=True),
|
||||
'enable_thinking': enable_thinking
|
||||
}
|
||||
|
||||
cache_str = json.dumps(cache_data, sort_keys=True)
|
||||
return hashlib.md5(cache_str.encode('utf-8')).hexdigest()[:16]
|
||||
|
||||
def _update_access_time(self, shard: dict, cache_key: str):
|
||||
"""更新访问时间(LRU 管理)"""
|
||||
shard['access_times'][cache_key] = time.time()
|
||||
|
||||
def _cleanup_old_agents(self, shard: dict):
|
||||
"""清理分片中的旧助手实例,基于 LRU 策略"""
|
||||
# 计算每个分片的最大容量
|
||||
shard_max_capacity = max(1, self.max_cached_agents // self.shard_count)
|
||||
|
||||
if len(shard['agents']) <= shard_max_capacity:
|
||||
return
|
||||
|
||||
# 按 LRU 顺序排序,删除最久未访问的实例
|
||||
sorted_keys = sorted(shard['access_times'].keys(),
|
||||
key=lambda k: shard['access_times'][k])
|
||||
|
||||
keys_to_remove = sorted_keys[:-shard_max_capacity]
|
||||
removed_count = 0
|
||||
|
||||
for cache_key in keys_to_remove:
|
||||
try:
|
||||
del shard['agents'][cache_key]
|
||||
del shard['unique_ids'][cache_key]
|
||||
del shard['access_times'][cache_key]
|
||||
del shard['creation_times'][cache_key]
|
||||
shard['creation_locks'].pop(cache_key, None)
|
||||
removed_count += 1
|
||||
logger.info(f"分片清理过期的助手实例缓存: {cache_key}")
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"分片已清理 {removed_count} 个过期的助手实例缓存")
|
||||
|
||||
async def get_or_create_agent(self, config: AgentConfig):
|
||||
"""获取或创建文件预加载的助手实例
|
||||
|
||||
Args:
|
||||
config: AgentConfig 对象,包含所有初始化参数
|
||||
"""
|
||||
|
||||
# 更新请求统计
|
||||
with self._stats_lock:
|
||||
self._global_stats['total_requests'] += 1
|
||||
|
||||
# 异步加载配置文件(带缓存)
|
||||
final_system_prompt = await load_system_prompt_async(
|
||||
config.project_dir, config.language, config.system_prompt, config.robot_type, config.bot_id, config.user_identifier
|
||||
)
|
||||
final_mcp_settings = await load_mcp_settings_async(
|
||||
config.project_dir, config.mcp_settings, config.bot_id, config.robot_type
|
||||
)
|
||||
config.system_prompt = final_system_prompt
|
||||
config.mcp_settings = final_mcp_settings
|
||||
cache_key = self._get_cache_key(config.bot_id, config.model_name, config.api_key, config.model_server,
|
||||
config.generate_cfg, final_system_prompt, final_mcp_settings,
|
||||
config.enable_thinking)
|
||||
|
||||
# 获取分片
|
||||
shard_index = self._get_shard_index(cache_key)
|
||||
shard = self.shards[shard_index]
|
||||
|
||||
# 使用分片级异步锁防止并发创建相同的agent
|
||||
async with shard['lock']:
|
||||
# 检查是否已存在该助手实例
|
||||
if cache_key in shard['agents']:
|
||||
self._update_access_time(shard, cache_key)
|
||||
agent = shard['agents'][cache_key]
|
||||
# 更新缓存命中统计
|
||||
with self._stats_lock:
|
||||
self._global_stats['cache_hits'] += 1
|
||||
|
||||
logger.info(f"分片复用现有的助手实例缓存: {cache_key} (bot_id: {config.bot_id}, shard: {shard_index})")
|
||||
return agent
|
||||
|
||||
# 更新缓存未命中统计
|
||||
with self._stats_lock:
|
||||
self._global_stats['cache_misses'] += 1
|
||||
|
||||
# 使用更细粒度的创建锁
|
||||
creation_lock = shard['creation_locks'].setdefault(cache_key, asyncio.Lock())
|
||||
|
||||
# 在分片锁外创建agent,减少锁持有时间
|
||||
async with creation_lock:
|
||||
# 再次检查是否已存在(获取锁后可能有其他请求已创建)
|
||||
async with shard['lock']:
|
||||
if cache_key in shard['agents']:
|
||||
self._update_access_time(shard, cache_key)
|
||||
agent = shard['agents'][cache_key]
|
||||
|
||||
with self._stats_lock:
|
||||
self._global_stats['cache_hits'] += 1
|
||||
|
||||
return agent
|
||||
|
||||
# 清理过期实例
|
||||
async with shard['lock']:
|
||||
self._cleanup_old_agents(shard)
|
||||
|
||||
# 创建新的助手实例
|
||||
logger.info(f"分片创建新的助手实例缓存: {cache_key}, bot_id: {config.bot_id}, shard: {shard_index}")
|
||||
current_time = time.time()
|
||||
|
||||
agent = await init_agent(config)
|
||||
|
||||
# 缓存实例
|
||||
async with shard['lock']:
|
||||
shard['agents'][cache_key] = agent
|
||||
shard['unique_ids'][cache_key] = config.bot_id
|
||||
shard['access_times'][cache_key] = current_time
|
||||
shard['creation_times'][cache_key] = current_time
|
||||
|
||||
# 清理创建锁
|
||||
shard['creation_locks'].pop(cache_key, None)
|
||||
|
||||
# 更新创建统计
|
||||
with self._stats_lock:
|
||||
self._global_stats['agent_creations'] += 1
|
||||
|
||||
logger.info(f"分片助手实例缓存创建完成: {cache_key}, shard: {shard_index}")
|
||||
return agent
|
||||
|
||||
def get_cache_stats(self) -> Dict:
|
||||
"""获取缓存统计信息"""
|
||||
current_time = time.time()
|
||||
total_agents = 0
|
||||
agents_info = []
|
||||
|
||||
for i, shard in enumerate(self.shards):
|
||||
for cache_key, agent in shard['agents'].items():
|
||||
total_agents += 1
|
||||
agents_info.append({
|
||||
"cache_key": cache_key,
|
||||
"unique_id": shard['unique_ids'].get(cache_key, "unknown"),
|
||||
"shard": i,
|
||||
"created_at": shard['creation_times'].get(cache_key, 0),
|
||||
"last_accessed": shard['access_times'].get(cache_key, 0),
|
||||
"age_seconds": int(current_time - shard['creation_times'].get(cache_key, current_time)),
|
||||
"idle_seconds": int(current_time - shard['access_times'].get(cache_key, current_time))
|
||||
})
|
||||
|
||||
stats = {
|
||||
"total_cached_agents": total_agents,
|
||||
"max_cached_agents": self.max_cached_agents,
|
||||
"shard_count": self.shard_count,
|
||||
"agents": agents_info
|
||||
}
|
||||
|
||||
# 添加全局统计
|
||||
with self._stats_lock:
|
||||
stats.update({
|
||||
"total_requests": self._global_stats['total_requests'],
|
||||
"cache_hits": self._global_stats['cache_hits'],
|
||||
"cache_misses": self._global_stats['cache_misses'],
|
||||
"cache_hit_rate": (
|
||||
self._global_stats['cache_hits'] / max(1, self._global_stats['total_requests']) * 100
|
||||
),
|
||||
"agent_creations": self._global_stats['agent_creations']
|
||||
})
|
||||
|
||||
return stats
|
||||
|
||||
def clear_cache(self) -> int:
|
||||
"""清空所有缓存"""
|
||||
cache_count = 0
|
||||
|
||||
for shard in self.shards:
|
||||
cache_count += len(shard['agents'])
|
||||
shard['agents'].clear()
|
||||
shard['unique_ids'].clear()
|
||||
shard['access_times'].clear()
|
||||
shard['creation_times'].clear()
|
||||
shard['creation_locks'].clear()
|
||||
|
||||
# 重置统计
|
||||
with self._stats_lock:
|
||||
self._global_stats = {
|
||||
'total_requests': 0,
|
||||
'cache_hits': 0,
|
||||
'cache_misses': 0,
|
||||
'agent_creations': 0
|
||||
}
|
||||
|
||||
logger.info(f"分片管理器已清空所有助手实例缓存,共清理 {cache_count} 个实例")
|
||||
return cache_count
|
||||
|
||||
def remove_cache_by_unique_id(self, unique_id: str) -> int:
|
||||
"""根据 unique_id 移除所有相关的缓存"""
|
||||
removed_count = 0
|
||||
|
||||
for i, shard in enumerate(self.shards):
|
||||
keys_to_remove = []
|
||||
|
||||
# 找到所有匹配的 unique_id 的缓存键
|
||||
for cache_key, stored_unique_id in shard['unique_ids'].items():
|
||||
if stored_unique_id == unique_id:
|
||||
keys_to_remove.append(cache_key)
|
||||
|
||||
# 移除找到的缓存
|
||||
for cache_key in keys_to_remove:
|
||||
try:
|
||||
del shard['agents'][cache_key]
|
||||
del shard['unique_ids'][cache_key]
|
||||
del shard['access_times'][cache_key]
|
||||
del shard['creation_times'][cache_key]
|
||||
shard['creation_locks'].pop(cache_key, None)
|
||||
removed_count += 1
|
||||
logger.info(f"分片 {i} 已移除助手实例缓存: {cache_key} (unique_id: {unique_id})")
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"分片管理器已移除 unique_id={unique_id} 的 {removed_count} 个助手实例缓存")
|
||||
else:
|
||||
logger.warning(f"分片管理器未找到 unique_id={unique_id} 的缓存实例")
|
||||
|
||||
return removed_count
|
||||
|
||||
|
||||
# 全局分片助手管理器实例
|
||||
_global_sharded_agent_manager: Optional[ShardedAgentManager] = None
|
||||
|
||||
|
||||
def get_global_sharded_agent_manager() -> ShardedAgentManager:
|
||||
"""获取全局分片助手管理器实例"""
|
||||
global _global_sharded_agent_manager
|
||||
if _global_sharded_agent_manager is None:
|
||||
_global_sharded_agent_manager = ShardedAgentManager()
|
||||
return _global_sharded_agent_manager
|
||||
|
||||
|
||||
def init_global_sharded_agent_manager(max_cached_agents: int = 20, shard_count: int = 16):
|
||||
"""初始化全局分片助手管理器"""
|
||||
global _global_sharded_agent_manager
|
||||
_global_sharded_agent_manager = ShardedAgentManager(max_cached_agents, shard_count)
|
||||
return _global_sharded_agent_manager
|
||||
@ -10,13 +10,12 @@ services:
|
||||
- "8001:8001"
|
||||
environment:
|
||||
# 应用配置
|
||||
- PYTHONPATH=/app
|
||||
- PYTHONUNBUFFERED=1
|
||||
- AGENT_POOL_SIZE=2
|
||||
- BACKEND_HOST=http://api-dev.gbase.ai
|
||||
- MAX_CONTEXT_TOKENS=262144
|
||||
- DEFAULT_THINKING_ENABLE=true
|
||||
volumes:
|
||||
# 挂载项目数据目录
|
||||
- ./projects:/app/projects
|
||||
- ./public:/app/public
|
||||
restart: unless-stopped
|
||||
|
||||
healthcheck:
|
||||
|
||||
@ -19,9 +19,6 @@ from utils.log_util.logger import init_with_fastapi
|
||||
# Import route modules
|
||||
from routes import chat, files, projects, system
|
||||
|
||||
# Import the system manager from routes.system to access the initialized components
|
||||
from routes.system import agent_manager, connection_pool, file_cache
|
||||
|
||||
app = FastAPI(title="Database Assistant API", version="1.0.0")
|
||||
|
||||
init_with_fastapi(app)
|
||||
|
||||
@ -4,32 +4,26 @@ import asyncio
|
||||
from typing import Union, Optional
|
||||
from fastapi import APIRouter, HTTPException, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
from utils import (
|
||||
Message, ChatRequest, ChatResponse
|
||||
)
|
||||
from agent.sharded_agent_manager import init_global_sharded_agent_manager
|
||||
from utils.api_models import ChatRequestV2
|
||||
from utils.fastapi_utils import (
|
||||
process_messages,
|
||||
create_project_directory, extract_api_key_from_auth, generate_v2_auth_token, fetch_bot_config,
|
||||
call_preamble_llm, get_preamble_text,
|
||||
call_preamble_llm,
|
||||
create_stream_chunk
|
||||
)
|
||||
from langchain_core.messages import AIMessageChunk, HumanMessage, ToolMessage, AIMessage
|
||||
from utils.settings import MAX_OUTPUT_TOKENS, MAX_CACHED_AGENTS, SHARD_COUNT
|
||||
from agent.agent_config import AgentConfig
|
||||
from agent.deep_assistant import init_agent
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 初始化全局助手管理器
|
||||
agent_manager = init_global_sharded_agent_manager(
|
||||
max_cached_agents=MAX_CACHED_AGENTS,
|
||||
shard_count=SHARD_COUNT
|
||||
)
|
||||
|
||||
|
||||
def append_user_last_message(messages: list, content: str) -> bool:
|
||||
@ -97,13 +91,13 @@ def format_messages_to_chat_history(messages: list) -> str:
|
||||
|
||||
|
||||
async def enhanced_generate_stream_response(
|
||||
agent_manager,
|
||||
agent,
|
||||
config: AgentConfig
|
||||
):
|
||||
"""增强的渐进式流式响应生成器 - 并发优化版本
|
||||
|
||||
Args:
|
||||
agent_manager: agent管理器
|
||||
agent: LangChain agent 对象
|
||||
config: AgentConfig 对象,包含所有参数
|
||||
"""
|
||||
try:
|
||||
@ -137,9 +131,6 @@ async def enhanced_generate_stream_response(
|
||||
# Agent 任务(准备 + 流式处理)
|
||||
async def agent_task():
|
||||
try:
|
||||
# 准备 agent
|
||||
agent = await agent_manager.get_or_create_agent(config)
|
||||
|
||||
# 开始流式处理
|
||||
logger.info(f"Starting agent stream response")
|
||||
chunk_id = 0
|
||||
@ -270,25 +261,43 @@ async def create_agent_and_generate_response(
|
||||
Args:
|
||||
config: AgentConfig 对象,包含所有参数
|
||||
"""
|
||||
config.safe_print()
|
||||
config.preamble_text, config.system_prompt = get_preamble_text(config.language, config.system_prompt)
|
||||
# 获取或创建 agent(需要先创建 agent 才能访问 checkpointer)
|
||||
agent = await init_agent(config)
|
||||
|
||||
# 如果有 checkpointer,检查是否有历史记录
|
||||
if config.session_id:
|
||||
# 检查 checkpointer
|
||||
checkpointer = None
|
||||
if hasattr(agent, '_checkpointer'):
|
||||
checkpointer = agent._checkpointer
|
||||
|
||||
if checkpointer:
|
||||
from agent.checkpoint_utils import check_checkpoint_history, prepare_messages_for_agent
|
||||
has_history = await check_checkpoint_history(checkpointer, config.session_id)
|
||||
|
||||
# 更新 config.messages
|
||||
config.messages = prepare_messages_for_agent(config.messages, has_history)
|
||||
|
||||
logger.info(f"Session {config.session_id}: has_history={has_history}, sending {len(config.messages)} messages")
|
||||
else:
|
||||
logger.warning(f"No checkpointer found for session {config.session_id}")
|
||||
else:
|
||||
logger.debug(f"No session_id provided, skipping checkpoint check")
|
||||
|
||||
# 如果是流式模式,使用增强的流式响应生成器
|
||||
if config.stream:
|
||||
return StreamingResponse(
|
||||
enhanced_generate_stream_response(
|
||||
agent_manager=agent_manager,
|
||||
agent=agent,
|
||||
config=config
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||||
)
|
||||
|
||||
messages = config.messages
|
||||
# 使用公共函数处理所有逻辑
|
||||
agent = await agent_manager.get_or_create_agent(config)
|
||||
agent_responses = await agent.ainvoke({"messages": messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS)
|
||||
append_messages = agent_responses["messages"][len(messages):]
|
||||
|
||||
# 使用更新后的 messages
|
||||
agent_responses = await agent.ainvoke({"messages": config.messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS)
|
||||
append_messages = agent_responses["messages"][len(config.messages):]
|
||||
response_text = ""
|
||||
for msg in append_messages:
|
||||
if isinstance(msg,AIMessage):
|
||||
@ -313,9 +322,9 @@ async def create_agent_and_generate_response(
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
usage={
|
||||
"prompt_tokens": sum(len(msg.get("content", "")) for msg in messages),
|
||||
"prompt_tokens": sum(len(msg.get("content", "")) for msg in config.messages),
|
||||
"completion_tokens": len(response_text),
|
||||
"total_tokens": sum(len(msg.get("content", "")) for msg in messages) + len(response_text)
|
||||
"total_tokens": sum(len(msg.get("content", "")) for msg in config.messages) + len(response_text)
|
||||
}
|
||||
)
|
||||
else:
|
||||
@ -375,9 +384,7 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
|
||||
# 创建 AgentConfig 对象
|
||||
config = AgentConfig.from_v1_request(request, api_key, project_dir, generate_cfg, messages)
|
||||
# 调用公共的agent创建和响应生成逻辑
|
||||
return await create_agent_and_generate_response(
|
||||
config=config
|
||||
)
|
||||
return await create_agent_and_generate_response(config)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
@ -451,9 +458,7 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st
|
||||
# 创建 AgentConfig 对象
|
||||
config = AgentConfig.from_v2_request(request, bot_config, project_dir, messages)
|
||||
# 调用公共的agent创建和响应生成逻辑
|
||||
return await create_agent_and_generate_response(
|
||||
config=config
|
||||
)
|
||||
return await create_agent_and_generate_response(config)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
125
routes/system.py
125
routes/system.py
@ -6,25 +6,12 @@ from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from utils import (
|
||||
get_global_connection_pool, init_global_connection_pool,
|
||||
get_global_file_cache, init_global_file_cache,
|
||||
setup_system_optimizations
|
||||
)
|
||||
from agent.sharded_agent_manager import init_global_sharded_agent_manager
|
||||
try:
|
||||
from utils.system_optimizer import apply_optimization_profile
|
||||
except ImportError:
|
||||
def apply_optimization_profile(profile):
|
||||
return {"profile": profile, "status": "system_optimizer not available"}
|
||||
|
||||
from embedding import get_model_manager
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
from utils.settings import (
|
||||
MAX_CACHED_AGENTS, SHARD_COUNT, MAX_CONNECTIONS_PER_HOST, MAX_CONNECTIONS_TOTAL,
|
||||
KEEPALIVE_TIMEOUT, CONNECT_TIMEOUT, TOTAL_TIMEOUT, FILE_CACHE_SIZE, FILE_CACHE_TTL,
|
||||
TOKENIZERS_PARALLELISM
|
||||
)
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
@ -49,35 +36,8 @@ class EncodeResponse(BaseModel):
|
||||
logger.info("正在初始化系统优化...")
|
||||
system_optimizer = setup_system_optimizations()
|
||||
|
||||
# 全局助手管理器配置(使用优化后的配置)
|
||||
max_cached_agents = MAX_CACHED_AGENTS # 增加缓存大小
|
||||
shard_count = SHARD_COUNT # 分片数量
|
||||
|
||||
# 初始化优化的全局助手管理器
|
||||
agent_manager = init_global_sharded_agent_manager(
|
||||
max_cached_agents=max_cached_agents,
|
||||
shard_count=shard_count
|
||||
)
|
||||
|
||||
# 初始化连接池
|
||||
connection_pool = init_global_connection_pool(
|
||||
max_connections_per_host=MAX_CONNECTIONS_PER_HOST,
|
||||
max_connections_total=MAX_CONNECTIONS_TOTAL,
|
||||
keepalive_timeout=KEEPALIVE_TIMEOUT,
|
||||
connect_timeout=CONNECT_TIMEOUT,
|
||||
total_timeout=TOTAL_TIMEOUT
|
||||
)
|
||||
|
||||
# 初始化文件缓存
|
||||
file_cache = init_global_file_cache(
|
||||
cache_size=FILE_CACHE_SIZE,
|
||||
ttl=FILE_CACHE_TTL
|
||||
)
|
||||
|
||||
logger.info("系统优化初始化完成")
|
||||
logger.info(f"- 分片Agent管理器: {shard_count} 个分片,最多缓存 {max_cached_agents} 个agent")
|
||||
logger.info(f"- 连接池: 每主机100连接,总计500连接")
|
||||
logger.info(f"- 文件缓存: 1000个文件,TTL 300秒")
|
||||
|
||||
|
||||
@router.get("/api/health")
|
||||
@ -90,9 +50,6 @@ async def health_check():
|
||||
async def get_performance_stats():
|
||||
"""获取系统性能统计信息"""
|
||||
try:
|
||||
# 获取agent管理器统计
|
||||
agent_stats = agent_manager.get_cache_stats()
|
||||
|
||||
# 获取连接池统计(简化版)
|
||||
pool_stats = {
|
||||
"connection_pool": "active",
|
||||
@ -128,7 +85,6 @@ async def get_performance_stats():
|
||||
"success": True,
|
||||
"timestamp": int(time.time()),
|
||||
"performance": {
|
||||
"agent_manager": agent_stats,
|
||||
"connection_pool": pool_stats,
|
||||
"file_cache": file_cache_stats,
|
||||
"system": system_stats
|
||||
@ -140,87 +96,6 @@ async def get_performance_stats():
|
||||
raise HTTPException(status_code=500, detail=f"获取性能统计失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/api/v1/system/optimize")
|
||||
async def optimize_system(profile: str = "balanced"):
|
||||
"""应用系统优化配置"""
|
||||
try:
|
||||
# 应用优化配置
|
||||
config = apply_optimization_profile(profile)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已应用 {profile} 优化配置",
|
||||
"config": config
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying optimization profile: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"应用优化配置失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/api/v1/system/clear-cache")
|
||||
async def clear_system_cache(cache_type: Optional[str] = None):
|
||||
"""清理系统缓存"""
|
||||
try:
|
||||
cleared_counts = {}
|
||||
|
||||
if cache_type is None or cache_type == "agent":
|
||||
# 清理agent缓存
|
||||
agent_count = agent_manager.clear_cache()
|
||||
cleared_counts["agent_cache"] = agent_count
|
||||
|
||||
if cache_type is None or cache_type == "file":
|
||||
# 清理文件缓存
|
||||
if hasattr(file_cache, '_cache'):
|
||||
file_count = len(file_cache._cache)
|
||||
file_cache._cache.clear()
|
||||
cleared_counts["file_cache"] = file_count
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已清理指定类型的缓存",
|
||||
"cleared_counts": cleared_counts
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing cache: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"清理缓存失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/api/v1/system/config")
|
||||
async def get_system_config():
|
||||
"""获取当前系统配置"""
|
||||
try:
|
||||
return {
|
||||
"success": True,
|
||||
"config": {
|
||||
"max_cached_agents": max_cached_agents,
|
||||
"shard_count": shard_count,
|
||||
"tokenizer_parallelism": TOKENIZERS_PARALLELISM,
|
||||
"max_connections_per_host": str(MAX_CONNECTIONS_PER_HOST),
|
||||
"max_connections_total": str(MAX_CONNECTIONS_TOTAL),
|
||||
"file_cache_size": str(FILE_CACHE_SIZE),
|
||||
"file_cache_ttl": str(FILE_CACHE_TTL)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting system config: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取系统配置失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/system/remove-project-cache")
|
||||
async def remove_project_cache(dataset_id: str):
|
||||
"""移除特定项目的缓存"""
|
||||
try:
|
||||
removed_count = agent_manager.remove_cache_by_unique_id(dataset_id)
|
||||
if removed_count > 0:
|
||||
return {"message": f"项目缓存移除成功: {dataset_id}", "removed_count": removed_count}
|
||||
else:
|
||||
return {"message": f"未找到项目缓存: {dataset_id}", "removed_count": 0}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"移除项目缓存失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/api/v1/embedding/encode", response_model=EncodeResponse)
|
||||
async def encode_texts(request: EncodeRequest):
|
||||
|
||||
@ -31,47 +31,10 @@ from .project_manager import (
|
||||
)
|
||||
|
||||
|
||||
|
||||
from .connection_pool import (
|
||||
HTTPConnectionPool,
|
||||
get_global_connection_pool,
|
||||
init_global_connection_pool,
|
||||
OAIWithConnectionPool
|
||||
)
|
||||
|
||||
from .async_file_ops import (
|
||||
AsyncFileCache,
|
||||
get_global_file_cache,
|
||||
init_global_file_cache,
|
||||
async_read_file,
|
||||
async_read_json,
|
||||
async_write_file,
|
||||
async_write_json,
|
||||
async_file_exists,
|
||||
async_get_file_mtime,
|
||||
ParallelFileReader,
|
||||
get_global_parallel_reader
|
||||
)
|
||||
|
||||
|
||||
from .system_optimizer import (
|
||||
SystemOptimizer,
|
||||
AsyncioOptimizer,
|
||||
setup_system_optimizations,
|
||||
create_performance_monitor,
|
||||
get_optimized_worker_config,
|
||||
OPTIMIZATION_CONFIGS,
|
||||
apply_optimization_profile,
|
||||
get_global_system_optimizer
|
||||
setup_system_optimizations
|
||||
)
|
||||
|
||||
# Import config cache module
|
||||
# Note: This has been moved to agent package
|
||||
# from .config_cache import (
|
||||
# config_cache,
|
||||
# ConfigFileCache
|
||||
# )
|
||||
|
||||
from .agent_pool import (
|
||||
AgentPool,
|
||||
get_agent_pool,
|
||||
|
||||
@ -1,212 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
HTTP连接池管理器 - 提供高效的HTTP连接复用
|
||||
"""
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from typing import Dict, Optional, Any
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
|
||||
|
||||
class HTTPConnectionPool:
|
||||
"""HTTP连接池管理器
|
||||
|
||||
提供连接复用、 Keep-Alive 连接、合理超时设置等功能
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
max_connections_per_host: int = 100,
|
||||
max_connections_total: int = 500,
|
||||
keepalive_timeout: int = 30,
|
||||
connect_timeout: int = 10,
|
||||
total_timeout: int = 60):
|
||||
"""
|
||||
初始化连接池
|
||||
|
||||
Args:
|
||||
max_connections_per_host: 每个主机的最大连接数
|
||||
max_connections_total: 总连接数限制
|
||||
keepalive_timeout: Keep-Alive超时时间(秒)
|
||||
connect_timeout: 连接超时时间(秒)
|
||||
total_timeout: 总请求超时时间(秒)
|
||||
"""
|
||||
self.max_connections_per_host = max_connections_per_host
|
||||
self.max_connections_total = max_connections_total
|
||||
self.keepalive_timeout = keepalive_timeout
|
||||
self.connect_timeout = connect_timeout
|
||||
self.total_timeout = total_timeout
|
||||
|
||||
# 创建连接器配置
|
||||
self.connector_config = {
|
||||
'limit': max_connections_total,
|
||||
'limit_per_host': max_connections_per_host,
|
||||
'keepalive_timeout': keepalive_timeout,
|
||||
'enable_cleanup_closed': True, # 自动清理关闭的连接
|
||||
'force_close': False, # 不强制关闭连接
|
||||
'use_dns_cache': True, # 使用DNS缓存
|
||||
'ttl_dns_cache': 300, # DNS缓存TTL
|
||||
}
|
||||
|
||||
# 使用线程本地存储来管理事件循环间的session
|
||||
self._sessions = weakref.WeakKeyDictionary()
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def _create_session(self) -> aiohttp.ClientSession:
|
||||
"""创建新的aiohttp会话"""
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=self.total_timeout,
|
||||
connect=self.connect_timeout,
|
||||
sock_connect=self.connect_timeout,
|
||||
sock_read=self.total_timeout
|
||||
)
|
||||
|
||||
connector = aiohttp.TCPConnector(**self.connector_config)
|
||||
|
||||
return aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
timeout=timeout,
|
||||
headers={
|
||||
'User-Agent': 'QwenAgent/1.0',
|
||||
'Connection': 'keep-alive',
|
||||
'Accept-Encoding': 'gzip, deflate, br',
|
||||
}
|
||||
)
|
||||
|
||||
def get_session(self) -> aiohttp.ClientSession:
|
||||
"""获取当前事件循环的session"""
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
with self._lock:
|
||||
if loop not in self._sessions:
|
||||
self._sessions[loop] = self._create_session()
|
||||
return self._sessions[loop]
|
||||
|
||||
async def request(self, method: str, url: str, **kwargs) -> aiohttp.ClientResponse:
|
||||
"""发送HTTP请求,自动处理连接复用"""
|
||||
session = self.get_session()
|
||||
return await session.request(method, url, **kwargs)
|
||||
|
||||
async def get(self, url: str, **kwargs) -> aiohttp.ClientResponse:
|
||||
"""发送GET请求"""
|
||||
return await self.request('GET', url, **kwargs)
|
||||
|
||||
async def post(self, url: str, **kwargs) -> aiohttp.ClientResponse:
|
||||
"""发送POST请求"""
|
||||
return await self.request('POST', url, **kwargs)
|
||||
|
||||
async def close(self):
|
||||
"""关闭所有session"""
|
||||
with self._lock:
|
||||
for loop, session in list(self._sessions.items()):
|
||||
if not session.closed:
|
||||
await session.close()
|
||||
self._sessions.clear()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
|
||||
# 全局连接池实例
|
||||
_global_connection_pool: Optional[HTTPConnectionPool] = None
|
||||
_pool_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_global_connection_pool() -> HTTPConnectionPool:
|
||||
"""获取全局连接池实例"""
|
||||
global _global_connection_pool
|
||||
if _global_connection_pool is None:
|
||||
with _pool_lock:
|
||||
if _global_connection_pool is None:
|
||||
_global_connection_pool = HTTPConnectionPool()
|
||||
return _global_connection_pool
|
||||
|
||||
|
||||
def init_global_connection_pool(
|
||||
max_connections_per_host: int = 100,
|
||||
max_connections_total: int = 500,
|
||||
keepalive_timeout: int = 30,
|
||||
connect_timeout: int = 10,
|
||||
total_timeout: int = 60
|
||||
) -> HTTPConnectionPool:
|
||||
"""初始化全局连接池"""
|
||||
global _global_connection_pool
|
||||
with _pool_lock:
|
||||
_global_connection_pool = HTTPConnectionPool(
|
||||
max_connections_per_host=max_connections_per_host,
|
||||
max_connections_total=max_connections_total,
|
||||
keepalive_timeout=keepalive_timeout,
|
||||
connect_timeout=connect_timeout,
|
||||
total_timeout=total_timeout
|
||||
)
|
||||
return _global_connection_pool
|
||||
|
||||
|
||||
class OAIWithConnectionPool:
|
||||
"""带有连接池的OpenAI API客户端"""
|
||||
|
||||
def __init__(self,
|
||||
config: Dict[str, Any],
|
||||
connection_pool: Optional[HTTPConnectionPool] = None):
|
||||
"""
|
||||
初始化客户端
|
||||
|
||||
Args:
|
||||
config: OpenAI API配置
|
||||
connection_pool: 可选的连接池实例
|
||||
"""
|
||||
self.config = config
|
||||
self.pool = connection_pool or get_global_connection_pool()
|
||||
self.base_url = config.get('model_server', '').rstrip('/')
|
||||
if not self.base_url:
|
||||
self.base_url = "https://api.openai.com/v1"
|
||||
|
||||
self.api_key = config.get('api_key', '')
|
||||
self.model = config.get('model', 'gpt-3.5-turbo')
|
||||
self.generate_cfg = config.get('generate_cfg', {})
|
||||
|
||||
async def chat_completions(self, messages: list, stream: bool = False, **kwargs):
|
||||
"""发送聊天完成请求"""
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
|
||||
headers = {
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
data = {
|
||||
'model': self.model,
|
||||
'messages': messages,
|
||||
'stream': stream,
|
||||
**self.generate_cfg,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
async with self.pool.post(url, json=data, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
if stream:
|
||||
return self._handle_stream_response(response)
|
||||
else:
|
||||
return await response.json()
|
||||
else:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"API request failed with status {response.status}: {error_text}")
|
||||
|
||||
async def _handle_stream_response(self, response: aiohttp.ClientResponse):
|
||||
"""处理流式响应"""
|
||||
async for line in response.content:
|
||||
line = line.decode('utf-8').strip()
|
||||
if line.startswith('data: '):
|
||||
data = line[6:] # 移除 'data: ' 前缀
|
||||
if data == '[DONE]':
|
||||
break
|
||||
try:
|
||||
import json
|
||||
yield json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
@ -249,9 +249,6 @@ def process_messages(messages: List[Dict], language: Optional[str] = None) -> Li
|
||||
parts = re.split(r'\[(THINK|PREAMBLE|TOOL_CALL|TOOL_RESPONSE|ANSWER)\]', msg["content"])
|
||||
|
||||
current_tag = None
|
||||
assistant_content = ""
|
||||
function_calls = []
|
||||
tool_responses = []
|
||||
tool_id_counter = 0 # 添加唯一的工具调用计数器
|
||||
tool_id_list = []
|
||||
for i in range(0, len(parts)):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user