Skip to content

Commit

Permalink
Fix truncated_bptt_steps hiddens detach() and improve docs (#8145)
Browse files Browse the repository at this point in the history
* Fix truncated_bptt_steps hiddens detach()
* Improve truncated_bptt_docs
* Add missing import
* Improve documentation wordings
* pep8
* detach typo
* Update test
* Implement comments
* parametrize test
* Apply suggestions from code review

Co-authored-by: Carlos Mocholí <[email protected]>

Signed-off-by: Guillaume Tauzin <[email protected]>

* Remove import

Signed-off-by: Guillaume Tauzin <[email protected]>

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
4 people authored Jul 1, 2021
1 parent 8b0aec8 commit baa7de2
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 12 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where using `precision=64` would cause buffers with complex dtype to be cast to real ([#8208](https://github.com/PyTorchLightning/pytorch-lightning/pull/8208))


- Fixed a bug where `truncated_bptt_steps` would throw an AttributeError when the target RNN has multiple hidden states ([#8145](https://github.com/PyTorchLightning/pytorch-lightning/pull/8145))


## [1.3.7] - 2021-06-22

- Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975))
Expand Down
29 changes: 25 additions & 4 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,11 @@ truncated_bptt_steps
^^^^^^^^^^^^^^^^^^^^

Truncated back prop breaks performs backprop every k steps of
a much longer sequence.
a much longer sequence. This is made possible by passing training batches
splitted along the time-dimensions into splits of size k to the
``training_step``. In order to keep the same forward propagation behavior, all
hidden states should be kept in-between each time-dimension split.


If this is enabled, your batches will automatically get truncated
and the trainer will apply Truncated Backprop to it.
Expand All @@ -1080,23 +1084,40 @@ recurrent network trajectories."

class MyModel(LightningModule):

def __init__(self):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
# batch_first has to be set to True
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
)

...

# Important: This property activates truncated backpropagation through time
# Setting this value to 2 splits the batch into sequences of size 2
self.truncated_bptt_steps = 2

# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
x, y = batch

# the training step must be updated to accept a ``hiddens`` argument
# hiddens are the hiddens from the previous truncated backprop step
out, (hiddens, _) = self.lstm(data, hiddens)
out, hiddens = self.lstm(x, hiddens)

...

return {
"loss": ...,
"hiddens": hiddens
}

Lightning takes care to split your batch along the time-dimension.
Lightning takes care of splitting your batch along the time-dimension. It is
assumed to be the second dimension of your batches. Therefore, in the
example above we have set ``batch_first=True``.

.. code-block:: python
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,7 @@ def _process_training_step_output(self, training_step_output: STEP_OUTPUT) -> Op
if isinstance(training_step_output, dict):
loss = training_step_output.pop("loss", None)
hiddens = training_step_output.pop("hiddens", None)
if hiddens is not None:
hiddens = hiddens.detach()

results.extra = training_step_output

# handle scalar return
Expand Down
22 changes: 16 additions & 6 deletions tests/models/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import os

import pytest
import torch

import tests.helpers.pipelines as tpipes
Expand Down Expand Up @@ -322,7 +323,8 @@ def test_all_features_cpu_model(tmpdir):
tpipes.run_model_test(trainer_options, model, on_gpu=False, min_acc=0.01)


def test_tbptt_cpu_model(tmpdir):
@pytest.mark.parametrize("n_hidden_states", [1, 2])
def test_tbptt_cpu_model(tmpdir, n_hidden_states):
"""Test truncated back propagation through time works."""
truncated_bptt_steps = 2
sequence_size = 30
Expand All @@ -341,15 +343,19 @@ def __len__(self):

class BpttTestModel(BoringModel):

def __init__(self, batch_size, in_features, out_features, *args, **kwargs):
def __init__(self, batch_size, in_features, out_features, n_hidden_states, *args, **kwargs):
super().__init__(*args, **kwargs)
self.test_hidden = None
self.batch_size = batch_size
self.layer = torch.nn.Linear(in_features, out_features)
self.n_hidden_states = n_hidden_states

def training_step(self, batch, batch_idx, hiddens):
assert hiddens == self.test_hidden, "Hidden state not persistent between tbptt steps"
self.test_hidden = torch.rand(1)
if self.n_hidden_states == 1:
self.test_hidden = torch.rand(1)
else:
self.test_hidden = tuple([torch.rand(1)] * self.n_hidden_states)

x_tensor, y_list = batch
assert x_tensor.shape[1] == truncated_bptt_steps, "tbptt split Tensor failed"
Expand Down Expand Up @@ -378,7 +384,12 @@ def train_dataloader(self):
sampler=None,
)

model = BpttTestModel(batch_size=batch_size, in_features=truncated_bptt_steps, out_features=truncated_bptt_steps)
model = BpttTestModel(
batch_size=batch_size,
in_features=truncated_bptt_steps,
out_features=truncated_bptt_steps,
n_hidden_states=n_hidden_states
)
model.example_input_array = torch.randn(5, truncated_bptt_steps)

# fit model
Expand All @@ -390,5 +401,4 @@ def train_dataloader(self):
weights_summary=None,
)
trainer.fit(model)

assert trainer.state.finished, f"Training failed with {trainer.state}"
assert trainer.state.finished, f"Training model with `{n_hidden_states}` hidden state failed with {trainer.state}"

0 comments on commit baa7de2

Please sign in to comment.