diff --git a/tests/framework/test_auto_unit.py b/tests/framework/test_auto_unit.py index 59d33929e1..62d402ff4d 100644 --- a/tests/framework/test_auto_unit.py +++ b/tests/framework/test_auto_unit.py @@ -266,6 +266,7 @@ def test_compile_state_dict(self) -> None: """ device = init_from_env() my_module = torch.nn.Linear(2, 2, device=device) + self.assertIsNone(my_module._compiled_call_impl) my_module_state_dict = my_module.state_dict() auto_unit = DummyAutoUnit( module=my_module, @@ -283,94 +284,7 @@ def test_compile_state_dict(self) -> None: self.assertTrue( torch.allclose(my_module_state_dict[k], compiled_state_dict[k]) ) - - @unittest.skipUnless( - condition=COMPILE_AVAIL, - reason="This test needs PyTorch 1.13 or greater to run.", - ) - def test_compile_eager(self) -> None: - """ - e2e torch compile test - """ - - my_module = torch.nn.Linear(2, 2) - - input_dim = 2 - dataset_len = 16 - batch_size = 2 - max_epochs = 1 - - auto_unit = DummyAutoUnit( - module=my_module, - torch_compile_params=TorchCompileParams(backend="eager"), - ) - - train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size) - self.assertFalse(auto_unit._compile_used) - train(auto_unit, train_dl, max_epochs=max_epochs) - self.assertTrue(auto_unit._compile_used) - - @unittest.skipUnless( - condition=COMPILE_AVAIL, - reason="This test needs PyTorch 1.13 or greater to run.", - ) - @unittest.skipUnless( - condition=cuda_available, reason="This test needs a GPU host to run." - ) - def test_compile_train(self) -> None: - """ - e2e torch compile on train - """ - - my_module = torch.nn.Linear(2, 2) - - input_dim = 2 - dataset_len = 16 - batch_size = 2 - max_epochs = 1 - - auto_unit = DummyAutoUnit( - module=my_module, - torch_compile_params=TorchCompileParams(backend="inductor"), - ) - - train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size) - - self.assertFalse(auto_unit._compile_used) - train(auto_unit, train_dl, max_epochs=max_epochs) - self.assertTrue(auto_unit._compile_used) - - @unittest.skipUnless( - condition=COMPILE_AVAIL, - reason="This test needs PyTorch 1.13 or greater to run.", - ) - @unittest.skipUnless( - condition=cuda_available, reason="This test needs a GPU host to run." - ) - def test_compile_eval(self) -> None: - """ - e2e torch compile on eval - """ - - my_module = torch.nn.Linear(2, 2) - - input_dim = 2 - dataset_len = 16 - batch_size = 2 - - auto_unit = DummyAutoUnit( - module=my_module, - torch_compile_params=TorchCompileParams(backend="inductor"), - ) - - input_dim = 2 - dataset_len = 8 - batch_size = 2 - - eval_dl = generate_random_dataloader(dataset_len, input_dim, batch_size) - self.assertFalse(auto_unit._compile_used) - evaluate(auto_unit, eval_dl) - self.assertTrue(auto_unit._compile_used) + self.assertIsNotNone(auto_unit.module._compiled_call_impl) @unittest.skipUnless( condition=COMPILE_AVAIL, @@ -983,16 +897,9 @@ def test_predict_detect_anomaly(self, mock_detect_anomaly) -> None: # pyre-fixme[11]: Annotation `Batch` is not defined as a type. class DummyAutoUnit(AutoUnit[Batch]): - # pyre-fixme[3]: Return type must be annotated. - # pyre-fixme[2]: Parameter must be annotated. - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._compile_used = False # pyre-fixme[3]: Return annotation cannot contain `Any`. def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]: - if COMPILE_AVAIL: - self._compile_used = torch._dynamo.is_compiling() inputs, targets = data outputs = self.module(inputs) loss = torch.nn.functional.cross_entropy(outputs, targets) diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index d30803e62a..b9f79b65d4 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -470,11 +470,6 @@ def __init__( ) if torch_compile_params: - # pyre-ignore - self.compute_loss = torch.compile( - self.compute_loss, - **asdict(torch_compile_params), - ) try: # use in-place compile to avoid altering the state_dict keys module.compile(**asdict(torch_compile_params))