80 lines
3.1 KiB
Python
80 lines
3.1 KiB
Python
"""日志回调处理器模块"""
|
||
|
||
import logging
|
||
from typing import Any, Optional, Dict, List
|
||
from langchain_core.callbacks import BaseCallbackHandler
|
||
from langchain_core.messages import BaseMessage
|
||
|
||
|
||
class LoggingCallbackHandler(BaseCallbackHandler):
|
||
"""自定义的 CallbackHandler,使用项目的 logger 来记录日志"""
|
||
|
||
def __init__(self, logger_name: str = 'app'):
|
||
self.logger = logging.getLogger(logger_name)
|
||
|
||
# def on_chat_model_start(
|
||
# self,
|
||
# serialized: Dict[str, Any],
|
||
# messages: List[List[BaseMessage]],
|
||
# **kwargs: Any,
|
||
# ) -> None:
|
||
# """当 Chat 模型开始时调用"""
|
||
# self.logger.info("✅ Chat Model Start - Messages:")
|
||
# for msg_list in messages:
|
||
# for msg in msg_list:
|
||
# msg_type = msg.__class__.__name__
|
||
# content = msg.content if hasattr(msg, 'content') else str(msg)
|
||
# self.logger.info(f"[{msg_type}] {content}")
|
||
|
||
# def on_llm_start(self, serialized: Dict[str, Any], prompts: Any, **kwargs: Any) -> None:
|
||
# """当 LLM 开始时调用(用于普通 LLM,非 Chat 模型)"""
|
||
# self.logger.info("✅ LLM Start - Input:")
|
||
# for prompt in prompts:
|
||
# self.logger.info(str(prompt))
|
||
|
||
|
||
def on_llm_end(self, response, **kwargs: Any) -> None:
|
||
"""当 LLM 结束时调用"""
|
||
self.logger.info("✅ LLM End - Output:")
|
||
|
||
# 打印生成的文本
|
||
if hasattr(response, 'generations') and response.generations:
|
||
for gen_idx, generation_list in enumerate(response.generations):
|
||
for msg_idx, generation in enumerate(generation_list):
|
||
if hasattr(generation, 'text'):
|
||
output_list = generation.text.split("\n")
|
||
for i, output in enumerate(output_list):
|
||
if output.strip():
|
||
self.logger.info(f"{output}")
|
||
elif hasattr(generation, 'message'):
|
||
output_list = generation.message.split("\n")
|
||
for i, output in enumerate(output_list):
|
||
if output.strip():
|
||
self.logger.info(f"{output}")
|
||
|
||
def on_llm_error(
|
||
self, error: Exception, **kwargs: Any
|
||
) -> None:
|
||
"""当 LLM 出错时调用"""
|
||
self.logger.error(f"❌ LLM Error: {error}")
|
||
|
||
def on_tool_start(
|
||
self, serialized: Optional[Dict[str, Any]], input_str: str, **kwargs: Any
|
||
) -> None:
|
||
"""当工具开始调用时调用"""
|
||
if serialized is None:
|
||
tool_name = 'unknown_tool'
|
||
else:
|
||
tool_name = serialized.get('name', 'unknown_tool')
|
||
self.logger.info(f"🔧 Tool Start - {tool_name} with input: {str(input_str)[:1000]}")
|
||
|
||
def on_tool_end(self, output: str, **kwargs: Any) -> None:
|
||
"""当工具调用结束时调用"""
|
||
self.logger.info(f"✅ Tool End Output: {output}")
|
||
|
||
def on_tool_error(
|
||
self, error: Exception, **kwargs: Any
|
||
) -> None:
|
||
"""当工具调用出错时调用"""
|
||
self.logger.error(f"❌ Tool Error: {error}")
|