diff --git a/tests/framework/test_auto_unit.py b/tests/framework/test_auto_unit.py index 3498746268..4ac2e93ba4 100644 --- a/tests/framework/test_auto_unit.py +++ b/tests/framework/test_auto_unit.py @@ -1095,25 +1095,25 @@ def on_train_step_end( # train_step should not be in the timer's recorded_durations because it overlaps with other timings in the AutoUnit's train_step tc.assertNotIn("TimingAutoUnit.train_step", recorded_timer_keys) - def on_eval_step_end( - self, state: State, data: Batch, step: int, loss: torch.Tensor, outputs: Any - ) -> None: - if self.eval_progress.num_steps_completed_in_epoch == 1: - tc = unittest.TestCase() - recorded_timer_keys = state.timer.recorded_durations.keys() - for k in ( - "TimingAutoUnit.on_eval_start", - "TimingAutoUnit.on_eval_epoch_start", - "evaluate.iter(dataloader)", - "evaluate.next(data_iter)", - "TimingAutoUnit.move_data_to_device", - "TimingAutoUnit.compute_loss", - "TimingAutoUnit.on_eval_step_end", - ): - tc.assertIn(k, recorded_timer_keys) - - # eval_step should not be in the timer's recorded_durations because it overlaps with other timings in the AutoUnit's eval_step - tc.assertNotIn("TimingAutoUnit.eval_step", recorded_timer_keys) + def on_eval_step_end( + self, state: State, data: Batch, step: int, loss: torch.Tensor, outputs: Any + ) -> None: + if self.eval_progress.num_steps_completed_in_epoch == 1: + tc = unittest.TestCase() + recorded_timer_keys = state.timer.recorded_durations.keys() + for k in ( + "TimingAutoUnit.on_eval_start", + "TimingAutoUnit.on_eval_epoch_start", + "evaluate.iter(dataloader)", + "evaluate.next(data_iter)", + "TimingAutoUnit.move_data_to_device", + "TimingAutoUnit.compute_loss", + "TimingAutoUnit.on_eval_step_end", + ): + tc.assertIn(k, recorded_timer_keys) + + # eval_step should not be in the timer's recorded_durations because it overlaps with other timings in the AutoUnit's eval_step + tc.assertNotIn("TimingAutoUnit.eval_step", recorded_timer_keys) class TimingAutoPredictUnit(AutoPredictUnit[Batch]):