212 lines
7.2 KiB
Python
212 lines
7.2 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
HTTP连接池管理器 - 提供高效的HTTP连接复用
|
||
"""
|
||
|
||
import aiohttp
|
||
import asyncio
|
||
from typing import Dict, Optional, Any
|
||
import threading
|
||
import time
|
||
import weakref
|
||
|
||
|
||
class HTTPConnectionPool:
|
||
"""HTTP连接池管理器
|
||
|
||
提供连接复用、 Keep-Alive 连接、合理超时设置等功能
|
||
"""
|
||
|
||
def __init__(self,
|
||
max_connections_per_host: int = 100,
|
||
max_connections_total: int = 500,
|
||
keepalive_timeout: int = 30,
|
||
connect_timeout: int = 10,
|
||
total_timeout: int = 60):
|
||
"""
|
||
初始化连接池
|
||
|
||
Args:
|
||
max_connections_per_host: 每个主机的最大连接数
|
||
max_connections_total: 总连接数限制
|
||
keepalive_timeout: Keep-Alive超时时间(秒)
|
||
connect_timeout: 连接超时时间(秒)
|
||
total_timeout: 总请求超时时间(秒)
|
||
"""
|
||
self.max_connections_per_host = max_connections_per_host
|
||
self.max_connections_total = max_connections_total
|
||
self.keepalive_timeout = keepalive_timeout
|
||
self.connect_timeout = connect_timeout
|
||
self.total_timeout = total_timeout
|
||
|
||
# 创建连接器配置
|
||
self.connector_config = {
|
||
'limit': max_connections_total,
|
||
'limit_per_host': max_connections_per_host,
|
||
'keepalive_timeout': keepalive_timeout,
|
||
'enable_cleanup_closed': True, # 自动清理关闭的连接
|
||
'force_close': False, # 不强制关闭连接
|
||
'use_dns_cache': True, # 使用DNS缓存
|
||
'ttl_dns_cache': 300, # DNS缓存TTL
|
||
}
|
||
|
||
# 使用线程本地存储来管理事件循环间的session
|
||
self._sessions = weakref.WeakKeyDictionary()
|
||
self._lock = threading.RLock()
|
||
|
||
def _create_session(self) -> aiohttp.ClientSession:
|
||
"""创建新的aiohttp会话"""
|
||
timeout = aiohttp.ClientTimeout(
|
||
total=self.total_timeout,
|
||
connect=self.connect_timeout,
|
||
sock_connect=self.connect_timeout,
|
||
sock_read=self.total_timeout
|
||
)
|
||
|
||
connector = aiohttp.TCPConnector(**self.connector_config)
|
||
|
||
return aiohttp.ClientSession(
|
||
connector=connector,
|
||
timeout=timeout,
|
||
headers={
|
||
'User-Agent': 'QwenAgent/1.0',
|
||
'Connection': 'keep-alive',
|
||
'Accept-Encoding': 'gzip, deflate, br',
|
||
}
|
||
)
|
||
|
||
def get_session(self) -> aiohttp.ClientSession:
|
||
"""获取当前事件循环的session"""
|
||
loop = asyncio.get_running_loop()
|
||
|
||
with self._lock:
|
||
if loop not in self._sessions:
|
||
self._sessions[loop] = self._create_session()
|
||
return self._sessions[loop]
|
||
|
||
async def request(self, method: str, url: str, **kwargs) -> aiohttp.ClientResponse:
|
||
"""发送HTTP请求,自动处理连接复用"""
|
||
session = self.get_session()
|
||
return await session.request(method, url, **kwargs)
|
||
|
||
async def get(self, url: str, **kwargs) -> aiohttp.ClientResponse:
|
||
"""发送GET请求"""
|
||
return await self.request('GET', url, **kwargs)
|
||
|
||
async def post(self, url: str, **kwargs) -> aiohttp.ClientResponse:
|
||
"""发送POST请求"""
|
||
return await self.request('POST', url, **kwargs)
|
||
|
||
async def close(self):
|
||
"""关闭所有session"""
|
||
with self._lock:
|
||
for loop, session in list(self._sessions.items()):
|
||
if not session.closed:
|
||
await session.close()
|
||
self._sessions.clear()
|
||
|
||
async def __aenter__(self):
|
||
return self
|
||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||
await self.close()
|
||
|
||
|
||
# 全局连接池实例
|
||
_global_connection_pool: Optional[HTTPConnectionPool] = None
|
||
_pool_lock = threading.Lock()
|
||
|
||
|
||
def get_global_connection_pool() -> HTTPConnectionPool:
|
||
"""获取全局连接池实例"""
|
||
global _global_connection_pool
|
||
if _global_connection_pool is None:
|
||
with _pool_lock:
|
||
if _global_connection_pool is None:
|
||
_global_connection_pool = HTTPConnectionPool()
|
||
return _global_connection_pool
|
||
|
||
|
||
def init_global_connection_pool(
|
||
max_connections_per_host: int = 100,
|
||
max_connections_total: int = 500,
|
||
keepalive_timeout: int = 30,
|
||
connect_timeout: int = 10,
|
||
total_timeout: int = 60
|
||
) -> HTTPConnectionPool:
|
||
"""初始化全局连接池"""
|
||
global _global_connection_pool
|
||
with _pool_lock:
|
||
_global_connection_pool = HTTPConnectionPool(
|
||
max_connections_per_host=max_connections_per_host,
|
||
max_connections_total=max_connections_total,
|
||
keepalive_timeout=keepalive_timeout,
|
||
connect_timeout=connect_timeout,
|
||
total_timeout=total_timeout
|
||
)
|
||
return _global_connection_pool
|
||
|
||
|
||
class OAIWithConnectionPool:
|
||
"""带有连接池的OpenAI API客户端"""
|
||
|
||
def __init__(self,
|
||
config: Dict[str, Any],
|
||
connection_pool: Optional[HTTPConnectionPool] = None):
|
||
"""
|
||
初始化客户端
|
||
|
||
Args:
|
||
config: OpenAI API配置
|
||
connection_pool: 可选的连接池实例
|
||
"""
|
||
self.config = config
|
||
self.pool = connection_pool or get_global_connection_pool()
|
||
self.base_url = config.get('model_server', '').rstrip('/')
|
||
if not self.base_url:
|
||
self.base_url = "https://api.openai.com/v1"
|
||
|
||
self.api_key = config.get('api_key', '')
|
||
self.model = config.get('model', 'gpt-3.5-turbo')
|
||
self.generate_cfg = config.get('generate_cfg', {})
|
||
|
||
async def chat_completions(self, messages: list, stream: bool = False, **kwargs):
|
||
"""发送聊天完成请求"""
|
||
url = f"{self.base_url}/chat/completions"
|
||
|
||
headers = {
|
||
'Authorization': f'Bearer {self.api_key}',
|
||
'Content-Type': 'application/json',
|
||
}
|
||
|
||
data = {
|
||
'model': self.model,
|
||
'messages': messages,
|
||
'stream': stream,
|
||
**self.generate_cfg,
|
||
**kwargs
|
||
}
|
||
|
||
async with self.pool.post(url, json=data, headers=headers) as response:
|
||
if response.status == 200:
|
||
if stream:
|
||
return self._handle_stream_response(response)
|
||
else:
|
||
return await response.json()
|
||
else:
|
||
error_text = await response.text()
|
||
raise Exception(f"API request failed with status {response.status}: {error_text}")
|
||
|
||
async def _handle_stream_response(self, response: aiohttp.ClientResponse):
|
||
"""处理流式响应"""
|
||
async for line in response.content:
|
||
line = line.decode('utf-8').strip()
|
||
if line.startswith('data: '):
|
||
data = line[6:] # 移除 'data: ' 前缀
|
||
if data == '[DONE]':
|
||
break
|
||
try:
|
||
import json
|
||
yield json.loads(data)
|
||
except json.JSONDecodeError:
|
||
continue |