From f9ba3c8e51199cc3837bb91b48b4ecf6ad3ad6da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=B1=E6=BD=AE?= Date: Sun, 18 Jan 2026 12:29:20 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=81=8A=E5=A4=A9=E8=AE=B0?= =?UTF-8?q?=E5=BD=95=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent/chat_history_manager.py | 290 ++++++++++++++++++++++++++++++++++ agent/checkpoint_manager.py | 4 + routes/chat.py | 155 +++++++++++++++++- utils/api_models.py | 27 +++- 4 files changed, 472 insertions(+), 4 deletions(-) create mode 100644 agent/chat_history_manager.py diff --git a/agent/chat_history_manager.py b/agent/chat_history_manager.py new file mode 100644 index 0000000..f2e2604 --- /dev/null +++ b/agent/chat_history_manager.py @@ -0,0 +1,290 @@ +""" +聊天历史记录管理器 +直接保存完整的原始聊天消息到数据库,不受 checkpoint summary 影响 +""" + +import logging +from datetime import datetime +from typing import Optional, List, Dict, Any +from dataclasses import dataclass + +from psycopg_pool import AsyncConnectionPool + +logger = logging.getLogger('app') + + +@dataclass +class ChatMessage: + """聊天消息""" + id: str + session_id: str + role: str + content: str + created_at: datetime + + +class ChatHistoryManager: + """ + 聊天历史管理器 + + 使用独立的数据库表存储完整的聊天历史记录 + 复用 checkpoint_manager 的 PostgreSQL 连接池 + """ + + def __init__(self, pool: AsyncConnectionPool): + """ + 初始化聊天历史管理器 + + Args: + pool: PostgreSQL 连接池 + """ + self._pool = pool + + async def create_table(self) -> None: + """创建 chat_messages 表""" + async with self._pool.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute(""" + CREATE TABLE IF NOT EXISTS chat_messages ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + session_id VARCHAR(255) NOT NULL, + role VARCHAR(50) NOT NULL, + content TEXT NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), + bot_id VARCHAR(255), + user_identifier VARCHAR(255) + ) + """) + # 创建索引以加速查询 + await cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_chat_messages_session_created + ON chat_messages (session_id, created_at DESC) + """) + await conn.commit() + logger.info("chat_messages table created successfully") + + async def save_message( + self, + session_id: str, + role: str, + content: str, + bot_id: Optional[str] = None, + user_identifier: Optional[str] = None + ) -> str: + """ + 保存一条聊天消息 + + Args: + session_id: 会话ID + role: 消息角色 ('user' 或 'assistant') + content: 消息内容 + bot_id: 机器人ID + user_identifier: 用户标识 + + Returns: + str: 消息ID + """ + async with self._pool.connection() as conn: + async with conn.cursor() as cursor: + await cursor.execute(""" + INSERT INTO chat_messages (session_id, role, content, bot_id, user_identifier) + VALUES (%s, %s, %s, %s, %s) + RETURNING id + """, (session_id, role, content, bot_id, user_identifier)) + result = await cursor.fetchone() + await conn.commit() + message_id = str(result[0]) if result else None + logger.debug(f"Saved message: session_id={session_id}, role={role}, id={message_id}") + return message_id + + async def save_messages( + self, + session_id: str, + messages: List[Dict[str, str]], + bot_id: Optional[str] = None, + user_identifier: Optional[str] = None + ) -> List[str]: + """ + 批量保存聊天消息 + + Args: + session_id: 会话ID + messages: 消息列表,每条消息包含 role 和 content + bot_id: 机器人ID + user_identifier: 用户标识 + + Returns: + List[str]: 消息ID列表 + """ + message_ids = [] + async with self._pool.connection() as conn: + async with conn.cursor() as cursor: + for msg in messages: + role = msg.get('role') + content = msg.get('content', '') + if not role or not content: + continue + await cursor.execute(""" + INSERT INTO chat_messages (session_id, role, content, bot_id, user_identifier) + VALUES (%s, %s, %s, %s, %s) + RETURNING id + """, (session_id, role, content, bot_id, user_identifier)) + result = await cursor.fetchone() + message_id = str(result[0]) if result else None + message_ids.append(message_id) + await conn.commit() + logger.info(f"Saved {len(message_ids)} messages for session_id={session_id}") + return message_ids + + async def get_history( + self, + session_id: str, + limit: int = 20, + before_id: Optional[str] = None + ) -> Dict[str, Any]: + """ + 获取聊天历史记录(倒序,最新在前) + + Args: + session_id: 会话ID + limit: 返回的消息数量 + before_id: 获取此消息ID之前的消息(用于分页) + + Returns: + dict: { + "messages": [...], + "has_more": bool + } + """ + async with self._pool.connection() as conn: + async with conn.cursor() as cursor: + if before_id: + # 查询 before_id 的 created_at,然后获取更早的消息 + await cursor.execute(""" + SELECT created_at FROM chat_messages + WHERE id = %s + """, (before_id,)) + result = await cursor.fetchone() + if not result: + # 如果找不到指定的 ID,从头开始 + query = """ + SELECT id, session_id, role, content, created_at + FROM chat_messages + WHERE session_id = %s + ORDER BY created_at DESC + LIMIT %s + 1 + """ + await cursor.execute(query, (session_id, limit)) + else: + before_time = result[0] + query = """ + SELECT id, session_id, role, content, created_at + FROM chat_messages + WHERE session_id = %s AND created_at < %s + ORDER BY created_at DESC + LIMIT %s + 1 + """ + await cursor.execute(query, (session_id, before_time, limit)) + else: + query = """ + SELECT id, session_id, role, content, created_at + FROM chat_messages + WHERE session_id = %s + ORDER BY created_at DESC + LIMIT %s + 1 + """ + await cursor.execute(query, (session_id, limit)) + + rows = await cursor.fetchall() + + # 判断是否有更多(多取一条用于判断) + has_more = len(rows) > limit + if has_more: + rows = rows[:limit] + + messages = [] + for row in rows: + messages.append({ + "id": str(row[0]), + "role": row[2], + "content": row[3], + "timestamp": row[4].isoformat() if row[4] else None + }) + + return { + "messages": messages, + "has_more": has_more + } + + async def get_history_by_message_id( + self, + session_id: str, + last_message_id: Optional[str] = None, + limit: int = 20 + ) -> Dict[str, Any]: + """ + 根据 last_message_id 获取更早的历史记录 + + Args: + session_id: 会话ID + last_message_id: 上一页最后一条消息的ID + limit: 返回的消息数量 + + Returns: + dict: { + "messages": [...], + "has_more": bool + } + """ + return await self.get_history( + session_id=session_id, + limit=limit, + before_id=last_message_id + ) + + +# 全局单例 +_global_manager: Optional['ChatHistoryManagerWithPool'] = None + + +class ChatHistoryManagerWithPool: + """ + 带连接池的聊天历史管理器单例 + 复用 checkpoint_manager 的连接池 + """ + def __init__(self): + self._pool: Optional[AsyncConnectionPool] = None + self._manager: Optional[ChatHistoryManager] = None + self._initialized = False + + async def initialize(self, pool: AsyncConnectionPool) -> None: + """初始化管理器""" + if self._initialized: + return + + self._pool = pool + self._manager = ChatHistoryManager(pool) + await self._manager.create_table() + self._initialized = True + logger.info("ChatHistoryManager initialized successfully") + + @property + def manager(self) -> ChatHistoryManager: + """获取 ChatHistoryManager 实例""" + if not self._initialized or not self._manager: + raise RuntimeError("ChatHistoryManager not initialized") + return self._manager + + +def get_chat_history_manager() -> ChatHistoryManagerWithPool: + """获取全局 ChatHistoryManager 单例""" + global _global_manager + if _global_manager is None: + _global_manager = ChatHistoryManagerWithPool() + return _global_manager + + +async def init_chat_history_manager(pool: AsyncConnectionPool) -> None: + """初始化全局聊天历史管理器""" + manager = get_chat_history_manager() + await manager.initialize(pool) diff --git a/agent/checkpoint_manager.py b/agent/checkpoint_manager.py index 7627267..8b96e27 100644 --- a/agent/checkpoint_manager.py +++ b/agent/checkpoint_manager.py @@ -68,6 +68,10 @@ class CheckpointerManager: checkpointer = AsyncPostgresSaver(conn=conn) await checkpointer.setup() + # 初始化 ChatHistoryManager(复用同一个连接池) + from .chat_history_manager import init_chat_history_manager + await init_chat_history_manager(self._pool) + self._initialized = True logger.info("PostgreSQL checkpointer pool initialized successfully") except Exception as e: diff --git a/routes/chat.py b/routes/chat.py index e69dc69..c41b378 100644 --- a/routes/chat.py +++ b/routes/chat.py @@ -1,7 +1,7 @@ import json import os import asyncio -from typing import Union, Optional +from typing import Union, Optional, Any, List, Dict from fastapi import APIRouter, HTTPException, Header from fastapi.responses import StreamingResponse import logging @@ -34,11 +34,18 @@ async def enhanced_generate_stream_response( agent: LangChain agent 对象 config: AgentConfig 对象,包含所有参数 """ + # 用于收集完整的响应内容,用于保存到数据库 + full_response_content = [] + try: # 创建输出队列和控制事件 output_queue = asyncio.Queue() preamble_completed = asyncio.Event() + # 在流式开始前保存用户消息 + if config.session_id: + asyncio.create_task(_save_user_messages(config)) + # Preamble 任务 async def preamble_task(): try: @@ -100,8 +107,11 @@ async def enhanced_generate_stream_response( message_tag = "TOOL_RESPONSE" new_content = f"[{message_tag}] {msg.name}\n{msg.text}\n" - # 发送内容块 + # 收集完整内容 if new_content: + full_response_content.append(new_content) + + # 发送内容块 if chunk_id == 0: logger.info(f"Agent首个Token已生成, 开始流式输出") chunk_id += 1 @@ -176,6 +186,10 @@ async def enhanced_generate_stream_response( yield "data: [DONE]\n\n" logger.info(f"Enhanced stream response completed") + # 流式结束后保存 AI 响应 + if full_response_content and config.session_id: + asyncio.create_task(_save_assistant_response(config, "".join(full_response_content))) + except Exception as e: logger.error(f"Error in enhanced_generate_stream_response: {e}") yield f'data: {{"error": "{str(e)}"}}\n\n' @@ -197,7 +211,7 @@ async def create_agent_and_generate_response( media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"} ) - + agent, checkpointer = await init_agent(config) # 使用更新后的 messages agent_responses = await agent.ainvoke({"messages": config.messages}, config=config.invoke_config(), max_tokens=MAX_OUTPUT_TOKENS) @@ -244,11 +258,86 @@ async def create_agent_and_generate_response( "total_tokens": sum(len(msg.get("content", "")) for msg in config.messages) + len(response_text) } ) + + # 保存聊天历史到数据库(与流式接口保持一致的逻辑) + await _save_user_messages(config) + await _save_assistant_response(config, response_text) else: raise HTTPException(status_code=500, detail="No response from agent") return result +async def _save_user_messages(config: AgentConfig) -> None: + """ + 保存最后一条用户消息(用于流式和非流式接口) + + Args: + config: AgentConfig 对象 + """ + # 只有在 session_id 存在时才保存 + if not config.session_id: + return + + try: + from agent.chat_history_manager import get_chat_history_manager + + manager = get_chat_history_manager() + + # 只保存最后一条 user 消息 + for msg in reversed(config.messages): + if isinstance(msg, dict): + role = msg.get("role", "") + content = msg.get("content", "") + if role == "user" and content: + await manager.manager.save_message( + session_id=config.session_id, + role=role, + content=content, + bot_id=config.bot_id, + user_identifier=config.user_identifier + ) + break # 只保存最后一条,然后退出 + + logger.debug(f"Saved last user message for session_id={config.session_id}") + except Exception as e: + # 保存失败不影响主流程 + logger.error(f"Failed to save user messages: {e}") + + +async def _save_assistant_response(config: AgentConfig, assistant_response: str) -> None: + """ + 保存 AI 助手的响应(用于流式和非流式接口) + + Args: + config: AgentConfig 对象 + assistant_response: AI 助手的响应内容 + """ + # 只有在 session_id 存在时才保存 + if not config.session_id: + return + + if not assistant_response: + return + + try: + from agent.chat_history_manager import get_chat_history_manager + + manager = get_chat_history_manager() + + # 保存 AI 助手的响应 + await manager.manager.save_message( + session_id=config.session_id, + role="assistant", + content=assistant_response, + bot_id=config.bot_id, + user_identifier=config.user_identifier + ) + + logger.debug(f"Saved assistant response for session_id={config.session_id}") + except Exception as e: + # 保存失败不影响主流程 + logger.error(f"Failed to save assistant response: {e}") + @router.post("/api/v1/chat/completions") async def chat_completions(request: ChatRequest, authorization: Optional[str] = Header(None)): @@ -563,3 +652,63 @@ async def chat_completions_v2(request: ChatRequestV2, authorization: Optional[st logger.error(f"Error in chat_completions_v2: {str(e)}") logger.error(f"Full traceback: {error_details}") raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + + +# ============================================================================ +# 聊天历史查询接口 +# ============================================================================ + +@router.get("/api/v1/chat/history", response_model=dict) +async def get_chat_history( + session_id: str, + last_message_id: Optional[str] = None, + limit: int = 20 +): + """ + 获取聊天历史记录 + + 从独立的聊天历史表查询,返回完整的原始消息(不受 checkpoint summary 影响) + + 参数: + session_id: 会话ID + last_message_id: 上一页最后一条消息的ID,用于获取更早的消息 + limit: 每次返回的消息数量,默认 20,最大 100 + + 返回: + { + "messages": [ + { + "id": "唯一消息ID", + "role": "user 或 assistant", + "content": "消息内容", + "timestamp": "ISO 8601 格式的时间戳" + }, + ... + ], + "has_more": true/false // 是否还有更多历史消息 + } + """ + try: + from agent.chat_history_manager import get_chat_history_manager + + # 参数验证 + limit = min(max(1, limit), 100) + + manager = get_chat_history_manager() + result = await manager.manager.get_history_by_message_id( + session_id=session_id, + last_message_id=last_message_id, + limit=limit + ) + + return { + "messages": result["messages"], + "has_more": result["has_more"] + } + + except Exception as e: + import traceback + error_details = traceback.format_exc() + logger.error(f"Error in get_chat_history: {str(e)}") + logger.error(f"Full traceback: {error_details}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") diff --git a/utils/api_models.py b/utils/api_models.py index 1ea2d91..2db59d3 100644 --- a/utils/api_models.py +++ b/utils/api_models.py @@ -356,7 +356,7 @@ def create_chat_response( """Create a chat completion response""" import time import uuid - + return { "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", "object": "chat.completion", @@ -378,3 +378,28 @@ def create_chat_response( "total_tokens": 0 } } + + +# ============================================================================ +# 聊天历史查询相关模型 +# ============================================================================ + +class ChatHistoryRequest(BaseModel): + """聊天历史查询请求""" + session_id: str = Field(..., description="会话ID (thread_id)") + last_message_id: Optional[str] = Field(None, description="上一条消息的ID,用于分页查询更早的消息") + limit: int = Field(20, ge=1, le=100, description="每次查询的消息数量上限") + + +class ChatHistoryMessage(BaseModel): + """聊天历史消息""" + id: str = Field(..., description="消息唯一ID") + role: str = Field(..., description="消息角色: user 或 assistant") + content: str = Field(..., description="消息内容") + timestamp: Optional[str] = Field(None, description="消息时间戳 (ISO 8601)") + + +class ChatHistoryResponse(BaseModel): + """聊天历史查询响应""" + messages: List[ChatHistoryMessage] = Field(..., description="消息列表,按时间倒序排列") + has_more: bool = Field(..., description="是否还有更多历史消息")