How can I see what collective ops are used during auto-parallelisation #22813
-
How can I see what collective ops are applied under the hood? I would hope to see a ppermute here, as that's all rolling does. import os
import jax
import jax.numpy as jnp
from jax import config, lax
from jax.experimental.mesh_utils import create_device_mesh
# Set num jax devices
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}"
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
def create_mesh(shape, axis_names, devices=None):
mesh_devices = create_device_mesh(mesh_shape=shape, devices=devices)
mesh = Mesh(mesh_devices, axis_names=axis_names)
return mesh
def main(num_shards):
P = PartitionSpec
mesh = create_mesh(shape=(num_shards,), axis_names=('a',))
@jax.jit
def f(x):
x = lax.with_sharding_constraint(x, NamedSharding(mesh, P('a')))
return jnp.roll(x, 1)
x = jnp.arange(num_shards * 2)
print(jax.make_jaxpr(f)(x))
if __name__ == '__main__':
main(num_shards=os.cpu_count()) This gives
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Try: |
Beta Was this translation helpful? Give feedback.
-
Cool, this worked. With this I see a permute, which it must have figured out from a slice->concatenate rule.
|
Beta Was this translation helpful? Give feedback.
Try:
print(f.lower(x).compile().as_text()
. This will print the HLO with collectives inside it.