Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecatetruncated_bptt_steps flag on Trainer in favor of same setting on the LightningModule #7323

Merged
merged 17 commits into from
May 5, 2021

Conversation

ananthsub
Copy link
Contributor

@ananthsub ananthsub commented May 3, 2021

What does this PR do?

Fixes #7322

This also fixes an inconsistency between the documentation and code. The documentation said hiddens would be passed only if truncated_bptt_steps > 0, but the code was sending hiddens if truncated_bptt_steps was not None, so if truncated_tbptt_steps = 0, then we'd send hiddens to the lightning module

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

@codecov
Copy link

codecov bot commented May 3, 2021

Codecov Report

Merging #7323 (d1669fd) into master (f135deb) will decrease coverage by 1%.
The diff coverage is 95%.

@@          Coverage Diff           @@
##           master   #7323   +/-   ##
======================================
- Coverage      87%     87%   -1%     
======================================
  Files         200     200           
  Lines       12895   12946   +51     
======================================
- Hits        11259   11237   -22     
- Misses       1636    1709   +73     

@pep8speaks
Copy link

pep8speaks commented May 3, 2021

Hello @ananthsub! Thanks for updating this PR.

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2021-05-04 16:34:32 UTC

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we have any docs example where the trainer flag use needs to be updated? Did you try grepping?

pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
pytorch_lightning/core/lightning.py Outdated Show resolved Hide resolved
pytorch_lightning/core/lightning.py Outdated Show resolved Hide resolved
pytorch_lightning/core/lightning.py Outdated Show resolved Hide resolved
@mergify mergify bot added the has conflicts label May 4, 2021
@ananthsub ananthsub requested a review from edenlightning as a code owner May 4, 2021 00:43
@mergify mergify bot removed the has conflicts label May 4, 2021
@ananthsub
Copy link
Contributor Author

Don't we have any docs example where the trainer flag use needs to be updated? Did you try grepping?

@carmocca updated the docs to reflect setting the property on the LightningModule instead of passing a trainer flag. Anything else that should be done for the trainer docs?

Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My grep powers still show a few more occurrences which we should update:

docs/source/common/trainer.rst:1430:truncated_bptt_steps
docs/source/common/trainer.rst:1436:    poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/truncated_bptt_steps.jpg"
docs/source/common/trainer.rst:1437:    src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/truncated_bptt_steps.mp4"></video>
docs/source/common/trainer.rst:1454:    trainer = Trainer(truncated_bptt_steps=None)
docs/source/common/trainer.rst:1457:    trainer = Trainer(truncated_bptt_steps=5)
docs/source/starter/new-project.rst:750:- :ref:`Automatic truncated-back-propagation-through-time <common/trainer:truncated_bptt_steps>`

So we should move these into lightning_module.rst as we do for automatic_optimization.
Not sure if we should keep the video though

cc: @edenlightning ?

docs/source/advanced/sequences.rst Show resolved Hide resolved
pytorch_lightning/accelerators/accelerator.py Outdated Show resolved Hide resolved
pytorch_lightning/core/lightning.py Outdated Show resolved Hide resolved
pytorch_lightning/core/lightning.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/training_loop.py Outdated Show resolved Hide resolved
Comment on lines +1021 to +1041
.. testcode:: python

from pytorch_lightning import LightningModule

class MyModel(LightningModule):

def __init__(self):
super().__init__()
# Important: This property activates truncated backpropagation through time
# Setting this value to 2 splits the batch into sequences of size 2
self.truncated_bptt_steps = 2

# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
# the training step must be updated to accept a ``hiddens`` argument
# hiddens are the hiddens from the previous truncated backprop step
out, hiddens = self.lstm(data, hiddens)
return {
"loss": ...,
"hiddens": hiddens
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why exactly the same example as in sequences.rst?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no good reason other than copy/paste. any suggestions on how they should be different?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may drop the example here and keep only the earlier one...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make this short here and link to the "sequences" document which provides all example and tutorials.

@awaelchli awaelchli added the _Will label May 4, 2021
docs/source/common/lightning_module.rst Outdated Show resolved Hide resolved
docs/source/common/lightning_module.rst Outdated Show resolved Hide resolved
pytorch_lightning/core/lightning.py Show resolved Hide resolved
pytorch_lightning/core/lightning.py Outdated Show resolved Hide resolved
@carmocca carmocca added the design Includes a design discussion label May 4, 2021
@ananthsub ananthsub changed the title Deprecate truncated_bptt_steps flag on Trainer in favor of same setting on the LightningModule Deprecatetruncated_bptt_steps flag on Trainer in favor of same setting on the LightningModule May 4, 2021
@@ -876,11 +877,22 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
)

# pass hiddens if using tbptt
if self.trainer.truncated_bptt_steps is not None:
if self._truncated_bptt_enabled():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall this also be a property?

Suggested change
if self._truncated_bptt_enabled():
if self.truncated_bptt_enabled:

args.append(hiddens)

return args

def _truncated_bptt_enabled(self) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _truncated_bptt_enabled(self) -> bool:
@property
def truncated_bptt_enabled(self) -> bool:

pytorch_lightning/trainer/training_loop.py Show resolved Hide resolved
Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !

@tchaton tchaton merged commit 98670c8 into Lightning-AI:master May 5, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion refactor
Projects
None yet
Development

Successfully merging this pull request may close these issues.

truncated_bptt_steps should be a property of the LightningModule, not the Trainer
7 participants