Convert all Chinese comments, docstrings, logger/print output, HTTPException detail messages, and API response messages to English across the entire codebase. Functional zh/ja localized strings (e.g. prompt templates, timezone display names, date formats) are preserved as-is. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
321 lines
11 KiB
Python
321 lines
11 KiB
Python
"""
|
|
Chat history manager.
|
|
Uses the shared database connection pool.
|
|
Stores complete raw chat messages directly in the database without being affected by checkpoint summaries.
|
|
"""
|
|
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Optional, List, Dict, Any
|
|
from dataclasses import dataclass
|
|
|
|
from psycopg_pool import AsyncConnectionPool
|
|
|
|
logger = logging.getLogger('app')
|
|
|
|
|
|
@dataclass
|
|
class ChatMessage:
|
|
"""Chat message."""
|
|
id: str
|
|
session_id: str
|
|
role: str
|
|
content: str
|
|
created_at: datetime
|
|
|
|
|
|
class ChatHistoryManager:
|
|
"""
|
|
Chat history manager.
|
|
|
|
Uses the shared PostgreSQL connection pool to store complete chat history records.
|
|
"""
|
|
|
|
def __init__(self, pool: AsyncConnectionPool):
|
|
"""
|
|
Initialize the chat history manager.
|
|
|
|
Args:
|
|
pool: PostgreSQL connection pool from DBPoolManager
|
|
"""
|
|
self._pool = pool
|
|
|
|
async def create_table(self) -> None:
|
|
"""Create the chat_messages table."""
|
|
async with self._pool.connection() as conn:
|
|
async with conn.cursor() as cursor:
|
|
await cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS chat_messages (
|
|
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
|
session_id VARCHAR(255) NOT NULL,
|
|
role VARCHAR(50) NOT NULL,
|
|
content TEXT NOT NULL,
|
|
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
|
bot_id VARCHAR(255),
|
|
user_identifier VARCHAR(255)
|
|
)
|
|
""")
|
|
# Create an index to speed up queries.
|
|
await cursor.execute("""
|
|
CREATE INDEX IF NOT EXISTS idx_chat_messages_session_created
|
|
ON chat_messages (session_id, created_at DESC)
|
|
""")
|
|
await conn.commit()
|
|
logger.info("chat_messages table created successfully")
|
|
|
|
async def save_message(
|
|
self,
|
|
session_id: str,
|
|
role: str,
|
|
content: str,
|
|
bot_id: Optional[str] = None,
|
|
user_identifier: Optional[str] = None
|
|
) -> str:
|
|
"""
|
|
Save a single chat message.
|
|
|
|
Args:
|
|
session_id: Session ID
|
|
role: Message role ('user' or 'assistant')
|
|
content: Message content
|
|
bot_id: Bot ID
|
|
user_identifier: User identifier
|
|
|
|
Returns:
|
|
str: Message ID
|
|
"""
|
|
async with self._pool.connection() as conn:
|
|
async with conn.cursor() as cursor:
|
|
await cursor.execute("""
|
|
INSERT INTO chat_messages (session_id, role, content, bot_id, user_identifier)
|
|
VALUES (%s, %s, %s, %s, %s)
|
|
RETURNING id
|
|
""", (session_id, role, content, bot_id, user_identifier))
|
|
result = await cursor.fetchone()
|
|
await conn.commit()
|
|
message_id = str(result[0]) if result else None
|
|
logger.debug(f"Saved message: session_id={session_id}, role={role}, id={message_id}")
|
|
return message_id
|
|
|
|
async def save_messages(
|
|
self,
|
|
session_id: str,
|
|
messages: List[Dict[str, str]],
|
|
bot_id: Optional[str] = None,
|
|
user_identifier: Optional[str] = None
|
|
) -> List[str]:
|
|
"""
|
|
Save multiple chat messages in a batch.
|
|
|
|
Args:
|
|
session_id: Session ID
|
|
messages: List of messages, each containing role and content
|
|
bot_id: Bot ID
|
|
user_identifier: User identifier
|
|
|
|
Returns:
|
|
List[str]: List of message IDs
|
|
"""
|
|
message_ids = []
|
|
async with self._pool.connection() as conn:
|
|
async with conn.cursor() as cursor:
|
|
for msg in messages:
|
|
role = msg.get('role')
|
|
content = msg.get('content', '')
|
|
if not role or not content:
|
|
continue
|
|
await cursor.execute("""
|
|
INSERT INTO chat_messages (session_id, role, content, bot_id, user_identifier)
|
|
VALUES (%s, %s, %s, %s, %s)
|
|
RETURNING id
|
|
""", (session_id, role, content, bot_id, user_identifier))
|
|
result = await cursor.fetchone()
|
|
message_id = str(result[0]) if result else None
|
|
message_ids.append(message_id)
|
|
await conn.commit()
|
|
logger.info(f"Saved {len(message_ids)} messages for session_id={session_id}")
|
|
return message_ids
|
|
|
|
async def get_history(
|
|
self,
|
|
session_id: str,
|
|
limit: int = 20,
|
|
before_id: Optional[str] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Get chat history in reverse chronological order, newest first.
|
|
|
|
Args:
|
|
session_id: Session ID
|
|
limit: Number of messages to return
|
|
before_id: Return messages before this message ID, used for pagination
|
|
|
|
Returns:
|
|
dict: {
|
|
"messages": [...],
|
|
"has_more": bool
|
|
}
|
|
"""
|
|
async with self._pool.connection() as conn:
|
|
async with conn.cursor() as cursor:
|
|
if before_id:
|
|
# Query the created_at value of before_id, then fetch earlier messages.
|
|
await cursor.execute("""
|
|
SELECT created_at FROM chat_messages
|
|
WHERE id = %s
|
|
""", (before_id,))
|
|
result = await cursor.fetchone()
|
|
if not result:
|
|
# If the specified ID is not found, start from the beginning.
|
|
query = """
|
|
SELECT id, session_id, role, content, created_at
|
|
FROM chat_messages
|
|
WHERE session_id = %s
|
|
ORDER BY created_at DESC
|
|
LIMIT %s + 1
|
|
"""
|
|
await cursor.execute(query, (session_id, limit))
|
|
else:
|
|
before_time = result[0]
|
|
query = """
|
|
SELECT id, session_id, role, content, created_at
|
|
FROM chat_messages
|
|
WHERE session_id = %s AND created_at < %s
|
|
ORDER BY created_at DESC
|
|
LIMIT %s + 1
|
|
"""
|
|
await cursor.execute(query, (session_id, before_time, limit))
|
|
else:
|
|
query = """
|
|
SELECT id, session_id, role, content, created_at
|
|
FROM chat_messages
|
|
WHERE session_id = %s
|
|
ORDER BY created_at DESC
|
|
LIMIT %s + 1
|
|
"""
|
|
await cursor.execute(query, (session_id, limit))
|
|
|
|
rows = await cursor.fetchall()
|
|
|
|
# Fetch one extra row to determine whether more results exist.
|
|
has_more = len(rows) > limit
|
|
if has_more:
|
|
rows = rows[:limit]
|
|
|
|
messages = []
|
|
for row in rows:
|
|
messages.append({
|
|
"id": str(row[0]),
|
|
"role": row[2],
|
|
"content": row[3],
|
|
"timestamp": row[4].isoformat() if row[4] else None
|
|
})
|
|
|
|
return {
|
|
"messages": messages,
|
|
"has_more": has_more
|
|
}
|
|
|
|
async def get_history_by_message_id(
|
|
self,
|
|
session_id: str,
|
|
last_message_id: Optional[str] = None,
|
|
limit: int = 20
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Get earlier history records based on last_message_id.
|
|
|
|
Args:
|
|
session_id: Session ID
|
|
last_message_id: ID of the last message on the previous page
|
|
limit: Number of messages to return
|
|
|
|
Returns:
|
|
dict: {
|
|
"messages": [...],
|
|
"has_more": bool
|
|
}
|
|
"""
|
|
return await self.get_history(
|
|
session_id=session_id,
|
|
limit=limit,
|
|
before_id=last_message_id
|
|
)
|
|
|
|
|
|
# Global singleton
|
|
_global_manager: Optional['ChatHistoryManagerWithPool'] = None
|
|
|
|
|
|
class ChatHistoryManagerWithPool:
|
|
"""
|
|
Singleton chat history manager with connection pool support.
|
|
|
|
Uses the shared PostgreSQL connection pool.
|
|
"""
|
|
def __init__(self):
|
|
self._pool: Optional[AsyncConnectionPool] = None
|
|
self._manager: Optional[ChatHistoryManager] = None
|
|
self._initialized = False
|
|
self._closed = False
|
|
|
|
async def initialize(self, pool: AsyncConnectionPool) -> None:
|
|
"""Initialize the manager using an externally provided connection pool.
|
|
|
|
Args:
|
|
pool: AsyncConnectionPool instance from DBPoolManager
|
|
"""
|
|
if self._initialized:
|
|
return
|
|
|
|
self._pool = pool
|
|
self._manager = ChatHistoryManager(pool)
|
|
await self._manager.create_table()
|
|
self._initialized = True
|
|
logger.info("ChatHistoryManager initialized successfully (using shared pool)")
|
|
|
|
@property
|
|
def manager(self) -> ChatHistoryManager:
|
|
"""Get the ChatHistoryManager instance."""
|
|
if self._closed:
|
|
raise RuntimeError("ChatHistoryManager is closed")
|
|
|
|
if not self._initialized or not self._manager:
|
|
raise RuntimeError("ChatHistoryManager not initialized")
|
|
return self._manager
|
|
|
|
async def close(self) -> None:
|
|
"""Close the manager without closing the pool managed by DBPoolManager."""
|
|
if self._closed:
|
|
return
|
|
|
|
logger.info("Closing ChatHistoryManager...")
|
|
self._closed = True
|
|
self._initialized = False
|
|
logger.info("ChatHistoryManager closed (pool managed by DBPoolManager)")
|
|
|
|
|
|
def get_chat_history_manager() -> ChatHistoryManagerWithPool:
|
|
"""Get the global ChatHistoryManager singleton."""
|
|
global _global_manager
|
|
if _global_manager is None:
|
|
_global_manager = ChatHistoryManagerWithPool()
|
|
return _global_manager
|
|
|
|
|
|
async def init_chat_history_manager(pool: AsyncConnectionPool) -> None:
|
|
"""Initialize the global chat history manager.
|
|
|
|
Args:
|
|
pool: AsyncConnectionPool instance from DBPoolManager
|
|
"""
|
|
manager = get_chat_history_manager()
|
|
await manager.initialize(pool)
|
|
|
|
|
|
async def close_chat_history_manager() -> None:
|
|
"""Close the global chat history manager."""
|
|
global _global_manager
|
|
if _global_manager is not None:
|
|
await _global_manager.close()
|