146 lines
6.1 KiB
Python
146 lines
6.1 KiB
Python
from langchain.agents.middleware import AgentState, AgentMiddleware, ModelRequest, ModelResponse
|
||
from langchain_core.messages import convert_to_openai_messages
|
||
from agent.prompt_loader import load_guideline_prompt
|
||
from utils.fastapi_utils import (extract_block_from_system_prompt, format_messages_to_chat_history, get_user_last_message_content)
|
||
from langchain.chat_models import BaseChatModel
|
||
from langgraph.runtime import Runtime
|
||
from typing import Any, Callable
|
||
from langchain_core.callbacks import BaseCallbackHandler
|
||
from langchain_core.outputs import LLMResult
|
||
import logging
|
||
import re
|
||
|
||
logger = logging.getLogger('app')
|
||
|
||
class GuidelineMiddleware(AgentMiddleware):
|
||
def __init__(self, bot_id: str, model:BaseChatModel, prompt: str, robot_type: str, language: str, user_identifier: str):
|
||
self.model = model
|
||
self.bot_id = bot_id
|
||
processed_system_prompt, guidelines, tool_description, scenarios, terms_list = extract_block_from_system_prompt(prompt)
|
||
|
||
self.processed_system_prompt = processed_system_prompt
|
||
self.guidelines = guidelines
|
||
self.tool_description = tool_description
|
||
self.scenarios = scenarios
|
||
|
||
self.language = language
|
||
self.user_identifier = user_identifier
|
||
|
||
self.robot_type = robot_type
|
||
self.terms_list = terms_list
|
||
|
||
if self.robot_type == "general_agent":
|
||
if not self.guidelines:
|
||
self.guidelines = """
|
||
1. General Inquiries
|
||
Condition: User inquiries about products, policies, troubleshooting, factual questions, etc.
|
||
Action: Priority given to invoking the 【Knowledge Base Retrieval】 tool to query the knowledge base.
|
||
|
||
2.Social Dialogue
|
||
Condition: User intent involves small talk, greetings, expressions of thanks, compliments, or other non-substantive conversations.
|
||
Action: Provide concise, friendly, and personified natural responses.
|
||
"""
|
||
if not self.tool_description:
|
||
self.tool_description = """
|
||
- **Knowledge Base Retrieval**: For knowledge queries/other inquiries, prioritize searching the knowledge base → rag_retrieve-rag_retrieve
|
||
"""
|
||
|
||
def get_guideline_prompt(self, messages: list[dict[str, Any]]) -> str:
|
||
## 处理terms
|
||
terms_analysis = self.get_term_analysis(messages)
|
||
|
||
guideline_prompt = ""
|
||
|
||
if self.guidelines:
|
||
chat_history = format_messages_to_chat_history(messages)
|
||
query_text = get_user_last_message_content(messages)
|
||
guideline_prompt = load_guideline_prompt(chat_history, query_text, self.guidelines, self.tool_description, self.scenarios, terms_analysis, self.language, self.user_identifier)
|
||
|
||
return guideline_prompt
|
||
|
||
def get_term_analysis(self, messages: list[dict[str, Any]]) -> str:
|
||
## 处理terms
|
||
terms_analysis = ""
|
||
if self.terms_list:
|
||
logger.info(f"Processing terms: {len(self.terms_list)} terms")
|
||
try:
|
||
from embedding.embedding import process_terms_with_embedding
|
||
query_text = get_user_last_message_content(messages)
|
||
terms_analysis = process_terms_with_embedding(terms_list, self.bot_id, query_text)
|
||
if terms_analysis:
|
||
self.processed_system_prompt = self.processed_system_prompt.replace("#terms#", terms_analysis)
|
||
logger.info(f"Terms analysis completed: {len(terms_analysis)} chars")
|
||
except Exception as e:
|
||
logger.error(f"Error processing terms with embedding: {e}")
|
||
terms_analysis = ""
|
||
else:
|
||
# 当terms_list为空时,删除对应的pkl缓存文件
|
||
try:
|
||
import os
|
||
cache_file = f"projects/cache/{self.bot_id}_terms.pkl"
|
||
if os.path.exists(cache_file):
|
||
os.remove(cache_file)
|
||
logger.info(f"Removed empty terms cache file: {cache_file}")
|
||
except Exception as e:
|
||
logger.error(f"Error removing terms cache file: {e}")
|
||
return terms_analysis
|
||
|
||
|
||
|
||
def before_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||
if not self.guidelines:
|
||
return None
|
||
|
||
guideline_prompt = self.get_guideline_prompt(convert_to_openai_messages(state['messages']))
|
||
|
||
# 使用回调处理器调用模型
|
||
response = self.model.invoke(
|
||
guideline_prompt,
|
||
config={"callbacks": [BaseCallbackHandler()]}
|
||
)
|
||
|
||
# 提取<think>与</think>之间的内容作为thinking
|
||
|
||
match = re.search(r'<think>(.*?)</think>', response.content, re.DOTALL)
|
||
response.additional_kwargs["thinking"] = match.group(1).strip() if match else response.content
|
||
|
||
messages = state['messages']+[response]
|
||
return {
|
||
"messages": messages
|
||
}
|
||
|
||
async def abefore_agent(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
|
||
if not self.guidelines:
|
||
return None
|
||
|
||
guideline_prompt = self.get_guideline_prompt(convert_to_openai_messages(state['messages']))
|
||
|
||
# 使用回调处理器调用模型
|
||
response = await self.model.ainvoke(
|
||
guideline_prompt,
|
||
config={"callbacks": [BaseCallbackHandler()]}
|
||
)
|
||
|
||
# 提取<think>与</think>之间的内容作为thinking
|
||
match = re.search(r'<think>(.*?)</think>', response.content, re.DOTALL)
|
||
response.additional_kwargs["thinking"] = match.group(1).strip() if match else response.content
|
||
|
||
messages = state['messages']+[response]
|
||
return {
|
||
"messages": messages
|
||
}
|
||
|
||
def wrap_model_call(
|
||
self,
|
||
request: ModelRequest,
|
||
handler: Callable[[ModelRequest], ModelResponse],
|
||
) -> ModelResponse:
|
||
return handler(request.override(system_prompt=self.processed_system_prompt))
|
||
|
||
async def awrap_model_call(
|
||
self,
|
||
request: ModelRequest,
|
||
handler: Callable[[ModelRequest], ModelResponse],
|
||
) -> ModelResponse:
|
||
return await handler(request.override(system_prompt=self.processed_system_prompt))
|