Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix double precision casting complex buffers #8208

Merged
merged 7 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where calling `log` with a `Metric` instance would raise an error if it was a nested attribute of the model ([#8181](https://github.com/PyTorchLightning/pytorch-lightning/pull/8181))


- Fixed a bug where using `precision=64` would cause buffers with complex dtype to be cast to real ([#8208](https://github.com/PyTorchLightning/pytorch-lightning/pull/8208))

## [1.3.7] - 2021-06-22

- Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def connect(
incoming floating point data to double (``torch.float64``) precision. Does not alter `optimizers` or
`lr_schedulers`.
"""
model = cast(pl.LightningModule, model.to(dtype=torch.float64))
model = cast(pl.LightningModule, model.double())
model = LightningDoublePrecisionModule(model)

return super().connect(model, optimizers, lr_schedulers)
Expand Down
24 changes: 23 additions & 1 deletion tests/plugins/test_double_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DoublePrecisionPlugin
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -123,7 +124,28 @@ def predict_dataloader(self):
return DataLoader(RandomDataset(32, 64))


@pytest.mark.parametrize('boring_model', (DoublePrecisionBoringModel, DoublePrecisionBoringModelNoForward))
class DoublePrecisionBoringModelComplexBuffer(BoringModel):

def __init__(self):
super().__init__()

self.register_buffer("complex_buffer", torch.complex(torch.rand(10), torch.rand(10)), False)

def on_fit_start(self):
assert self.layer.weight.dtype == torch.float64
assert self.complex_buffer.dtype == torch.complex64


@pytest.mark.parametrize(
'boring_model', [
DoublePrecisionBoringModel,
DoublePrecisionBoringModelNoForward,
pytest.param(
DoublePrecisionBoringModelComplexBuffer,
marks=pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="torch.complex not available")
),
]
)
def test_double_precision(tmpdir, boring_model):
model = boring_model()

Expand Down