添加聊天记录查询

This commit is contained in:
朱潮 2026-01-18 12:29:20 +08:00
parent fa3e30cc07
commit f9ba3c8e51
4 changed files with 472 additions and 4 deletions

View File

@ -0,0 +1,290 @@
"""
聊天历史记录管理器
直接保存完整的原始聊天消息到数据库不受 checkpoint summary 影响
"""
import logging
from datetime import datetime
from typing import Optional, List, Dict, Any
from dataclasses import dataclass
from psycopg_pool import AsyncConnectionPool
logger = logging.getLogger('app')
@dataclass
class ChatMessage:
"""聊天消息"""
id: str
session_id: str
role: str
content: str
created_at: datetime
class ChatHistoryManager:
"""
聊天历史管理器
使用独立的数据库表存储完整的聊天历史记录
复用 checkpoint_manager PostgreSQL 连接池
"""
def __init__(self, pool: AsyncConnectionPool):
"""
初始化聊天历史管理器
Args:
pool: PostgreSQL 连接池
"""
self._pool = pool
async def create_table(self) -> None:
"""创建 chat_messages 表"""
async with self._pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute("""
CREATE TABLE IF NOT EXISTS chat_messages (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
session_id VARCHAR(255) NOT NULL,
role VARCHAR(50) NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
bot_id VARCHAR(255),
user_identifier VARCHAR(255)
)
""")
# 创建索引以加速查询
await cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_chat_messages_session_created
ON chat_messages (session_id, created_at DESC)
""")
await conn.commit()
logger.info("chat_messages table created successfully")
async def save_message(
self,
session_id: str,
role: str,
content: str,
bot_id: Optional[str] = None,
user_identifier: Optional[str] = None
) -> str:
"""
保存一条聊天消息
Args:
session_id: 会话ID
role: 消息角色 ('user' 'assistant')
content: 消息内容
bot_id: 机器人ID
user_identifier: 用户标识
Returns:
str: 消息ID
"""
async with self._pool.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute("""
INSERT INTO chat_messages (session_id, role, content, bot_id, user_identifier)
VALUES (%s, %s, %s, %s, %s)
RETURNING id
""", (session_id, role, content, bot_id, user_identifier))
result = await cursor.fetchone()
await conn.commit()
message_id = str(result[0]) if result else None
logger.debug(f"Saved message: session_id={session_id}, role={role}, id={message_id}")
return message_id
async def save_messages(
self,
session_id: str,
messages: List[Dict[str, str]],
bot_id: Optional[str] = None,
user_identifier: Optional[str] = None
) -> List[str]:
"""
批量保存聊天消息
Args:
session_id: 会话ID
messages: 消息列表每条消息包含 role content
bot_id: 机器人ID
user_identifier: 用户标识
Returns:
List[str]: 消息ID列表
"""
message_ids = []
async with self._pool.connection() as conn:
async with conn.cursor() as cursor:
for msg in messages:
role = msg.get('role')
content = msg.get('content', '')
if not role or not content:
continue
await cursor.execute("""
INSERT INTO chat_messages (session_id, role, content, bot_id, user_identifier)
VALUES (%s, %s, %s, %s, %s)
RETURNING id
""", (session_id, role, content, bot_id, user_identifier))
result = await cursor.fetchone()
message_id = str(result[0]) if result else None
message_ids.append(message_id)
await conn.commit()
logger.info(f"Saved {len(message_ids)} messages for session_id={session_id}")
return message_ids
async def get_history(
self,
session_id: str,
limit: int = 20,
before_id: Optional[str] = None
) -> Dict[str, Any]:
"""
获取聊天历史记录倒序最新在前
Args:
session_id: 会话ID
limit: 返回的消息数量
before_id: 获取此消息ID之前的消息用于分页
Returns:
dict: {
"messages": [...],
"has_more": bool
}
"""
async with self._pool.connection() as conn:
async with conn.cursor() as cursor:
if before_id:
# 查询 before_id 的 created_at然后获取更早的消息
await cursor.execute("""
SELECT created_at FROM chat_messages
WHERE id = %s
""", (before_id,))
result = await cursor.fetchone()
if not result:
# 如果找不到指定的 ID从头开始
query = """
SELECT id, session_id, role, content, created_at
FROM chat_messages
WHERE session_id = %s
ORDER BY created_at DESC
LIMIT %s + 1
"""
await cursor.execute(query, (session_id, limit))
else:
before_time = result[0]
query = """
SELECT id, session_id, role, content, created_at
FROM chat_messages
WHERE session_id = %s AND created_at < %s
ORDER BY created_at DESC
LIMIT %s + 1
"""
await cursor.execute(query, (session_id, before_time, limit))
else:
query = """
SELECT id, session_id, role, content, created_at
FROM chat_messages
WHERE session_id = %s
ORDER BY created_at DESC
LIMIT %s + 1
"""
await cursor.execute(query, (session_id, limit))
rows = await cursor.fetchall()
# 判断是否有更多(多取一条用于判断)
has_more = len(rows) > limit
if has_more:
rows = rows[:limit]
messages = []
for row in rows:
messages.append({
"id": str(row[0]),
"role": row[2],
"content": row[3],
"timestamp": row[4].isoformat() if row[4] else None
})
return {
"messages": messages,
"has_more": has_more
}
async def get_history_by_message_id(
self,
session_id: str,
last_message_id: Optional[str] = None,
limit: int = 20
) -> Dict[str, Any]:
"""
根据 last_message_id 获取更早的历史记录
Args:
session_id: 会话ID
last_message_id: 上一页最后一条消息的ID
limit: 返回的消息数量
Returns:
dict: {
"messages": [...],
"has_more": bool
}
"""
return await self.get_history(
session_id=session_id,
limit=limit,
before_id=last_message_id
)
# 全局单例
_global_manager: Optional['ChatHistoryManagerWithPool'] = None
class ChatHistoryManagerWithPool:
"""
带连接池的聊天历史管理器单例
复用 checkpoint_manager 的连接池
"""
def __init__(self):
self._pool: Optional[AsyncConnectionPool] = None
self._manager: Optional[ChatHistoryManager] = None
self._initialized = False
async def initialize(self, pool: AsyncConnectionPool) -> None:
"""初始化管理器"""
if self._initialized:
return
self._pool = pool
self._manager = ChatHistoryManager(pool)
await self._manager.create_table()
self._initialized = True
logger.info("ChatHistoryManager initialized successfully")
@property
def manager(self) -> ChatHistoryManager:
"""获取 ChatHistoryManager 实例"""
if not self._initialized or not self._manager:
raise RuntimeError("ChatHistoryManager not initialized")
return self._manager
def get_chat_history_manager() -> ChatHistoryManagerWithPool:
"""获取全局 ChatHistoryManager 单例"""
global _global_manager
if _global_manager is None:
_global_manager = ChatHistoryManagerWithPool()
return _global_manager
async def init_chat_history_manager(pool: AsyncConnectionPool) -> None:
"""初始化全局聊天历史管理器"""
manager = get_chat_history_manager()
await manager.initialize(pool)

View File

@ -68,6 +68,10 @@ class CheckpointerManager:
checkpointer = AsyncPostgresSaver(conn=conn)
await checkpointer.setup()
# 初始化 ChatHistoryManager复用同一个连接池
from .chat_history_manager import init_chat_history_manager
await init_chat_history_manager(self._pool)
self._initialized = True
logger.info("PostgreSQL checkpointer pool initialized successfully")
except Exception as e:

View File

@ -1,7 +1,7 @@
import json
import os
import asyncio
from typing import Union, Optional
from typing import Union, Optional, Any, List, Dict
from fastapi import APIRouter, HTTPException, Header
from fastapi.responses import StreamingResponse
import logging
@ -34,11 +34,18 @@ async def enhanced_generate_stream_response(
agent: LangChain agent 对象
config: AgentConfig 对象包含所有参数
"""
# 用于收集完整的响应内容,用于保存到数据库
full_response_content = []
try:
# 创建输出队列和控制事件
output_queue = asyncio.Queue()
preamble_completed = asyncio.Event()
# 在流式开始前保存用户消息
if config.session_id:
asyncio.create_task(_save_user_messages(config))
# Preamble 任务
async def preamble_task():
try:
@ -100,8 +107,11 @@ async def enhanced_generate_stream_response(
message_tag = "TOOL_RESPONSE"
new_content = f"[{message_tag}] {msg.name}\n{msg.text}\n"
# 发送内容块
# 收集完整内容
if new_content:
full_response_content.append(new_content)
# 发送内容块
if chunk_id == 0:
logger.info(f"Agent首个Token已生成, 开始流式输出")
chunk_id += 1
@ -176,6 +186,10 @@ async def enhanced_generate_stream_response(
yield "data: [DONE]\n\n"
logger.info(f"Enhanced stream response completed")
# 流式结束后保存 AI 响应
if full_response_content and config.session_id:
asyncio.create_task(_save_assistant_response(config, "".join(full_response_content)))
except Exception as e:
logger.error(f"Error in enhanced_generate_stream_response: {e}")
yield f'data: {{"error": "{str(e)}"}}\n\n'
@ -197,7 +211,7 @@ async def create_agent_and_generate_response(
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
)
agent, checkpointer = await init_agent(config)
# 使用更新后的 messages
agent_responses = await agent.ainvoke({"messages": config.messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS)
@ -244,11 +258,86 @@ async def create_agent_and_generate_response(
"total_tokens": sum(len(msg.get("content", "")) for msg in config.messages) + len(response_text)
}
)
# 保存聊天历史到数据库(与流式接口保持一致的逻辑)
await _save_user_messages(config)
await _save_assistant_response(config, response_text)
else:
raise HTTPException(status_code=500, detail="No response from agent")
return result
async def _save_user_messages(config: AgentConfig) -> None:
"""
保存最后一条用户消息用于流式和非流式接口
Args:
config: AgentConfig 对象
"""
# 只有在 session_id 存在时才保存
if not config.session_id:
return
try:
from agent.chat_history_manager import get_chat_history_manager
manager = get_chat_history_manager()
# 只保存最后一条 user 消息
for msg in reversed(config.messages):
if isinstance(msg, dict):
role = msg.get("role", "")
content = msg.get("content", "")
if role == "user" and content:
await manager.manager.save_message(
session_id=config.session_id,
role=role,
content=content,
bot_id=config.bot_id,
user_identifier=config.user_identifier
)
break # 只保存最后一条,然后退出
logger.debug(f"Saved last user message for session_id={config.session_id}")
except Exception as e:
# 保存失败不影响主流程
logger.error(f"Failed to save user messages: {e}")
async def _save_assistant_response(config: AgentConfig, assistant_response: str) -> None:
"""
保存 AI 助手的响应用于流式和非流式接口
Args:
config: AgentConfig 对象
assistant_response: AI 助手的响应内容
"""
# 只有在 session_id 存在时才保存
if not config.session_id:
return
if not assistant_response:
return
try:
from agent.chat_history_manager import get_chat_history_manager
manager = get_chat_history_manager()
# 保存 AI 助手的响应
await manager.manager.save_message(
session_id=config.session_id,
role="assistant",
content=assistant_response,
bot_id=config.bot_id,
user_identifier=config.user_identifier
)
logger.debug(f"Saved assistant response for session_id={config.session_id}")
except Exception as e:
# 保存失败不影响主流程
logger.error(f"Failed to save assistant response: {e}")
@router.post("/api/v1/chat/completions")
async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)):
@ -563,3 +652,63 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st
logger.error(f"Error in chat_completions_v2: {str(e)}")
logger.error(f"Full traceback: {error_details}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
# ============================================================================
# 聊天历史查询接口
# ============================================================================
@router.get("/api/v1/chat/history", response_model=dict)
async def get_chat_history(
session_id: str,
last_message_id: Optional[str] = None,
limit: int = 20
):
"""
获取聊天历史记录
从独立的聊天历史表查询返回完整的原始消息不受 checkpoint summary 影响
参数:
session_id: 会话ID
last_message_id: 上一页最后一条消息的ID用于获取更早的消息
limit: 每次返回的消息数量默认 20最大 100
返回:
{
"messages": [
{
"id": "唯一消息ID",
"role": "user 或 assistant",
"content": "消息内容",
"timestamp": "ISO 8601 格式的时间戳"
},
...
],
"has_more": true/false // 是否还有更多历史消息
}
"""
try:
from agent.chat_history_manager import get_chat_history_manager
# 参数验证
limit = min(max(1, limit), 100)
manager = get_chat_history_manager()
result = await manager.manager.get_history_by_message_id(
session_id=session_id,
last_message_id=last_message_id,
limit=limit
)
return {
"messages": result["messages"],
"has_more": result["has_more"]
}
except Exception as e:
import traceback
error_details = traceback.format_exc()
logger.error(f"Error in get_chat_history: {str(e)}")
logger.error(f"Full traceback: {error_details}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")

View File

@ -356,7 +356,7 @@ def create_chat_response(
"""Create a chat completion response"""
import time
import uuid
return {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion",
@ -378,3 +378,28 @@ def create_chat_response(
"total_tokens": 0
}
}
# ============================================================================
# 聊天历史查询相关模型
# ============================================================================
class ChatHistoryRequest(BaseModel):
"""聊天历史查询请求"""
session_id: str = Field(..., description="会话ID (thread_id)")
last_message_id: Optional[str] = Field(None, description="上一条消息的ID用于分页查询更早的消息")
limit: int = Field(20, ge=1, le=100, description="每次查询的消息数量上限")
class ChatHistoryMessage(BaseModel):
"""聊天历史消息"""
id: str = Field(..., description="消息唯一ID")
role: str = Field(..., description="消息角色: user 或 assistant")
content: str = Field(..., description="消息内容")
timestamp: Optional[str] = Field(None, description="消息时间戳 (ISO 8601)")
class ChatHistoryResponse(BaseModel):
"""聊天历史查询响应"""
messages: List[ChatHistoryMessage] = Field(..., description="消息列表,按时间倒序排列")
has_more: bool = Field(..., description="是否还有更多历史消息")