Skip to content

How vmap allocates memory #14381

Answered by jakevdp
tianjuxue asked this question in Q&A
Discussion options

You must be logged in to vote

There is no general answer to "how much memory vmap needs to solve a problem". vmap doesn't do computation; rather it transforms one abstract computation into another that is applicable to batched inputs.

Consider this simple example:

def f(x):
  return (x[:, None] * x[None, :]).sum()

We can get a sense for what operations this lowers to by printing its jaxpr:

x = jnp.ones(100)
print(jax.make_jaxpr(f)(x))
{ lambda ; a:f32[100]. let
    b:f32[100,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(100, 1)] a
    c:f32[1,100] = broadcast_in_dim[broadcast_dimensions=(1,) shape=(1, 100)] a
    d:f32[100,100] = mul b c
    e:f32[] = reduce_sum[axes=(0, 1)] d
  in (e,) }

This function does …

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by tianjuxue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants