#!/usr/bin/env python3 """ HTTP连接池管理器 - 提供高效的HTTP连接复用 """ import aiohttp import asyncio from typing import Dict, Optional, Any import threading import time import weakref class HTTPConnectionPool: """HTTP连接池管理器 提供连接复用、 Keep-Alive 连接、合理超时设置等功能 """ def __init__(self, max_connections_per_host: int = 100, max_connections_total: int = 500, keepalive_timeout: int = 30, connect_timeout: int = 10, total_timeout: int = 60): """ 初始化连接池 Args: max_connections_per_host: 每个主机的最大连接数 max_connections_total: 总连接数限制 keepalive_timeout: Keep-Alive超时时间(秒) connect_timeout: 连接超时时间(秒) total_timeout: 总请求超时时间(秒) """ self.max_connections_per_host = max_connections_per_host self.max_connections_total = max_connections_total self.keepalive_timeout = keepalive_timeout self.connect_timeout = connect_timeout self.total_timeout = total_timeout # 创建连接器配置 self.connector_config = { 'limit': max_connections_total, 'limit_per_host': max_connections_per_host, 'keepalive_timeout': keepalive_timeout, 'enable_cleanup_closed': True, # 自动清理关闭的连接 'force_close': False, # 不强制关闭连接 'use_dns_cache': True, # 使用DNS缓存 'ttl_dns_cache': 300, # DNS缓存TTL } # 使用线程本地存储来管理事件循环间的session self._sessions = weakref.WeakKeyDictionary() self._lock = threading.RLock() def _create_session(self) -> aiohttp.ClientSession: """创建新的aiohttp会话""" timeout = aiohttp.ClientTimeout( total=self.total_timeout, connect=self.connect_timeout, sock_connect=self.connect_timeout, sock_read=self.total_timeout ) connector = aiohttp.TCPConnector(**self.connector_config) return aiohttp.ClientSession( connector=connector, timeout=timeout, headers={ 'User-Agent': 'QwenAgent/1.0', 'Connection': 'keep-alive', 'Accept-Encoding': 'gzip, deflate, br', } ) def get_session(self) -> aiohttp.ClientSession: """获取当前事件循环的session""" loop = asyncio.get_running_loop() with self._lock: if loop not in self._sessions: self._sessions[loop] = self._create_session() return self._sessions[loop] async def request(self, method: str, url: str, **kwargs) -> aiohttp.ClientResponse: """发送HTTP请求,自动处理连接复用""" session = self.get_session() return await session.request(method, url, **kwargs) async def get(self, url: str, **kwargs) -> aiohttp.ClientResponse: """发送GET请求""" return await self.request('GET', url, **kwargs) async def post(self, url: str, **kwargs) -> aiohttp.ClientResponse: """发送POST请求""" return await self.request('POST', url, **kwargs) async def close(self): """关闭所有session""" with self._lock: for loop, session in list(self._sessions.items()): if not session.closed: await session.close() self._sessions.clear() async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() # 全局连接池实例 _global_connection_pool: Optional[HTTPConnectionPool] = None _pool_lock = threading.Lock() def get_global_connection_pool() -> HTTPConnectionPool: """获取全局连接池实例""" global _global_connection_pool if _global_connection_pool is None: with _pool_lock: if _global_connection_pool is None: _global_connection_pool = HTTPConnectionPool() return _global_connection_pool def init_global_connection_pool( max_connections_per_host: int = 100, max_connections_total: int = 500, keepalive_timeout: int = 30, connect_timeout: int = 10, total_timeout: int = 60 ) -> HTTPConnectionPool: """初始化全局连接池""" global _global_connection_pool with _pool_lock: _global_connection_pool = HTTPConnectionPool( max_connections_per_host=max_connections_per_host, max_connections_total=max_connections_total, keepalive_timeout=keepalive_timeout, connect_timeout=connect_timeout, total_timeout=total_timeout ) return _global_connection_pool class OAIWithConnectionPool: """带有连接池的OpenAI API客户端""" def __init__(self, config: Dict[str, Any], connection_pool: Optional[HTTPConnectionPool] = None): """ 初始化客户端 Args: config: OpenAI API配置 connection_pool: 可选的连接池实例 """ self.config = config self.pool = connection_pool or get_global_connection_pool() self.base_url = config.get('model_server', '').rstrip('/') if not self.base_url: self.base_url = "https://api.openai.com/v1" self.api_key = config.get('api_key', '') self.model = config.get('model', 'gpt-3.5-turbo') self.generate_cfg = config.get('generate_cfg', {}) async def chat_completions(self, messages: list, stream: bool = False, **kwargs): """发送聊天完成请求""" url = f"{self.base_url}/chat/completions" headers = { 'Authorization': f'Bearer {self.api_key}', 'Content-Type': 'application/json', } data = { 'model': self.model, 'messages': messages, 'stream': stream, **self.generate_cfg, **kwargs } async with self.pool.post(url, json=data, headers=headers) as response: if response.status == 200: if stream: return self._handle_stream_response(response) else: return await response.json() else: error_text = await response.text() raise Exception(f"API request failed with status {response.status}: {error_text}") async def _handle_stream_response(self, response: aiohttp.ClientResponse): """处理流式响应""" async for line in response.content: line = line.decode('utf-8').strip() if line.startswith('data: '): data = line[6:] # 移除 'data: ' 前缀 if data == '[DONE]': break try: import json yield json.loads(data) except json.JSONDecodeError: continue