181 lines
8.4 KiB
Python
181 lines
8.4 KiB
Python
from langchain.agents.middleware import AgentState, AgentMiddleware, ModelRequest, ModelResponse
|
||
from langchain_core.messages import convert_to_openai_messages
|
||
from agent.prompt_loader import load_guideline_prompt
|
||
from utils.fastapi_utils import (extract_block_from_system_prompt, format_messages_to_chat_history, get_user_last_message_content)
|
||
from langchain.chat_models import BaseChatModel
|
||
from langgraph.runtime import Runtime
|
||
|
||
from langchain_core.messages import SystemMessage, HumanMessage
|
||
from typing import Any, Callable
|
||
from langchain_core.callbacks import BaseCallbackHandler
|
||
from langchain_core.outputs import LLMResult
|
||
from .agent_config import AgentConfig
|
||
import logging
|
||
import re
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
|
||
class GuidelineMiddleware(AgentMiddleware):
|
||
def __init__(self, model:BaseChatModel, config:AgentConfig, prompt: str):
|
||
self.model = model
|
||
self.config = config # 保存完整 config,用于访问 _mem0_context
|
||
self.bot_id = config.bot_id
|
||
|
||
processed_system_prompt, guidelines, tool_description, scenarios, terms_list = extract_block_from_system_prompt(prompt)
|
||
|
||
self.processed_system_prompt = processed_system_prompt
|
||
self.guidelines = guidelines
|
||
self.tool_description = tool_description
|
||
self.scenarios = scenarios
|
||
|
||
self.language = config.language
|
||
self.user_identifier = config.user_identifier
|
||
|
||
self.terms_list = terms_list
|
||
self.messages = config.messages
|
||
|
||
if not self.guidelines:
|
||
self.guidelines = """
|
||
1. General Inquiries
|
||
Condition: User inquiries about products, policies, troubleshooting, factual questions, definitions, workflows, data lookups, or other knowledge-seeking requests.
|
||
Action: First choose the most suitable 【Knowledge Base Retrieval】 tool by scenario. Use table_rag_retrieve first for structured data, lists, statistics, comparisons, extraction, mixed requests, or unclear cases. Use rag_retrieve first only for clearly pure concept / definition / workflow / policy explanation questions. If the first retrieval result is empty, errored, irrelevant, or only partially answers the request, call the other retrieval tool before replying. Only reply that no relevant information was found after both retrieval tools have been tried and still provide no sufficient evidence.
|
||
|
||
2.Social Dialogue
|
||
Condition: User intent involves small talk, greetings, expressions of thanks, compliments, or other non-substantive conversations.
|
||
Action: Provide concise, friendly, and personified natural responses.
|
||
"""
|
||
if not self.tool_description:
|
||
self.tool_description = """
|
||
- **Knowledge Base Retrieval**: Choose retrieval order by scenario. Default to `table_rag_retrieve -> rag_retrieve` for structured, list, mixed, or unclear requests. Use `rag_retrieve -> table_rag_retrieve` only for clearly pure concept or workflow questions. Do not answer with "no result" until both tools have been tried when retrieval is needed.
|
||
"""
|
||
|
||
def get_guideline_prompt(self, config: AgentConfig) -> str:
|
||
"""生成 guideline 提示词
|
||
|
||
Args:
|
||
config: AgentConfig 对象,包含 _session_history 和 _mem0_context
|
||
|
||
Returns:
|
||
str: 生成的 guideline 提示词
|
||
"""
|
||
messages = convert_to_openai_messages(config._session_history)
|
||
memory_text = config._mem0_context
|
||
|
||
# 处理terms(修改 self.processed_system_prompt)
|
||
self.get_term_analysis(messages)
|
||
|
||
guideline_prompt = ""
|
||
if self.guidelines:
|
||
chat_history = format_messages_to_chat_history(messages)
|
||
guideline_prompt = load_guideline_prompt(chat_history, memory_text, self.guidelines, self.tool_description, self.scenarios, self.language, self.user_identifier)
|
||
|
||
return guideline_prompt
|
||
|
||
def get_term_analysis(self, messages: list[dict[str, Any]]) -> str:
|
||
## 处理terms
|
||
terms_analysis = ""
|
||
if self.terms_list:
|
||
logger.info(f"Processing terms: {len(self.terms_list)} terms")
|
||
try:
|
||
from embedding.embedding import process_terms_with_embedding
|
||
query_text = get_user_last_message_content(messages)
|
||
terms_analysis = process_terms_with_embedding(terms_list, self.bot_id, query_text)
|
||
if terms_analysis:
|
||
self.processed_system_prompt = self.processed_system_prompt.replace("#terms#", terms_analysis)
|
||
logger.info(f"Terms analysis completed: {len(terms_analysis)} chars")
|
||
except Exception as e:
|
||
logger.error(f"Error processing terms with embedding: {e}")
|
||
terms_analysis = ""
|
||
else:
|
||
# 当terms_list为空时,删除对应的pkl缓存文件
|
||
try:
|
||
import os
|
||
cache_file = f"projects/cache/{self.bot_id}_terms.pkl"
|
||
if os.path.exists(cache_file):
|
||
os.remove(cache_file)
|
||
logger.info(f"Removed empty terms cache file: {cache_file}")
|
||
except Exception as e:
|
||
logger.error(f"Error removing terms cache file: {e}")
|
||
return terms_analysis
|
||
|
||
|
||
|
||
def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||
if not self.guidelines:
|
||
return None
|
||
|
||
guideline_prompt = self.get_guideline_prompt(self.config)
|
||
# 准备完整的消息列表
|
||
messages = state['messages'].copy()
|
||
|
||
# 将guideline_prompt作为系统消息添加到消息列表
|
||
system_message = SystemMessage(content=guideline_prompt)
|
||
|
||
messages = [system_message,messages[-1]]
|
||
|
||
# 使用回调处理器调用模型
|
||
response = self.model.invoke(
|
||
messages,
|
||
config={"metadata": {"message_tag": "THINK"}}
|
||
)
|
||
|
||
response.additional_kwargs["message_tag"] = "THINK"
|
||
response.content = f"<think>{response.content}</think>"
|
||
|
||
# 将响应添加到原始消息列表,并追加 HumanMessage 确保消息以 user 结尾
|
||
# 某些模型不支持 assistant message prefill,要求最后一条消息必须是 user
|
||
state['messages'] = state['messages'] + [response, HumanMessage(content=self._get_follow_up_prompt())]
|
||
return state
|
||
|
||
async def abefore_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||
if not self.guidelines:
|
||
return None
|
||
# 准备完整的消息列表
|
||
messages = state['messages'].copy()
|
||
|
||
guideline_prompt = self.get_guideline_prompt(self.config)
|
||
|
||
# 将guideline_prompt作为系统消息添加到消息列表
|
||
system_message = SystemMessage(content=guideline_prompt)
|
||
messages = [system_message,messages[-1]]
|
||
|
||
# 使用回调处理器调用模型
|
||
response = await self.model.ainvoke(
|
||
messages,
|
||
config={"metadata": {"message_tag": "THINK"}}
|
||
)
|
||
response.additional_kwargs["message_tag"] = "THINK"
|
||
response.content = f"<think>{response.content}</think>"
|
||
|
||
# 将响应添加到原始消息列表,并追加 HumanMessage 确保消息以 user 结尾
|
||
# 某些模型不支持 assistant message prefill,要求最后一条消息必须是 user
|
||
state['messages'] = state['messages'] + [response, HumanMessage(content=self._get_follow_up_prompt())]
|
||
return state
|
||
|
||
def _get_follow_up_prompt(self) -> str:
|
||
"""根据语言返回引导主 agent 回复的提示"""
|
||
prompts = {
|
||
"ja": "以上の分析に基づいて、ユーザーに返信してください。",
|
||
"jp": "以上の分析に基づいて、ユーザーに返信してください。",
|
||
"zh": "请根据以上分析,回复用户。",
|
||
"zh-TW": "請根據以上分析,回覆用戶。",
|
||
"ko": "위 분석을 바탕으로 사용자에게 답변해 주세요.",
|
||
"en": "Based on the above analysis, please respond to the user.",
|
||
}
|
||
return prompts.get(self.language, prompts["en"])
|
||
|
||
def wrap_model_call(
|
||
self,
|
||
request: ModelRequest,
|
||
handler: Callable[[ModelRequest], ModelResponse],
|
||
) -> ModelResponse:
|
||
return handler(request.override(system_prompt=self.processed_system_prompt))
|
||
|
||
async def awrap_model_call(
|
||
self,
|
||
request: ModelRequest,
|
||
handler: Callable[[ModelRequest], ModelResponse],
|
||
) -> ModelResponse:
|
||
return await handler(request.override(system_prompt=self.processed_system_prompt))
|