Skip to content

Understanding vjp for scan and Remat #9730

Answered by YouJiacheng
srush asked this question in Q&A
Discussion options

You must be logged in to vote

@srush My observation is: jax.jit can avoid OOM

@partial(jnp.vectorize, signature="(c),(),(c)->()")                                                                                                                                                                 
def cauchy_dot(v, omega, lambd):                                                                                                                                                                                   
    return (v / (omega - lambd)).sum()


x = jnp.zeros((1000,))
omega = jnp.ones((10000, 10000))
print(jax.jit(jax.grad(lambda *args: jnp.sum(cauchy_dot(*args))))(x, omega, x))

Inspect the unoptimized IR in HLO form

lowered = 

Replies: 1 comment 7 replies

Comment options

You must be logged in to vote
7 replies
@srush
Comment options

@srush
Comment options

@srush
Comment options

@YouJiacheng
Comment options

@srush
Comment options

Answer selected by srush
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants