Skip to content

Commit

Permalink
add AutoUnitMixin (pytorch#571)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#571

Add AutoUnitMixin which handles prefetching and initialization of units and use in AutoUnit and AutoPredictUnit

Reviewed By: JKSenthil

Differential Revision: D49852040

fbshipit-source-id: 4e2b433d0e091c243189ff9528ea1ce4e95cd72c
  • Loading branch information
galrotem authored and facebook-github-bot committed Oct 26, 2023
1 parent a748cb1 commit 2794b9a
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 228 deletions.
8 changes: 4 additions & 4 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def test_is_last_batch(self) -> None:

my_unit = DummyAutoUnit(module=my_module)
train(my_unit, dataloader, max_epochs=1, max_steps_per_epoch=4)
self.assertFalse(my_unit._is_last_train_batch)
self.assertFalse(my_unit._is_last_batch)

def test_auto_unit_timing_train(self) -> None:
"""
Expand Down Expand Up @@ -844,12 +844,12 @@ def test_get_next_batch_with_single_phase(self) -> None:
batch = auto_unit._get_next_batch(state, second_data_iter)
self.assertEqual(batch, 3)
self._assert_next_batch_dicts(auto_unit, train_prefetched=True)
self.assertTrue(auto_unit._is_last_train_batch)
self.assertTrue(auto_unit._is_last_batch)

with move_data_to_device_mock, self.assertRaises(StopIteration):
auto_unit._get_next_batch(state, second_data_iter)
self._assert_next_batch_dicts(auto_unit)
self.assertFalse(auto_unit._is_last_train_batch)
self.assertFalse(auto_unit._is_last_batch)

def test_get_next_batch_with_multiple_phases(self) -> None:
auto_unit = DummyAutoUnit(module=torch.nn.Linear(2, 2))
Expand Down Expand Up @@ -990,7 +990,7 @@ def compute_loss(
) -> Tuple[torch.Tensor, torch.Tensor]:
tc = unittest.TestCase()
tc.assertEqual(
self._is_last_train_batch,
self._is_last_batch,
self.train_progress.num_steps_completed_in_epoch + 1
== self.expected_steps_per_epoch,
)
Expand Down
Loading

0 comments on commit 2794b9a

Please sign in to comment.