104 lines
3.5 KiB
Python
104 lines
3.5 KiB
Python
import logging
|
|
import uuid
|
|
import json
|
|
from datetime import datetime
|
|
|
|
from fastapi import Request, FastAPI
|
|
from starlette.responses import JSONResponse
|
|
from starlette.requests import HTTPConnection
|
|
from typing import Callable, Awaitable
|
|
|
|
from .context import g
|
|
|
|
|
|
def add_request_routes(app: FastAPI):
|
|
@app.middleware("http")
|
|
async def before_request(request: Request, call_next: Callable[[Request], Awaitable[JSONResponse]]):
|
|
|
|
# First try to get x-request-id from headers; generate a new one if missing
|
|
trace_id = request.headers.get('X-Request-ID')
|
|
if not trace_id:
|
|
trace_id = "agent_" + str(uuid.uuid4())
|
|
|
|
# user_id = "unknown user_id"
|
|
|
|
g.trace_id = trace_id
|
|
# g.user_id = user_id
|
|
response = await call_next(request)
|
|
|
|
response.headers['X-Request-ID'] = g.trace_id
|
|
return response
|
|
|
|
|
|
class Formatter(logging.Formatter):
|
|
def formatTime(self, record, datefmt=None):
|
|
# Convert the timestamp to a datetime object
|
|
dt = datetime.fromtimestamp(record.created)
|
|
# Format the timestamp to millisecond precision
|
|
if datefmt:
|
|
s = dt.strftime(datefmt)
|
|
# Drop the last three digits of microseconds, keeping only milliseconds
|
|
s = s[:-3]
|
|
else:
|
|
# Drop the last three digits of microseconds, keeping only milliseconds
|
|
s = dt.strftime("%H:%M:%S.%f")[:-3]
|
|
return s
|
|
|
|
def format(self, record):
|
|
# Handle trace_id - use a default value when there is no request context
|
|
if not hasattr(record, "trace_id"):
|
|
try:
|
|
record.trace_id = getattr(g, "trace_id")
|
|
except LookupError:
|
|
record.trace_id = "N/A"
|
|
# Handle subagent - default to "main" for the orchestrator / no-context paths.
|
|
# Catch KeyError too: GlobalContext.__getattr__ raises KeyError on a missing key.
|
|
if not hasattr(record, "subagent"):
|
|
try:
|
|
record.subagent = getattr(g, "subagent")
|
|
except (KeyError, LookupError):
|
|
record.subagent = "main"
|
|
# Handle user_id
|
|
# if not hasattr(record, "user_id"):
|
|
# record.user_id = getattr(g, "user_id")
|
|
|
|
# Format the timestamp
|
|
record.timestamp = self.formatTime(record, self.datefmt)
|
|
|
|
return super().format(record)
|
|
|
|
|
|
# Register session context tracing here once
|
|
def init_logger_once(name,level):
|
|
logger = logging.getLogger(name)
|
|
logger.setLevel(level=level)
|
|
formatter = Formatter("%(timestamp)s | %(levelname)-5s | %(trace_id)s | %(subagent)s | %(name)s:%(funcName)s:%(lineno)s - %(message)s", datefmt='%Y-%m-%d %H:%M:%S.%f')
|
|
handler = logging.StreamHandler()
|
|
handler.setFormatter(formatter)
|
|
logger.addHandler(handler)
|
|
|
|
|
|
def init_with_fastapi(app,level=logging.INFO):
|
|
init_logger_once("app",level)
|
|
add_request_routes(app)
|
|
|
|
def info(message, *args, **kwargs):
|
|
app_logger = logging.getLogger('app')
|
|
app_logger.info(message, *args, **kwargs)
|
|
|
|
def debug(message, *args, **kwargs):
|
|
app_logger = logging.getLogger('app')
|
|
app_logger.debug(message, *args, **kwargs)
|
|
|
|
def warning(message, *args, **kwargs):
|
|
app_logger = logging.getLogger('app')
|
|
app_logger.warning(message, *args, **kwargs)
|
|
|
|
def error(message, *args, **kwargs):
|
|
app_logger = logging.getLogger('app')
|
|
app_logger.error(message, *args, **kwargs)
|
|
|
|
def critical(message, *args, **kwargs):
|
|
app_logger = logging.getLogger('app')
|
|
app_logger.critical(message, *args, **kwargs)
|