From 9219195febdb76ff55ef4738f99b7615861f9898 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 29 Jun 2021 22:15:30 +0100 Subject: [PATCH 1/5] Fix double precision casting complex buffers --- pytorch_lightning/plugins/precision/double.py | 2 +- tests/plugins/test_double_plugin.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/double.py b/pytorch_lightning/plugins/precision/double.py index 86177c5500e2f..064c65b500f29 100644 --- a/pytorch_lightning/plugins/precision/double.py +++ b/pytorch_lightning/plugins/precision/double.py @@ -96,7 +96,7 @@ def connect( incoming floating point data to double (``torch.float64``) precision. Does not alter `optimizers` or `lr_schedulers`. """ - model = cast(pl.LightningModule, model.to(dtype=torch.float64)) + model = cast(pl.LightningModule, model.double()) model = LightningDoublePrecisionModule(model) return super().connect(model, optimizers, lr_schedulers) diff --git a/tests/plugins/test_double_plugin.py b/tests/plugins/test_double_plugin.py index be4f690f25ed6..d1fd7bebe6d2b 100644 --- a/tests/plugins/test_double_plugin.py +++ b/tests/plugins/test_double_plugin.py @@ -40,6 +40,11 @@ def __len__(self): class DoublePrecisionBoringModel(BoringModel): + def __init__(self): + super().__init__() + + self.register_buffer("complex_buffer", torch.complex(torch.rand(10), torch.rand(10)), False) + def training_step(self, batch, batch_idx): float_data, int_data = batch assert torch.tensor([0.]).dtype == torch.float64 @@ -77,9 +82,11 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): def on_fit_start(self): assert self.layer.weight.dtype == torch.float64 + assert self.complex_buffer.dtype == torch.complex64 def on_after_backward(self): assert self.layer.weight.grad.dtype == torch.float64 + assert self.complex_buffer.dtype == torch.complex64 def train_dataloader(self): dataset = RandomFloatIntDataset(32, 64) From 300268a73f15f10e127b4f23237139071ce5cf54 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 29 Jun 2021 22:18:09 +0100 Subject: [PATCH 2/5] Update CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 34eaf6e359257..5bc354cb465a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -342,6 +342,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where calling `log` with a `Metric` instance would raise an error if it was a nested attribute of the model ([#8181](https://github.com/PyTorchLightning/pytorch-lightning/pull/8181)) + +- Fixed a bug where using `precision=64` would cause buffers with complex dtype to be cast to real ([#8208](https://github.com/PyTorchLightning/pytorch-lightning/pull/8208)) + ## [1.3.7] - 2021-06-22 - Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975)) From c7111b652c621940fec77b2e982ef7376d554980 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 30 Jun 2021 07:55:26 +0100 Subject: [PATCH 3/5] Fixes --- tests/plugins/test_double_plugin.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/tests/plugins/test_double_plugin.py b/tests/plugins/test_double_plugin.py index d1fd7bebe6d2b..cdd0447279713 100644 --- a/tests/plugins/test_double_plugin.py +++ b/tests/plugins/test_double_plugin.py @@ -40,11 +40,6 @@ def __len__(self): class DoublePrecisionBoringModel(BoringModel): - def __init__(self): - super().__init__() - - self.register_buffer("complex_buffer", torch.complex(torch.rand(10), torch.rand(10)), False) - def training_step(self, batch, batch_idx): float_data, int_data = batch assert torch.tensor([0.]).dtype == torch.float64 @@ -82,11 +77,9 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): def on_fit_start(self): assert self.layer.weight.dtype == torch.float64 - assert self.complex_buffer.dtype == torch.complex64 def on_after_backward(self): assert self.layer.weight.grad.dtype == torch.float64 - assert self.complex_buffer.dtype == torch.complex64 def train_dataloader(self): dataset = RandomFloatIntDataset(32, 64) @@ -130,7 +123,25 @@ def predict_dataloader(self): return DataLoader(RandomDataset(32, 64)) -@pytest.mark.parametrize('boring_model', (DoublePrecisionBoringModel, DoublePrecisionBoringModelNoForward)) +class DoublePrecisionBoringModelComplexBuffer(BoringModel): + + def __init__(self): + super().__init__() + + self.register_buffer("complex_buffer", torch.complex(torch.rand(10), torch.rand(10)), False) + + def on_fit_start(self): + assert self.layer.weight.dtype == torch.float64 + assert self.complex_buffer.dtype == torch.complex64 + + +@pytest.mark.parametrize( + 'boring_model', ( + DoublePrecisionBoringModel, + DoublePrecisionBoringModelNoForward, + DoublePrecisionBoringModelComplexBuffer, + ) +) def test_double_precision(tmpdir, boring_model): model = boring_model() From 7d19a04c8e1b6c953f63e9b4a57bc48570c5f676 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 30 Jun 2021 08:02:20 +0100 Subject: [PATCH 4/5] Fixes --- tests/plugins/test_double_plugin.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/plugins/test_double_plugin.py b/tests/plugins/test_double_plugin.py index cdd0447279713..3ba769c9dd6e3 100644 --- a/tests/plugins/test_double_plugin.py +++ b/tests/plugins/test_double_plugin.py @@ -20,6 +20,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import DoublePrecisionPlugin +from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -136,11 +137,14 @@ def on_fit_start(self): @pytest.mark.parametrize( - 'boring_model', ( + 'boring_model', [ DoublePrecisionBoringModel, DoublePrecisionBoringModelNoForward, - DoublePrecisionBoringModelComplexBuffer, - ) + pytest.param( + DoublePrecisionBoringModelComplexBuffer, + marks=pytest.mark.skipif(_TORCH_GREATER_EQUAL_1_7, reason="`torch.complex` not available") + ), + ] ) def test_double_precision(tmpdir, boring_model): model = boring_model() From dbc6ff9b39416f7186cb8c94e9abc3ba89b5bd5e Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 30 Jun 2021 08:16:17 +0100 Subject: [PATCH 5/5] Fix --- tests/plugins/test_double_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_double_plugin.py b/tests/plugins/test_double_plugin.py index 3ba769c9dd6e3..302ee985b2379 100644 --- a/tests/plugins/test_double_plugin.py +++ b/tests/plugins/test_double_plugin.py @@ -142,7 +142,7 @@ def on_fit_start(self): DoublePrecisionBoringModelNoForward, pytest.param( DoublePrecisionBoringModelComplexBuffer, - marks=pytest.mark.skipif(_TORCH_GREATER_EQUAL_1_7, reason="`torch.complex` not available") + marks=pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="torch.complex not available") ), ] )