diff --git a/examples/trainercontroller_configs/epoch-level-training-loss-below-threshold-log-config.yaml b/examples/trainercontroller_configs/epoch-level-training-loss-below-threshold-log-config.yaml new file mode 100644 index 000000000..53c8a1777 --- /dev/null +++ b/examples/trainercontroller_configs/epoch-level-training-loss-below-threshold-log-config.yaml @@ -0,0 +1,14 @@ +controller_metrics: + - name: training_loss_window + class: HistoryBasedMetric + arguments: + window_size: 1 +controllers: + - name: epoch_level_stop_on_training_loss_below_threshold + triggers: + - on_step_end + rule: len(training_loss_window["training_loss"]["loss"]) == training_loss_window["window_size"] and training_loss_window["training_loss"]["loss"][0] < 2.2 and training_loss_window["training_loss"]["epoch"][0] > 2 + config: + trigger_log_level: warning + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tuning/trainercontroller/callback.py b/tuning/trainercontroller/callback.py index ebb661b3f..2c5ddfcdf 100644 --- a/tuning/trainercontroller/callback.py +++ b/tuning/trainercontroller/callback.py @@ -59,11 +59,14 @@ CONTROLLER_CONFIG_KEY = "config" CONTROLLER_PATIENCE_CONFIG_KEY = "patience" CONTROLLER_TRIGGERS_KEY = "triggers" +CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL = "trigger_log_level" CONTROLLER_OPERATIONS_KEY = OPERATIONS_KEY -# Default operations / metrics to register +# Default values DEFAULT_OPERATIONS = {"operations": [{"name": "hfcontrols", "class": "HFControls"}]} DEFAULT_METRICS = {} +DEFAULT_CONFIG = {} +DEFAULT_TRIGGER_LOG_LEVEL = "debug" # pylint: disable=too-many-instance-attributes class TrainerControllerCallback(TrainerCallback): @@ -250,15 +253,14 @@ def _take_control_actions(self, event_name: str, **kwargs): continue if rule_succeeded: for operation_action in control_action.operation_actions: - logger.info( - "Taking [%s] action in controller [%s]", - operation_action.action, - control_action.name, - ) operation_action.instance.act( action=operation_action.action, event_name=event_name, tc_metrics=self.metrics, + control_name=control_action.name, + log_level=control_action.config[ + CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL + ], **kwargs, ) @@ -303,6 +305,7 @@ def on_init_end( kwargs["state"] = state kwargs["control"] = control + log_levels = logging.get_log_levels_dict() # Check if there any metrics listed in the configuration if ( CONTROLLER_METRICS_KEY not in self.trainer_controller_config @@ -399,8 +402,24 @@ def on_init_end( ), operation_actions=[], ) + config_log_level_str = DEFAULT_TRIGGER_LOG_LEVEL if CONTROLLER_CONFIG_KEY in controller: control.config = controller[CONTROLLER_CONFIG_KEY] + config_log_level_str = control.config.get( + CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL, config_log_level_str + ) + if config_log_level_str not in log_levels: + logger.warning( + "Incorrect trigger log-level [%s] specified in the config." + " Defaulting to 'debug' level", + config_log_level_str, + ) + config_log_level_str = DEFAULT_TRIGGER_LOG_LEVEL + else: + control.config = DEFAULT_CONFIG + control.config[CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL] = log_levels[ + config_log_level_str + ] if CONTROLLER_PATIENCE_CONFIG_KEY in controller: control.patience = PatienceControl( **controller[CONTROLLER_PATIENCE_CONFIG_KEY] diff --git a/tuning/trainercontroller/operations/operation.py b/tuning/trainercontroller/operations/operation.py index baa220c17..6e6d764fb 100644 --- a/tuning/trainercontroller/operations/operation.py +++ b/tuning/trainercontroller/operations/operation.py @@ -3,6 +3,11 @@ import inspect import re +# Third Party +from transformers.utils import logging + +logger = logging.get_logger(__name__) + class Operation(metaclass=abc.ABCMeta): """Base class for operations""" @@ -32,15 +37,27 @@ def validate(self, action: str) -> bool: """ return action in self.valid_actions - def act(self, action: str, **kwargs): + def act( + self, action: str, event_name: str, control_name: str, log_level: int, **kwargs + ): """Validates the action and invokes it. Args: action: str. String depicting the action. + event_name: str. Event name triggering the act. + control_name: str. Name of the controller defining the act. + log_level: int. Log level for triggering the log. kwargs: List of arguments (key, value)-pairs. """ if not self.validate(action): raise ValueError(f"Invalid operation {action}") + logger.log( + log_level, + "Taking [%s] action in controller [%s] triggered at event [%s]", + action, + control_name, + event_name, + ) self.valid_actions[action](**kwargs) def get_actions(self) -> list[str]: