Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
2b4ce8d
Removal of transformers logger and addition of python logger
Abhishek-TAMU Jul 30, 2024
76fffc5
FMT and lint check: Removal of transformers logger and addition of py…
Abhishek-TAMU Jul 30, 2024
773f0d1
fix: remove lm_head for granite with llama arch models (#258)
Ssukriti Jul 29, 2024
2e36147
Merge branch 'main' into fix_logging
Abhishek-TAMU Jul 30, 2024
8846bc2
Fix: Addition of env var TRANSFORMERS_VERBOSITY check
Abhishek-TAMU Jul 30, 2024
4ed3878
FMT Fix: Addition of env var TRANSFORMERS_VERBOSITY check
Abhishek-TAMU Jul 30, 2024
efb8363
Adding logging support to accelerate launch
Abhishek-TAMU Jul 30, 2024
2ecfaf7
FMT_FIX: Adding logging support to accelerate launch
Abhishek-TAMU Jul 30, 2024
f2e8afb
Merge branch 'main' into fix_logging
Abhishek-TAMU Jul 30, 2024
c44b1dc
Merge branch 'main' into fix_logging
Abhishek-TAMU Jul 31, 2024
abd8abc
Merge branch 'foundation-model-stack:main' into fix_logging
Abhishek-TAMU Aug 1, 2024
5d08efb
Logging changes and unit tests added
Abhishek-TAMU Aug 1, 2024
a4cc9c2
Merge branch 'foundation-model-stack:main' into fix_logging
Abhishek-TAMU Aug 1, 2024
7fffda7
Solved conflict with main
Abhishek-TAMU Aug 1, 2024
ba8a972
FMT:Fix Solved conflict with main
Abhishek-TAMU Aug 1, 2024
f1159f9
enabling tests for prompt tuning
Abhishek-TAMU Aug 1, 2024
756d097
Merge branch 'main' into main
anhuong Aug 1, 2024
bf30ed2
Merge remote-tracking branch 'upstream/main'
Abhishek-TAMU Aug 5, 2024
cb846ca
Merge branch 'main' of github.com:Abhishek-TAMU/fms-hf-tuning
Abhishek-TAMU Aug 5, 2024
da7acc6
merge with main
Abhishek-TAMU Aug 5, 2024
8af5792
PR changes for changing logger
Abhishek-TAMU Aug 5, 2024
697056c
Merge remote-tracking branch 'upstream/main'
Abhishek-TAMU Aug 5, 2024
768d93a
Merge branch 'main' into fix_logging
Abhishek-TAMU Aug 5, 2024
ba489b5
Unit Tests changes
Abhishek-TAMU Aug 5, 2024
dc9a521
commented os.environ[LOG_LEVEL] in accelerate.py for testing
Abhishek-TAMU Aug 6, 2024
f3a984a
Merge remote-tracking branch 'upstream/main'
Abhishek-TAMU Aug 6, 2024
024b12e
Merge branch 'main' into fix_logging
Abhishek-TAMU Aug 6, 2024
cfeb709
PR changes
Abhishek-TAMU Aug 6, 2024
bf36b36
FIX:FMT
Abhishek-TAMU Aug 6, 2024
fe4b6d5
PR Changes
Abhishek-TAMU Aug 6, 2024
c544f47
PR Changes
Abhishek-TAMU Aug 6, 2024
f0fcfdb
Merge remote-tracking branch 'upstream/main'
Abhishek-TAMU Aug 7, 2024
e65cb2d
Merge branch 'main' into fix_logging
Abhishek-TAMU Aug 7, 2024
4841119
PR Changes
Abhishek-TAMU Aug 7, 2024
fe8bb05
Merge remote-tracking branch 'upstream/main'
Abhishek-TAMU Aug 8, 2024
5ecf4dd
Metrics file epoch indexing from 0
Abhishek-TAMU Aug 8, 2024
89124ac
Revert last commit
Abhishek-TAMU Aug 8, 2024
bddad06
Merge remote-tracking branch 'upstream/main'
Abhishek-TAMU Aug 8, 2024
0777460
Merge branch 'main' into fix_logging
Abhishek-TAMU Aug 8, 2024
3068f51
PR Changes
Abhishek-TAMU Aug 8, 2024
0866bce
PR Changes
Abhishek-TAMU Aug 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ def get_base_model_from_adapter_config(adapter_config):


def main():
LOGLEVEL = os.environ.get("LOG_LEVEL", "WARNING").upper()
logging.basicConfig(level=LOGLEVEL)

if not os.getenv("TERMINATION_LOG_FILE"):
os.environ["TERMINATION_LOG_FILE"] = ERROR_LOG

Expand All @@ -80,6 +77,18 @@ def main():
or 'SFT_TRAINER_CONFIG_JSON_ENV_VAR'."
)

# Configure log_level of python native logger.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also add a comment here that CLI arg takes precedence over env var. And if neither is set, we use default "WARNING"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! This is done

# CLI arg takes precedence over env var. And if neither is set, we use default "WARNING"
log_level = job_config.get(
"log_level"
) # this will be set to either the value found or None
if (
not log_level
): # if log level not set by job_config aka by JSON, set it via env var or set default
log_level = os.environ.get("LOG_LEVEL", "WARNING")
log_level = log_level.upper()
logging.basicConfig(level=log_level)

args = process_accelerate_launch_args(job_config)
logging.debug("accelerate launch parsed args: %s", args)
except FileNotFoundError as e:
Expand Down Expand Up @@ -109,7 +118,6 @@ def main():
job_config["output_dir"] = tempdir
updated_args = serialize_args(job_config)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = updated_args

launch_command(args)
except subprocess.CalledProcessError as e:
# If the subprocess throws an exception, the base exception is hidden in the
Expand Down
84 changes: 84 additions & 0 deletions tests/utils/test_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Standard
from unittest import mock
import copy
import logging
import os

# First Party
from tests.test_sft_trainer import TRAIN_ARGS

# Local
from tuning.utils.logging import set_log_level


@mock.patch.dict(os.environ, {}, clear=True)
def test_set_log_level_for_logger_default():
"""
Ensure that the correct log level is being set for python native logger and
transformers logger when no env var or CLI flag is passed
"""

train_args = copy.deepcopy(TRAIN_ARGS)
training_args, logger = set_log_level(train_args)
assert logger.getEffectiveLevel() == logging.WARNING
assert training_args.log_level == "passive"


@mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True)
def test_set_log_level_for_logger_with_env_var():
"""
Ensure that the correct log level is being set for python native logger and
transformers logger when env var LOG_LEVEL is used
"""

train_args_env = copy.deepcopy(TRAIN_ARGS)
training_args, logger = set_log_level(train_args_env)
assert logger.getEffectiveLevel() == logging.INFO
assert training_args.log_level == "info"


@mock.patch.dict(os.environ, {"TRANSFORMERS_VERBOSITY": "info"}, clear=True)
def test_set_log_level_for_logger_with_set_verbosity_and_cli():
"""
Ensure that the correct log level is being set for python native logger and
log_level of transformers logger is unchanged when env var TRANSFORMERS_VERBOSITY is used
and CLI flag is passed
"""

train_args = copy.deepcopy(TRAIN_ARGS)
train_args.log_level = "error"
training_args, logger = set_log_level(train_args)
assert logger.getEffectiveLevel() == logging.ERROR
assert training_args.log_level == "error"


@mock.patch.dict(os.environ, {"LOG_LEVEL": "info"}, clear=True)
def test_set_log_level_for_logger_with_env_var_and_cli():
"""
Ensure that the correct log level is being set for python native logger and
transformers logger when env var LOG_LEVEL is used and CLI flag is passed.
In this case, CLI arg takes precedence over the set env var LOG_LEVEL.
"""

train_args = copy.deepcopy(TRAIN_ARGS)
train_args.log_level = "error"
training_args, logger = set_log_level(train_args)
assert logger.getEffectiveLevel() == logging.ERROR
assert training_args.log_level == "error"
9 changes: 9 additions & 0 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ class TrainingArguments(transformers.TrainingArguments):
+ "Requires additional configs, see tuning.configs/tracker_configs.py"
},
)
log_level: str = field(
default="passive",
metadata={
"help": "The log level to adopt during training. \
By default, 'passive' level is set which keeps the \
current log level for the Transformers library (which will be 'warning` by default) \
Other possible values are 'debug', 'info', 'warning', 'error' and 'critical'"
},
)


@dataclass
Expand Down
12 changes: 7 additions & 5 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
LlamaTokenizerFast,
TrainerCallback,
)
from transformers.utils import is_accelerate_available, logging
from transformers.utils import is_accelerate_available
from trl import SFTConfig, SFTTrainer
import fire
import transformers
Expand All @@ -60,6 +60,7 @@
USER_ERROR_EXIT_CODE,
write_termination_log,
)
from tuning.utils.logging import set_log_level
from tuning.utils.preprocessing_utils import (
format_dataset,
get_data_collator,
Expand Down Expand Up @@ -111,7 +112,7 @@ def train(
fused_lora and fast_kernels must used together (may change in future). \
"""

logger = logging.get_logger("sft_trainer")
train_args, logger = set_log_level(train_args, "sft_trainer_train")

# Validate parameters
if (not isinstance(train_args.num_train_epochs, (float, int))) or (
Expand Down Expand Up @@ -479,11 +480,8 @@ def parse_arguments(parser, json_config=None):


def main(**kwargs): # pylint: disable=unused-argument
logger = logging.get_logger("__main__")

parser = get_parser()
job_config = get_json_config()
logger.debug("Input args parsed: %s", job_config)
# accept arguments via command-line or JSON
try:
(
Expand All @@ -498,6 +496,10 @@ def main(**kwargs): # pylint: disable=unused-argument
fusedops_kernels_config,
exp_metadata,
) = parse_arguments(parser, job_config)

# Function to set log level for python native logger and transformers training logger
training_args, logger = set_log_level(training_args, __name__)

logger.debug(
"Input args parsed: \
model_args %s, data_args %s, training_args %s, trainer_controller_args %s, \
Expand Down
5 changes: 3 additions & 2 deletions tuning/trackers/aimstack_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

# Standard
import json
import logging
import os

# Third Party
from aim.hugging_face import AimCallback # pylint: disable=import-error
from transformers.utils import logging

# Local
from .tracker import Tracker
Expand Down Expand Up @@ -99,7 +99,8 @@ def __init__(self, tracker_config: AimConfig):
information about the repo or the server and port where aim db is present.
"""
super().__init__(name="aim", tracker_config=tracker_config)
self.logger = logging.get_logger("aimstack_tracker")
# Get logger with root log level
self.logger = logging.getLogger()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Abhishek-TAMU any particular reason to remove module name from the logger I had kept it to segregate logs from various subcomponents.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment for file_logging_tracker

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explained in the comment above..no need for futher explanation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Thank you!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still dont quite understand this. You're setting log_level at runtime to restrict the amount of logs that a user wants to see (debug vs info etc.). But if you don't do logging.get_logger("aimstack_tracker") and so on in each file, all the logs will say they come from root, not module specific, helpful but not as helpful if the module is also included.
Maybe it'd be clearer if you could share an example run of what the logs look like if user sets debug or error for example and we can see what the logs look like.


def get_hf_callback(self):
"""Returns the aim.hugging_face.AimCallback object associated with this tracker.
Expand Down
5 changes: 3 additions & 2 deletions tuning/trackers/filelogging_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
# Standard
from datetime import datetime
import json
import logging
import os

# Third Party
from transformers import TrainerCallback
from transformers.utils import logging

# Local
from .tracker import Tracker
Expand Down Expand Up @@ -80,7 +80,8 @@ def __init__(self, tracker_config: FileLoggingTrackerConfig):
which contains the location of file where logs are recorded.
"""
super().__init__(name="file_logger", tracker_config=tracker_config)
self.logger = logging.get_logger("file_logging_tracker")
# Get logger with root log level
self.logger = logging.getLogger()

def get_hf_callback(self):
"""Returns the FileLoggingCallback object associated with this tracker.
Expand Down
15 changes: 6 additions & 9 deletions tuning/trackers/tracker_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,15 @@

# Standard
import dataclasses
import logging

# Third Party
from transformers.utils import logging
from transformers.utils.import_utils import _is_package_available

# Local
from .filelogging_tracker import FileLoggingTracker
from tuning.config.tracker_configs import FileLoggingTrackerConfig, TrackerConfigFactory

logger = logging.get_logger("tracker_factory")


# Information about all registered trackers
AIMSTACK_TRACKER = "aim"
FILE_LOGGING_TRACKER = "file_logger"
Expand Down Expand Up @@ -54,9 +51,9 @@ def _register_aim_tracker():
AimTracker = _get_tracker_class(AimStackTracker, AimConfig)

REGISTERED_TRACKERS[AIMSTACK_TRACKER] = AimTracker
logger.info("Registered aimstack tracker")
logging.info("Registered aimstack tracker")
else:
logger.info(
logging.info(
"Not registering Aimstack tracker due to unavailablity of package.\n"
"Please install aim if you intend to use it.\n"
"\t pip install aim"
Expand All @@ -72,14 +69,14 @@ def _is_tracker_installed(name):
def _register_file_logging_tracker():
FileTracker = _get_tracker_class(FileLoggingTracker, FileLoggingTrackerConfig)
REGISTERED_TRACKERS[FILE_LOGGING_TRACKER] = FileTracker
logger.info("Registered file logging tracker")
logging.info("Registered file logging tracker")


# List of Available Trackers
# file_logger - Logs loss to a file
# aim - Aimstack Tracker
def _register_trackers():
logger.info("Registering trackers")
logging.info("Registering trackers")
if AIMSTACK_TRACKER not in REGISTERED_TRACKERS:
_register_aim_tracker()
if FILE_LOGGING_TRACKER not in REGISTERED_TRACKERS:
Expand Down Expand Up @@ -142,7 +139,7 @@ def get_tracker(name: str, tracker_configs: TrackerConfigFactory):
e = "Requested Tracker {} not found. List trackers available for use is - {} ".format(
name, available
)
logger.error(e)
logging.error(e)
raise ValueError(e)

meta = REGISTERED_TRACKERS[name]
Expand Down
5 changes: 0 additions & 5 deletions tuning/trainercontroller/controllermetrics/eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,9 @@
# Standard
from typing import Any

# Third Party
from transformers.utils import logging

# Local
from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler

logger = logging.get_logger(__name__)


class EvalMetrics(MetricHandler):
"""Implements the controller metric which exposes the evaluation metrics"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@

# Third Party
from transformers import TrainerState
from transformers.utils import logging

# Local
from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler

logger = logging.get_logger(__name__)
METRICS_KEY = "metrics"
LOG_LOSS_KEY = "loss"
TRAINING_LOSS_KEY = "training_loss"
Expand Down
6 changes: 2 additions & 4 deletions tuning/trainercontroller/operations/hfcontrols.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# Standard
from dataclasses import fields
import inspect
import logging
import re

# Third Party
from transformers import TrainerControl
from transformers.utils import logging

# Local
from .operation import Operation

logger = logging.get_logger(__name__)


class HFControls(Operation):
"""Implements the control actions for the HuggingFace controls in
Expand Down Expand Up @@ -39,7 +37,7 @@ def control_action(self, control: TrainerControl, **kwargs):
control: TrainerControl. Data class for controls.
kwargs: List of arguments (key, value)-pairs
"""
logger.debug("Arguments passed to control_action: %s", repr(kwargs))
logging.debug("Arguments passed to control_action: %s", repr(kwargs))
frame_info = inspect.currentframe().f_back
arg_values = inspect.getargvalues(frame_info)
setattr(control, arg_values.locals["action"], True)
10 changes: 4 additions & 6 deletions tuning/trainercontroller/patience.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# SPDX-License-Identifier: Apache-2.0
# https://spdx.dev/learn/handling-license-info/

# Third Party
from transformers.utils import logging
# Standard
import logging

# Resets the patience if the rule outcome happens to be false.
# Here, the expectation is to have unbroken "True"s for patience
Expand All @@ -31,8 +31,6 @@
# will be exceeded afer the fifth event.
MODE_NO_RESET_ON_FAILURE = "no_reset_on_failure"

logger = logging.get_logger(__name__)


class PatienceControl:
"""Implements the patience control for every rule"""
Expand All @@ -51,7 +49,7 @@ def should_tolerate(
elif self._mode == MODE_RESET_ON_FAILURE:
self._patience_counter = 0
if self._patience_counter <= self._patience_threshold:
logger.debug(
logging.debug(
"Control {} triggered on event {}: "
"Enforcing patience [patience_counter = {:.2f}, "
"patience_threshold = {:.2f}]".format(
Expand All @@ -62,7 +60,7 @@ def should_tolerate(
)
)
return True
logger.debug(
logging.debug(
"Control {} triggered on event {}: "
"Exceeded patience [patience_counter = {:.2f}, "
"patience_threshold = {:.2f}]".format(
Expand Down
Loading