-
I am implementing the Lattice Boltzmann methods with JAX. The method requires operations similar to convolution (but different), i.e., each lattice node interacts with its neighbors. I am using In the following example, import jax
import jax.numpy as np
def to_id_xyz(lattice_id):
# E.g., lattice_id = N**2 + 2N + 3 will be converted to (1, 2, 3)
id_z = lattice_id % N
lattice_id = lattice_id // N
id_y = lattice_id % N
id_x = lattice_id // N
return id_x, id_y, id_z
def extract_3x3x3_grid(lattice_id, value_tensor):
id_x, id_y, id_z = to_id_xyz(lattice_id)
grid_index = np.ix_(np.array([(id_x - 1) % N, id_x, (id_x + 1) % N]),
np.array([(id_y - 1) % N, id_y, (id_y + 1) % N]),
np.array([(id_z - 1) % N, id_z, (id_z + 1) % N]))
grid_values = value_tensor[grid_index]
return grid_values
def memory_test1(lattice_id, value_tensor):
value_tensor_local = extract_3x3x3_grid(lattice_id, value_tensor) # (3, 3, 3, dof)
vel = np.ones((3, 3, 3, dof, 1))
u_local = value_tensor_local[:, :, :, :, None] * vel # (3, 3, 3, dof, 1)
return np.sum(u_local)
memory_test1_vmap = jax.jit(jax.vmap(memory_test1, in_axes=(0, None)))
def memory_test2(lattice_id, value_tensor):
value_tensor_local = extract_3x3x3_grid(lattice_id, value_tensor) # (3, 3, 3, dof)
vel = np.ones((3, 3, 3, dof, 2))
u_local = value_tensor_local[:, :, :, :, None] * vel # (3, 3, 3, dof, 2)
return np.sum(u_local)
memory_test2_vmap = jax.jit(jax.vmap(memory_test2, in_axes=(0, None)))
N = 400
dof = 19
value_tensor = np.ones((N, N, N, dof))
result1 = memory_test1_vmap(np.arange(N*N*N), value_tensor)
print(f"max result1 = {np.max(result1)}")
result2 = memory_test2_vmap(np.arange(N*N*N), value_tensor) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
There is no general answer to "how much memory Consider this simple example: def f(x):
return (x[:, None] * x[None, :]).sum() We can get a sense for what operations this lowers to by printing its jaxpr: x = jnp.ones(100)
print(jax.make_jaxpr(f)(x))
This function does two broadcasts, one multiply, and then sums over the resulting size Now what if we f_vmap = jax.vmap(f)
x_batched = jnp.ones((10, 100))
print(jax.make_jaxpr(f_vmap)(x_batched))
Again it's two broadcasts and a multiply, followed by a sum over the resulting size The point is that Your function is dealing with arrays whose size is measured in tens of gigabytes, on a machine with 48GB available, and you're finding that when you double the size of your arrays, you run out of memory. This on its face is not entirely surprising. If you're interested in the details of which operations and intermediate values are created by your vmapped functions, you can use Does that make sense? |
Beta Was this translation helpful? Give feedback.
There is no general answer to "how much memory
vmap
needs to solve a problem".vmap
doesn't do computation; rather it transforms one abstract computation into another that is applicable to batched inputs.Consider this simple example:
We can get a sense for what operations this lowers to by printing its jaxpr:
This function does …