Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autodiff: checkpointing strategy #936

Closed
louisfd opened this issue Nov 7, 2023 · 4 comments
Closed

Autodiff: checkpointing strategy #936

louisfd opened this issue Nov 7, 2023 · 4 comments
Assignees
Labels
performance Anything related to performance very hard Reserved for framework experts: Extremely challenging.

Comments

@louisfd
Copy link
Member

louisfd commented Nov 7, 2023

In autodiff, we should have a checkpointing strategy for better memory consumption (see for instance https://www-sop.inria.fr/tropics/papers/DauvergneHascoet06.pdf) .

Currently, for most operations run in the forward pass, a state will be saved for the backward pass. The state often consists of a few tensors, so it is needless to say that they accumulate and use a lot of memory.

A way to use less memory for the backward pass would be to, instead of having kept the state in memory, recompute the forward pass of the operation to re-obtain the state, just before computing its backward pass. This will lead to more computations, but less memory consumption.

This leads to a tradeoff between compute and memory. Some operations, like matrix multiplication, are "compute-bound", meaning the bottleneck is generally the actual computations, while some, such as element-wise multiplication, are "memory-bound", meaning the computation is actually so simple that the moving of data is the bottleneck.

For compute-bound operations, it is better to keep the state than to recompute. But for memory-bound operations, we would benefit from recomputing.

Also, if many operations are tagged as memory-bound, this will greatly help fusing kernels with Burn-Fusion, which will be able to fuse kernels transparently during the backward pass.

The current strategy, where every state is saved, would simply become a specific case of the new strategy, where everything is considered compute-bound.

@louisfd louisfd added performance Anything related to performance very hard Reserved for framework experts: Extremely challenging. labels Nov 7, 2023
@AuruTus
Copy link
Contributor

AuruTus commented Nov 23, 2023

Hi, I'm wondering how the toggle of this strategy should be added into burn's ad graph? The AD tool tapenade behind that paper seems to have an IR and a pair of directives to control which snippets should be treated with checkpoints.

@louisfd
Copy link
Member Author

louisfd commented Nov 23, 2023

Hi @AuruTus
To be honest, haven't read the paper. I just thought the figures seem to explain well the concept of checkpointing, I'm not sure if we should follow what they did or conceptualize our own checkpointing strategy.
We plan on tackling that issue early 2024; for now I haven't given it more thought than what is written above!

@benjamin-macadam
Copy link

If you are discussing check-pointing strategies, it may be worth considering Jax's approach to AD, explained in You Only Linearize Once, since that can shed some light on what is going on with checkpointing. The idea is to break the the vector-jacobian product into two pieces - I'm going to use Haskell type signatures where -o is linear implication and ! means a variable may be reused (e.g. it is a smooth argument).

  1. A jacobian-vector product jvp: (!a -o b) -> ((!a,a) -o b) (this is quite easy to implement). If you set this up correctly you're guaranteed to have nested derivatives that work out of the box.
  2. A linear transpose t: (a -o b) -> (b -o a) - the linear map is representable as a matrix, this is the same thing as multiplying by the transpose of the matrix. This generally where the 'tape' comes in handy. Using currying this can become t_contextual: ((!a, b) -o c) -> ((!a, c) -o b), but realistically you probably want t_contextual as your primitive (that's what we did here).
  3. You can now define vjp = t_contextual . jvp, or vjp f = t_contextual (jvp f). If you set it up this way you're guaranteed to have nested forward/reverse derivatives that work correctly.
  4. (A nice little bonus is that if you've already have a vjp operator then you can use it to define the linear transpose. So this really means you only need to write vjp's for linear functions, and jvps for all functions.)

Now, I said that you want to treat t_contextual as a primitive, but that's really for the user - that's where all of the logic around check-pointing lives when you look at section 6 of 'You Only Linearize Once'. I've mostly changed my focus to animation/simulation these days, but I do know a few Math/CS profs in Canada who are still working on these things.

@louisfd
Copy link
Member Author

louisfd commented Feb 27, 2024

Solved in #1358

@louisfd louisfd closed this as completed Feb 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Anything related to performance very hard Reserved for framework experts: Extremely challenging.
Projects
Status: Done
Development

No branches or pull requests

3 participants