qwen_agent/utils/connection_pool.py
2025-11-16 12:25:45 +08:00

212 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
HTTP连接池管理器 - 提供高效的HTTP连接复用
"""
import aiohttp
import asyncio
from typing import Dict, Optional, Any
import threading
import time
import weakref
class HTTPConnectionPool:
"""HTTP连接池管理器
提供连接复用、 Keep-Alive 连接、合理超时设置等功能
"""
def __init__(self,
max_connections_per_host: int = 100,
max_connections_total: int = 500,
keepalive_timeout: int = 30,
connect_timeout: int = 10,
total_timeout: int = 60):
"""
初始化连接池
Args:
max_connections_per_host: 每个主机的最大连接数
max_connections_total: 总连接数限制
keepalive_timeout: Keep-Alive超时时间(秒)
connect_timeout: 连接超时时间(秒)
total_timeout: 总请求超时时间(秒)
"""
self.max_connections_per_host = max_connections_per_host
self.max_connections_total = max_connections_total
self.keepalive_timeout = keepalive_timeout
self.connect_timeout = connect_timeout
self.total_timeout = total_timeout
# 创建连接器配置
self.connector_config = {
'limit': max_connections_total,
'limit_per_host': max_connections_per_host,
'keepalive_timeout': keepalive_timeout,
'enable_cleanup_closed': True, # 自动清理关闭的连接
'force_close': False, # 不强制关闭连接
'use_dns_cache': True, # 使用DNS缓存
'ttl_dns_cache': 300, # DNS缓存TTL
}
# 使用线程本地存储来管理事件循环间的session
self._sessions = weakref.WeakKeyDictionary()
self._lock = threading.RLock()
def _create_session(self) -> aiohttp.ClientSession:
"""创建新的aiohttp会话"""
timeout = aiohttp.ClientTimeout(
total=self.total_timeout,
connect=self.connect_timeout,
sock_connect=self.connect_timeout,
sock_read=self.total_timeout
)
connector = aiohttp.TCPConnector(**self.connector_config)
return aiohttp.ClientSession(
connector=connector,
timeout=timeout,
headers={
'User-Agent': 'QwenAgent/1.0',
'Connection': 'keep-alive',
'Accept-Encoding': 'gzip, deflate, br',
}
)
def get_session(self) -> aiohttp.ClientSession:
"""获取当前事件循环的session"""
loop = asyncio.get_running_loop()
with self._lock:
if loop not in self._sessions:
self._sessions[loop] = self._create_session()
return self._sessions[loop]
async def request(self, method: str, url: str, **kwargs) -> aiohttp.ClientResponse:
"""发送HTTP请求自动处理连接复用"""
session = self.get_session()
return await session.request(method, url, **kwargs)
async def get(self, url: str, **kwargs) -> aiohttp.ClientResponse:
"""发送GET请求"""
return await self.request('GET', url, **kwargs)
async def post(self, url: str, **kwargs) -> aiohttp.ClientResponse:
"""发送POST请求"""
return await self.request('POST', url, **kwargs)
async def close(self):
"""关闭所有session"""
with self._lock:
for loop, session in list(self._sessions.items()):
if not session.closed:
await session.close()
self._sessions.clear()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
# 全局连接池实例
_global_connection_pool: Optional[HTTPConnectionPool] = None
_pool_lock = threading.Lock()
def get_global_connection_pool() -> HTTPConnectionPool:
"""获取全局连接池实例"""
global _global_connection_pool
if _global_connection_pool is None:
with _pool_lock:
if _global_connection_pool is None:
_global_connection_pool = HTTPConnectionPool()
return _global_connection_pool
def init_global_connection_pool(
max_connections_per_host: int = 100,
max_connections_total: int = 500,
keepalive_timeout: int = 30,
connect_timeout: int = 10,
total_timeout: int = 60
) -> HTTPConnectionPool:
"""初始化全局连接池"""
global _global_connection_pool
with _pool_lock:
_global_connection_pool = HTTPConnectionPool(
max_connections_per_host=max_connections_per_host,
max_connections_total=max_connections_total,
keepalive_timeout=keepalive_timeout,
connect_timeout=connect_timeout,
total_timeout=total_timeout
)
return _global_connection_pool
class OAIWithConnectionPool:
"""带有连接池的OpenAI API客户端"""
def __init__(self,
config: Dict[str, Any],
connection_pool: Optional[HTTPConnectionPool] = None):
"""
初始化客户端
Args:
config: OpenAI API配置
connection_pool: 可选的连接池实例
"""
self.config = config
self.pool = connection_pool or get_global_connection_pool()
self.base_url = config.get('model_server', '').rstrip('/')
if not self.base_url:
self.base_url = "https://api.openai.com/v1"
self.api_key = config.get('api_key', '')
self.model = config.get('model', 'gpt-3.5-turbo')
self.generate_cfg = config.get('generate_cfg', {})
async def chat_completions(self, messages: list, stream: bool = False, **kwargs):
"""发送聊天完成请求"""
url = f"{self.base_url}/chat/completions"
headers = {
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json',
}
data = {
'model': self.model,
'messages': messages,
'stream': stream,
**self.generate_cfg,
**kwargs
}
async with self.pool.post(url, json=data, headers=headers) as response:
if response.status == 200:
if stream:
return self._handle_stream_response(response)
else:
return await response.json()
else:
error_text = await response.text()
raise Exception(f"API request failed with status {response.status}: {error_text}")
async def _handle_stream_response(self, response: aiohttp.ClientResponse):
"""处理流式响应"""
async for line in response.content:
line = line.decode('utf-8').strip()
if line.startswith('data: '):
data = line[6:] # 移除 'data: ' 前缀
if data == '[DONE]':
break
try:
import json
yield json.loads(data)
except json.JSONDecodeError:
continue