From b36364e65781089463a41f55740bf3b3c4187019 Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Thu, 25 Jul 2024 01:32:33 -0400 Subject: [PATCH 01/13] feat: logging control operation Signed-off-by: Padmanabha V Seshadri --- .../log_controller.yaml | 16 +++++ tests/data/trainercontroller/__init__.py | 1 + .../trainercontroller/log_controller.yaml | 16 +++++ .../test_tuning_trainercontroller.py | 20 ++++-- .../controllermetrics/trainingstate.py | 6 +- .../trainercontroller/operations/__init__.py | 2 + .../operations/logcontrol.py | 61 +++++++++++++++++++ .../trainercontroller/operations/operation.py | 10 +++ 8 files changed, 126 insertions(+), 6 deletions(-) create mode 100644 examples/trainercontroller_configs/log_controller.yaml create mode 100644 tests/data/trainercontroller/log_controller.yaml create mode 100644 tuning/trainercontroller/operations/logcontrol.py diff --git a/examples/trainercontroller_configs/log_controller.yaml b/examples/trainercontroller_configs/log_controller.yaml new file mode 100644 index 000000000..8bda95355 --- /dev/null +++ b/examples/trainercontroller_configs/log_controller.yaml @@ -0,0 +1,16 @@ +controller_metrics: + - name: state + class: TrainingState +operations: + - name: logcontrolstep + class: LogControl + arguments: + log_format: 'This is a test log format [{event_name}] => {tc_metrics[state]}' + log_level: WARNING +controllers: + - name: log-controller-step + triggers: + - on_step_end + rule: 'True' + operations: + - logcontrolstep.should_log \ No newline at end of file diff --git a/tests/data/trainercontroller/__init__.py b/tests/data/trainercontroller/__init__.py index 35f4f13c9..18035f102 100644 --- a/tests/data/trainercontroller/__init__.py +++ b/tests/data/trainercontroller/__init__.py @@ -78,3 +78,4 @@ _DATA_DIR, "thresholded-training-loss.yaml" ) TRAINER_CONFIG_TEST_ON_SAVE_YAML = os.path.join(_DATA_DIR, "on-save.yaml") +TRAINER_CONFIG_LOG_CONTROLLER_YAML = os.path.join(_DATA_DIR, "log_controller.yaml") diff --git a/tests/data/trainercontroller/log_controller.yaml b/tests/data/trainercontroller/log_controller.yaml new file mode 100644 index 000000000..8bda95355 --- /dev/null +++ b/tests/data/trainercontroller/log_controller.yaml @@ -0,0 +1,16 @@ +controller_metrics: + - name: state + class: TrainingState +operations: + - name: logcontrolstep + class: LogControl + arguments: + log_format: 'This is a test log format [{event_name}] => {tc_metrics[state]}' + log_level: WARNING +controllers: + - name: log-controller-step + triggers: + - on_step_end + rule: 'True' + operations: + - logcontrolstep.should_log \ No newline at end of file diff --git a/tests/trainercontroller/test_tuning_trainercontroller.py b/tests/trainercontroller/test_tuning_trainercontroller.py index c4464da89..3c50eb07f 100644 --- a/tests/trainercontroller/test_tuning_trainercontroller.py +++ b/tests/trainercontroller/test_tuning_trainercontroller.py @@ -137,7 +137,6 @@ def test_thresholded_training_loss(): tc_callback.on_log(args=test_data.args, state=test_data.states[2], control=control) assert control.should_training_stop is True - def test_thresholded_training_loss_on_save(): """Tests the thresholded training loss example in `examples/trainer-controller-configs/on-save.yaml` @@ -145,14 +144,25 @@ def test_thresholded_training_loss_on_save(): test_data = _setup_data() tc_callback = tc.TrainerControllerCallback(td.TRAINER_CONFIG_TEST_ON_SAVE_YAML) control = TrainerControl(should_training_stop=False) - # Trigger on_init_end to perform registration of handlers to events - tc_callback.on_init_end( - args=test_data.args, state=test_data.states[2], control=control - ) # Trigger rule and test the condition tc_callback.on_save(args=test_data.args, state=test_data.states[2], control=control) assert control.should_training_stop is True +def test_log_controller(caplog): + """Tests the expose metric scenario example in + `examples/trainer-controller-configs/log_controller.yaml` + """ + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback(td.TRAINER_CONFIG_LOG_CONTROLLER_YAML) + control = TrainerControl(should_log=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.states[2], control=control + ) + tc_callback.on_step_end( + args=test_data.args, state=test_data.states[2], control=control + ) + assert "This is a test log format" in caplog.text def test_non_decreasing_training_loss(): """Tests the non-decreasing training loss example in diff --git a/tuning/trainercontroller/controllermetrics/trainingstate.py b/tuning/trainercontroller/controllermetrics/trainingstate.py index 59ab3638c..883bcc7cc 100644 --- a/tuning/trainercontroller/controllermetrics/trainingstate.py +++ b/tuning/trainercontroller/controllermetrics/trainingstate.py @@ -21,10 +21,13 @@ # Third Party from transformers import TrainerState +from transformers.utils import logging # Local from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler +logger = logging.get_logger(__name__) + class TrainingState(MetricHandler): """Implements the controller metric which exposes the trainer state""" @@ -49,7 +52,7 @@ def __init__(self, **kwargs): "on_train_begin", "on_evaluate", ], - **kwargs + **kwargs, ) def validate(self) -> bool: @@ -71,4 +74,5 @@ def compute(self, state: TrainerState = None, **kwargs) -> Any: Returns: dict. Trainer state as a dictionary """ + logger.warning(f"Trainer state: {dataclasses.asdict(state)}") return dataclasses.asdict(state) diff --git a/tuning/trainercontroller/operations/__init__.py b/tuning/trainercontroller/operations/__init__.py index 99456d7ec..c0253d8f4 100644 --- a/tuning/trainercontroller/operations/__init__.py +++ b/tuning/trainercontroller/operations/__init__.py @@ -3,6 +3,7 @@ # Local from .hfcontrols import HFControls +from .logcontrol import LogControl from .operation import Operation # List of operation handlers @@ -20,3 +21,4 @@ def register(cl: Type): # Register the default operation handlers in this package here register(HFControls) +register(LogControl) diff --git a/tuning/trainercontroller/operations/logcontrol.py b/tuning/trainercontroller/operations/logcontrol.py new file mode 100644 index 000000000..67f2eebef --- /dev/null +++ b/tuning/trainercontroller/operations/logcontrol.py @@ -0,0 +1,61 @@ +# Third Party +from transformers import TrainingArguments +from transformers.utils import logging + +# Local +from .operation import Operation + +logger = logging.get_logger(__name__) +logger.setLevel(level=logging.DEBUG) + +VALID_LOG_LEVELS = { + "ERROR": logging.ERROR, + "WARNING": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, +} + + +class LogControl(Operation): + """Operation that can be used to log useful information on specific events.""" + + def __init__(self, log_format: str, log_level: str, **kwargs): + """Initializes the HuggingFace controls. In this init, the fields with `should_` of the + transformers.TrainerControl data class are extracted, and for each of those fields, the + control_action() method's pointer is set, and injected as a class member function. + + Args: + kwargs: List of arguments (key, value)-pairs + """ + if log_level not in VALID_LOG_LEVELS: + raise ValueError( + "Specified log_level [%s] is invalid for LogControl" % (log_level) + ) + self.log_level = VALID_LOG_LEVELS[log_level] + self.log_format = log_format + super().__init__(**kwargs) + + def should_log( + self, + event_name: str, + tc_metrics: dict, + args: TrainingArguments, + **kwargs, + ): + """This method peeks into the stack-frame of the caller to get the action the triggered + a call to it. Using the name of the action, the value of the control is set. + + Args: + control: TrainerControl. Data class for controls. + kwargs: List of arguments (key, value)-pairs + """ + log_msg = self.log_format.format( + event_name=event_name, + tc_metrics=tc_metrics, + args=args, + **kwargs, + ) + logger.log( + self.log_level, + log_msg, + ) diff --git a/tuning/trainercontroller/operations/operation.py b/tuning/trainercontroller/operations/operation.py index 6e6d764fb..11e359213 100644 --- a/tuning/trainercontroller/operations/operation.py +++ b/tuning/trainercontroller/operations/operation.py @@ -17,6 +17,8 @@ def __init__(self, name: str, **kwargs): every action should preceed with prefix `should_`. If so, it is treated as a valid action. """ + self._name = name + self.kwargs = kwargs self.valid_actions = {} self.name = name self.kwargs = kwargs @@ -26,6 +28,14 @@ def __init__(self, name: str, **kwargs): if re.search(r"^should_.+", action_name) is not None: self.valid_actions[action_name] = action_method + def get_name(self) -> str: + """Returns the name of the operation. + + Returns: + str + """ + return self._name + def validate(self, action: str) -> bool: """Validates the action by checking if it valid action or not. From 2faee059dad3157bf902e84ae2c70ccae4537cbf Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Thu, 25 Jul 2024 01:38:38 -0400 Subject: [PATCH 02/13] fix: Removed unwanted warning Signed-off-by: Padmanabha V Seshadri --- tuning/trainercontroller/controllermetrics/trainingstate.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tuning/trainercontroller/controllermetrics/trainingstate.py b/tuning/trainercontroller/controllermetrics/trainingstate.py index 883bcc7cc..8dc276339 100644 --- a/tuning/trainercontroller/controllermetrics/trainingstate.py +++ b/tuning/trainercontroller/controllermetrics/trainingstate.py @@ -74,5 +74,4 @@ def compute(self, state: TrainerState = None, **kwargs) -> Any: Returns: dict. Trainer state as a dictionary """ - logger.warning(f"Trainer state: {dataclasses.asdict(state)}") return dataclasses.asdict(state) From 304b5c1b508a096d9ab4214656070a729247ffef Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Mon, 29 Jul 2024 01:36:09 -0400 Subject: [PATCH 03/13] fix: Replaced log_levels with one from package Signed-off-by: Padmanabha V Seshadri --- tuning/trainercontroller/operations/logcontrol.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tuning/trainercontroller/operations/logcontrol.py b/tuning/trainercontroller/operations/logcontrol.py index 67f2eebef..490cb8249 100644 --- a/tuning/trainercontroller/operations/logcontrol.py +++ b/tuning/trainercontroller/operations/logcontrol.py @@ -8,14 +8,6 @@ logger = logging.get_logger(__name__) logger.setLevel(level=logging.DEBUG) -VALID_LOG_LEVELS = { - "ERROR": logging.ERROR, - "WARNING": logging.WARNING, - "INFO": logging.INFO, - "DEBUG": logging.DEBUG, -} - - class LogControl(Operation): """Operation that can be used to log useful information on specific events.""" @@ -27,11 +19,12 @@ def __init__(self, log_format: str, log_level: str, **kwargs): Args: kwargs: List of arguments (key, value)-pairs """ - if log_level not in VALID_LOG_LEVELS: + log_levels = logging.get_log_levels_dict() + if log_level not in log_levels: raise ValueError( "Specified log_level [%s] is invalid for LogControl" % (log_level) ) - self.log_level = VALID_LOG_LEVELS[log_level] + self.log_level = log_levels[log_level] self.log_format = log_format super().__init__(**kwargs) From 925bc71bf146c899ee93e5efc4b47db94c2b5c20 Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Mon, 29 Jul 2024 01:43:14 -0400 Subject: [PATCH 04/13] fix: Formatting issues resolved Signed-off-by: Padmanabha V Seshadri --- tuning/trainercontroller/operations/logcontrol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tuning/trainercontroller/operations/logcontrol.py b/tuning/trainercontroller/operations/logcontrol.py index 490cb8249..dc2ba67bd 100644 --- a/tuning/trainercontroller/operations/logcontrol.py +++ b/tuning/trainercontroller/operations/logcontrol.py @@ -8,6 +8,7 @@ logger = logging.get_logger(__name__) logger.setLevel(level=logging.DEBUG) + class LogControl(Operation): """Operation that can be used to log useful information on specific events.""" From 73da0b62c7150e87c00b30647ed31b89f0f99f72 Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Mon, 29 Jul 2024 01:46:49 -0400 Subject: [PATCH 05/13] fix: log_level value should be small-case Signed-off-by: Padmanabha V Seshadri --- tests/data/trainercontroller/log_controller.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/trainercontroller/log_controller.yaml b/tests/data/trainercontroller/log_controller.yaml index 8bda95355..da7bb8a4a 100644 --- a/tests/data/trainercontroller/log_controller.yaml +++ b/tests/data/trainercontroller/log_controller.yaml @@ -6,7 +6,7 @@ operations: class: LogControl arguments: log_format: 'This is a test log format [{event_name}] => {tc_metrics[state]}' - log_level: WARNING + log_level: warning controllers: - name: log-controller-step triggers: From 29bbb60d74b63cf886aa29b9035e83b192983172 Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Mon, 5 Aug 2024 01:53:04 -0400 Subject: [PATCH 06/13] fix: Format issues resolved Signed-off-by: Padmanabha V Seshadri --- tests/trainercontroller/test_tuning_trainercontroller.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/trainercontroller/test_tuning_trainercontroller.py b/tests/trainercontroller/test_tuning_trainercontroller.py index 3c50eb07f..f070e999e 100644 --- a/tests/trainercontroller/test_tuning_trainercontroller.py +++ b/tests/trainercontroller/test_tuning_trainercontroller.py @@ -137,6 +137,7 @@ def test_thresholded_training_loss(): tc_callback.on_log(args=test_data.args, state=test_data.states[2], control=control) assert control.should_training_stop is True + def test_thresholded_training_loss_on_save(): """Tests the thresholded training loss example in `examples/trainer-controller-configs/on-save.yaml` @@ -148,6 +149,7 @@ def test_thresholded_training_loss_on_save(): tc_callback.on_save(args=test_data.args, state=test_data.states[2], control=control) assert control.should_training_stop is True + def test_log_controller(caplog): """Tests the expose metric scenario example in `examples/trainer-controller-configs/log_controller.yaml` @@ -164,6 +166,7 @@ def test_log_controller(caplog): ) assert "This is a test log format" in caplog.text + def test_non_decreasing_training_loss(): """Tests the non-decreasing training loss example in `examples/trainer-controller-configs/non-decreasing-training-loss.yaml` From 070ff219308e45f9ea4c6f591ad56148f56dedc0 Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Tue, 6 Aug 2024 01:29:42 -0400 Subject: [PATCH 07/13] fix: Arguments reordered to support folding Signed-off-by: Padmanabha V Seshadri --- tests/trainercontroller/test_tuning_trainercontroller.py | 4 ++++ tuning/trainercontroller/callback.py | 2 +- tuning/trainercontroller/operations/logcontrol.py | 6 ++++-- tuning/trainercontroller/operations/operation.py | 2 +- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/trainercontroller/test_tuning_trainercontroller.py b/tests/trainercontroller/test_tuning_trainercontroller.py index f070e999e..ba1a05808 100644 --- a/tests/trainercontroller/test_tuning_trainercontroller.py +++ b/tests/trainercontroller/test_tuning_trainercontroller.py @@ -145,6 +145,10 @@ def test_thresholded_training_loss_on_save(): test_data = _setup_data() tc_callback = tc.TrainerControllerCallback(td.TRAINER_CONFIG_TEST_ON_SAVE_YAML) control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.states[2], control=control + ) # Trigger rule and test the condition tc_callback.on_save(args=test_data.args, state=test_data.states[2], control=control) assert control.should_training_stop is True diff --git a/tuning/trainercontroller/callback.py b/tuning/trainercontroller/callback.py index 2c5ddfcdf..fe8072804 100644 --- a/tuning/trainercontroller/callback.py +++ b/tuning/trainercontroller/callback.py @@ -255,8 +255,8 @@ def _take_control_actions(self, event_name: str, **kwargs): for operation_action in control_action.operation_actions: operation_action.instance.act( action=operation_action.action, - event_name=event_name, tc_metrics=self.metrics, + event_name=event_name, control_name=control_action.name, log_level=control_action.config[ CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL diff --git a/tuning/trainercontroller/operations/logcontrol.py b/tuning/trainercontroller/operations/logcontrol.py index dc2ba67bd..24f1777f5 100644 --- a/tuning/trainercontroller/operations/logcontrol.py +++ b/tuning/trainercontroller/operations/logcontrol.py @@ -31,9 +31,11 @@ def __init__(self, log_format: str, log_level: str, **kwargs): def should_log( self, - event_name: str, tc_metrics: dict, - args: TrainingArguments, + event_name: str=None, + control_name: str=None, + log_level: str=logging.DEBUG, + args: TrainingArguments=None, **kwargs, ): """This method peeks into the stack-frame of the caller to get the action the triggered diff --git a/tuning/trainercontroller/operations/operation.py b/tuning/trainercontroller/operations/operation.py index 11e359213..7d3d1bfce 100644 --- a/tuning/trainercontroller/operations/operation.py +++ b/tuning/trainercontroller/operations/operation.py @@ -48,7 +48,7 @@ def validate(self, action: str) -> bool: return action in self.valid_actions def act( - self, action: str, event_name: str, control_name: str, log_level: int, **kwargs + self, action: str, event_name: str=None, control_name: str=None, log_level: int=logging.DEBUG, **kwargs ): """Validates the action and invokes it. From c07bd28dfe3683e9f5c846bfa34f224209b8916a Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Tue, 6 Aug 2024 01:30:45 -0400 Subject: [PATCH 08/13] fix: Formatting issues resolved Signed-off-by: Padmanabha V Seshadri --- tuning/trainercontroller/operations/logcontrol.py | 8 ++++---- tuning/trainercontroller/operations/operation.py | 7 ++++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tuning/trainercontroller/operations/logcontrol.py b/tuning/trainercontroller/operations/logcontrol.py index 24f1777f5..775734fd0 100644 --- a/tuning/trainercontroller/operations/logcontrol.py +++ b/tuning/trainercontroller/operations/logcontrol.py @@ -32,10 +32,10 @@ def __init__(self, log_format: str, log_level: str, **kwargs): def should_log( self, tc_metrics: dict, - event_name: str=None, - control_name: str=None, - log_level: str=logging.DEBUG, - args: TrainingArguments=None, + event_name: str = None, + control_name: str = None, + log_level: str = logging.DEBUG, + args: TrainingArguments = None, **kwargs, ): """This method peeks into the stack-frame of the caller to get the action the triggered diff --git a/tuning/trainercontroller/operations/operation.py b/tuning/trainercontroller/operations/operation.py index 7d3d1bfce..03dcb55c1 100644 --- a/tuning/trainercontroller/operations/operation.py +++ b/tuning/trainercontroller/operations/operation.py @@ -48,7 +48,12 @@ def validate(self, action: str) -> bool: return action in self.valid_actions def act( - self, action: str, event_name: str=None, control_name: str=None, log_level: int=logging.DEBUG, **kwargs + self, + action: str, + event_name: str = None, + control_name: str = None, + log_level: int = logging.DEBUG, + **kwargs, ): """Validates the action and invokes it. From 021e7c79460e240d094516ed62eec350cf28f2dc Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Tue, 6 Aug 2024 01:37:05 -0400 Subject: [PATCH 09/13] fix: Reordered arguments in logcontrol Signed-off-by: Padmanabha V Seshadri --- tuning/trainercontroller/callback.py | 6 +++--- tuning/trainercontroller/operations/logcontrol.py | 2 +- tuning/trainercontroller/operations/operation.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tuning/trainercontroller/callback.py b/tuning/trainercontroller/callback.py index fe8072804..0200d760f 100644 --- a/tuning/trainercontroller/callback.py +++ b/tuning/trainercontroller/callback.py @@ -255,12 +255,12 @@ def _take_control_actions(self, event_name: str, **kwargs): for operation_action in control_action.operation_actions: operation_action.instance.act( action=operation_action.action, - tc_metrics=self.metrics, - event_name=event_name, - control_name=control_action.name, log_level=control_action.config[ CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL ], + tc_metrics=self.metrics, + event_name=event_name, + control_name=control_action.name, **kwargs, ) diff --git a/tuning/trainercontroller/operations/logcontrol.py b/tuning/trainercontroller/operations/logcontrol.py index 775734fd0..6d05e9054 100644 --- a/tuning/trainercontroller/operations/logcontrol.py +++ b/tuning/trainercontroller/operations/logcontrol.py @@ -34,7 +34,6 @@ def should_log( tc_metrics: dict, event_name: str = None, control_name: str = None, - log_level: str = logging.DEBUG, args: TrainingArguments = None, **kwargs, ): @@ -47,6 +46,7 @@ def should_log( """ log_msg = self.log_format.format( event_name=event_name, + control_name=control_name, tc_metrics=tc_metrics, args=args, **kwargs, diff --git a/tuning/trainercontroller/operations/operation.py b/tuning/trainercontroller/operations/operation.py index 03dcb55c1..b8c87a738 100644 --- a/tuning/trainercontroller/operations/operation.py +++ b/tuning/trainercontroller/operations/operation.py @@ -50,9 +50,9 @@ def validate(self, action: str) -> bool: def act( self, action: str, + log_level: int, event_name: str = None, control_name: str = None, - log_level: int = logging.DEBUG, **kwargs, ): """Validates the action and invokes it. From 006a3501ae7cf9c431a53469cf6e4692538a11a6 Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Tue, 6 Aug 2024 01:59:02 -0400 Subject: [PATCH 10/13] fix: Metrics flattened and passed to operations Signed-off-by: Padmanabha V Seshadri --- tests/data/trainercontroller/log_controller.yaml | 4 ++-- .../loss_on_threshold_with_trainer_state.yaml | 4 ++-- tests/data/trainercontroller/loss_with_invalid_type_rule.yaml | 2 +- tests/data/trainercontroller/on-save.yaml | 4 ++-- tuning/trainercontroller/callback.py | 2 +- tuning/trainercontroller/operations/logcontrol.py | 2 -- 6 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/data/trainercontroller/log_controller.yaml b/tests/data/trainercontroller/log_controller.yaml index da7bb8a4a..0becdc7e2 100644 --- a/tests/data/trainercontroller/log_controller.yaml +++ b/tests/data/trainercontroller/log_controller.yaml @@ -1,11 +1,11 @@ controller_metrics: - - name: state + - name: trainer_state class: TrainingState operations: - name: logcontrolstep class: LogControl arguments: - log_format: 'This is a test log format [{event_name}] => {tc_metrics[state]}' + log_format: 'This is a test log format [{event_name}] => {trainer_state}' log_level: warning controllers: - name: log-controller-step diff --git a/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml b/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml index 9dc764c4e..cb9bcf957 100644 --- a/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml +++ b/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml @@ -1,5 +1,5 @@ controller_metrics: - - name: state + - name: trainer_state class: TrainingState - name: training_loss class: Loss @@ -7,6 +7,6 @@ controllers: - name: loss_controller triggers: - on_log - rule: training_loss['loss'] < 2 and state["epoch"] >= 0.5 + rule: training_loss['loss'] < 2 and trainer_state["epoch"] >= 0.5 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml b/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml index 01495f106..bf8648e93 100644 --- a/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml +++ b/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml @@ -1,5 +1,5 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_wrong_os_rule diff --git a/tests/data/trainercontroller/on-save.yaml b/tests/data/trainercontroller/on-save.yaml index a6fffb116..225cba1cc 100644 --- a/tests/data/trainercontroller/on-save.yaml +++ b/tests/data/trainercontroller/on-save.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: state + - name: trainer_state class: TrainingState controllers: - name: stop_on_training_loss_on_save triggers: - on_save - rule: state["epoch"] >= 0.5 + rule: trainer_state["epoch"] >= 0.5 operations: - hfcontrols.should_training_stop diff --git a/tuning/trainercontroller/callback.py b/tuning/trainercontroller/callback.py index 0200d760f..f6c6d9061 100644 --- a/tuning/trainercontroller/callback.py +++ b/tuning/trainercontroller/callback.py @@ -258,9 +258,9 @@ def _take_control_actions(self, event_name: str, **kwargs): log_level=control_action.config[ CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL ], - tc_metrics=self.metrics, event_name=event_name, control_name=control_action.name, + **self.metrics, **kwargs, ) diff --git a/tuning/trainercontroller/operations/logcontrol.py b/tuning/trainercontroller/operations/logcontrol.py index 6d05e9054..385de3b4d 100644 --- a/tuning/trainercontroller/operations/logcontrol.py +++ b/tuning/trainercontroller/operations/logcontrol.py @@ -31,7 +31,6 @@ def __init__(self, log_format: str, log_level: str, **kwargs): def should_log( self, - tc_metrics: dict, event_name: str = None, control_name: str = None, args: TrainingArguments = None, @@ -47,7 +46,6 @@ def should_log( log_msg = self.log_format.format( event_name=event_name, control_name=control_name, - tc_metrics=tc_metrics, args=args, **kwargs, ) From 2710e9f1bdbdfd5dd412f32589208fd485c0ba41 Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Tue, 6 Aug 2024 05:36:58 -0400 Subject: [PATCH 11/13] fix: Example log controller yaml Signed-off-by: Padmanabha V Seshadri --- examples/trainercontroller_configs/log_controller.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/trainercontroller_configs/log_controller.yaml b/examples/trainercontroller_configs/log_controller.yaml index 8bda95355..f20ad17ce 100644 --- a/examples/trainercontroller_configs/log_controller.yaml +++ b/examples/trainercontroller_configs/log_controller.yaml @@ -1,11 +1,11 @@ controller_metrics: - - name: state + - name: trainer_state class: TrainingState operations: - name: logcontrolstep class: LogControl arguments: - log_format: 'This is a test log format [{event_name}] => {tc_metrics[state]}' + log_format: 'This is a test log format [{event_name}] => {trainer_state}' log_level: WARNING controllers: - name: log-controller-step From 76a17b3cda37f50d7da0413097a1ce3dc86e15b8 Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Thu, 8 Aug 2024 05:45:05 -0400 Subject: [PATCH 12/13] fix: Logging control yaml and kwargs corrected Signed-off-by: Padmanabha V Seshadri --- examples/trainercontroller_configs/log_controller.yaml | 2 +- tuning/trainercontroller/operations/operation.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/trainercontroller_configs/log_controller.yaml b/examples/trainercontroller_configs/log_controller.yaml index f20ad17ce..0becdc7e2 100644 --- a/examples/trainercontroller_configs/log_controller.yaml +++ b/examples/trainercontroller_configs/log_controller.yaml @@ -6,7 +6,7 @@ operations: class: LogControl arguments: log_format: 'This is a test log format [{event_name}] => {trainer_state}' - log_level: WARNING + log_level: warning controllers: - name: log-controller-step triggers: diff --git a/tuning/trainercontroller/operations/operation.py b/tuning/trainercontroller/operations/operation.py index b8c87a738..b7019a0eb 100644 --- a/tuning/trainercontroller/operations/operation.py +++ b/tuning/trainercontroller/operations/operation.py @@ -73,6 +73,8 @@ def act( control_name, event_name, ) + kwargs['event_name'] = event_name + kwargs['control_name'] = control_name self.valid_actions[action](**kwargs) def get_actions(self) -> list[str]: From 4656bee68f6f540dcdfd85783103cdbf8ad4030b Mon Sep 17 00:00:00 2001 From: Padmanabha V Seshadri Date: Thu, 8 Aug 2024 05:57:36 -0400 Subject: [PATCH 13/13] fix: Format issues Signed-off-by: Padmanabha V Seshadri --- tuning/trainercontroller/operations/operation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tuning/trainercontroller/operations/operation.py b/tuning/trainercontroller/operations/operation.py index b7019a0eb..70805a015 100644 --- a/tuning/trainercontroller/operations/operation.py +++ b/tuning/trainercontroller/operations/operation.py @@ -73,8 +73,8 @@ def act( control_name, event_name, ) - kwargs['event_name'] = event_name - kwargs['control_name'] = control_name + kwargs["event_name"] = event_name + kwargs["control_name"] = control_name self.valid_actions[action](**kwargs) def get_actions(self) -> list[str]: