-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deprecatetruncated_bptt_steps
flag on Trainer in favor of same setting on the LightningModule
#7323
Conversation
Codecov Report
@@ Coverage Diff @@
## master #7323 +/- ##
======================================
- Coverage 87% 87% -1%
======================================
Files 200 200
Lines 12895 12946 +51
======================================
- Hits 11259 11237 -22
- Misses 1636 1709 +73 |
Hello @ananthsub! Thanks for updating this PR. There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2021-05-04 16:34:32 UTC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we have any docs example where the trainer flag use needs to be updated? Did you try grepping?
@carmocca updated the docs to reflect setting the property on the LightningModule instead of passing a trainer flag. Anything else that should be done for the trainer docs? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My grep powers still show a few more occurrences which we should update:
docs/source/common/trainer.rst:1430:truncated_bptt_steps
docs/source/common/trainer.rst:1436: poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/truncated_bptt_steps.jpg"
docs/source/common/trainer.rst:1437: src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/truncated_bptt_steps.mp4"></video>
docs/source/common/trainer.rst:1454: trainer = Trainer(truncated_bptt_steps=None)
docs/source/common/trainer.rst:1457: trainer = Trainer(truncated_bptt_steps=5)
docs/source/starter/new-project.rst:750:- :ref:`Automatic truncated-back-propagation-through-time <common/trainer:truncated_bptt_steps>`
So we should move these into lightning_module.rst
as we do for automatic_optimization
.
Not sure if we should keep the video though
cc: @edenlightning ?
.. testcode:: python | ||
|
||
from pytorch_lightning import LightningModule | ||
|
||
class MyModel(LightningModule): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
# Important: This property activates truncated backpropagation through time | ||
# Setting this value to 2 splits the batch into sequences of size 2 | ||
self.truncated_bptt_steps = 2 | ||
|
||
# Truncated back-propagation through time | ||
def training_step(self, batch, batch_idx, hiddens): | ||
# the training step must be updated to accept a ``hiddens`` argument | ||
# hiddens are the hiddens from the previous truncated backprop step | ||
out, hiddens = self.lstm(data, hiddens) | ||
return { | ||
"loss": ..., | ||
"hiddens": hiddens | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why exactly the same example as in sequences.rst?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no good reason other than copy/paste. any suggestions on how they should be different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may drop the example here and keep only the earlier one...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would make this short here and link to the "sequences" document which provides all example and tutorials.
truncated_bptt_steps
flag on Trainer in favor of same setting on the LightningModule
@@ -876,11 +877,22 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens): | |||
) | |||
|
|||
# pass hiddens if using tbptt | |||
if self.trainer.truncated_bptt_steps is not None: | |||
if self._truncated_bptt_enabled(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall this also be a property?
if self._truncated_bptt_enabled(): | |
if self.truncated_bptt_enabled: |
args.append(hiddens) | ||
|
||
return args | ||
|
||
def _truncated_bptt_enabled(self) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _truncated_bptt_enabled(self) -> bool: | |
@property | |
def truncated_bptt_enabled(self) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM !
What does this PR do?
Fixes #7322
This also fixes an inconsistency between the documentation and code. The documentation said
hiddens
would be passed only if truncated_bptt_steps > 0, but the code was sending hiddens if truncated_bptt_steps was not None, so if truncated_tbptt_steps = 0, then we'd send hiddens to the lightning moduleBefore submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃