qwen_agent/agent/logging_handler.py
2026-05-11 20:29:27 +08:00

89 lines
3.8 KiB
Python

"""Logging callback handler module."""
import logging
import traceback
from typing import Any, Optional, Dict, List
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
class LoggingCallbackHandler(BaseCallbackHandler):
"""Custom callback handler that uses the project's logger for logging."""
def __init__(self, logger_name: str = 'app'):
self.logger = logging.getLogger(logger_name)
# def on_chat_model_start(
# self,
# serialized: Dict[str, Any],
# messages: List[List[BaseMessage]],
# **kwargs: Any,
# ) -> None:
# """Called when the chat model starts."""
# self.logger.info("✅ Chat Model Start - Messages:")
# for msg_list in messages:
# for msg in msg_list:
# msg_type = msg.__class__.__name__
# content = msg.content if hasattr(msg, 'content') else str(msg)
# self.logger.info(f"[{msg_type}] {content}")
# def on_llm_start(self, serialized: Dict[str, Any], prompts: Any, **kwargs: Any) -> None:
# """Called when the LLM starts, for standard LLMs rather than chat models."""
# self.logger.info("✅ LLM Start - Input:")
# for prompt in prompts:
# self.logger.info(str(prompt))
def on_llm_end(self, response, **kwargs: Any) -> None:
"""Called when the LLM finishes."""
self.logger.info("✅ LLM End - Output:")
# Print the generated text.
if hasattr(response, 'generations') and response.generations:
for gen_idx, generation_list in enumerate(response.generations):
for msg_idx, generation in enumerate(generation_list):
# ChatGeneration: use the text attribute to get content.
if hasattr(generation, 'text') and generation.text:
for line in generation.text.split("\n"):
if line.strip():
self.logger.info(f" {line}")
# If a message attribute exists, output additional info such as tool_calls.
if hasattr(generation, 'message') and generation.message:
msg = generation.message
content = msg.content if hasattr(msg, 'content') else ''
if not content or (isinstance(content, str) and not content.strip()):
self.logger.info(f" [EMPTY content]")
tool_calls = msg.tool_calls if hasattr(msg, 'tool_calls') else []
if tool_calls:
self.logger.info(f" [tool_calls: {[tc.get('name', '') for tc in tool_calls]}")
def on_llm_error(
self, error: Exception, **kwargs: Any
) -> None:
"""Called when the LLM raises an error."""
self.logger.error(f"❌ LLM Error: {error}")
def on_tool_start(
self, serialized: Optional[Dict[str, Any]], input_str: str, **kwargs: Any
) -> None:
"""Called when a tool invocation starts."""
if serialized is None:
tool_name = 'unknown_tool'
else:
tool_name = serialized.get('name', 'unknown_tool')
self.logger.info(f"🔧 Tool Start - {tool_name} with input: {str(input_str)[:1000]}")
def on_tool_end(self, output: str, **kwargs: Any) -> None:
"""Called when a tool invocation finishes."""
self.logger.info(f"✅ Tool End Output: {output}")
def on_tool_error(
self, error: Exception, **kwargs: Any
) -> None:
"""Called when a tool invocation raises an error."""
self.logger.error(
"❌ Tool Error: %s\n%s",
repr(error),
"".join(traceback.format_exception(type(error), error, error.__traceback__)),
)