diff --git a/examples/trainercontroller_configs/loss.yaml b/examples/trainercontroller_configs/loss.yaml index d7d0baa2b..c4322a6b4 100644 --- a/examples/trainercontroller_configs/loss.yaml +++ b/examples/trainercontroller_configs/loss.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller triggers: - on_log - rule: loss < 1.0 + rule: training_loss["loss"] < 1.0 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_custom_operation.yaml b/tests/data/trainercontroller/loss_custom_operation.yaml index 603459234..3ec952a85 100644 --- a/tests/data/trainercontroller/loss_custom_operation.yaml +++ b/tests/data/trainercontroller/loss_custom_operation.yaml @@ -1,5 +1,5 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss operations: - name: custom_operation @@ -8,6 +8,6 @@ controllers: - name: loss_controller_custom_operation triggers: - on_log - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - custom_operation.should_perform_action_xyz \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml b/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml index 3dac47cb2..e0d3a71d3 100644 --- a/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml +++ b/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml @@ -1,5 +1,5 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss operations: - name: custom_operation @@ -8,6 +8,6 @@ controllers: - name: loss_controller_custom_operation_invalid_action triggers: - on_log - rule: loss < 1.0 + rule: training_loss["loss"] < 1.0 operations: - custom_operation.should_ \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_invalid_metric.yaml b/tests/data/trainercontroller/loss_invalid_metric.yaml index 4d94878aa..8491175b0 100644 --- a/tests/data/trainercontroller/loss_invalid_metric.yaml +++ b/tests/data/trainercontroller/loss_invalid_metric.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: MissingMetricClass controllers: - name: loss_controller_invalid_metric triggers: - on_log - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_invalid_operation.yaml b/tests/data/trainercontroller/loss_invalid_operation.yaml index f904e27d9..769c9441a 100644 --- a/tests/data/trainercontroller/loss_invalid_operation.yaml +++ b/tests/data/trainercontroller/loss_invalid_operation.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_invalid_operation triggers: - on_log - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - missingop.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_invalid_operation_action.yaml b/tests/data/trainercontroller/loss_invalid_operation_action.yaml index 3015516ef..7d8a17ad0 100644 --- a/tests/data/trainercontroller/loss_invalid_operation_action.yaml +++ b/tests/data/trainercontroller/loss_invalid_operation_action.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_invalid_operation_action triggers: - on_log - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - hfcontrols.missingaction \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_invalid_trigger.yaml b/tests/data/trainercontroller/loss_invalid_trigger.yaml index 382ad7783..38abe7ed9 100644 --- a/tests/data/trainercontroller/loss_invalid_trigger.yaml +++ b/tests/data/trainercontroller/loss_invalid_trigger.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_invalid_trigger triggers: - log_it_all_incorrect_trigger_name - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_on_threshold.yaml b/tests/data/trainercontroller/loss_on_threshold.yaml index d7d0baa2b..24891e8ed 100644 --- a/tests/data/trainercontroller/loss_on_threshold.yaml +++ b/tests/data/trainercontroller/loss_on_threshold.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller triggers: - on_log - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml b/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml index 45e2a3eea..9dc764c4e 100644 --- a/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml +++ b/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml @@ -1,12 +1,12 @@ controller_metrics: - name: state class: TrainingState - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller triggers: - on_log - rule: loss < 2 and state["epoch"] >= 0.5 + rule: training_loss['loss'] < 2 and state["epoch"] >= 0.5 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_unavailable_metric.yaml b/tests/data/trainercontroller/loss_unavailable_metric.yaml index 055b93cf3..564184290 100644 --- a/tests/data/trainercontroller/loss_unavailable_metric.yaml +++ b/tests/data/trainercontroller/loss_unavailable_metric.yaml @@ -1,10 +1,10 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_unavailable_metric triggers: - on_step_end - rule: loss < 1.0 + rule: training_loss['loss'] < 1.0 operations: - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml b/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml index 6d5c65328..e2cbb26de 100644 --- a/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml +++ b/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml @@ -1,5 +1,5 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_wrong_input_rule diff --git a/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml b/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml index badcf940a..5ee4bc224 100644 --- a/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml +++ b/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml @@ -1,5 +1,5 @@ controller_metrics: - - name: loss + - name: training_loss class: Loss controllers: - name: loss_controller_wrong_os_rule diff --git a/tests/trainercontroller/custom_operation.py b/tests/trainercontroller/custom_operation.py index 2c402fa96..522200b49 100644 --- a/tests/trainercontroller/custom_operation.py +++ b/tests/trainercontroller/custom_operation.py @@ -26,13 +26,6 @@ class CustomOperation(Operation): """Implements a custom operation for testing""" - def __init__(self, **_): - """Initializes the custom operation class. - Args: - kwargs: List of arguments (key, value)-pairs - """ - super().__init__() - def should_perform_action_xyz(self, control: TrainerControl, **_): """This method performs a set training stop flag action. diff --git a/tests/trainercontroller/custom_operation_invalid_action.py b/tests/trainercontroller/custom_operation_invalid_action.py index 5c04199d3..6871a64fd 100644 --- a/tests/trainercontroller/custom_operation_invalid_action.py +++ b/tests/trainercontroller/custom_operation_invalid_action.py @@ -26,13 +26,6 @@ class CustomOperationInvalidAction(Operation): """Implements a custom operation for testing""" - def __init__(self, **_): - """Initializes the custom operation class. - Args: - kwargs: List of arguments (key, value)-pairs - """ - super().__init__() - def should_(self, control: TrainerControl, **_): """This method defines an action within an invalid name. diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 0e360ad4f..30095b986 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -160,7 +160,7 @@ def train( trainer_controller_args.trainer_controller_config_file is not None ): tc_callback = TrainerControllerCallback( - trainer_controller_args.trainer_controller_config_file + trainer_controller_args.trainer_controller_config_file, ) trainer_callbacks.append(tc_callback) diff --git a/tuning/trainercontroller/callback.py b/tuning/trainercontroller/callback.py index 0ca833051..ebb661b3f 100644 --- a/tuning/trainercontroller/callback.py +++ b/tuning/trainercontroller/callback.py @@ -258,6 +258,7 @@ def _take_control_actions(self, event_name: str, **kwargs): operation_action.instance.act( action=operation_action.action, event_name=event_name, + tc_metrics=self.metrics, **kwargs, ) diff --git a/tuning/trainercontroller/controllermetrics/loss.py b/tuning/trainercontroller/controllermetrics/loss.py index 2fd450148..543d6395c 100644 --- a/tuning/trainercontroller/controllermetrics/loss.py +++ b/tuning/trainercontroller/controllermetrics/loss.py @@ -61,4 +61,4 @@ def compute(self, state: TrainerState = None, **kwargs) -> Any: log = state.log_history[i] if "loss" not in log: continue - return float(log["loss"]) + return log diff --git a/tuning/trainercontroller/operations/hfcontrols.py b/tuning/trainercontroller/operations/hfcontrols.py index 2bba9a1d2..c1f7589e6 100644 --- a/tuning/trainercontroller/operations/hfcontrols.py +++ b/tuning/trainercontroller/operations/hfcontrols.py @@ -29,7 +29,7 @@ def __init__(self, **kwargs): for control_field in fields(TrainerControl): if re.search(r"^should_.+", control_field.name) is not None: setattr(self, control_field.name, self.control_action) - super().__init__() + super().__init__(**kwargs) def control_action(self, control: TrainerControl, **kwargs): """This method peeks into the stack-frame of the caller to get the action the triggered diff --git a/tuning/trainercontroller/operations/operation.py b/tuning/trainercontroller/operations/operation.py index 916420e81..baa220c17 100644 --- a/tuning/trainercontroller/operations/operation.py +++ b/tuning/trainercontroller/operations/operation.py @@ -7,12 +7,14 @@ class Operation(metaclass=abc.ABCMeta): """Base class for operations""" - def __init__(self): + def __init__(self, name: str, **kwargs): """Initializes the HuggingFace controls. In this init, we follow the convention that every action should preceed with prefix `should_`. If so, it is treated as a valid action. """ self.valid_actions = {} + self.name = name + self.kwargs = kwargs for action_name, action_method in inspect.getmembers( self, predicate=inspect.ismethod ):