""" Chat history manager. Uses the shared database connection pool. Stores complete raw chat messages directly in the database without being affected by checkpoint summaries. """ import logging from datetime import datetime from typing import Optional, List, Dict, Any from dataclasses import dataclass from psycopg_pool import AsyncConnectionPool logger = logging.getLogger('app') @dataclass class ChatMessage: """Chat message.""" id: str session_id: str role: str content: str created_at: datetime class ChatHistoryManager: """ Chat history manager. Uses the shared PostgreSQL connection pool to store complete chat history records. """ def __init__(self, pool: AsyncConnectionPool): """ Initialize the chat history manager. Args: pool: PostgreSQL connection pool from DBPoolManager """ self._pool = pool async def create_table(self) -> None: """Create the chat_messages table.""" async with self._pool.connection() as conn: async with conn.cursor() as cursor: await cursor.execute(""" CREATE TABLE IF NOT EXISTS chat_messages ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), session_id VARCHAR(255) NOT NULL, role VARCHAR(50) NOT NULL, content TEXT NOT NULL, created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), bot_id VARCHAR(255), user_identifier VARCHAR(255) ) """) # Create an index to speed up queries. await cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_chat_messages_session_created ON chat_messages (session_id, created_at DESC) """) await conn.commit() logger.info("chat_messages table created successfully") async def save_message( self, session_id: str, role: str, content: str, bot_id: Optional[str] = None, user_identifier: Optional[str] = None ) -> str: """ Save a single chat message. Args: session_id: Session ID role: Message role ('user' or 'assistant') content: Message content bot_id: Bot ID user_identifier: User identifier Returns: str: Message ID """ async with self._pool.connection() as conn: async with conn.cursor() as cursor: await cursor.execute(""" INSERT INTO chat_messages (session_id, role, content, bot_id, user_identifier) VALUES (%s, %s, %s, %s, %s) RETURNING id """, (session_id, role, content, bot_id, user_identifier)) result = await cursor.fetchone() await conn.commit() message_id = str(result[0]) if result else None logger.debug(f"Saved message: session_id={session_id}, role={role}, id={message_id}") return message_id async def save_messages( self, session_id: str, messages: List[Dict[str, str]], bot_id: Optional[str] = None, user_identifier: Optional[str] = None ) -> List[str]: """ Save multiple chat messages in a batch. Args: session_id: Session ID messages: List of messages, each containing role and content bot_id: Bot ID user_identifier: User identifier Returns: List[str]: List of message IDs """ message_ids = [] async with self._pool.connection() as conn: async with conn.cursor() as cursor: for msg in messages: role = msg.get('role') content = msg.get('content', '') if not role or not content: continue await cursor.execute(""" INSERT INTO chat_messages (session_id, role, content, bot_id, user_identifier) VALUES (%s, %s, %s, %s, %s) RETURNING id """, (session_id, role, content, bot_id, user_identifier)) result = await cursor.fetchone() message_id = str(result[0]) if result else None message_ids.append(message_id) await conn.commit() logger.info(f"Saved {len(message_ids)} messages for session_id={session_id}") return message_ids async def get_history( self, session_id: str, limit: int = 20, before_id: Optional[str] = None ) -> Dict[str, Any]: """ Get chat history in reverse chronological order, newest first. Args: session_id: Session ID limit: Number of messages to return before_id: Return messages before this message ID, used for pagination Returns: dict: { "messages": [...], "has_more": bool } """ async with self._pool.connection() as conn: async with conn.cursor() as cursor: if before_id: # Query the created_at value of before_id, then fetch earlier messages. await cursor.execute(""" SELECT created_at FROM chat_messages WHERE id = %s """, (before_id,)) result = await cursor.fetchone() if not result: # If the specified ID is not found, start from the beginning. query = """ SELECT id, session_id, role, content, created_at FROM chat_messages WHERE session_id = %s ORDER BY created_at DESC LIMIT %s + 1 """ await cursor.execute(query, (session_id, limit)) else: before_time = result[0] query = """ SELECT id, session_id, role, content, created_at FROM chat_messages WHERE session_id = %s AND created_at < %s ORDER BY created_at DESC LIMIT %s + 1 """ await cursor.execute(query, (session_id, before_time, limit)) else: query = """ SELECT id, session_id, role, content, created_at FROM chat_messages WHERE session_id = %s ORDER BY created_at DESC LIMIT %s + 1 """ await cursor.execute(query, (session_id, limit)) rows = await cursor.fetchall() # Fetch one extra row to determine whether more results exist. has_more = len(rows) > limit if has_more: rows = rows[:limit] messages = [] for row in rows: messages.append({ "id": str(row[0]), "role": row[2], "content": row[3], "timestamp": row[4].isoformat() if row[4] else None }) return { "messages": messages, "has_more": has_more } async def get_history_by_message_id( self, session_id: str, last_message_id: Optional[str] = None, limit: int = 20 ) -> Dict[str, Any]: """ Get earlier history records based on last_message_id. Args: session_id: Session ID last_message_id: ID of the last message on the previous page limit: Number of messages to return Returns: dict: { "messages": [...], "has_more": bool } """ return await self.get_history( session_id=session_id, limit=limit, before_id=last_message_id ) # Global singleton _global_manager: Optional['ChatHistoryManagerWithPool'] = None class ChatHistoryManagerWithPool: """ Singleton chat history manager with connection pool support. Uses the shared PostgreSQL connection pool. """ def __init__(self): self._pool: Optional[AsyncConnectionPool] = None self._manager: Optional[ChatHistoryManager] = None self._initialized = False self._closed = False async def initialize(self, pool: AsyncConnectionPool) -> None: """Initialize the manager using an externally provided connection pool. Args: pool: AsyncConnectionPool instance from DBPoolManager """ if self._initialized: return self._pool = pool self._manager = ChatHistoryManager(pool) await self._manager.create_table() self._initialized = True logger.info("ChatHistoryManager initialized successfully (using shared pool)") @property def manager(self) -> ChatHistoryManager: """Get the ChatHistoryManager instance.""" if self._closed: raise RuntimeError("ChatHistoryManager is closed") if not self._initialized or not self._manager: raise RuntimeError("ChatHistoryManager not initialized") return self._manager async def close(self) -> None: """Close the manager without closing the pool managed by DBPoolManager.""" if self._closed: return logger.info("Closing ChatHistoryManager...") self._closed = True self._initialized = False logger.info("ChatHistoryManager closed (pool managed by DBPoolManager)") def get_chat_history_manager() -> ChatHistoryManagerWithPool: """Get the global ChatHistoryManager singleton.""" global _global_manager if _global_manager is None: _global_manager = ChatHistoryManagerWithPool() return _global_manager async def init_chat_history_manager(pool: AsyncConnectionPool) -> None: """Initialize the global chat history manager. Args: pool: AsyncConnectionPool instance from DBPoolManager """ manager = get_chat_history_manager() await manager.initialize(pool) async def close_chat_history_manager() -> None: """Close the global chat history manager.""" global _global_manager if _global_manager is not None: await _global_manager.close()