From 2bcc5caa099dd129b2fec3dc28709e4f1ca463ce Mon Sep 17 00:00:00 2001 From: Gal Rotem Date: Fri, 18 Aug 2023 21:29:44 -0700 Subject: [PATCH] fix indent bug in test_auto_unit (#507) Summary: bypass-github-export-checks Pull Request resolved: https://github.com/pytorch/tnt/pull/507 # Context I'm working to fix torchtnt pyre targets (coming up in next change) and discovered this bug # This diff Fix redundant indentation Reviewed By: JKSenthil Differential Revision: D48478812 fbshipit-source-id: 1dc2a28a3930331141ee9d96d9f479334eb4256f --- tests/framework/test_auto_unit.py | 38 +++++++++++++++---------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/tests/framework/test_auto_unit.py b/tests/framework/test_auto_unit.py index 3498746268..4ac2e93ba4 100644 --- a/tests/framework/test_auto_unit.py +++ b/tests/framework/test_auto_unit.py @@ -1095,25 +1095,25 @@ def on_train_step_end( # train_step should not be in the timer's recorded_durations because it overlaps with other timings in the AutoUnit's train_step tc.assertNotIn("TimingAutoUnit.train_step", recorded_timer_keys) - def on_eval_step_end( - self, state: State, data: Batch, step: int, loss: torch.Tensor, outputs: Any - ) -> None: - if self.eval_progress.num_steps_completed_in_epoch == 1: - tc = unittest.TestCase() - recorded_timer_keys = state.timer.recorded_durations.keys() - for k in ( - "TimingAutoUnit.on_eval_start", - "TimingAutoUnit.on_eval_epoch_start", - "evaluate.iter(dataloader)", - "evaluate.next(data_iter)", - "TimingAutoUnit.move_data_to_device", - "TimingAutoUnit.compute_loss", - "TimingAutoUnit.on_eval_step_end", - ): - tc.assertIn(k, recorded_timer_keys) - - # eval_step should not be in the timer's recorded_durations because it overlaps with other timings in the AutoUnit's eval_step - tc.assertNotIn("TimingAutoUnit.eval_step", recorded_timer_keys) + def on_eval_step_end( + self, state: State, data: Batch, step: int, loss: torch.Tensor, outputs: Any + ) -> None: + if self.eval_progress.num_steps_completed_in_epoch == 1: + tc = unittest.TestCase() + recorded_timer_keys = state.timer.recorded_durations.keys() + for k in ( + "TimingAutoUnit.on_eval_start", + "TimingAutoUnit.on_eval_epoch_start", + "evaluate.iter(dataloader)", + "evaluate.next(data_iter)", + "TimingAutoUnit.move_data_to_device", + "TimingAutoUnit.compute_loss", + "TimingAutoUnit.on_eval_step_end", + ): + tc.assertIn(k, recorded_timer_keys) + + # eval_step should not be in the timer's recorded_durations because it overlaps with other timings in the AutoUnit's eval_step + tc.assertNotIn("TimingAutoUnit.eval_step", recorded_timer_keys) class TimingAutoPredictUnit(AutoPredictUnit[Batch]):