diff --git a/tests/plugins/test_double_plugin.py b/tests/plugins/test_double_plugin.py index 26416f1edf420..6ac36de4558ca 100644 --- a/tests/plugins/test_double_plugin.py +++ b/tests/plugins/test_double_plugin.py @@ -125,5 +125,6 @@ def test_double_precision(tmpdir, boring_model): trainer.fit(model) trainer.test(model) trainer.predict(model) + torch.set_grad_enabled(True) # trainer.predict kills gradient assert model.training_step == original_training_step