From e052e8843b5103beaa8620bd571fc161cfe51f3d Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Wed, 9 Aug 2023 10:45:54 -0700 Subject: [PATCH] Set default value for max_steps in _is_epoch_done util Summary: In evaluate and predict the user does not pass in a max_steps argument so it may seem strange that `state.max_steps` is used here: https://github.com/pytorch/tnt/blob/master/torchtnt/framework/evaluate.py#L143 Removing this for better readability Differential Revision: D48198032 fbshipit-source-id: 1533ec73d5b3a84a3c7eff6c7daeb9b722fcfbdf --- tests/framework/test_utils.py | 6 +++--- torchtnt/framework/evaluate.py | 1 - torchtnt/framework/predict.py | 1 - torchtnt/framework/utils.py | 4 +++- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/framework/test_utils.py b/tests/framework/test_utils.py index 028e2dc617..2569d15ec0 100644 --- a/tests/framework/test_utils.py +++ b/tests/framework/test_utils.py @@ -160,14 +160,14 @@ def test_is_epoch_done(self) -> None: ) self.assertTrue(_is_epoch_done(p, max_steps_per_epoch=5, max_steps=200)) - self.assertTrue(_is_epoch_done(p, max_steps_per_epoch=5, max_steps=None)) + self.assertTrue(_is_epoch_done(p, max_steps_per_epoch=5)) self.assertTrue(_is_epoch_done(p, max_steps_per_epoch=100, max_steps=100)) self.assertTrue(_is_epoch_done(p, max_steps_per_epoch=None, max_steps=100)) self.assertFalse(_is_epoch_done(p, max_steps_per_epoch=6, max_steps=200)) self.assertFalse(_is_epoch_done(p, max_steps_per_epoch=None, max_steps=200)) - self.assertFalse(_is_epoch_done(p, max_steps_per_epoch=6, max_steps=None)) - self.assertFalse(_is_epoch_done(p, max_steps_per_epoch=None, max_steps=None)) + self.assertFalse(_is_epoch_done(p, max_steps_per_epoch=6)) + self.assertFalse(_is_epoch_done(p, max_steps_per_epoch=None)) @patch("torchtnt.framework.utils.record_function") def test_get_timing_context(self, mock_record_function) -> None: diff --git a/torchtnt/framework/evaluate.py b/torchtnt/framework/evaluate.py index 97d9160a56..789b45d1c5 100644 --- a/torchtnt/framework/evaluate.py +++ b/torchtnt/framework/evaluate.py @@ -140,7 +140,6 @@ def _evaluate_impl( or _is_epoch_done( eval_unit.eval_progress, eval_state.max_steps_per_epoch, - eval_state.max_steps, ) ): try: diff --git a/torchtnt/framework/predict.py b/torchtnt/framework/predict.py index 8ee0096a5f..24e45e965c 100644 --- a/torchtnt/framework/predict.py +++ b/torchtnt/framework/predict.py @@ -146,7 +146,6 @@ def _predict_impl( or _is_epoch_done( predict_unit.predict_progress, predict_state.max_steps_per_epoch, - predict_state.max_steps, ) ): try: diff --git a/torchtnt/framework/utils.py b/torchtnt/framework/utils.py index 6de4a7f7d5..041912fe39 100644 --- a/torchtnt/framework/utils.py +++ b/torchtnt/framework/utils.py @@ -34,7 +34,9 @@ def _is_done( def _is_epoch_done( - progress: Progress, max_steps_per_epoch: Optional[int], max_steps: Optional[int] + progress: Progress, + max_steps_per_epoch: Optional[int], + max_steps: Optional[int] = None, ) -> bool: return (max_steps is not None and progress.num_steps_completed >= max_steps) or ( max_steps_per_epoch is not None