291 lines
9.7 KiB
Python
291 lines
9.7 KiB
Python
"""
|
||
聊天历史记录管理器
|
||
直接保存完整的原始聊天消息到数据库,不受 checkpoint summary 影响
|
||
"""
|
||
|
||
import logging
|
||
from datetime import datetime
|
||
from typing import Optional, List, Dict, Any
|
||
from dataclasses import dataclass
|
||
|
||
from psycopg_pool import AsyncConnectionPool
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
|
||
@dataclass
|
||
class ChatMessage:
|
||
"""聊天消息"""
|
||
id: str
|
||
session_id: str
|
||
role: str
|
||
content: str
|
||
created_at: datetime
|
||
|
||
|
||
class ChatHistoryManager:
|
||
"""
|
||
聊天历史管理器
|
||
|
||
使用独立的数据库表存储完整的聊天历史记录
|
||
复用 checkpoint_manager 的 PostgreSQL 连接池
|
||
"""
|
||
|
||
def __init__(self, pool: AsyncConnectionPool):
|
||
"""
|
||
初始化聊天历史管理器
|
||
|
||
Args:
|
||
pool: PostgreSQL 连接池
|
||
"""
|
||
self._pool = pool
|
||
|
||
async def create_table(self) -> None:
|
||
"""创建 chat_messages 表"""
|
||
async with self._pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
CREATE TABLE IF NOT EXISTS chat_messages (
|
||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||
session_id VARCHAR(255) NOT NULL,
|
||
role VARCHAR(50) NOT NULL,
|
||
content TEXT NOT NULL,
|
||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||
bot_id VARCHAR(255),
|
||
user_identifier VARCHAR(255)
|
||
)
|
||
""")
|
||
# 创建索引以加速查询
|
||
await cursor.execute("""
|
||
CREATE INDEX IF NOT EXISTS idx_chat_messages_session_created
|
||
ON chat_messages (session_id, created_at DESC)
|
||
""")
|
||
await conn.commit()
|
||
logger.info("chat_messages table created successfully")
|
||
|
||
async def save_message(
|
||
self,
|
||
session_id: str,
|
||
role: str,
|
||
content: str,
|
||
bot_id: Optional[str] = None,
|
||
user_identifier: Optional[str] = None
|
||
) -> str:
|
||
"""
|
||
保存一条聊天消息
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
role: 消息角色 ('user' 或 'assistant')
|
||
content: 消息内容
|
||
bot_id: 机器人ID
|
||
user_identifier: 用户标识
|
||
|
||
Returns:
|
||
str: 消息ID
|
||
"""
|
||
async with self._pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
await cursor.execute("""
|
||
INSERT INTO chat_messages (session_id, role, content, bot_id, user_identifier)
|
||
VALUES (%s, %s, %s, %s, %s)
|
||
RETURNING id
|
||
""", (session_id, role, content, bot_id, user_identifier))
|
||
result = await cursor.fetchone()
|
||
await conn.commit()
|
||
message_id = str(result[0]) if result else None
|
||
logger.debug(f"Saved message: session_id={session_id}, role={role}, id={message_id}")
|
||
return message_id
|
||
|
||
async def save_messages(
|
||
self,
|
||
session_id: str,
|
||
messages: List[Dict[str, str]],
|
||
bot_id: Optional[str] = None,
|
||
user_identifier: Optional[str] = None
|
||
) -> List[str]:
|
||
"""
|
||
批量保存聊天消息
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
messages: 消息列表,每条消息包含 role 和 content
|
||
bot_id: 机器人ID
|
||
user_identifier: 用户标识
|
||
|
||
Returns:
|
||
List[str]: 消息ID列表
|
||
"""
|
||
message_ids = []
|
||
async with self._pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
for msg in messages:
|
||
role = msg.get('role')
|
||
content = msg.get('content', '')
|
||
if not role or not content:
|
||
continue
|
||
await cursor.execute("""
|
||
INSERT INTO chat_messages (session_id, role, content, bot_id, user_identifier)
|
||
VALUES (%s, %s, %s, %s, %s)
|
||
RETURNING id
|
||
""", (session_id, role, content, bot_id, user_identifier))
|
||
result = await cursor.fetchone()
|
||
message_id = str(result[0]) if result else None
|
||
message_ids.append(message_id)
|
||
await conn.commit()
|
||
logger.info(f"Saved {len(message_ids)} messages for session_id={session_id}")
|
||
return message_ids
|
||
|
||
async def get_history(
|
||
self,
|
||
session_id: str,
|
||
limit: int = 20,
|
||
before_id: Optional[str] = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
获取聊天历史记录(倒序,最新在前)
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
limit: 返回的消息数量
|
||
before_id: 获取此消息ID之前的消息(用于分页)
|
||
|
||
Returns:
|
||
dict: {
|
||
"messages": [...],
|
||
"has_more": bool
|
||
}
|
||
"""
|
||
async with self._pool.connection() as conn:
|
||
async with conn.cursor() as cursor:
|
||
if before_id:
|
||
# 查询 before_id 的 created_at,然后获取更早的消息
|
||
await cursor.execute("""
|
||
SELECT created_at FROM chat_messages
|
||
WHERE id = %s
|
||
""", (before_id,))
|
||
result = await cursor.fetchone()
|
||
if not result:
|
||
# 如果找不到指定的 ID,从头开始
|
||
query = """
|
||
SELECT id, session_id, role, content, created_at
|
||
FROM chat_messages
|
||
WHERE session_id = %s
|
||
ORDER BY created_at DESC
|
||
LIMIT %s + 1
|
||
"""
|
||
await cursor.execute(query, (session_id, limit))
|
||
else:
|
||
before_time = result[0]
|
||
query = """
|
||
SELECT id, session_id, role, content, created_at
|
||
FROM chat_messages
|
||
WHERE session_id = %s AND created_at < %s
|
||
ORDER BY created_at DESC
|
||
LIMIT %s + 1
|
||
"""
|
||
await cursor.execute(query, (session_id, before_time, limit))
|
||
else:
|
||
query = """
|
||
SELECT id, session_id, role, content, created_at
|
||
FROM chat_messages
|
||
WHERE session_id = %s
|
||
ORDER BY created_at DESC
|
||
LIMIT %s + 1
|
||
"""
|
||
await cursor.execute(query, (session_id, limit))
|
||
|
||
rows = await cursor.fetchall()
|
||
|
||
# 判断是否有更多(多取一条用于判断)
|
||
has_more = len(rows) > limit
|
||
if has_more:
|
||
rows = rows[:limit]
|
||
|
||
messages = []
|
||
for row in rows:
|
||
messages.append({
|
||
"id": str(row[0]),
|
||
"role": row[2],
|
||
"content": row[3],
|
||
"timestamp": row[4].isoformat() if row[4] else None
|
||
})
|
||
|
||
return {
|
||
"messages": messages,
|
||
"has_more": has_more
|
||
}
|
||
|
||
async def get_history_by_message_id(
|
||
self,
|
||
session_id: str,
|
||
last_message_id: Optional[str] = None,
|
||
limit: int = 20
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
根据 last_message_id 获取更早的历史记录
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
last_message_id: 上一页最后一条消息的ID
|
||
limit: 返回的消息数量
|
||
|
||
Returns:
|
||
dict: {
|
||
"messages": [...],
|
||
"has_more": bool
|
||
}
|
||
"""
|
||
return await self.get_history(
|
||
session_id=session_id,
|
||
limit=limit,
|
||
before_id=last_message_id
|
||
)
|
||
|
||
|
||
# 全局单例
|
||
_global_manager: Optional['ChatHistoryManagerWithPool'] = None
|
||
|
||
|
||
class ChatHistoryManagerWithPool:
|
||
"""
|
||
带连接池的聊天历史管理器单例
|
||
复用 checkpoint_manager 的连接池
|
||
"""
|
||
def __init__(self):
|
||
self._pool: Optional[AsyncConnectionPool] = None
|
||
self._manager: Optional[ChatHistoryManager] = None
|
||
self._initialized = False
|
||
|
||
async def initialize(self, pool: AsyncConnectionPool) -> None:
|
||
"""初始化管理器"""
|
||
if self._initialized:
|
||
return
|
||
|
||
self._pool = pool
|
||
self._manager = ChatHistoryManager(pool)
|
||
await self._manager.create_table()
|
||
self._initialized = True
|
||
logger.info("ChatHistoryManager initialized successfully")
|
||
|
||
@property
|
||
def manager(self) -> ChatHistoryManager:
|
||
"""获取 ChatHistoryManager 实例"""
|
||
if not self._initialized or not self._manager:
|
||
raise RuntimeError("ChatHistoryManager not initialized")
|
||
return self._manager
|
||
|
||
|
||
def get_chat_history_manager() -> ChatHistoryManagerWithPool:
|
||
"""获取全局 ChatHistoryManager 单例"""
|
||
global _global_manager
|
||
if _global_manager is None:
|
||
_global_manager = ChatHistoryManagerWithPool()
|
||
return _global_manager
|
||
|
||
|
||
async def init_chat_history_manager(pool: AsyncConnectionPool) -> None:
|
||
"""初始化全局聊天历史管理器"""
|
||
manager = get_chat_history_manager()
|
||
await manager.initialize(pool)
|