diff --git a/tuning/trainercontroller/callback.py b/tuning/trainercontroller/callback.py index f6c6d9061..fad1bbf70 100644 --- a/tuning/trainercontroller/callback.py +++ b/tuning/trainercontroller/callback.py @@ -582,3 +582,45 @@ def on_save( kwargs["state"] = state kwargs["control"] = control self._actions_on_event(event_name="on_save", **kwargs) + + def on_step_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Training arguments, state and controls are folded into kwargs to be passed off to + # handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + self._actions_on_event(event_name="on_step_begin", **kwargs) + + def on_optimizer_step( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Training arguments, state and controls are folded into kwargs to be passed off to + # handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + self._actions_on_event(event_name="on_optimizer_step", **kwargs) + + def on_substep_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + # Training arguments, state and controls are folded into kwargs to be passed off to + # handlers + kwargs["args"] = args + kwargs["state"] = state + kwargs["control"] = control + self._actions_on_event(event_name="on_substep_end", **kwargs)