""" 聊天历史记录管理器 直接保存完整的原始聊天消息到数据库,不受 checkpoint summary 影响 """ import logging from datetime import datetime from typing import Optional, List, Dict, Any from dataclasses import dataclass from psycopg_pool import AsyncConnectionPool logger = logging.getLogger('app') @dataclass class ChatMessage: """聊天消息""" id: str session_id: str role: str content: str created_at: datetime class ChatHistoryManager: """ 聊天历史管理器 使用独立的数据库表存储完整的聊天历史记录 复用 checkpoint_manager 的 PostgreSQL 连接池 """ def __init__(self, pool: AsyncConnectionPool): """ 初始化聊天历史管理器 Args: pool: PostgreSQL 连接池 """ self._pool = pool async def create_table(self) -> None: """创建 chat_messages 表""" async with self._pool.connection() as conn: async with conn.cursor() as cursor: await cursor.execute(""" CREATE TABLE IF NOT EXISTS chat_messages ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), session_id VARCHAR(255) NOT NULL, role VARCHAR(50) NOT NULL, content TEXT NOT NULL, created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), bot_id VARCHAR(255), user_identifier VARCHAR(255) ) """) # 创建索引以加速查询 await cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_chat_messages_session_created ON chat_messages (session_id, created_at DESC) """) await conn.commit() logger.info("chat_messages table created successfully") async def save_message( self, session_id: str, role: str, content: str, bot_id: Optional[str] = None, user_identifier: Optional[str] = None ) -> str: """ 保存一条聊天消息 Args: session_id: 会话ID role: 消息角色 ('user' 或 'assistant') content: 消息内容 bot_id: 机器人ID user_identifier: 用户标识 Returns: str: 消息ID """ async with self._pool.connection() as conn: async with conn.cursor() as cursor: await cursor.execute(""" INSERT INTO chat_messages (session_id, role, content, bot_id, user_identifier) VALUES (%s, %s, %s, %s, %s) RETURNING id """, (session_id, role, content, bot_id, user_identifier)) result = await cursor.fetchone() await conn.commit() message_id = str(result[0]) if result else None logger.debug(f"Saved message: session_id={session_id}, role={role}, id={message_id}") return message_id async def save_messages( self, session_id: str, messages: List[Dict[str, str]], bot_id: Optional[str] = None, user_identifier: Optional[str] = None ) -> List[str]: """ 批量保存聊天消息 Args: session_id: 会话ID messages: 消息列表,每条消息包含 role 和 content bot_id: 机器人ID user_identifier: 用户标识 Returns: List[str]: 消息ID列表 """ message_ids = [] async with self._pool.connection() as conn: async with conn.cursor() as cursor: for msg in messages: role = msg.get('role') content = msg.get('content', '') if not role or not content: continue await cursor.execute(""" INSERT INTO chat_messages (session_id, role, content, bot_id, user_identifier) VALUES (%s, %s, %s, %s, %s) RETURNING id """, (session_id, role, content, bot_id, user_identifier)) result = await cursor.fetchone() message_id = str(result[0]) if result else None message_ids.append(message_id) await conn.commit() logger.info(f"Saved {len(message_ids)} messages for session_id={session_id}") return message_ids async def get_history( self, session_id: str, limit: int = 20, before_id: Optional[str] = None ) -> Dict[str, Any]: """ 获取聊天历史记录(倒序,最新在前) Args: session_id: 会话ID limit: 返回的消息数量 before_id: 获取此消息ID之前的消息(用于分页) Returns: dict: { "messages": [...], "has_more": bool } """ async with self._pool.connection() as conn: async with conn.cursor() as cursor: if before_id: # 查询 before_id 的 created_at,然后获取更早的消息 await cursor.execute(""" SELECT created_at FROM chat_messages WHERE id = %s """, (before_id,)) result = await cursor.fetchone() if not result: # 如果找不到指定的 ID,从头开始 query = """ SELECT id, session_id, role, content, created_at FROM chat_messages WHERE session_id = %s ORDER BY created_at DESC LIMIT %s + 1 """ await cursor.execute(query, (session_id, limit)) else: before_time = result[0] query = """ SELECT id, session_id, role, content, created_at FROM chat_messages WHERE session_id = %s AND created_at < %s ORDER BY created_at DESC LIMIT %s + 1 """ await cursor.execute(query, (session_id, before_time, limit)) else: query = """ SELECT id, session_id, role, content, created_at FROM chat_messages WHERE session_id = %s ORDER BY created_at DESC LIMIT %s + 1 """ await cursor.execute(query, (session_id, limit)) rows = await cursor.fetchall() # 判断是否有更多(多取一条用于判断) has_more = len(rows) > limit if has_more: rows = rows[:limit] messages = [] for row in rows: messages.append({ "id": str(row[0]), "role": row[2], "content": row[3], "timestamp": row[4].isoformat() if row[4] else None }) return { "messages": messages, "has_more": has_more } async def get_history_by_message_id( self, session_id: str, last_message_id: Optional[str] = None, limit: int = 20 ) -> Dict[str, Any]: """ 根据 last_message_id 获取更早的历史记录 Args: session_id: 会话ID last_message_id: 上一页最后一条消息的ID limit: 返回的消息数量 Returns: dict: { "messages": [...], "has_more": bool } """ return await self.get_history( session_id=session_id, limit=limit, before_id=last_message_id ) # 全局单例 _global_manager: Optional['ChatHistoryManagerWithPool'] = None class ChatHistoryManagerWithPool: """ 带连接池的聊天历史管理器单例 复用 checkpoint_manager 的连接池 """ def __init__(self): self._pool: Optional[AsyncConnectionPool] = None self._manager: Optional[ChatHistoryManager] = None self._initialized = False async def initialize(self, pool: AsyncConnectionPool) -> None: """初始化管理器""" if self._initialized: return self._pool = pool self._manager = ChatHistoryManager(pool) await self._manager.create_table() self._initialized = True logger.info("ChatHistoryManager initialized successfully") @property def manager(self) -> ChatHistoryManager: """获取 ChatHistoryManager 实例""" if not self._initialized or not self._manager: raise RuntimeError("ChatHistoryManager not initialized") return self._manager def get_chat_history_manager() -> ChatHistoryManagerWithPool: """获取全局 ChatHistoryManager 单例""" global _global_manager if _global_manager is None: _global_manager = ChatHistoryManagerWithPool() return _global_manager async def init_chat_history_manager(pool: AsyncConnectionPool) -> None: """初始化全局聊天历史管理器""" manager = get_chat_history_manager() await manager.initialize(pool)