diff --git a/build/accelerate_launch.py b/build/accelerate_launch.py index ee8718b5d..2cfc9069f 100644 --- a/build/accelerate_launch.py +++ b/build/accelerate_launch.py @@ -61,9 +61,6 @@ def get_base_model_from_adapter_config(adapter_config): def main(): - LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper() - logging.basicConfig(level=LOGLEVEL) - if not os.getenv("TERMINATION_LOG_FILE"): os.environ["TERMINATION_LOG_FILE"] = ERROR_LOG @@ -80,6 +77,18 @@ def main(): or 'SFT_TRAINER_CONFIG_JSON_ENV_VAR'." ) + # Configure log_level of python native logger. + # CLI arg takes precedence over env var. And if neither is set, we use default "WARNING" + log_level = job_config.get( + "log_level" + ) # this will be set to either the value found or None + if ( + not log_level + ): # if log level not set by job_config aka by JSON, set it via env var or set default + log_level = os.environ.get("LOG_LEVEL", "WARNING") + log_level = log_level.upper() + logging.basicConfig(level=log_level) + args = process_accelerate_launch_args(job_config) logging.debug("accelerate launch parsed args: %s", args) except FileNotFoundError as e: @@ -109,7 +118,6 @@ def main(): job_config["output_dir"] = tempdir updated_args = serialize_args(job_config) os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = updated_args - launch_command(args) except subprocess.CalledProcessError as e: # If the subprocess throws an exception, the base exception is hidden in the diff --git a/tests/utils/test_logging.py b/tests/utils/test_logging.py new file mode 100644 index 000000000..7b7aa1a2a --- /dev/null +++ b/tests/utils/test_logging.py @@ -0,0 +1,84 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +from unittest import mock +import copy +import logging +import os + +# First Party +from tests.test_sft_trainer import TRAIN_ARGS + +# Local +from tuning.utils.logging import set_log_level + + +@mock.patch.dict(os.environ, {}, clear=True) +def test_set_log_level_for_logger_default(): + """ + Ensure that the correct log level is being set for python native logger and + transformers logger when no env var or CLI flag is passed + """ + + train_args = copy.deepcopy(TRAIN_ARGS) + training_args, logger = set_log_level(train_args) + assert logger.getEffectiveLevel() == logging.WARNING + assert training_args.log_level == "passive" + + +@mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True) +def test_set_log_level_for_logger_with_env_var(): + """ + Ensure that the correct log level is being set for python native logger and + transformers logger when env var LOG_LEVEL is used + """ + + train_args_env = copy.deepcopy(TRAIN_ARGS) + training_args, logger = set_log_level(train_args_env) + assert logger.getEffectiveLevel() == logging.INFO + assert training_args.log_level == "info" + + +@mock.patch.dict(os.environ, {"TRANSFORMERS_VERBOSITY": "info"}, clear=True) +def test_set_log_level_for_logger_with_set_verbosity_and_cli(): + """ + Ensure that the correct log level is being set for python native logger and + log_level of transformers logger is unchanged when env var TRANSFORMERS_VERBOSITY is used + and CLI flag is passed + """ + + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.log_level = "error" + training_args, logger = set_log_level(train_args) + assert logger.getEffectiveLevel() == logging.ERROR + assert training_args.log_level == "error" + + +@mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True) +def test_set_log_level_for_logger_with_env_var_and_cli(): + """ + Ensure that the correct log level is being set for python native logger and + transformers logger when env var LOG_LEVEL is used and CLI flag is passed. + In this case, CLI arg takes precedence over the set env var LOG_LEVEL. + """ + + train_args = copy.deepcopy(TRAIN_ARGS) + train_args.log_level = "error" + training_args, logger = set_log_level(train_args) + assert logger.getEffectiveLevel() == logging.ERROR + assert training_args.log_level == "error" diff --git a/tuning/config/configs.py b/tuning/config/configs.py index c08c90b12..2990ef801 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -136,6 +136,15 @@ class TrainingArguments(transformers.TrainingArguments): + "Requires additional configs, see tuning.configs/tracker_configs.py" }, ) + log_level: str = field( + default="passive", + metadata={ + "help": "The log level to adopt during training. \ + By default, 'passive' level is set which keeps the \ + current log level for the Transformers library (which will be 'warning` by default) \ + Other possible values are 'debug', 'info', 'warning', 'error' and 'critical'" + }, + ) @dataclass diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 88c639d40..2fce4ec37 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -33,7 +33,7 @@ LlamaTokenizerFast, TrainerCallback, ) -from transformers.utils import is_accelerate_available, logging +from transformers.utils import is_accelerate_available from trl import SFTConfig, SFTTrainer import fire import transformers @@ -60,6 +60,7 @@ USER_ERROR_EXIT_CODE, write_termination_log, ) +from tuning.utils.logging import set_log_level from tuning.utils.preprocessing_utils import ( format_dataset, get_data_collator, @@ -111,7 +112,7 @@ def train( fused_lora and fast_kernels must used together (may change in future). \ """ - logger = logging.get_logger("sft_trainer") + train_args, logger = set_log_level(train_args, "sft_trainer_train") # Validate parameters if (not isinstance(train_args.num_train_epochs, (float, int))) or ( @@ -479,11 +480,8 @@ def parse_arguments(parser, json_config=None): def main(**kwargs): # pylint: disable=unused-argument - logger = logging.get_logger("__main__") - parser = get_parser() job_config = get_json_config() - logger.debug("Input args parsed: %s", job_config) # accept arguments via command-line or JSON try: ( @@ -498,6 +496,10 @@ def main(**kwargs): # pylint: disable=unused-argument fusedops_kernels_config, exp_metadata, ) = parse_arguments(parser, job_config) + + # Function to set log level for python native logger and transformers training logger + training_args, logger = set_log_level(training_args, __name__) + logger.debug( "Input args parsed: \ model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \ diff --git a/tuning/trackers/aimstack_tracker.py b/tuning/trackers/aimstack_tracker.py index c3f043839..139fbfc28 100644 --- a/tuning/trackers/aimstack_tracker.py +++ b/tuning/trackers/aimstack_tracker.py @@ -14,11 +14,11 @@ # Standard import json +import logging import os # Third Party from aim.hugging_face import AimCallback # pylint: disable=import-error -from transformers.utils import logging # Local from .tracker import Tracker @@ -99,7 +99,8 @@ def __init__(self, tracker_config: AimConfig): information about the repo or the server and port where aim db is present. """ super().__init__(name="aim", tracker_config=tracker_config) - self.logger = logging.get_logger("aimstack_tracker") + # Get logger with root log level + self.logger = logging.getLogger() def get_hf_callback(self): """Returns the aim.hugging_face.AimCallback object associated with this tracker. diff --git a/tuning/trackers/filelogging_tracker.py b/tuning/trackers/filelogging_tracker.py index 213377d96..133687866 100644 --- a/tuning/trackers/filelogging_tracker.py +++ b/tuning/trackers/filelogging_tracker.py @@ -15,11 +15,11 @@ # Standard from datetime import datetime import json +import logging import os # Third Party from transformers import TrainerCallback -from transformers.utils import logging # Local from .tracker import Tracker @@ -80,7 +80,8 @@ def __init__(self, tracker_config: FileLoggingTrackerConfig): which contains the location of file where logs are recorded. """ super().__init__(name="file_logger", tracker_config=tracker_config) - self.logger = logging.get_logger("file_logging_tracker") + # Get logger with root log level + self.logger = logging.getLogger() def get_hf_callback(self): """Returns the FileLoggingCallback object associated with this tracker. diff --git a/tuning/trackers/tracker_factory.py b/tuning/trackers/tracker_factory.py index 98771c143..096099306 100644 --- a/tuning/trackers/tracker_factory.py +++ b/tuning/trackers/tracker_factory.py @@ -14,18 +14,15 @@ # Standard import dataclasses +import logging # Third Party -from transformers.utils import logging from transformers.utils.import_utils import _is_package_available # Local from .filelogging_tracker import FileLoggingTracker from tuning.config.tracker_configs import FileLoggingTrackerConfig, TrackerConfigFactory -logger = logging.get_logger("tracker_factory") - - # Information about all registered trackers AIMSTACK_TRACKER = "aim" FILE_LOGGING_TRACKER = "file_logger" @@ -54,9 +51,9 @@ def _register_aim_tracker(): AimTracker = _get_tracker_class(AimStackTracker, AimConfig) REGISTERED_TRACKERS[AIMSTACK_TRACKER] = AimTracker - logger.info("Registered aimstack tracker") + logging.info("Registered aimstack tracker") else: - logger.info( + logging.info( "Not registering Aimstack tracker due to unavailablity of package.\n" "Please install aim if you intend to use it.\n" "\t pip install aim" @@ -72,14 +69,14 @@ def _is_tracker_installed(name): def _register_file_logging_tracker(): FileTracker = _get_tracker_class(FileLoggingTracker, FileLoggingTrackerConfig) REGISTERED_TRACKERS[FILE_LOGGING_TRACKER] = FileTracker - logger.info("Registered file logging tracker") + logging.info("Registered file logging tracker") # List of Available Trackers # file_logger - Logs loss to a file # aim - Aimstack Tracker def _register_trackers(): - logger.info("Registering trackers") + logging.info("Registering trackers") if AIMSTACK_TRACKER not in REGISTERED_TRACKERS: _register_aim_tracker() if FILE_LOGGING_TRACKER not in REGISTERED_TRACKERS: @@ -142,7 +139,7 @@ def get_tracker(name: str, tracker_configs: TrackerConfigFactory): e = "Requested Tracker {} not found. List trackers available for use is - {} ".format( name, available ) - logger.error(e) + logging.error(e) raise ValueError(e) meta = REGISTERED_TRACKERS[name] diff --git a/tuning/trainercontroller/controllermetrics/eval_metrics.py b/tuning/trainercontroller/controllermetrics/eval_metrics.py index 696714437..a87772674 100644 --- a/tuning/trainercontroller/controllermetrics/eval_metrics.py +++ b/tuning/trainercontroller/controllermetrics/eval_metrics.py @@ -18,14 +18,9 @@ # Standard from typing import Any -# Third Party -from transformers.utils import logging - # Local from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler -logger = logging.get_logger(__name__) - class EvalMetrics(MetricHandler): """Implements the controller metric which exposes the evaluation metrics""" diff --git a/tuning/trainercontroller/controllermetrics/history_based_metrics.py b/tuning/trainercontroller/controllermetrics/history_based_metrics.py index ae547d3c6..f66d634e5 100644 --- a/tuning/trainercontroller/controllermetrics/history_based_metrics.py +++ b/tuning/trainercontroller/controllermetrics/history_based_metrics.py @@ -21,12 +21,10 @@ # Third Party from transformers import TrainerState -from transformers.utils import logging # Local from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler -logger = logging.get_logger(__name__) METRICS_KEY = "metrics" LOG_LOSS_KEY = "loss" TRAINING_LOSS_KEY = "training_loss" diff --git a/tuning/trainercontroller/operations/hfcontrols.py b/tuning/trainercontroller/operations/hfcontrols.py index c1f7589e6..0548b4c12 100644 --- a/tuning/trainercontroller/operations/hfcontrols.py +++ b/tuning/trainercontroller/operations/hfcontrols.py @@ -1,17 +1,15 @@ # Standard from dataclasses import fields import inspect +import logging import re # Third Party from transformers import TrainerControl -from transformers.utils import logging # Local from .operation import Operation -logger = logging.get_logger(__name__) - class HFControls(Operation): """Implements the control actions for the HuggingFace controls in @@ -39,7 +37,7 @@ def control_action(self, control: TrainerControl, **kwargs): control: TrainerControl. Data class for controls. kwargs: List of arguments (key, value)-pairs """ - logger.debug("Arguments passed to control_action: %s", repr(kwargs)) + logging.debug("Arguments passed to control_action: %s", repr(kwargs)) frame_info = inspect.currentframe().f_back arg_values = inspect.getargvalues(frame_info) setattr(control, arg_values.locals["action"], True) diff --git a/tuning/trainercontroller/patience.py b/tuning/trainercontroller/patience.py index b8098fdf0..ecdb0699a 100644 --- a/tuning/trainercontroller/patience.py +++ b/tuning/trainercontroller/patience.py @@ -15,8 +15,8 @@ # SPDX-License-Identifier: Apache-2.0 # https://spdx.dev/learn/handling-license-info/ -# Third Party -from transformers.utils import logging +# Standard +import logging # Resets the patience if the rule outcome happens to be false. # Here, the expectation is to have unbroken "True"s for patience @@ -31,8 +31,6 @@ # will be exceeded afer the fifth event. MODE_NO_RESET_ON_FAILURE = "no_reset_on_failure" -logger = logging.get_logger(__name__) - class PatienceControl: """Implements the patience control for every rule""" @@ -51,7 +49,7 @@ def should_tolerate( elif self._mode == MODE_RESET_ON_FAILURE: self._patience_counter = 0 if self._patience_counter <= self._patience_threshold: - logger.debug( + logging.debug( "Control {} triggered on event {}: " "Enforcing patience [patience_counter = {:.2f}, " "patience_threshold = {:.2f}]".format( @@ -62,7 +60,7 @@ def should_tolerate( ) ) return True - logger.debug( + logging.debug( "Control {} triggered on event {}: " "Exceeded patience [patience_counter = {:.2f}, " "patience_threshold = {:.2f}]".format( diff --git a/tuning/utils/data_type_utils.py b/tuning/utils/data_type_utils.py index cefebb100..52bae6d77 100644 --- a/tuning/utils/data_type_utils.py +++ b/tuning/utils/data_type_utils.py @@ -14,13 +14,11 @@ # Standard from typing import Union +import logging # Third Party -from transformers.utils import logging import torch -logger = logging.get_logger("data_utils") - def str_to_torch_dtype(dtype_str: str) -> torch.dtype: """Given a string representation of a Torch data type, convert it to the actual torch dtype. @@ -35,7 +33,7 @@ def str_to_torch_dtype(dtype_str: str) -> torch.dtype: """ dt = getattr(torch, dtype_str, None) if not isinstance(dt, torch.dtype): - logger.error(" ValueError: Unrecognized data type of a torch.Tensor") + logging.error(" ValueError: Unrecognized data type of a torch.Tensor") raise ValueError("Unrecognized data type of a torch.Tensor") return dt diff --git a/tuning/utils/logging.py b/tuning/utils/logging.py new file mode 100644 index 000000000..1f1b6c73e --- /dev/null +++ b/tuning/utils/logging.py @@ -0,0 +1,64 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import logging +import os + + +def set_log_level(train_args, logger_name=None): + """Set log level of python native logger and TF logger via argument from CLI or env variable. + + Args: + train_args + Training arguments for training model. + logger_name + Logger name with which the logger is instantiated. + + Returns: + train_args + Updated training arguments for training model. + train_logger + Logger with updated effective log level + """ + + # Clear any existing handlers if necessary + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + + # Configure Python native logger and transformers log level + # If CLI arg is passed, assign same log level to python native logger + log_level = "WARNING" + if train_args.log_level != "passive": + log_level = train_args.log_level + + # If CLI arg not is passed and env var LOG_LEVEL is set, + # assign same log level to both logger + elif os.environ.get("LOG_LEVEL"): + log_level = os.environ.get("LOG_LEVEL") + train_args.log_level = ( + log_level.lower() + if not os.environ.get("TRANSFORMERS_VERBOSITY") + else os.environ.get("TRANSFORMERS_VERBOSITY") + ) + + logging.basicConfig( + format="%(levelname)s:%(filename)s:%(message)s", level=log_level.upper() + ) + + if logger_name: + train_logger = logging.getLogger(logger_name) + else: + train_logger = logging.getLogger() + return train_args, train_logger diff --git a/tuning/utils/preprocessing_utils.py b/tuning/utils/preprocessing_utils.py index fef512fc4..68b2755d8 100644 --- a/tuning/utils/preprocessing_utils.py +++ b/tuning/utils/preprocessing_utils.py @@ -14,12 +14,12 @@ # Standard from typing import Any, Callable, Dict, Optional, Union import json +import logging # Third Party from datasets import Dataset, IterableDataset from datasets.exceptions import DatasetGenerationError from transformers import AutoTokenizer, DataCollatorForSeq2Seq -from transformers.utils import logging from trl import DataCollatorForCompletionOnlyLM import datasets @@ -27,8 +27,6 @@ from tuning.config import configs from tuning.utils.data_utils import apply_custom_formatting_template -logger = logging.get_logger("sft_trainer_preprocessing") - # In future we may make the fields configurable JSON_INPUT_KEY = "input" JSON_OUTPUT_KEY = "output" @@ -220,7 +218,7 @@ def format_dataset( tokenizer, data_args.data_formatter_template, ) - logger.info("Training dataset length is %s", len(train_dataset)) + logging.info("Training dataset length is %s", len(train_dataset)) if data_args.validation_data_path: (eval_dataset) = get_formatted_dataset_with_single_sequence( data_args.validation_data_path, @@ -228,7 +226,7 @@ def format_dataset( tokenizer, data_args.data_formatter_template, ) - logger.info("Validation dataset length is %s", len(eval_dataset)) + logging.info("Validation dataset length is %s", len(eval_dataset)) else: # This is for JSON containing input/output fields train_dataset = get_preprocessed_dataset(