diff --git a/CHANGELOG.md b/CHANGELOG.md index 77086c3c342f6..f2c9bcb726a79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -358,6 +358,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where using `precision=64` would cause buffers with complex dtype to be cast to real ([#8208](https://github.com/PyTorchLightning/pytorch-lightning/pull/8208)) + +- Fixed a bug where `truncated_bptt_steps` would throw an AttributeError when the target RNN has multiple hidden states ([#8145](https://github.com/PyTorchLightning/pytorch-lightning/pull/8145)) + + ## [1.3.7] - 2021-06-22 - Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975)) diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index f08da426e892b..6043eab649ebf 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -1063,7 +1063,11 @@ truncated_bptt_steps ^^^^^^^^^^^^^^^^^^^^ Truncated back prop breaks performs backprop every k steps of -a much longer sequence. +a much longer sequence. This is made possible by passing training batches +splitted along the time-dimensions into splits of size k to the +``training_step``. In order to keep the same forward propagation behavior, all +hidden states should be kept in-between each time-dimension split. + If this is enabled, your batches will automatically get truncated and the trainer will apply Truncated Backprop to it. @@ -1080,23 +1084,40 @@ recurrent network trajectories." class MyModel(LightningModule): - def __init__(self): + def __init__(self, input_size, hidden_size, num_layers): super().__init__() + # batch_first has to be set to True + self.lstm = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + ) + + ... + # Important: This property activates truncated backpropagation through time # Setting this value to 2 splits the batch into sequences of size 2 self.truncated_bptt_steps = 2 # Truncated back-propagation through time def training_step(self, batch, batch_idx, hiddens): + x, y = batch + # the training step must be updated to accept a ``hiddens`` argument # hiddens are the hiddens from the previous truncated backprop step - out, (hiddens, _) = self.lstm(data, hiddens) + out, hiddens = self.lstm(x, hiddens) + + ... + return { "loss": ..., "hiddens": hiddens } -Lightning takes care to split your batch along the time-dimension. +Lightning takes care of splitting your batch along the time-dimension. It is +assumed to be the second dimension of your batches. Therefore, in the +example above we have set ``batch_first=True``. .. code-block:: python diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 76051fc3f1e94..e261a61a366fb 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -345,8 +345,7 @@ def _process_training_step_output(self, training_step_output: STEP_OUTPUT) -> Op if isinstance(training_step_output, dict): loss = training_step_output.pop("loss", None) hiddens = training_step_output.pop("hiddens", None) - if hiddens is not None: - hiddens = hiddens.detach() + results.extra = training_step_output # handle scalar return diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index b54e0d091bd16..84721fe8b575c 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import pytest import torch import tests.helpers.pipelines as tpipes @@ -322,7 +323,8 @@ def test_all_features_cpu_model(tmpdir): tpipes.run_model_test(trainer_options, model, on_gpu=False, min_acc=0.01) -def test_tbptt_cpu_model(tmpdir): +@pytest.mark.parametrize("n_hidden_states", [1, 2]) +def test_tbptt_cpu_model(tmpdir, n_hidden_states): """Test truncated back propagation through time works.""" truncated_bptt_steps = 2 sequence_size = 30 @@ -341,15 +343,19 @@ def __len__(self): class BpttTestModel(BoringModel): - def __init__(self, batch_size, in_features, out_features, *args, **kwargs): + def __init__(self, batch_size, in_features, out_features, n_hidden_states, *args, **kwargs): super().__init__(*args, **kwargs) self.test_hidden = None self.batch_size = batch_size self.layer = torch.nn.Linear(in_features, out_features) + self.n_hidden_states = n_hidden_states def training_step(self, batch, batch_idx, hiddens): assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps" - self.test_hidden = torch.rand(1) + if self.n_hidden_states == 1: + self.test_hidden = torch.rand(1) + else: + self.test_hidden = tuple([torch.rand(1)] * self.n_hidden_states) x_tensor, y_list = batch assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed" @@ -378,7 +384,12 @@ def train_dataloader(self): sampler=None, ) - model = BpttTestModel(batch_size=batch_size, in_features=truncated_bptt_steps, out_features=truncated_bptt_steps) + model = BpttTestModel( + batch_size=batch_size, + in_features=truncated_bptt_steps, + out_features=truncated_bptt_steps, + n_hidden_states=n_hidden_states + ) model.example_input_array = torch.randn(5, truncated_bptt_steps) # fit model @@ -390,5 +401,4 @@ def train_dataloader(self): weights_summary=None, ) trainer.fit(model) - - assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}"