Understanding vjp for scan
and Remat
#9730
-
Hi all. I am trying to understand the different ways to write a large kernel in Jax. The underlying code looks like this: (assume lambd is a huge tensor and adding an extra @partial(np.vectorize, signature="(c),(),(c)->()")
def cauchy_dot(v, omega, lambd):
return (v / (omega - lambd)).sum() Running this code runs out of memory on the gradient step, but I can use
My understanding though is that this is only a band-aid as remat just avoid keeping the memory, but still materializes the matrix. So tried writing a non-materialized version @partial(np.vectorize, signature="(c),(),(c)->()")
def cauchy_dot(v, omega, lambd):
def s(carry, x):
v2, l = x
return carry + (v2 / (omega - l)), None
return jax.lax.scan(s, 0.0, (v, lambd))[0] This however fails, as even though the forward is non-materialized, my understanding is that the vjp for scan will need to materialize the intermediate steps. @partial(np.vectorize, signature="(c),(),(c)->()")
def cauchy_dot(v, omega, lambd):
def s(carry, x):
v2, l = x
return carry + (v2 / (omega - l)), None
return jax.lax.scan(s, 0.0, (v, lambd))[0] But I can get around this by @partial(np.vectorize, signature="(c),(),(c)->()")
def cauchy_dot(v, omega, lambd):
@jax.remat
def inner(v2, l):
return (v2 / (omega - l))
def s(carry, x):
v2, l = x
return carry + inner(v2, l), None
return jax.lax.scan(s, 0.0, (v, lambd))[0] ============ Update: Further benchmarking on this seems to now show that |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 7 replies
-
@srush My observation is: @partial(jnp.vectorize, signature="(c),(),(c)->()")
def cauchy_dot(v, omega, lambd):
return (v / (omega - lambd)).sum()
x = jnp.zeros((1000,))
omega = jnp.ones((10000, 10000))
print(jax.jit(jax.grad(lambda *args: jnp.sum(cauchy_dot(*args))))(x, omega, x)) Inspect the unoptimized IR in HLO form lowered = jax.jit(jax.grad(lambda *args: jnp.sum(cauchy_dot(*args)))).lower(x, omega, x)
print(lowered.compiler_ir(dialect="mhlo"))
Maybe the reduce avoid materializing the huge matrix? I'm not sure. You maybe misunderstand the effect of
Thus, if you checkpoint every step, saved inputs still cost O(#steps) memory. |
Beta Was this translation helpful? Give feedback.
@srush My observation is:
jax.jit
can avoid OOMInspect the unoptimized IR in HLO form