From ef27459f5bba703f3d79209fcf79c78aedb78921 Mon Sep 17 00:00:00 2001 From: Danielle Pintz Date: Tue, 20 Jun 2023 13:05:27 -0700 Subject: [PATCH] Separate grad scaler test out from test_app_state_mixin Summary: Currently `test_app_state_mixin` is failing on OSS CI: ``` self = def test_app_state_mixin(self) -> None: """ Test that app_state, tracked_optimizers, tracked_lr_schedulers are set as expected with AutoUnit """ my_module = torch.nn.Linear(2, 2) auto_unit = DummyAutoUnit( module=my_module, precision="fp16", ) self.assertEqual(auto_unit.tracked_modules()["module"], my_module) self.assertTrue( isinstance( auto_unit.tracked_misc_statefuls()["grad_scaler"], torch.cuda.amp.GradScaler, ) ) for key in ("module", "optimizer", "lr_scheduler", "grad_scaler"): > self.assertTrue(key in auto_unit.app_state()) E AssertionError: False is not true tests/framework/test_auto_unit.py:69: AssertionError ``` https://github.com/pytorch/tnt/actions/runs/5321328919/jobs/9636295932 Differential Revision: D46870935 fbshipit-source-id: f7908384d6f6924f6d9b98984330bd01a0e4b2bb --- tests/framework/test_auto_unit.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/tests/framework/test_auto_unit.py b/tests/framework/test_auto_unit.py index ec904744ff..a973171979 100644 --- a/tests/framework/test_auto_unit.py +++ b/tests/framework/test_auto_unit.py @@ -56,18 +56,33 @@ def test_app_state_mixin(self) -> None: auto_unit = DummyAutoUnit( module=my_module, - precision="fp16", ) self.assertEqual(auto_unit.tracked_modules()["module"], my_module) + for key in ("module", "optimizer", "lr_scheduler"): + self.assertIn(key, auto_unit.app_state()) + + @unittest.skipUnless( + condition=cuda_available, reason="This test needs a GPU host to run." + ) + def test_app_state_mixin_grad_scaler(self) -> None: + """ + Test that grad_scaler is added to the AutoUnit tracked_misc_statefuls when using fp16 precision + """ + my_module = torch.nn.Linear(2, 2) + + auto_unit = DummyAutoUnit( + module=my_module, + precision="fp16", + ) + self.assertTrue( isinstance( auto_unit.tracked_misc_statefuls()["grad_scaler"], torch.cuda.amp.GradScaler, ) ) - for key in ("module", "optimizer", "lr_scheduler", "grad_scaler"): - self.assertIn(key, auto_unit.app_state()) + self.assertIn("grad_scaler", auto_unit.app_state()) def test_lr_scheduler_step(self) -> None: """