session_id
This commit is contained in:
parent
d9ee1edf8a
commit
0d50cd8e9f
@ -1,15 +1,22 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
# from deepagents import create_deep_agent
|
# from deepagents import create_deep_agent
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
|
from langchain.agents.middleware import SummarizationMiddleware
|
||||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||||
from langchain_core.callbacks import BaseCallbackHandler
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
from utils.fastapi_utils import detect_provider
|
from utils.fastapi_utils import detect_provider
|
||||||
|
|
||||||
from .guideline_middleware import GuidelineMiddleware
|
from .guideline_middleware import GuidelineMiddleware
|
||||||
|
|
||||||
|
MAX_CONTEXT_TOKENS = int(os.getenv("MAX_CONTEXT_TOKENS", 65536))
|
||||||
|
MAX_OUTPUT_TOKENS = int(os.getenv("MAX_OUTPUT_TOKENS", 8000))
|
||||||
|
SUMMARIZATION_MAX_TOKENS = MAX_CONTEXT_TOKENS - MAX_OUTPUT_TOKENS - 1000
|
||||||
|
|
||||||
class LoggingCallbackHandler(BaseCallbackHandler):
|
class LoggingCallbackHandler(BaseCallbackHandler):
|
||||||
"""自定义的 CallbackHandler,使用项目的 logger 来记录日志"""
|
"""自定义的 CallbackHandler,使用项目的 logger 来记录日志"""
|
||||||
@ -120,7 +127,24 @@ async def get_tools_from_mcp(mcp):
|
|||||||
|
|
||||||
async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None,
|
async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None,
|
||||||
model_server=None, generate_cfg=None,
|
model_server=None, generate_cfg=None,
|
||||||
system_prompt=None, mcp=None, robot_type=None, language="jp", user_identifier=None):
|
system_prompt=None, mcp=None, robot_type=None, language="jp", user_identifier=None,
|
||||||
|
session_id=None):
|
||||||
|
"""
|
||||||
|
初始化 Agent,支持持久化内存和对话摘要
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bot_id: Bot ID
|
||||||
|
model_name: 模型名称
|
||||||
|
api_key: API密钥
|
||||||
|
model_server: 模型服务器地址
|
||||||
|
generate_cfg: 生成配置
|
||||||
|
system_prompt: 系统提示
|
||||||
|
mcp: MCP配置
|
||||||
|
robot_type: 机器人类型
|
||||||
|
language: 语言
|
||||||
|
user_identifier: 用户标识
|
||||||
|
session_id: 会话ID(如果为None,则不启用持久化内存)
|
||||||
|
"""
|
||||||
system = system_prompt if system_prompt else read_system_prompt()
|
system = system_prompt if system_prompt else read_system_prompt()
|
||||||
mcp = mcp if mcp else read_mcp_settings()
|
mcp = mcp if mcp else read_mcp_settings()
|
||||||
mcp_tools = await get_tools_from_mcp(mcp)
|
mcp_tools = await get_tools_from_mcp(mcp)
|
||||||
@ -143,14 +167,33 @@ async def init_agent(bot_id: str, model_name="qwen3-next", api_key=None,
|
|||||||
# 创建自定义的日志处理器
|
# 创建自定义的日志处理器
|
||||||
logging_handler = LoggingCallbackHandler()
|
logging_handler = LoggingCallbackHandler()
|
||||||
|
|
||||||
|
# 构建中间件列表
|
||||||
|
middleware = [GuidelineMiddleware(bot_id, llm_instance, system, robot_type, language, user_identifier)]
|
||||||
|
|
||||||
|
# 初始化 checkpointer 和中间件
|
||||||
|
checkpointer = None
|
||||||
|
|
||||||
|
if session_id:
|
||||||
|
checkpointer = MemorySaver()
|
||||||
|
summarization_middleware = SummarizationMiddleware(
|
||||||
|
model=llm_instance,
|
||||||
|
max_tokens_before_summary=SUMMARIZATION_MAX_TOKENS,
|
||||||
|
messages_to_keep=20, # 摘要后保留最近 20 条消息
|
||||||
|
summary_prompt="请简洁地总结以上对话的要点,包括重要的用户信息、讨论过的话题和关键结论。"
|
||||||
|
)
|
||||||
|
middleware.append(summarization_middleware)
|
||||||
|
|
||||||
agent = create_agent(
|
agent = create_agent(
|
||||||
model=llm_instance,
|
model=llm_instance,
|
||||||
system_prompt=system,
|
system_prompt=system,
|
||||||
tools=mcp_tools,
|
tools=mcp_tools,
|
||||||
middleware=[GuidelineMiddleware(bot_id, llm_instance, system, robot_type, language, user_identifier)]
|
middleware=middleware,
|
||||||
|
checkpointer=checkpointer # 传入 checkpointer 以启用持久化
|
||||||
)
|
)
|
||||||
|
|
||||||
# 将 handler 存储在 agent 的属性中,方便在调用时使用
|
# 将 handler 和 checkpointer 存储在 agent 的属性中,方便在调用时使用
|
||||||
agent.logging_handler = logging_handler
|
agent.logging_handler = logging_handler
|
||||||
|
agent.checkpointer = checkpointer
|
||||||
|
agent.bot_id = bot_id
|
||||||
|
agent.session_id = session_id
|
||||||
return agent
|
return agent
|
||||||
|
|||||||
@ -127,7 +127,8 @@ class ShardedAgentManager:
|
|||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
mcp_settings: Optional[List[Dict]] = None,
|
mcp_settings: Optional[List[Dict]] = None,
|
||||||
robot_type: Optional[str] = "general_agent",
|
robot_type: Optional[str] = "general_agent",
|
||||||
user_identifier: Optional[str] = None):
|
user_identifier: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None):
|
||||||
"""获取或创建文件预加载的助手实例"""
|
"""获取或创建文件预加载的助手实例"""
|
||||||
|
|
||||||
# 更新请求统计
|
# 更新请求统计
|
||||||
@ -201,6 +202,7 @@ class ShardedAgentManager:
|
|||||||
robot_type=robot_type,
|
robot_type=robot_type,
|
||||||
language=language,
|
language=language,
|
||||||
user_identifier=user_identifier,
|
user_identifier=user_identifier,
|
||||||
|
session_id=session_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 缓存实例
|
# 缓存实例
|
||||||
|
|||||||
@ -82,7 +82,8 @@ async def enhanced_generate_stream_response(
|
|||||||
robot_type: str,
|
robot_type: str,
|
||||||
project_dir: Optional[str],
|
project_dir: Optional[str],
|
||||||
generate_cfg: Optional[dict],
|
generate_cfg: Optional[dict],
|
||||||
user_identifier: Optional[str]
|
user_identifier: Optional[str],
|
||||||
|
session_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""增强的渐进式流式响应生成器 - 并发优化版本"""
|
"""增强的渐进式流式响应生成器 - 并发优化版本"""
|
||||||
try:
|
try:
|
||||||
@ -133,7 +134,8 @@ async def enhanced_generate_stream_response(
|
|||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
mcp_settings=mcp_settings,
|
mcp_settings=mcp_settings,
|
||||||
robot_type=robot_type,
|
robot_type=robot_type,
|
||||||
user_identifier=user_identifier
|
user_identifier=user_identifier,
|
||||||
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 开始流式处理
|
# 开始流式处理
|
||||||
@ -141,7 +143,11 @@ async def enhanced_generate_stream_response(
|
|||||||
chunk_id = 0
|
chunk_id = 0
|
||||||
message_tag = ""
|
message_tag = ""
|
||||||
|
|
||||||
config = {"callbacks": [agent.logging_handler]} if hasattr(agent, 'logging_handler') else {}
|
config = {}
|
||||||
|
if session_id:
|
||||||
|
config["configurable"] = {"thread_id": session_id}
|
||||||
|
if hasattr(agent, 'logging_handler'):
|
||||||
|
config["callbacks"] = [agent.logging_handler]
|
||||||
async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config):
|
async for msg, metadata in agent.astream({"messages": messages}, stream_mode="messages", config=config):
|
||||||
new_content = ""
|
new_content = ""
|
||||||
|
|
||||||
@ -265,7 +271,8 @@ async def create_agent_and_generate_response(
|
|||||||
robot_type: str,
|
robot_type: str,
|
||||||
project_dir: Optional[str] = None,
|
project_dir: Optional[str] = None,
|
||||||
generate_cfg: Optional[dict] = None,
|
generate_cfg: Optional[dict] = None,
|
||||||
user_identifier: Optional[str] = None
|
user_identifier: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None
|
||||||
) -> Union[ChatResponse, StreamingResponse]:
|
) -> Union[ChatResponse, StreamingResponse]:
|
||||||
"""创建agent并生成响应的公共逻辑"""
|
"""创建agent并生成响应的公共逻辑"""
|
||||||
if generate_cfg is None:
|
if generate_cfg is None:
|
||||||
@ -288,7 +295,8 @@ async def create_agent_and_generate_response(
|
|||||||
robot_type=robot_type,
|
robot_type=robot_type,
|
||||||
project_dir=project_dir,
|
project_dir=project_dir,
|
||||||
generate_cfg=generate_cfg,
|
generate_cfg=generate_cfg,
|
||||||
user_identifier=user_identifier
|
user_identifier=user_identifier,
|
||||||
|
session_id=session_id
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
|
||||||
@ -307,14 +315,19 @@ async def create_agent_and_generate_response(
|
|||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
mcp_settings=mcp_settings,
|
mcp_settings=mcp_settings,
|
||||||
robot_type=robot_type,
|
robot_type=robot_type,
|
||||||
user_identifier=user_identifier
|
user_identifier=user_identifier,
|
||||||
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 准备最终的消息
|
# 准备最终的消息
|
||||||
final_messages = messages.copy()
|
final_messages = messages.copy()
|
||||||
|
|
||||||
# 非流式响应
|
# 非流式响应
|
||||||
config = {"callbacks": [agent.logging_handler]} if hasattr(agent, 'logging_handler') else {}
|
config = {}
|
||||||
|
if session_id:
|
||||||
|
config["configurable"] = {"thread_id": session_id}
|
||||||
|
if hasattr(agent, 'logging_handler'):
|
||||||
|
config["callbacks"] = [agent.logging_handler]
|
||||||
agent_responses = await agent.ainvoke({"messages": final_messages}, config=config)
|
agent_responses = await agent.ainvoke({"messages": final_messages}, config=config)
|
||||||
append_messages = agent_responses["messages"][len(final_messages):]
|
append_messages = agent_responses["messages"][len(final_messages):]
|
||||||
response_text = ""
|
response_text = ""
|
||||||
@ -396,7 +409,7 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
|
|||||||
project_dir = create_project_directory(request.dataset_ids, bot_id, request.robot_type)
|
project_dir = create_project_directory(request.dataset_ids, bot_id, request.robot_type)
|
||||||
|
|
||||||
# 收集额外参数作为 generate_cfg
|
# 收集额外参数作为 generate_cfg
|
||||||
exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier'}
|
exclude_fields = {'messages', 'model', 'model_server', 'dataset_ids', 'language', 'tool_response', 'system_prompt', 'mcp_settings' ,'stream', 'robot_type', 'bot_id', 'user_identifier', 'session_id'}
|
||||||
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
|
generate_cfg = {k: v for k, v in request.model_dump().items() if k not in exclude_fields}
|
||||||
|
|
||||||
# 处理消息
|
# 处理消息
|
||||||
@ -417,7 +430,8 @@ async def chat_completions(request: ChatRequest, authorization: Optional[str] =
|
|||||||
robot_type=request.robot_type,
|
robot_type=request.robot_type,
|
||||||
project_dir=project_dir,
|
project_dir=project_dir,
|
||||||
generate_cfg=generate_cfg,
|
generate_cfg=generate_cfg,
|
||||||
user_identifier=request.user_identifier
|
user_identifier=request.user_identifier,
|
||||||
|
session_id=request.session_id
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -52,6 +52,7 @@ class ChatRequest(BaseModel):
|
|||||||
mcp_settings: Optional[List[Dict]] = None
|
mcp_settings: Optional[List[Dict]] = None
|
||||||
robot_type: Optional[str] = "general_agent"
|
robot_type: Optional[str] = "general_agent"
|
||||||
user_identifier: Optional[str] = ""
|
user_identifier: Optional[str] = ""
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatRequestV2(BaseModel):
|
class ChatRequestV2(BaseModel):
|
||||||
@ -61,6 +62,7 @@ class ChatRequestV2(BaseModel):
|
|||||||
bot_id: str
|
bot_id: str
|
||||||
language: Optional[str] = "zh"
|
language: Optional[str] = "zh"
|
||||||
user_identifier: Optional[str] = ""
|
user_identifier: Optional[str] = ""
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class FileProcessRequest(BaseModel):
|
class FileProcessRequest(BaseModel):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user