diff --git a/examples/trainercontroller_configs/log_controller.yaml b/examples/trainercontroller_configs/log_controller.yaml new file mode 100644 index 000000000..0becdc7e2 --- /dev/null +++ b/examples/trainercontroller_configs/log_controller.yaml @@ -0,0 +1,16 @@ +controller_metrics: + - name: trainer_state + class: TrainingState +operations: + - name: logcontrolstep + class: LogControl + arguments: + log_format: 'This is a test log format [{event_name}] => {trainer_state}' + log_level: warning +controllers: + - name: log-controller-step + triggers: + - on_step_end + rule: 'True' + operations: + - logcontrolstep.should_log \ No newline at end of file diff --git a/tests/data/trainercontroller/__init__.py b/tests/data/trainercontroller/__init__.py index 35f4f13c9..18035f102 100644 --- a/tests/data/trainercontroller/__init__.py +++ b/tests/data/trainercontroller/__init__.py @@ -78,3 +78,4 @@ _DATA_DIR, "thresholded-training-loss.yaml" ) TRAINER_CONFIG_TEST_ON_SAVE_YAML = os.path.join(_DATA_DIR, "on-save.yaml") +TRAINER_CONFIG_LOG_CONTROLLER_YAML = os.path.join(_DATA_DIR, "log_controller.yaml") diff --git a/tests/data/trainercontroller/log_controller.yaml b/tests/data/trainercontroller/log_controller.yaml new file mode 100644 index 000000000..0becdc7e2 --- /dev/null +++ b/tests/data/trainercontroller/log_controller.yaml @@ -0,0 +1,16 @@ +controller_metrics: + - name: trainer_state + class: TrainingState +operations: + - name: logcontrolstep + class: LogControl + arguments: + log_format: 'This is a test log format [{event_name}] => {trainer_state}' + log_level: warning +controllers: + - name: log-controller-step + triggers: + - on_step_end + rule: 'True' + operations: + - logcontrolstep.should_log \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml b/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml index 9dc764c4e..cb9bcf957 100644 --- a/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml +++ b/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml @@ -1,5 +1,5 @@ controller_metrics: - - name: state + - name: trainer_state class: TrainingState - name: training_loss class: Loss @@ -7,6 +7,6 @@ controllers: - name: loss_controller triggers: - on_log - rule: training_loss['loss'] < 2 and state["epoch"] >= 0.5 + rule: training_loss['loss'] < 2 and trainer_state["epoch"] >= 0.5 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml b/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml index 01495f106..bf8648e93 100644 --- a/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml +++ b/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml @@ -1,5 +1,5 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_wrong_os_rule diff --git a/tests/data/trainercontroller/on-save.yaml b/tests/data/trainercontroller/on-save.yaml index a6fffb116..225cba1cc 100644 --- a/tests/data/trainercontroller/on-save.yaml +++ b/tests/data/trainercontroller/on-save.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: state + - name: trainer_state class: TrainingState controllers: - name: stop_on_training_loss_on_save triggers: - on_save - rule: state["epoch"] >= 0.5 + rule: trainer_state["epoch"] >= 0.5 operations: - hfcontrols.should_training_stop diff --git a/tests/trainercontroller/test_tuning_trainercontroller.py b/tests/trainercontroller/test_tuning_trainercontroller.py index c4464da89..ba1a05808 100644 --- a/tests/trainercontroller/test_tuning_trainercontroller.py +++ b/tests/trainercontroller/test_tuning_trainercontroller.py @@ -154,6 +154,23 @@ def test_thresholded_training_loss_on_save(): assert control.should_training_stop is True +def test_log_controller(caplog): + """Tests the expose metric scenario example in + `examples/trainer-controller-configs/log_controller.yaml` + """ + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback(td.TRAINER_CONFIG_LOG_CONTROLLER_YAML) + control = TrainerControl(should_log=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.states[2], control=control + ) + tc_callback.on_step_end( + args=test_data.args, state=test_data.states[2], control=control + ) + assert "This is a test log format" in caplog.text + + def test_non_decreasing_training_loss(): """Tests the non-decreasing training loss example in `examples/trainer-controller-configs/non-decreasing-training-loss.yaml` diff --git a/tuning/trainercontroller/callback.py b/tuning/trainercontroller/callback.py index 2c5ddfcdf..f6c6d9061 100644 --- a/tuning/trainercontroller/callback.py +++ b/tuning/trainercontroller/callback.py @@ -255,12 +255,12 @@ def _take_control_actions(self, event_name: str, **kwargs): for operation_action in control_action.operation_actions: operation_action.instance.act( action=operation_action.action, - event_name=event_name, - tc_metrics=self.metrics, - control_name=control_action.name, log_level=control_action.config[ CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL ], + event_name=event_name, + control_name=control_action.name, + **self.metrics, **kwargs, ) diff --git a/tuning/trainercontroller/controllermetrics/trainingstate.py b/tuning/trainercontroller/controllermetrics/trainingstate.py index 59ab3638c..8dc276339 100644 --- a/tuning/trainercontroller/controllermetrics/trainingstate.py +++ b/tuning/trainercontroller/controllermetrics/trainingstate.py @@ -21,10 +21,13 @@ # Third Party from transformers import TrainerState +from transformers.utils import logging # Local from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler +logger = logging.get_logger(__name__) + class TrainingState(MetricHandler): """Implements the controller metric which exposes the trainer state""" @@ -49,7 +52,7 @@ def __init__(self, **kwargs): "on_train_begin", "on_evaluate", ], - **kwargs + **kwargs, ) def validate(self) -> bool: diff --git a/tuning/trainercontroller/operations/__init__.py b/tuning/trainercontroller/operations/__init__.py index 99456d7ec..c0253d8f4 100644 --- a/tuning/trainercontroller/operations/__init__.py +++ b/tuning/trainercontroller/operations/__init__.py @@ -3,6 +3,7 @@ # Local from .hfcontrols import HFControls +from .logcontrol import LogControl from .operation import Operation # List of operation handlers @@ -20,3 +21,4 @@ def register(cl: Type): # Register the default operation handlers in this package here register(HFControls) +register(LogControl) diff --git a/tuning/trainercontroller/operations/logcontrol.py b/tuning/trainercontroller/operations/logcontrol.py new file mode 100644 index 000000000..385de3b4d --- /dev/null +++ b/tuning/trainercontroller/operations/logcontrol.py @@ -0,0 +1,55 @@ +# Third Party +from transformers import TrainingArguments +from transformers.utils import logging + +# Local +from .operation import Operation + +logger = logging.get_logger(__name__) +logger.setLevel(level=logging.DEBUG) + + +class LogControl(Operation): + """Operation that can be used to log useful information on specific events.""" + + def __init__(self, log_format: str, log_level: str, **kwargs): + """Initializes the HuggingFace controls. In this init, the fields with `should_` of the + transformers.TrainerControl data class are extracted, and for each of those fields, the + control_action() method's pointer is set, and injected as a class member function. + + Args: + kwargs: List of arguments (key, value)-pairs + """ + log_levels = logging.get_log_levels_dict() + if log_level not in log_levels: + raise ValueError( + "Specified log_level [%s] is invalid for LogControl" % (log_level) + ) + self.log_level = log_levels[log_level] + self.log_format = log_format + super().__init__(**kwargs) + + def should_log( + self, + event_name: str = None, + control_name: str = None, + args: TrainingArguments = None, + **kwargs, + ): + """This method peeks into the stack-frame of the caller to get the action the triggered + a call to it. Using the name of the action, the value of the control is set. + + Args: + control: TrainerControl. Data class for controls. + kwargs: List of arguments (key, value)-pairs + """ + log_msg = self.log_format.format( + event_name=event_name, + control_name=control_name, + args=args, + **kwargs, + ) + logger.log( + self.log_level, + log_msg, + ) diff --git a/tuning/trainercontroller/operations/operation.py b/tuning/trainercontroller/operations/operation.py index 6e6d764fb..70805a015 100644 --- a/tuning/trainercontroller/operations/operation.py +++ b/tuning/trainercontroller/operations/operation.py @@ -17,6 +17,8 @@ def __init__(self, name: str, **kwargs): every action should preceed with prefix `should_`. If so, it is treated as a valid action. """ + self._name = name + self.kwargs = kwargs self.valid_actions = {} self.name = name self.kwargs = kwargs @@ -26,6 +28,14 @@ def __init__(self, name: str, **kwargs): if re.search(r"^should_.+", action_name) is not None: self.valid_actions[action_name] = action_method + def get_name(self) -> str: + """Returns the name of the operation. + + Returns: + str + """ + return self._name + def validate(self, action: str) -> bool: """Validates the action by checking if it valid action or not. @@ -38,7 +48,12 @@ def validate(self, action: str) -> bool: return action in self.valid_actions def act( - self, action: str, event_name: str, control_name: str, log_level: int, **kwargs + self, + action: str, + log_level: int, + event_name: str = None, + control_name: str = None, + **kwargs, ): """Validates the action and invokes it. @@ -58,6 +73,8 @@ def act( control_name, event_name, ) + kwargs["event_name"] = event_name + kwargs["control_name"] = control_name self.valid_actions[action](**kwargs) def get_actions(self) -> list[str]: