From afa2a9cef0e18b1d2c7ced4dabd6f78a13c5e4d1 Mon Sep 17 00:00:00 2001 From: Gal Rotem Date: Tue, 22 Aug 2023 19:30:22 -0700 Subject: [PATCH] avoid compiling compute_loss (#511) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/511 # Context As per discussion in D48361308, looks like we can remove this since it's only used for torch dynamo but it's not officially supported # This diff Separating the change from the above since we also need to clean up some existing UTs which don't make sense anymore. Going to rebase the above change on top of this Differential Revision: D48581289 fbshipit-source-id: 272c8d9e8a534b6f148b06bfa1797ea81b3bcd8a --- tests/framework/test_auto_unit.py | 97 +------------------------------ torchtnt/framework/auto_unit.py | 5 -- 2 files changed, 2 insertions(+), 100 deletions(-) diff --git a/tests/framework/test_auto_unit.py b/tests/framework/test_auto_unit.py index 59d33929e1..62d402ff4d 100644 --- a/tests/framework/test_auto_unit.py +++ b/tests/framework/test_auto_unit.py @@ -266,6 +266,7 @@ def test_compile_state_dict(self) -> None: """ device = init_from_env() my_module = torch.nn.Linear(2, 2, device=device) + self.assertIsNone(my_module._compiled_call_impl) my_module_state_dict = my_module.state_dict() auto_unit = DummyAutoUnit( module=my_module, @@ -283,94 +284,7 @@ def test_compile_state_dict(self) -> None: self.assertTrue( torch.allclose(my_module_state_dict[k], compiled_state_dict[k]) ) - - @unittest.skipUnless( - condition=COMPILE_AVAIL, - reason="This test needs PyTorch 1.13 or greater to run.", - ) - def test_compile_eager(self) -> None: - """ - e2e torch compile test - """ - - my_module = torch.nn.Linear(2, 2) - - input_dim = 2 - dataset_len = 16 - batch_size = 2 - max_epochs = 1 - - auto_unit = DummyAutoUnit( - module=my_module, - torch_compile_params=TorchCompileParams(backend="eager"), - ) - - train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size) - self.assertFalse(auto_unit._compile_used) - train(auto_unit, train_dl, max_epochs=max_epochs) - self.assertTrue(auto_unit._compile_used) - - @unittest.skipUnless( - condition=COMPILE_AVAIL, - reason="This test needs PyTorch 1.13 or greater to run.", - ) - @unittest.skipUnless( - condition=cuda_available, reason="This test needs a GPU host to run." - ) - def test_compile_train(self) -> None: - """ - e2e torch compile on train - """ - - my_module = torch.nn.Linear(2, 2) - - input_dim = 2 - dataset_len = 16 - batch_size = 2 - max_epochs = 1 - - auto_unit = DummyAutoUnit( - module=my_module, - torch_compile_params=TorchCompileParams(backend="inductor"), - ) - - train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size) - - self.assertFalse(auto_unit._compile_used) - train(auto_unit, train_dl, max_epochs=max_epochs) - self.assertTrue(auto_unit._compile_used) - - @unittest.skipUnless( - condition=COMPILE_AVAIL, - reason="This test needs PyTorch 1.13 or greater to run.", - ) - @unittest.skipUnless( - condition=cuda_available, reason="This test needs a GPU host to run." - ) - def test_compile_eval(self) -> None: - """ - e2e torch compile on eval - """ - - my_module = torch.nn.Linear(2, 2) - - input_dim = 2 - dataset_len = 16 - batch_size = 2 - - auto_unit = DummyAutoUnit( - module=my_module, - torch_compile_params=TorchCompileParams(backend="inductor"), - ) - - input_dim = 2 - dataset_len = 8 - batch_size = 2 - - eval_dl = generate_random_dataloader(dataset_len, input_dim, batch_size) - self.assertFalse(auto_unit._compile_used) - evaluate(auto_unit, eval_dl) - self.assertTrue(auto_unit._compile_used) + self.assertIsNotNone(auto_unit.module._compiled_call_impl) @unittest.skipUnless( condition=COMPILE_AVAIL, @@ -983,16 +897,9 @@ def test_predict_detect_anomaly(self, mock_detect_anomaly) -> None: # pyre-fixme[11]: Annotation `Batch` is not defined as a type. class DummyAutoUnit(AutoUnit[Batch]): - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._compile_used = False # pyre-fixme[3]: Return annotation cannot contain `Any`. def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]: - if COMPILE_AVAIL: - self._compile_used = torch._dynamo.is_compiling() inputs, targets = data outputs = self.module(inputs) loss = torch.nn.functional.cross_entropy(outputs, targets) diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index d30803e62a..b9f79b65d4 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -470,11 +470,6 @@ def __init__( ) if torch_compile_params: - # pyre-ignore - self.compute_loss = torch.compile( - self.compute_loss, - **asdict(torch_compile_params), - ) try: # use in-place compile to avoid altering the state_dict keys module.compile(**asdict(torch_compile_params))