qwen_agent/agent/chat_history_manager.py
2026-01-18 12:29:20 +08:00

291 lines
9.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
聊天历史记录管理器
直接保存完整的原始聊天消息到数据库,不受 checkpoint summary 影响
"""
import logging
from datetime import datetime
from typing import Optional, List, Dict, Any
from dataclasses import dataclass
from psycopg_pool import AsyncConnectionPool
logger = logging.getLogger('app')
@dataclass
class ChatMessage:
"""聊天消息"""
id: str
session_id: str
role: str
content: str
created_at: datetime
class ChatHistoryManager:
"""
聊天历史管理器
使用独立的数据库表存储完整的聊天历史记录
复用 checkpoint_manager 的 PostgreSQL 连接池
"""
def __init__(self, pool: AsyncConnectionPool):
"""
初始化聊天历史管理器
Args:
pool: PostgreSQL 连接池
"""
self._pool = pool
async def create_table(self) -> None:
"""创建 chat_messages 表"""
async with self._pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute("""
CREATE TABLE IF NOT EXISTS chat_messages (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
session_id VARCHAR(255) NOT NULL,
role VARCHAR(50) NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
bot_id VARCHAR(255),
user_identifier VARCHAR(255)
)
""")
# 创建索引以加速查询
await cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_chat_messages_session_created
ON chat_messages (session_id, created_at DESC)
""")
await conn.commit()
logger.info("chat_messages table created successfully")
async def save_message(
self,
session_id: str,
role: str,
content: str,
bot_id: Optional[str] = None,
user_identifier: Optional[str] = None
) -> str:
"""
保存一条聊天消息
Args:
session_id: 会话ID
role: 消息角色 ('user''assistant')
content: 消息内容
bot_id: 机器人ID
user_identifier: 用户标识
Returns:
str: 消息ID
"""
async with self._pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute("""
INSERT INTO chat_messages (session_id, role, content, bot_id, user_identifier)
VALUES (%s, %s, %s, %s, %s)
RETURNING id
""", (session_id, role, content, bot_id, user_identifier))
result = await cursor.fetchone()
await conn.commit()
message_id = str(result[0]) if result else None
logger.debug(f"Saved message: session_id={session_id}, role={role}, id={message_id}")
return message_id
async def save_messages(
self,
session_id: str,
messages: List[Dict[str, str]],
bot_id: Optional[str] = None,
user_identifier: Optional[str] = None
) -> List[str]:
"""
批量保存聊天消息
Args:
session_id: 会话ID
messages: 消息列表,每条消息包含 role 和 content
bot_id: 机器人ID
user_identifier: 用户标识
Returns:
List[str]: 消息ID列表
"""
message_ids = []
async with self._pool.connection() as conn:
async with conn.cursor() as cursor:
for msg in messages:
role = msg.get('role')
content = msg.get('content', '')
if not role or not content:
continue
await cursor.execute("""
INSERT INTO chat_messages (session_id, role, content, bot_id, user_identifier)
VALUES (%s, %s, %s, %s, %s)
RETURNING id
""", (session_id, role, content, bot_id, user_identifier))
result = await cursor.fetchone()
message_id = str(result[0]) if result else None
message_ids.append(message_id)
await conn.commit()
logger.info(f"Saved {len(message_ids)} messages for session_id={session_id}")
return message_ids
async def get_history(
self,
session_id: str,
limit: int = 20,
before_id: Optional[str] = None
) -> Dict[str, Any]:
"""
获取聊天历史记录(倒序,最新在前)
Args:
session_id: 会话ID
limit: 返回的消息数量
before_id: 获取此消息ID之前的消息用于分页
Returns:
dict: {
"messages": [...],
"has_more": bool
}
"""
async with self._pool.connection() as conn:
async with conn.cursor() as cursor:
if before_id:
# 查询 before_id 的 created_at然后获取更早的消息
await cursor.execute("""
SELECT created_at FROM chat_messages
WHERE id = %s
""", (before_id,))
result = await cursor.fetchone()
if not result:
# 如果找不到指定的 ID从头开始
query = """
SELECT id, session_id, role, content, created_at
FROM chat_messages
WHERE session_id = %s
ORDER BY created_at DESC
LIMIT %s + 1
"""
await cursor.execute(query, (session_id, limit))
else:
before_time = result[0]
query = """
SELECT id, session_id, role, content, created_at
FROM chat_messages
WHERE session_id = %s AND created_at < %s
ORDER BY created_at DESC
LIMIT %s + 1
"""
await cursor.execute(query, (session_id, before_time, limit))
else:
query = """
SELECT id, session_id, role, content, created_at
FROM chat_messages
WHERE session_id = %s
ORDER BY created_at DESC
LIMIT %s + 1
"""
await cursor.execute(query, (session_id, limit))
rows = await cursor.fetchall()
# 判断是否有更多(多取一条用于判断)
has_more = len(rows) > limit
if has_more:
rows = rows[:limit]
messages = []
for row in rows:
messages.append({
"id": str(row[0]),
"role": row[2],
"content": row[3],
"timestamp": row[4].isoformat() if row[4] else None
})
return {
"messages": messages,
"has_more": has_more
}
async def get_history_by_message_id(
self,
session_id: str,
last_message_id: Optional[str] = None,
limit: int = 20
) -> Dict[str, Any]:
"""
根据 last_message_id 获取更早的历史记录
Args:
session_id: 会话ID
last_message_id: 上一页最后一条消息的ID
limit: 返回的消息数量
Returns:
dict: {
"messages": [...],
"has_more": bool
}
"""
return await self.get_history(
session_id=session_id,
limit=limit,
before_id=last_message_id
)
# 全局单例
_global_manager: Optional['ChatHistoryManagerWithPool'] = None
class ChatHistoryManagerWithPool:
"""
带连接池的聊天历史管理器单例
复用 checkpoint_manager 的连接池
"""
def __init__(self):
self._pool: Optional[AsyncConnectionPool] = None
self._manager: Optional[ChatHistoryManager] = None
self._initialized = False
async def initialize(self, pool: AsyncConnectionPool) -> None:
"""初始化管理器"""
if self._initialized:
return
self._pool = pool
self._manager = ChatHistoryManager(pool)
await self._manager.create_table()
self._initialized = True
logger.info("ChatHistoryManager initialized successfully")
@property
def manager(self) -> ChatHistoryManager:
"""获取 ChatHistoryManager 实例"""
if not self._initialized or not self._manager:
raise RuntimeError("ChatHistoryManager not initialized")
return self._manager
def get_chat_history_manager() -> ChatHistoryManagerWithPool:
"""获取全局 ChatHistoryManager 单例"""
global _global_manager
if _global_manager is None:
_global_manager = ChatHistoryManagerWithPool()
return _global_manager
async def init_chat_history_manager(pool: AsyncConnectionPool) -> None:
"""初始化全局聊天历史管理器"""
manager = get_chat_history_manager()
await manager.initialize(pool)