HELP! PRNGKey with shard_map not working #22862
-
Basically, I was trying to port my code from using pmap to using shard_map, as suggested by jax docs. When using pmap, considering a single host for now on a TPU-v4-8 (4 JAX devices), I used to replicate states such as train state (flax) and random keys via flax.jax_utils.replicate, and then used to pass these replicated states to the pmapped functions, and it used to work great! But, using shard_map, I am unable to figure out how to make this work! I have the following code (example) thats explains my issue:
Now, If I pass the rngs_mapped to the shard_mapped function, I immediately get an error:
Printing the object seems to indicate that it has been assigned axes that it wasn't supposed to have I guess? Ofcourse, I can fix this using reshapes in this example:
But this just seems like a hack. Whats the proper way to do this? I am basically trying to train a diffusion models, and I fear similar issues might also occur with the states and what not. I am aware of the issue #22860 but it was just marked as 'bug' 8 hours ago and there has been no activity. I Just want to be sure its not a mistake from my end. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
A main difference between So here you have a batch of keys of size one, and Taking this into account (and fixing your sharding specifications to match the input and output shapes) looks like this: out2 = jax.jit(shard_map(jax.vmap(func), mesh=mesh,
in_specs=(P('i'), P('i')),
out_specs=(P('i'), P('i'))))(ones, rngs_mapped) |
Beta Was this translation helpful? Give feedback.
-
There's a nice strategy without requiring vmap here: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/scaling/JAX/data_parallel_fsdp.html So here's a full answer to your question: N_CPUS = 8
import os
os.environ["XLA_FLAGS"] = f'--xla_force_host_platform_device_count={N_CPUS}' # Use 8 CPU devices
import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.experimental.shard_map import shard_map
P = jax.sharding.PartitionSpec
devices = mesh_utils.create_device_mesh((N_CPUS,), devices=jax.devices("cpu"))
mesh = jax.sharding.Mesh(devices, ("x"))
sharding = jax.sharding.NamedSharding(mesh, P("x"))
rng = jax.random.PRNGKey(0)
# replicate the rng (no need to split)
rng = jax.device_put(rng, jax.sharding.NamedSharding(mesh, P()))
arr = jnp.arange(N_CPUS)
def fold_rng_over_axis(rng: jax.random.PRNGKey, axis_name: str) -> jax.random.PRNGKey:
"""Folds the random number generator over the given axis.
This is useful for generating a different random number for each device
across a certain axis (e.g. the model axis).
Copied from https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/scaling/JAX/data_parallel_fsdp.html
Args:
rng: The random number generator.
axis_name: The axis name to fold the random number generator over.
Returns:
A new random number generator, different for each device index along the axis.
"""
axis_index = jax.lax.axis_index(axis_name)
return jax.random.fold_in(rng, axis_index)
@jax.jit
def f(rng, x):
rng = fold_rng_over_axis(rng, axis_name="x")
print("x shape", x.shape)
print("rng shape", rng.shape)
return jax.random.uniform(rng) + x
f_sh = shard_map(f, mesh=mesh, in_specs=(P(), P("x")), out_specs=P("x"))
print("f_sh output", f_sh(rng, arr)) Output:
|
Beta Was this translation helpful? Give feedback.
Found something. This seems to work:
Basically, If I don't give the axis name at all, then no need to replicate anything and no reshaping etc required. And I am still able to use collectives inside. This works on multi-host. Just want to confirm from you the caveats of this method. Is this alright?
Once again, thank you soo much for this!
EDIT:
should be this actually: