Skip to content

Commit

Permalink
Add get_replication to shard_map.py for verifying if an array is repl…
Browse files Browse the repository at this point in the history
…icated.

PiperOrigin-RevId: 676910872
  • Loading branch information
pschuh authored and Google-ML-Automation committed Sep 20, 2024
1 parent 82b0e0e commit 1acf956
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
10 changes: 10 additions & 0 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2011,3 +2011,13 @@ def _match_replication(src, dst, x):
if src - dst:
x = pbroadcast(x, tuple(n for n in src if n not in dst))
return x

# TODO(parkers,mattjj): change implementation when we have sharding-in-types.
def get_replication(x: jax.Array) -> set[AxisName]:
"""For a jax.Array, return what axes it is known to be replicated along."""

if isinstance(x, RewriteTracer):
return x.rep
if isinstance(x, batching.BatchTracer):
return get_replication(x.val)
raise ValueError("get_replication not defined on %s" % repr(type(x)))
28 changes: 28 additions & 0 deletions tests/shard_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2151,6 +2151,34 @@ def f(a):

f(A()) # don't crash

def test_get_check_rep(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))

def f(x, reduce_along, use_jit):
out_spec = P(*(n for n in ('x', 'y') if n not in reduce_along))

@partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=out_spec)
def g(x):
result = lax.psum(x, axis_name=reduce_along)
def check_rep(result):
self.assertEqual(
jax.experimental.shard_map.get_replication(result),
set(reduce_along))
return result
result = check_rep(result)
result = jax.vmap(check_rep)(result)
return result
if use_jit:
return jax.jit(g)(x)
else:
return g(x)

for use_jit in [True, False]:
x = np.zeros((8, 8), dtype=np.float32)
f(x, reduce_along=('y',), use_jit=use_jit)
f(x, reduce_along=('x',), use_jit=use_jit)
f(x, reduce_along=('x', 'y'), use_jit=use_jit)


class FunSpec(NamedTuple):
name: str
Expand Down

0 comments on commit 1acf956

Please sign in to comment.