diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 54517d2b55729..55825d311030e 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -14,7 +14,7 @@ from tests.helpers.boring_model import BoringModel -def test_lightning_module_base_wrapper(tmpdir): +def test_deepspeed_lightning_module(tmpdir): """ Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves types and device correctly. """ @@ -26,6 +26,11 @@ def test_lightning_module_base_wrapper(tmpdir): assert module.dtype == torch.half assert model.dtype == torch.half + x = torch.randn((1, 32), dtype=torch.float) + out = module(x) + + assert out.dtype == torch.half + module.to(torch.double) assert module.dtype == torch.double assert model.dtype == torch.double