添加聊天记录查询
This commit is contained in:
parent
fa3e30cc07
commit
f9ba3c8e51
290
agent/chat_history_manager.py
Normal file
290
agent/chat_history_manager.py
Normal file
@ -0,0 +1,290 @@
|
||||
"""
|
||||
聊天历史记录管理器
|
||||
直接保存完整的原始聊天消息到数据库,不受 checkpoint summary 影响
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
logger = logging.getLogger('app')
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""聊天消息"""
|
||||
id: str
|
||||
session_id: str
|
||||
role: str
|
||||
content: str
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ChatHistoryManager:
|
||||
"""
|
||||
聊天历史管理器
|
||||
|
||||
使用独立的数据库表存储完整的聊天历史记录
|
||||
复用 checkpoint_manager 的 PostgreSQL 连接池
|
||||
"""
|
||||
|
||||
def __init__(self, pool: AsyncConnectionPool):
|
||||
"""
|
||||
初始化聊天历史管理器
|
||||
|
||||
Args:
|
||||
pool: PostgreSQL 连接池
|
||||
"""
|
||||
self._pool = pool
|
||||
|
||||
async def create_table(self) -> None:
|
||||
"""创建 chat_messages 表"""
|
||||
async with self._pool.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS chat_messages (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
session_id VARCHAR(255) NOT NULL,
|
||||
role VARCHAR(50) NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||||
bot_id VARCHAR(255),
|
||||
user_identifier VARCHAR(255)
|
||||
)
|
||||
""")
|
||||
# 创建索引以加速查询
|
||||
await cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_chat_messages_session_created
|
||||
ON chat_messages (session_id, created_at DESC)
|
||||
""")
|
||||
await conn.commit()
|
||||
logger.info("chat_messages table created successfully")
|
||||
|
||||
async def save_message(
|
||||
self,
|
||||
session_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
bot_id: Optional[str] = None,
|
||||
user_identifier: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
保存一条聊天消息
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
role: 消息角色 ('user' 或 'assistant')
|
||||
content: 消息内容
|
||||
bot_id: 机器人ID
|
||||
user_identifier: 用户标识
|
||||
|
||||
Returns:
|
||||
str: 消息ID
|
||||
"""
|
||||
async with self._pool.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("""
|
||||
INSERT INTO chat_messages (session_id, role, content, bot_id, user_identifier)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
RETURNING id
|
||||
""", (session_id, role, content, bot_id, user_identifier))
|
||||
result = await cursor.fetchone()
|
||||
await conn.commit()
|
||||
message_id = str(result[0]) if result else None
|
||||
logger.debug(f"Saved message: session_id={session_id}, role={role}, id={message_id}")
|
||||
return message_id
|
||||
|
||||
async def save_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
messages: List[Dict[str, str]],
|
||||
bot_id: Optional[str] = None,
|
||||
user_identifier: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
批量保存聊天消息
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
messages: 消息列表,每条消息包含 role 和 content
|
||||
bot_id: 机器人ID
|
||||
user_identifier: 用户标识
|
||||
|
||||
Returns:
|
||||
List[str]: 消息ID列表
|
||||
"""
|
||||
message_ids = []
|
||||
async with self._pool.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
for msg in messages:
|
||||
role = msg.get('role')
|
||||
content = msg.get('content', '')
|
||||
if not role or not content:
|
||||
continue
|
||||
await cursor.execute("""
|
||||
INSERT INTO chat_messages (session_id, role, content, bot_id, user_identifier)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
RETURNING id
|
||||
""", (session_id, role, content, bot_id, user_identifier))
|
||||
result = await cursor.fetchone()
|
||||
message_id = str(result[0]) if result else None
|
||||
message_ids.append(message_id)
|
||||
await conn.commit()
|
||||
logger.info(f"Saved {len(message_ids)} messages for session_id={session_id}")
|
||||
return message_ids
|
||||
|
||||
async def get_history(
|
||||
self,
|
||||
session_id: str,
|
||||
limit: int = 20,
|
||||
before_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取聊天历史记录(倒序,最新在前)
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
limit: 返回的消息数量
|
||||
before_id: 获取此消息ID之前的消息(用于分页)
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"messages": [...],
|
||||
"has_more": bool
|
||||
}
|
||||
"""
|
||||
async with self._pool.connection() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
if before_id:
|
||||
# 查询 before_id 的 created_at,然后获取更早的消息
|
||||
await cursor.execute("""
|
||||
SELECT created_at FROM chat_messages
|
||||
WHERE id = %s
|
||||
""", (before_id,))
|
||||
result = await cursor.fetchone()
|
||||
if not result:
|
||||
# 如果找不到指定的 ID,从头开始
|
||||
query = """
|
||||
SELECT id, session_id, role, content, created_at
|
||||
FROM chat_messages
|
||||
WHERE session_id = %s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s + 1
|
||||
"""
|
||||
await cursor.execute(query, (session_id, limit))
|
||||
else:
|
||||
before_time = result[0]
|
||||
query = """
|
||||
SELECT id, session_id, role, content, created_at
|
||||
FROM chat_messages
|
||||
WHERE session_id = %s AND created_at < %s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s + 1
|
||||
"""
|
||||
await cursor.execute(query, (session_id, before_time, limit))
|
||||
else:
|
||||
query = """
|
||||
SELECT id, session_id, role, content, created_at
|
||||
FROM chat_messages
|
||||
WHERE session_id = %s
|
||||
ORDER BY created_at DESC
|
||||
LIMIT %s + 1
|
||||
"""
|
||||
await cursor.execute(query, (session_id, limit))
|
||||
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
# 判断是否有更多(多取一条用于判断)
|
||||
has_more = len(rows) > limit
|
||||
if has_more:
|
||||
rows = rows[:limit]
|
||||
|
||||
messages = []
|
||||
for row in rows:
|
||||
messages.append({
|
||||
"id": str(row[0]),
|
||||
"role": row[2],
|
||||
"content": row[3],
|
||||
"timestamp": row[4].isoformat() if row[4] else None
|
||||
})
|
||||
|
||||
return {
|
||||
"messages": messages,
|
||||
"has_more": has_more
|
||||
}
|
||||
|
||||
async def get_history_by_message_id(
|
||||
self,
|
||||
session_id: str,
|
||||
last_message_id: Optional[str] = None,
|
||||
limit: int = 20
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
根据 last_message_id 获取更早的历史记录
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
last_message_id: 上一页最后一条消息的ID
|
||||
limit: 返回的消息数量
|
||||
|
||||
Returns:
|
||||
dict: {
|
||||
"messages": [...],
|
||||
"has_more": bool
|
||||
}
|
||||
"""
|
||||
return await self.get_history(
|
||||
session_id=session_id,
|
||||
limit=limit,
|
||||
before_id=last_message_id
|
||||
)
|
||||
|
||||
|
||||
# 全局单例
|
||||
_global_manager: Optional['ChatHistoryManagerWithPool'] = None
|
||||
|
||||
|
||||
class ChatHistoryManagerWithPool:
|
||||
"""
|
||||
带连接池的聊天历史管理器单例
|
||||
复用 checkpoint_manager 的连接池
|
||||
"""
|
||||
def __init__(self):
|
||||
self._pool: Optional[AsyncConnectionPool] = None
|
||||
self._manager: Optional[ChatHistoryManager] = None
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self, pool: AsyncConnectionPool) -> None:
|
||||
"""初始化管理器"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._pool = pool
|
||||
self._manager = ChatHistoryManager(pool)
|
||||
await self._manager.create_table()
|
||||
self._initialized = True
|
||||
logger.info("ChatHistoryManager initialized successfully")
|
||||
|
||||
@property
|
||||
def manager(self) -> ChatHistoryManager:
|
||||
"""获取 ChatHistoryManager 实例"""
|
||||
if not self._initialized or not self._manager:
|
||||
raise RuntimeError("ChatHistoryManager not initialized")
|
||||
return self._manager
|
||||
|
||||
|
||||
def get_chat_history_manager() -> ChatHistoryManagerWithPool:
|
||||
"""获取全局 ChatHistoryManager 单例"""
|
||||
global _global_manager
|
||||
if _global_manager is None:
|
||||
_global_manager = ChatHistoryManagerWithPool()
|
||||
return _global_manager
|
||||
|
||||
|
||||
async def init_chat_history_manager(pool: AsyncConnectionPool) -> None:
|
||||
"""初始化全局聊天历史管理器"""
|
||||
manager = get_chat_history_manager()
|
||||
await manager.initialize(pool)
|
||||
@ -68,6 +68,10 @@ class CheckpointerManager:
|
||||
checkpointer = AsyncPostgresSaver(conn=conn)
|
||||
await checkpointer.setup()
|
||||
|
||||
# 初始化 ChatHistoryManager(复用同一个连接池)
|
||||
from .chat_history_manager import init_chat_history_manager
|
||||
await init_chat_history_manager(self._pool)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("PostgreSQL checkpointer pool initialized successfully")
|
||||
except Exception as e:
|
||||
|
||||
155
routes/chat.py
155
routes/chat.py
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
from typing import Union, Optional
|
||||
from typing import Union, Optional, Any, List, Dict
|
||||
from fastapi import APIRouter, HTTPException, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
import logging
|
||||
@ -34,11 +34,18 @@ async def enhanced_generate_stream_response(
|
||||
agent: LangChain agent 对象
|
||||
config: AgentConfig 对象,包含所有参数
|
||||
"""
|
||||
# 用于收集完整的响应内容,用于保存到数据库
|
||||
full_response_content = []
|
||||
|
||||
try:
|
||||
# 创建输出队列和控制事件
|
||||
output_queue = asyncio.Queue()
|
||||
preamble_completed = asyncio.Event()
|
||||
|
||||
# 在流式开始前保存用户消息
|
||||
if config.session_id:
|
||||
asyncio.create_task(_save_user_messages(config))
|
||||
|
||||
# Preamble 任务
|
||||
async def preamble_task():
|
||||
try:
|
||||
@ -100,8 +107,11 @@ async def enhanced_generate_stream_response(
|
||||
message_tag = "TOOL_RESPONSE"
|
||||
new_content = f"[{message_tag}] {msg.name}\n{msg.text}\n"
|
||||
|
||||
# 发送内容块
|
||||
# 收集完整内容
|
||||
if new_content:
|
||||
full_response_content.append(new_content)
|
||||
|
||||
# 发送内容块
|
||||
if chunk_id == 0:
|
||||
logger.info(f"Agent首个Token已生成, 开始流式输出")
|
||||
chunk_id += 1
|
||||
@ -176,6 +186,10 @@ async def enhanced_generate_stream_response(
|
||||
yield "data: [DONE]\n\n"
|
||||
logger.info(f"Enhanced stream response completed")
|
||||
|
||||
# 流式结束后保存 AI 响应
|
||||
if full_response_content and config.session_id:
|
||||
asyncio.create_task(_save_assistant_response(config, "".join(full_response_content)))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in enhanced_generate_stream_response: {e}")
|
||||
yield f'data: {{"error": "{str(e)}"}}\n\n'
|
||||
@ -197,7 +211,7 @@ async def create_agent_and_generate_response(
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||||
)
|
||||
|
||||
|
||||
agent, checkpointer = await init_agent(config)
|
||||
# 使用更新后的 messages
|
||||
agent_responses = await agent.ainvoke({"messages": config.messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS)
|
||||
@ -244,11 +258,86 @@ async def create_agent_and_generate_response(
|
||||
"total_tokens": sum(len(msg.get("content", "")) for msg in config.messages) + len(response_text)
|
||||
}
|
||||
)
|
||||
|
||||
# 保存聊天历史到数据库(与流式接口保持一致的逻辑)
|
||||
await _save_user_messages(config)
|
||||
await _save_assistant_response(config, response_text)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="No response from agent")
|
||||
|
||||
return result
|
||||
|
||||
async def _save_user_messages(config: AgentConfig) -> None:
|
||||
"""
|
||||
保存最后一条用户消息(用于流式和非流式接口)
|
||||
|
||||
Args:
|
||||
config: AgentConfig 对象
|
||||
"""
|
||||
# 只有在 session_id 存在时才保存
|
||||
if not config.session_id:
|
||||
return
|
||||
|
||||
try:
|
||||
from agent.chat_history_manager import get_chat_history_manager
|
||||
|
||||
manager = get_chat_history_manager()
|
||||
|
||||
# 只保存最后一条 user 消息
|
||||
for msg in reversed(config.messages):
|
||||
if isinstance(msg, dict):
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
if role == "user" and content:
|
||||
await manager.manager.save_message(
|
||||
session_id=config.session_id,
|
||||
role=role,
|
||||
content=content,
|
||||
bot_id=config.bot_id,
|
||||
user_identifier=config.user_identifier
|
||||
)
|
||||
break # 只保存最后一条,然后退出
|
||||
|
||||
logger.debug(f"Saved last user message for session_id={config.session_id}")
|
||||
except Exception as e:
|
||||
# 保存失败不影响主流程
|
||||
logger.error(f"Failed to save user messages: {e}")
|
||||
|
||||
|
||||
async def _save_assistant_response(config: AgentConfig, assistant_response: str) -> None:
|
||||
"""
|
||||
保存 AI 助手的响应(用于流式和非流式接口)
|
||||
|
||||
Args:
|
||||
config: AgentConfig 对象
|
||||
assistant_response: AI 助手的响应内容
|
||||
"""
|
||||
# 只有在 session_id 存在时才保存
|
||||
if not config.session_id:
|
||||
return
|
||||
|
||||
if not assistant_response:
|
||||
return
|
||||
|
||||
try:
|
||||
from agent.chat_history_manager import get_chat_history_manager
|
||||
|
||||
manager = get_chat_history_manager()
|
||||
|
||||
# 保存 AI 助手的响应
|
||||
await manager.manager.save_message(
|
||||
session_id=config.session_id,
|
||||
role="assistant",
|
||||
content=assistant_response,
|
||||
bot_id=config.bot_id,
|
||||
user_identifier=config.user_identifier
|
||||
)
|
||||
|
||||
logger.debug(f"Saved assistant response for session_id={config.session_id}")
|
||||
except Exception as e:
|
||||
# 保存失败不影响主流程
|
||||
logger.error(f"Failed to save assistant response: {e}")
|
||||
|
||||
|
||||
@router.post("/api/v1/chat/completions")
|
||||
async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)):
|
||||
@ -563,3 +652,63 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st
|
||||
logger.error(f"Error in chat_completions_v2: {str(e)}")
|
||||
logger.error(f"Full traceback: {error_details}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 聊天历史查询接口
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/api/v1/chat/history", response_model=dict)
|
||||
async def get_chat_history(
|
||||
session_id: str,
|
||||
last_message_id: Optional[str] = None,
|
||||
limit: int = 20
|
||||
):
|
||||
"""
|
||||
获取聊天历史记录
|
||||
|
||||
从独立的聊天历史表查询,返回完整的原始消息(不受 checkpoint summary 影响)
|
||||
|
||||
参数:
|
||||
session_id: 会话ID
|
||||
last_message_id: 上一页最后一条消息的ID,用于获取更早的消息
|
||||
limit: 每次返回的消息数量,默认 20,最大 100
|
||||
|
||||
返回:
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"id": "唯一消息ID",
|
||||
"role": "user 或 assistant",
|
||||
"content": "消息内容",
|
||||
"timestamp": "ISO 8601 格式的时间戳"
|
||||
},
|
||||
...
|
||||
],
|
||||
"has_more": true/false // 是否还有更多历史消息
|
||||
}
|
||||
"""
|
||||
try:
|
||||
from agent.chat_history_manager import get_chat_history_manager
|
||||
|
||||
# 参数验证
|
||||
limit = min(max(1, limit), 100)
|
||||
|
||||
manager = get_chat_history_manager()
|
||||
result = await manager.manager.get_history_by_message_id(
|
||||
session_id=session_id,
|
||||
last_message_id=last_message_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return {
|
||||
"messages": result["messages"],
|
||||
"has_more": result["has_more"]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(f"Error in get_chat_history: {str(e)}")
|
||||
logger.error(f"Full traceback: {error_details}")
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
@ -356,7 +356,7 @@ def create_chat_response(
|
||||
"""Create a chat completion response"""
|
||||
import time
|
||||
import uuid
|
||||
|
||||
|
||||
return {
|
||||
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
"object": "chat.completion",
|
||||
@ -378,3 +378,28 @@ def create_chat_response(
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 聊天历史查询相关模型
|
||||
# ============================================================================
|
||||
|
||||
class ChatHistoryRequest(BaseModel):
|
||||
"""聊天历史查询请求"""
|
||||
session_id: str = Field(..., description="会话ID (thread_id)")
|
||||
last_message_id: Optional[str] = Field(None, description="上一条消息的ID,用于分页查询更早的消息")
|
||||
limit: int = Field(20, ge=1, le=100, description="每次查询的消息数量上限")
|
||||
|
||||
|
||||
class ChatHistoryMessage(BaseModel):
|
||||
"""聊天历史消息"""
|
||||
id: str = Field(..., description="消息唯一ID")
|
||||
role: str = Field(..., description="消息角色: user 或 assistant")
|
||||
content: str = Field(..., description="消息内容")
|
||||
timestamp: Optional[str] = Field(None, description="消息时间戳 (ISO 8601)")
|
||||
|
||||
|
||||
class ChatHistoryResponse(BaseModel):
|
||||
"""聊天历史查询响应"""
|
||||
messages: List[ChatHistoryMessage] = Field(..., description="消息列表,按时间倒序排列")
|
||||
has_more: bool = Field(..., description="是否还有更多历史消息")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user