-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unexpectedly high grad-of-scan memory usage #3186
Comments
You're right that cc @mattjj who created |
This is perfect, thanks so much! I hadn't seen For some reason def f(x, ys):
z = jnp.ones((3000, 3000))
@jax.remat
def inner(y):
return x * (y + z)
def scanned(carry, y):
return carry + inner(y), None
summed, _ = jax.lax.scan(scanned, jnp.zeros_like(z), ys)
return summed.mean() |
By the way, we're working on some other improvements that should make this work well even without |
Very cool, I'll keep my eyes peeled and keep updating the package. The work you all are doing here is really great. |
Not sure if it's entirely relevant, but I'll mention what helped me instead. |
Consider the following function that sums
x * (y + z)
over ally
inys
and then averages over the resulting matrix of sums:Because I use
lax.scan
(instead of, e.g.,vmap
orlax.map
followed by a sum over the first axis), memory usage doesn't significantly scale with the number ofys
. The following code uses ~203MB regardless of whethern = 5
orn = 10
:But the gradient uses 557MB for
n = 5
and 908MB forn = 10
:The story is similar when these functions are
jit
ted.My best guess about what's going on here is that
grad
is storing every(y + z)
in memory. Is this intended? And is there some way to tellgrad
to be more economical about what it stores in memory to achieve a similarlax.scan
memory reduction when computing the gradient?The text was updated successfully, but these errors were encountered: