"""Logging callback handler module.""" import logging import traceback from typing import Any, Optional, Dict, List from langchain_core.callbacks import BaseCallbackHandler from langchain_core.messages import BaseMessage class LoggingCallbackHandler(BaseCallbackHandler): """Custom callback handler that uses the project's logger for logging.""" def __init__(self, logger_name: str = 'app'): self.logger = logging.getLogger(logger_name) # def on_chat_model_start( # self, # serialized: Dict[str, Any], # messages: List[List[BaseMessage]], # **kwargs: Any, # ) -> None: # """Called when the chat model starts.""" # self.logger.info("✅ Chat Model Start - Messages:") # for msg_list in messages: # for msg in msg_list: # msg_type = msg.__class__.__name__ # content = msg.content if hasattr(msg, 'content') else str(msg) # self.logger.info(f"[{msg_type}] {content}") # def on_llm_start(self, serialized: Dict[str, Any], prompts: Any, **kwargs: Any) -> None: # """Called when the LLM starts, for standard LLMs rather than chat models.""" # self.logger.info("✅ LLM Start - Input:") # for prompt in prompts: # self.logger.info(str(prompt)) def on_llm_end(self, response, **kwargs: Any) -> None: """Called when the LLM finishes.""" self.logger.info("✅ LLM End - Output:") # Print the generated text. if hasattr(response, 'generations') and response.generations: for gen_idx, generation_list in enumerate(response.generations): for msg_idx, generation in enumerate(generation_list): # ChatGeneration: use the text attribute to get content. if hasattr(generation, 'text') and generation.text: for line in generation.text.split("\n"): if line.strip(): self.logger.info(f" {line}") # If a message attribute exists, output additional info such as tool_calls. if hasattr(generation, 'message') and generation.message: msg = generation.message content = msg.content if hasattr(msg, 'content') else '' if not content or (isinstance(content, str) and not content.strip()): self.logger.info(f" [EMPTY content]") tool_calls = msg.tool_calls if hasattr(msg, 'tool_calls') else [] if tool_calls: self.logger.info(f" [tool_calls: {[tc.get('name', '') for tc in tool_calls]}") def on_llm_error( self, error: Exception, **kwargs: Any ) -> None: """Called when the LLM raises an error.""" self.logger.error(f"❌ LLM Error: {error}") def on_tool_start( self, serialized: Optional[Dict[str, Any]], input_str: str, **kwargs: Any ) -> None: """Called when a tool invocation starts.""" if serialized is None: tool_name = 'unknown_tool' else: tool_name = serialized.get('name', 'unknown_tool') self.logger.info(f"🔧 Tool Start - {tool_name} with input: {str(input_str)[:1000]}") def on_tool_end(self, output: str, **kwargs: Any) -> None: """Called when a tool invocation finishes.""" self.logger.info(f"✅ Tool End Output: {output}") def on_tool_error( self, error: Exception, **kwargs: Any ) -> None: """Called when a tool invocation raises an error.""" self.logger.error( "❌ Tool Error: %s\n%s", repr(error), "".join(traceback.format_exception(type(error), error, error.__traceback__)), )