Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generalize closure api in Lightning #8642

Merged
merged 53 commits into from
Aug 26, 2021
Merged

generalize closure api in Lightning #8642

merged 53 commits into from
Aug 26, 2021

Conversation

awaelchli
Copy link
Contributor

@awaelchli awaelchli commented Jul 30, 2021

What does this PR do?

Proposal for a general API for handling closures in Lightning.

Motivation:

Depending on how involved a loop customization is, sooner or later a user has to deal with closures at some point. This is quite difficult in Lightning currently due to a high entanglement with gradient accumulation, manual backward, precision etc. By extracting out a high level api for building closures, we enable customization while providing a clean implementation for manual optimization etc.

Proposed Solution

A Lightning Closure with the following responsibilities:

  • takes three functions as input (step, backward, zero grad) and calls them in the right order.
  • caches the output of the step (training_step) function, to be accessed outside once the closure has run
  • includes profiling on top of the three functions
class LightningClosure(Closure):
    def __init__(
        self,
        step_fn: Callable[[], dict],
        backward_fn: Optional[Callable[[Tensor], Tensor]] = None,
        zero_grad_fn: Optional[Callable[[], None]] = None,
        profiler: Optional[BaseProfiler] = None,
    ):

A Lightning Closure gets created inside our training loop like so (pseudo code):

closure = LightningClosure(step_fn=self._training_step, backward_fn=self._backward, zero_grad_fn=self._zero_grad)

and then later on it gets passed to the optimizer after which we can access results:

optimizer.step(closure)
result = closure.result

With the help of the closure abstraction, it is now very easy to implement manual optimization by setting

closure = LightningClosure(step_fn=self._manual_training_step, backward_fn=None, zero_grad_fn=None)

Result

  • Easier to read our internal code path branching between automatic optimization and manual opt.
  • A step towards loop customization: Expert users may use this simple API to make their own closures.
  • Less magic: Standardized way to access results produced within the closure

Final Thoughts

This PR adds an abstraction. Lightning in its philosophy doesn't want to put abstraction in the way of a user, and we don't want users to fight against our abstractions. The introduction of LightningClosure doesn't contradict that. It is mainly motivated by internal use and documentation. It's a self-contained component that I believe offers great insight for any code readers and explorers looking into how Lightning works with optimizers and closures.

Fixes #9129

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

I made sure I had fun coding 🙃

@pep8speaks
Copy link

pep8speaks commented Jul 30, 2021

Hello @awaelchli! Thanks for updating this PR. We checked the lines you've touched for PEP 8 issues, and found:

Line 258:13: W503 line break before binary operator
Line 259:13: W503 line break before binary operator

Comment last updated at 2021-08-13 12:30:53 UTC

@awaelchli awaelchli added design Includes a design discussion feature Is an improvement or enhancement refactor labels Jul 30, 2021
@codecov
Copy link

codecov bot commented Jul 30, 2021

Codecov Report

Merging #8642 (e79c15a) into master (366fb39) will decrease coverage by 4%.
The diff coverage is 100%.

@@           Coverage Diff           @@
##           master   #8642    +/-   ##
=======================================
- Coverage      92%     88%    -4%     
=======================================
  Files         175     176     +1     
  Lines       14696   14741    +45     
=======================================
- Hits        13508   12941   -567     
- Misses       1188    1800   +612     

Copy link
Member

@justusschock justusschock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One comment on memory and two really minor ones on naming :)

pytorch_lightning/loops/closure.py Outdated Show resolved Hide resolved
pytorch_lightning/loops/closure.py Outdated Show resolved Hide resolved
pytorch_lightning/loops/closure.py Outdated Show resolved Hide resolved
@awaelchli awaelchli added this to the v1.5 milestone Aug 13, 2021
Co-authored-by: Carlos Mocholí <[email protected]>
@awaelchli awaelchli force-pushed the feature/loops/closure branch from bc52030 to 592daa8 Compare August 13, 2021 02:02
@awaelchli awaelchli force-pushed the feature/loops/closure branch from 9ce755f to 3137123 Compare August 13, 2021 09:46
@awaelchli awaelchli mentioned this pull request Aug 13, 2021
11 tasks
pytorch_lightning/loops/batch/training_batch_loop.py Outdated Show resolved Hide resolved
pytorch_lightning/loops/batch/training_batch_loop.py Outdated Show resolved Hide resolved
pytorch_lightning/loops/batch/training_batch_loop.py Outdated Show resolved Hide resolved
pytorch_lightning/loops/closure.py Outdated Show resolved Hide resolved
pytorch_lightning/loops/closure.py Outdated Show resolved Hide resolved
pytorch_lightning/loops/closure.py Outdated Show resolved Hide resolved
@mergify mergify bot added ready PRs ready to be merged and removed has conflicts labels Aug 25, 2021
Copy link
Contributor

@tchaton tchaton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great !

@tchaton tchaton enabled auto-merge (squash) August 25, 2021 19:03
@Borda Borda requested a review from ananthsub August 26, 2021 05:54
@tchaton tchaton merged commit 6592d0e into master Aug 26, 2021
@tchaton tchaton deleted the feature/loops/closure branch August 26, 2021 08:36
Comment on lines +279 to +281
# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
check_finite_loss(self.trainer.lightning_module, loss)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli since this finite check was added here, can we remove _process_closure_result? (just below)

Copy link
Contributor Author

@awaelchli awaelchli Aug 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think the way it is right now, it is redundant and can be removed. However, looking at it right now, the NaN checks are not the way they are supposed to be.

The intention of terminate_on_nan is to check the loss BEFORE backward applies, and to check the weights AFTER backward. this must have been messed up in some of our loop refactors :(

I will address this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Includes a design discussion feature Is an improvement or enhancement ready PRs ready to be merged refactor
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RFC] Introduce LightningClosure
8 participants