from ast import Str
from langchain.agents.middleware import AgentState, AgentMiddleware, ModelRequest, ModelResponse
from langchain_core.messages import convert_to_openai_messages
from agent.prompt_loader import load_guideline_prompt
from utils.fastapi_utils import (extract_block_from_system_prompt, format_messages_to_chat_history, get_user_last_message_content)
from langchain.chat_models import BaseChatModel
from langgraph.runtime import Runtime
from langchain_core.messages import SystemMessage
from typing import Any, Callable
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from .agent_config import AgentConfig
import logging
import re
logger = logging.getLogger('app')
class GuidelineMiddleware(AgentMiddleware):
def __init__(self, model:BaseChatModel, config:AgentConfig, prompt: str):
self.model = model
self.config = config # 保存完整 config,用于访问 _mem0_context
self.bot_id = config.bot_id
processed_system_prompt, guidelines, tool_description, scenarios, terms_list = extract_block_from_system_prompt(prompt)
self.processed_system_prompt = processed_system_prompt
self.guidelines = guidelines
self.tool_description = tool_description
self.scenarios = scenarios
self.language = config.language
self.user_identifier = config.user_identifier
self.terms_list = terms_list
self.messages = config.messages
if not self.guidelines:
self.guidelines = """
1. General Inquiries
Condition: User inquiries about products, policies, troubleshooting, factual questions, etc.
Action: Priority given to invoking the 【Knowledge Base Retrieval】 tool to query the knowledge base.
2.Social Dialogue
Condition: User intent involves small talk, greetings, expressions of thanks, compliments, or other non-substantive conversations.
Action: Provide concise, friendly, and personified natural responses.
"""
if not self.tool_description:
self.tool_description = """
- **Knowledge Base Retrieval**: For knowledge queries/other inquiries, prioritize searching the knowledge base → rag_retrieve-rag_retrieve
"""
def get_guideline_prompt(self, config: AgentConfig) -> str:
"""生成 guideline 提示词
Args:
config: AgentConfig 对象,包含 _session_history 和 _mem0_context
Returns:
str: 生成的 guideline 提示词
"""
messages = convert_to_openai_messages(config._session_history)
memory_text = config._mem0_context
# 处理terms(修改 self.processed_system_prompt)
self.get_term_analysis(messages)
guideline_prompt = ""
if self.guidelines:
chat_history = format_messages_to_chat_history(messages)
guideline_prompt = load_guideline_prompt(chat_history, memory_text, self.guidelines, self.tool_description, self.scenarios, self.language, self.user_identifier)
return guideline_prompt
def get_term_analysis(self, messages: list[dict[str, Any]]) -> str:
## 处理terms
terms_analysis = ""
if self.terms_list:
logger.info(f"Processing terms: {len(self.terms_list)} terms")
try:
from embedding.embedding import process_terms_with_embedding
query_text = get_user_last_message_content(messages)
terms_analysis = process_terms_with_embedding(terms_list, self.bot_id, query_text)
if terms_analysis:
self.processed_system_prompt = self.processed_system_prompt.replace("#terms#", terms_analysis)
logger.info(f"Terms analysis completed: {len(terms_analysis)} chars")
except Exception as e:
logger.error(f"Error processing terms with embedding: {e}")
terms_analysis = ""
else:
# 当terms_list为空时,删除对应的pkl缓存文件
try:
import os
cache_file = f"projects/cache/{self.bot_id}_terms.pkl"
if os.path.exists(cache_file):
os.remove(cache_file)
logger.info(f"Removed empty terms cache file: {cache_file}")
except Exception as e:
logger.error(f"Error removing terms cache file: {e}")
return terms_analysis
def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
if not self.guidelines:
return None
guideline_prompt = self.get_guideline_prompt(self.config)
# 准备完整的消息列表
messages = state['messages'].copy()
# 将guideline_prompt作为系统消息添加到消息列表
system_message = SystemMessage(content=guideline_prompt)
messages = [system_message,messages[-1]]
# 使用回调处理器调用模型
response = self.model.invoke(
messages,
config={"metadata": {"message_tag": "THINK"}}
)
response.additional_kwargs["message_tag"] = "THINK"
response.content = f"{response.content}"
# 将响应添加到原始消息列表
state['messages'] = state['messages'] + [response]
return state
async def abefore_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
if not self.guidelines:
return None
# 准备完整的消息列表
messages = state['messages'].copy()
guideline_prompt = self.get_guideline_prompt(self.config)
# 将guideline_prompt作为系统消息添加到消息列表
system_message = SystemMessage(content=guideline_prompt)
messages = [system_message,messages[-1]]
# 使用回调处理器调用模型
response = await self.model.ainvoke(
messages,
config={"metadata": {"message_tag": "THINK"}}
)
response.additional_kwargs["message_tag"] = "THINK"
response.content = f"{response.content}"
# 将响应添加到原始消息列表
state['messages'] = state['messages'] + [response]
return state
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
return handler(request.override(system_prompt=self.processed_system_prompt))
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
return await handler(request.override(system_prompt=self.processed_system_prompt))