Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpectedly high grad-of-scan memory usage #3186

Closed
jeffgortmaker opened this issue May 22, 2020 · 5 comments
Closed

Unexpectedly high grad-of-scan memory usage #3186

jeffgortmaker opened this issue May 22, 2020 · 5 comments
Assignees
Labels
question Questions for the JAX team

Comments

@jeffgortmaker
Copy link

jeffgortmaker commented May 22, 2020

Consider the following function that sums x * (y + z) over all y in ys and then averages over the resulting matrix of sums:

import jax.lax
import jax.numpy as jnp

def f(x, ys):
    z = jnp.ones((3000, 3000))

    def scanned(carry, y):
        return carry + x * (y + z), None

    summed, _ = jax.lax.scan(scanned, jnp.zeros_like(z), ys)
    return summed.mean()

Because I use lax.scan (instead of, e.g., vmap or lax.map followed by a sum over the first axis), memory usage doesn't significantly scale with the number of ys. The following code uses ~203MB regardless of whether n = 5 or n = 10:

import resource

print(f(1.0, jnp.ones(n)))
print(f"{1e-3 * resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}MB")

But the gradient uses 557MB for n = 5 and 908MB for n = 10:

import jax

print(jax.grad(f)(1.0, jnp.ones(n)))
print(f"{1e-3 * resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}MB")

The story is similar when these functions are jitted.

My best guess about what's going on here is that grad is storing every (y + z) in memory. Is this intended? And is there some way to tell grad to be more economical about what it stores in memory to achieve a similar lax.scan memory reduction when computing the gradient?

@skye
Copy link
Member

skye commented May 22, 2020

You're right that grad causes every (y + z) to be stored. Since the result of f is computed using x * (y + z), it needs to save the (y + z) values to compute the gradient. You can try using the new jax.remat, which causes values needed by the gradient computation to be recomputed instead of stored, thus saving memory. This probably makes sense for a scan like this, where you're creating a large amount of easy-to-compute values. See #1749 for examples of using remat. I think doing scan(remat(scanned), ...) should work in this case.

cc @mattjj who created remat

@skye skye self-assigned this May 22, 2020
@jeffgortmaker
Copy link
Author

This is perfect, thanks so much! I hadn't seen remat before -- looks like it's tailor-made for this type of problem.

For some reason rematifying scanned directly didn't seem to work; I found that I had to rematify the actual computation within the scan to get the desired memory reduction:

def f(x, ys):
    z = jnp.ones((3000, 3000))

    @jax.remat
    def inner(y):
        return x * (y + z)

    def scanned(carry, y):
        return carry + inner(y), None

    summed, _ = jax.lax.scan(scanned, jnp.zeros_like(z), ys)
    return summed.mean()

@mattjj
Copy link
Collaborator

mattjj commented May 22, 2020

By the way, we're working on some other improvements that should make this work well even without remat by never instantiating the large ones((3000, 3000)) array. We'd still need remat in general, but in this case the memory savings can be had by avoiding the large constant.

@mattjj mattjj added the question Questions for the JAX team label May 22, 2020
@jeffgortmaker
Copy link
Author

Very cool, I'll keep my eyes peeled and keep updating the package. The work you all are doing here is really great.

@jwnys
Copy link

jwnys commented Jun 7, 2022

Not sure if it's entirely relevant, but I'll mention what helped me instead.
If you don't want to give up computational speed to reduce your memory (which is what you get with remat), what worked for me (getting memory requirements down from >150GB to <32GB) was to unroll the scan, using unroll=len(xs). I needed Hessians of a scan function, and this somehow resolved everything for me... I'm still not sure why this worked, so it would be good to get some information on this @mattjj , just to know whether this is actually a good idea.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

4 participants