Skip to content
57 changes: 54 additions & 3 deletions src/axolotl/logging_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,56 @@
Common logging module for axolotl
"""

import logging
import os
import sys
from logging import Formatter
from logging import Formatter, Logger, LogRecord
from logging.config import dictConfig
from typing import Any, Dict

from colorama import Fore, Style, init

DEFAULT_AXOLOTL_LOG_LEVEL = "INFO"
DEFAULT_LOG_LEVEL = "WARNING"


class AxolotlOrWarnErrorFilter(logging.Filter):
Comment thread
salmanmohammadi marked this conversation as resolved.
"""
Allows ANY WARNING or higher (unless overridden by LOG_LEVEL)
Allows axolotl.* at INFO or higher (unless overridden by AXOLOTL_LOG_LEVEL)
Drops all other records (i.e. non-axolotl.INFO, DEBUG, etc. by default)
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

self.axolotl_level = logging.getLevelNamesMapping()[
os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL)
]
self.other_level = logging.getLevelNamesMapping()[
os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL)
]

def filter(self, record: LogRecord) -> bool:
# General filter
if record.levelno >= self.other_level:
return True

# Axolotl filter
return (
record.name.startswith("axolotl") and record.levelno >= self.axolotl_level
)


class AxolotlLogger(Logger):
"""A Logger that automatically rejects non-axolotl INFOs."""

def __init__(self, name: str, level: int = logging.NOTSET):
super().__init__(name, level)

# set global filter on the logger itself
self.addFilter(AxolotlOrWarnErrorFilter())


class ColorfulFormatter(Formatter):
"""
Expand Down Expand Up @@ -55,11 +97,15 @@ def format(self, record):
"stream": sys.stdout,
},
},
"root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
# log level will be superseded by the AxolotlLogger
"root": {
"handlers": ["console"],
"level": os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL),
},
"loggers": {
"axolotl": {
"handlers": ["color_console"],
"level": "DEBUG",
"level": os.getenv("AXOLOTL_LOG_LEVEL", DEFAULT_AXOLOTL_LOG_LEVEL),
"propagate": False,
},
},
Expand All @@ -70,3 +116,8 @@ def configure_logging():
"""Configure with default logging"""
init() # Initialize colorama
dictConfig(DEFAULT_LOGGING_CONFIG)
logging.setLoggerClass(AxolotlLogger)

# set default `ACCELERATE_LOG_LEVEL` to `LOG_LEVEL` if available and not set
if "ACCELERATE_LOG_LEVEL" not in os.environ:
os.environ["ACCELERATE_LOG_LEVEL"] = os.getenv("LOG_LEVEL", DEFAULT_LOG_LEVEL)