qwen_agent/agent/guideline_middleware.py
2025-12-17 23:25:40 +08:00

163 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ast import Str
from langchain.agents.middleware import AgentState, AgentMiddleware, ModelRequest, ModelResponse
from langchain_core.messages import convert_to_openai_messages
from agent.prompt_loader import load_guideline_prompt
from utils.fastapi_utils import (extract_block_from_system_prompt, format_messages_to_chat_history, get_user_last_message_content)
from langchain.chat_models import BaseChatModel
from langgraph.runtime import Runtime
from langchain_core.messages import SystemMessage
from typing import Any, Callable
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from .agent_config import AgentConfig
import logging
import re
logger = logging.getLogger('app')
class GuidelineMiddleware(AgentMiddleware):
def __init__(self, model:BaseChatModel, config:AgentConfig, prompt: str):
self.model = model
self.bot_id = config.bot_id
processed_system_prompt, guidelines, tool_description, scenarios, terms_list = extract_block_from_system_prompt(prompt)
self.processed_system_prompt = processed_system_prompt
self.guidelines = guidelines
self.tool_description = tool_description
self.scenarios = scenarios
self.language = config.language
self.user_identifier = config.user_identifier
self.robot_type = config.robot_type
self.terms_list = terms_list
self.messages = config._origin_messages
if self.robot_type == "general_agent":
if not self.guidelines:
self.guidelines = """
1. General Inquiries
Condition: User inquiries about products, policies, troubleshooting, factual questions, etc.
Action: Priority given to invoking the 【Knowledge Base Retrieval】 tool to query the knowledge base.
2.Social Dialogue
Condition: User intent involves small talk, greetings, expressions of thanks, compliments, or other non-substantive conversations.
Action: Provide concise, friendly, and personified natural responses.
"""
if not self.tool_description:
self.tool_description = """
- **Knowledge Base Retrieval**: For knowledge queries/other inquiries, prioritize searching the knowledge base → rag_retrieve-rag_retrieve
"""
def get_guideline_prompt(self, messages: list[dict[str, Any]]) -> str:
## 处理terms
terms_analysis = self.get_term_analysis(messages)
guideline_prompt = ""
if self.guidelines:
chat_history = format_messages_to_chat_history(messages)
guideline_prompt = load_guideline_prompt(chat_history, self.guidelines, self.tool_description, self.scenarios, terms_analysis, self.language, self.user_identifier)
return guideline_prompt
def get_term_analysis(self, messages: list[dict[str, Any]]) -> str:
## 处理terms
terms_analysis = ""
if self.terms_list:
logger.info(f"Processing terms: {len(self.terms_list)} terms")
try:
from embedding.embedding import process_terms_with_embedding
query_text = get_user_last_message_content(messages)
terms_analysis = process_terms_with_embedding(terms_list, self.bot_id, query_text)
if terms_analysis:
self.processed_system_prompt = self.processed_system_prompt.replace("#terms#", terms_analysis)
logger.info(f"Terms analysis completed: {len(terms_analysis)} chars")
except Exception as e:
logger.error(f"Error processing terms with embedding: {e}")
terms_analysis = ""
else:
# 当terms_list为空时删除对应的pkl缓存文件
try:
import os
cache_file = f"projects/cache/{self.bot_id}_terms.pkl"
if os.path.exists(cache_file):
os.remove(cache_file)
logger.info(f"Removed empty terms cache file: {cache_file}")
except Exception as e:
logger.error(f"Error removing terms cache file: {e}")
return terms_analysis
def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
if not self.guidelines:
return None
guideline_prompt = self.get_guideline_prompt(convert_to_openai_messages(self._origin_messages))
# 准备完整的消息列表
messages = state['messages'].copy()
# 将guideline_prompt作为系统消息添加到消息列表
system_message = SystemMessage(content=guideline_prompt)
messages = [system_message,messages[-1]]
# 使用回调处理器调用模型
response = self.model.invoke(
messages,
config={"metadata": {"message_tag": "THINK"}}
)
response.additional_kwargs["message_tag"] = "THINK"
response.content = f"<think>{response.content}</think>"
# 将响应添加到原始消息列表
final_messages = state['messages'] + [response]
return {
"messages": final_messages
}
async def abefore_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
if not self.guidelines:
return None
guideline_prompt = self.get_guideline_prompt(convert_to_openai_messages(self._origin_messages))
# 准备完整的消息列表
messages = state['messages'].copy()
# 将guideline_prompt作为系统消息添加到消息列表
system_message = SystemMessage(content=guideline_prompt)
messages = [system_message,messages[-1]]
# 使用回调处理器调用模型
response = await self.model.ainvoke(
messages,
config={"metadata": {"message_tag": "THINK"}}
)
response.additional_kwargs["message_tag"] = "THINK"
response.content = f"<think>{response.content}</think>"
# 将响应添加到原始消息列表
final_messages = state['messages'] + [response]
return {
"messages": final_messages
}
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
return handler(request.override(system_prompt=self.processed_system_prompt))
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
return await handler(request.override(system_prompt=self.processed_system_prompt))