Skip to content

Commit

Permalink
Set default value for max_steps in _is_epoch_done util
Browse files Browse the repository at this point in the history
Summary:
In evaluate and predict the user does not pass in a max_steps argument so it may seem strange that `state.max_steps` is used here: https://github.com/pytorch/tnt/blob/master/torchtnt/framework/evaluate.py#L143

Removing this for better readability

Differential Revision: D48198032

fbshipit-source-id: 1533ec73d5b3a84a3c7eff6c7daeb9b722fcfbdf
  • Loading branch information
daniellepintz authored and facebook-github-bot committed Aug 9, 2023
1 parent 5150591 commit e052e88
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions tests/framework/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,14 @@ def test_is_epoch_done(self) -> None:
)

self.assertTrue(_is_epoch_done(p, max_steps_per_epoch=5, max_steps=200))
self.assertTrue(_is_epoch_done(p, max_steps_per_epoch=5, max_steps=None))
self.assertTrue(_is_epoch_done(p, max_steps_per_epoch=5))
self.assertTrue(_is_epoch_done(p, max_steps_per_epoch=100, max_steps=100))
self.assertTrue(_is_epoch_done(p, max_steps_per_epoch=None, max_steps=100))

self.assertFalse(_is_epoch_done(p, max_steps_per_epoch=6, max_steps=200))
self.assertFalse(_is_epoch_done(p, max_steps_per_epoch=None, max_steps=200))
self.assertFalse(_is_epoch_done(p, max_steps_per_epoch=6, max_steps=None))
self.assertFalse(_is_epoch_done(p, max_steps_per_epoch=None, max_steps=None))
self.assertFalse(_is_epoch_done(p, max_steps_per_epoch=6))
self.assertFalse(_is_epoch_done(p, max_steps_per_epoch=None))

@patch("torchtnt.framework.utils.record_function")
def test_get_timing_context(self, mock_record_function) -> None:
Expand Down
1 change: 0 additions & 1 deletion torchtnt/framework/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def _evaluate_impl(
or _is_epoch_done(
eval_unit.eval_progress,
eval_state.max_steps_per_epoch,
eval_state.max_steps,
)
):
try:
Expand Down
1 change: 0 additions & 1 deletion torchtnt/framework/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def _predict_impl(
or _is_epoch_done(
predict_unit.predict_progress,
predict_state.max_steps_per_epoch,
predict_state.max_steps,
)
):
try:
Expand Down
4 changes: 3 additions & 1 deletion torchtnt/framework/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def _is_done(


def _is_epoch_done(
progress: Progress, max_steps_per_epoch: Optional[int], max_steps: Optional[int]
progress: Progress,
max_steps_per_epoch: Optional[int],
max_steps: Optional[int] = None,
) -> bool:
return (max_steps is not None and progress.num_steps_completed >= max_steps) or (
max_steps_per_epoch is not None
Expand Down

0 comments on commit e052e88

Please sign in to comment.