qwen_agent/agent/guideline_middleware.py
2025-12-13 02:52:01 +08:00

149 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from langchain.agents.middleware import AgentState, AgentMiddleware, ModelRequest, ModelResponse
from langchain_core.messages import convert_to_openai_messages
from agent.prompt_loader import load_guideline_prompt
from utils.fastapi_utils import (extract_block_from_system_prompt, format_messages_to_chat_history, get_user_last_message_content)
from langchain.chat_models import BaseChatModel
from langgraph.runtime import Runtime
from typing import Any, Callable
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
import logging
logger = logging.getLogger('app')
class ThinkingCallbackHandler(BaseCallbackHandler):
"""自定义回调处理器用于将模型响应内容转换为thinking格式"""
def on_llm_end(self, response: LLMResult, **kwargs) -> None:
"""在LLM调用结束后处理响应将内容转换为thinking格式"""
logger.info("Successfully converted response content to thinking format")
class GuidelineMiddleware(AgentMiddleware):
def __init__(self, bot_id: str, model:BaseChatModel, prompt: str, robot_type: str, language: str, user_identifier: str):
self.model = model
self.bot_id = bot_id
processed_system_prompt, guidelines, tool_description, scenarios, terms_list = extract_block_from_system_prompt(prompt)
self.processed_system_prompt = processed_system_prompt
self.guidelines = guidelines
self.tool_description = tool_description
self.scenarios = scenarios
self.language = language
self.user_identifier = user_identifier
self.robot_type = robot_type
self.terms_list = terms_list
if self.robot_type == "general_agent":
if not self.guidelines:
self.guidelines = """
1. General Inquiries
Condition: User inquiries about products, policies, troubleshooting, factual questions, etc.
Action: Priority given to invoking the 【Knowledge Base Retrieval】 tool to query the knowledge base.
2.Social Dialogue
Condition: User intent involves small talk, greetings, expressions of thanks, compliments, or other non-substantive conversations.
Action: Provide concise, friendly, and personified natural responses.
"""
if not self.tool_description:
self.tool_description = """
- **Knowledge Base Retrieval**: For knowledge queries/other inquiries, prioritize searching the knowledge base → rag_retrieve-rag_retrieve
"""
def get_guideline_prompt(self, messages: list[dict[str, Any]]) -> str:
## 处理terms
terms_analysis = self.get_term_analysis(messages)
guideline_prompt = ""
if self.guidelines:
chat_history = format_messages_to_chat_history(messages)
query_text = get_user_last_message_content(messages)
guideline_prompt = load_guideline_prompt(chat_history, query_text, self.guidelines, self.tool_description, self.scenarios, terms_analysis, self.language, self.user_identifier)
return guideline_prompt
def get_term_analysis(self, messages: list[dict[str, Any]]) -> str:
## 处理terms
terms_analysis = ""
if self.terms_list:
logger.info(f"Processing terms: {len(self.terms_list)} terms")
try:
from embedding.embedding import process_terms_with_embedding
query_text = get_user_last_message_content(messages)
terms_analysis = process_terms_with_embedding(terms_list, self.bot_id, query_text)
if terms_analysis:
self.processed_system_prompt = self.processed_system_prompt.replace("#terms#", terms_analysis)
logger.info(f"Terms analysis completed: {len(terms_analysis)} chars")
except Exception as e:
logger.error(f"Error processing terms with embedding: {e}")
terms_analysis = ""
else:
# 当terms_list为空时删除对应的pkl缓存文件
try:
import os
cache_file = f"projects/cache/{self.bot_id}_terms.pkl"
if os.path.exists(cache_file):
os.remove(cache_file)
logger.info(f"Removed empty terms cache file: {cache_file}")
except Exception as e:
logger.error(f"Error removing terms cache file: {e}")
return terms_analysis
def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
if not self.guidelines:
return None
guideline_prompt = self.get_guideline_prompt(convert_to_openai_messages(state['messages']))
# 创建回调处理器实例
thinking_handler = ThinkingCallbackHandler()
# 使用回调处理器调用模型
response = self.model.invoke(
guideline_prompt,
config={"callbacks": [thinking_handler]}
)
response.additional_kwargs["thinking"] = response.content
messages = state['messages']+[response]
return {
"messages": messages
}
async def abefore_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
if not self.guidelines:
return None
guideline_prompt = self.get_guideline_prompt(convert_to_openai_messages(state['messages']))
# 使用回调处理器调用模型
response = await self.model.ainvoke(
guideline_prompt,
config={"callbacks": [ThinkingCallbackHandler()]}
)
response.additional_kwargs["thinking"] = response.content
messages = state['messages']+[response]
return {
"messages": messages
}
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
return handler(request.override(system_prompt=self.processed_system_prompt))
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse:
return await handler(request.override(system_prompt=self.processed_system_prompt))