Skip to content

Commit 473e2bf

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Put abstract_mesh on every eqn so that we can preserve it during eval_jaxpr and check_jaxpr roundtrip.
Also allow users to enter into `Auto`/`User` mode inside jit along all or some axes. Add checks to make sure that avals inside a context match the surrounding context. This check happens inside `abstract_eval` rules but maybe we need a more central place for it which we can create later on. PiperOrigin-RevId: 707128096
1 parent e1f037b commit 473e2bf

File tree

11 files changed

+262
-43
lines changed

11 files changed

+262
-43
lines changed

Diff for: jax/_src/api.py

-2
Original file line numberDiff line numberDiff line change
@@ -2187,8 +2187,6 @@ def _infer_src_sharding(src, x) -> Sharding | None:
21872187
return src # pytype: disable=bad-return-type
21882188
if isinstance(x, array.ArrayImpl):
21892189
return x.sharding
2190-
if config.sharding_in_types.value and hasattr(x, 'sharding'):
2191-
return x.sharding
21922190
if isinstance(x, core.Tracer):
21932191
val = x.to_concrete_value()
21942192
if val is not None and isinstance(val, array.ArrayImpl):

Diff for: jax/_src/core.py

+29-11
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,12 @@ def __init__(self, compute_type: str | None, threefry_partitionable: bool,
273273
xla_metadata=None):
274274
self.compute_type = compute_type
275275
self.threefry_partitionable = threefry_partitionable
276+
self.cur_abstract_mesh = mesh_lib.get_abstract_mesh()
276277
self.xla_metadata = xla_metadata
277278
self._managers = [
278279
(compute_on.extend_compute_type, self.compute_type),
279280
(config.threefry_partitionable.__call__, self.threefry_partitionable),
281+
(mesh_lib.set_abstract_mesh, self.cur_abstract_mesh),
280282
(xla_metadata_lib.set_xla_metadata, self.xla_metadata),
281283
]
282284

@@ -292,6 +294,7 @@ def __repr__(self):
292294
return (
293295
f"JaxprEqnContext(compute_type={self.compute_type}, "
294296
f"threefry_partitionable={self.threefry_partitionable}, "
297+
f"cur_abstract_mesh={self.cur_abstract_mesh}, "
295298
f"xla_metadata={self.xla_metadata})"
296299
)
297300

@@ -535,6 +538,17 @@ def write(v: Var, val: Any) -> None:
535538
clean_up_dead_vars(eqn, env, lu)
536539
return map(read, jaxpr.outvars)
537540

541+
def check_avals_context_mesh(avals, prim_name):
542+
if config.sharding_in_types.value:
543+
for a in avals:
544+
cur_mesh = mesh_lib.get_abstract_mesh()
545+
if a.sharding.mesh != cur_mesh:
546+
raise ValueError(
547+
f"For primitive {prim_name}, context mesh {cur_mesh} should match"
548+
f" the aval mesh {a.sharding.mesh} for shape {a.str_short()}. This"
549+
" error occurs at source: "
550+
f" {source_info_util.summarize(source_info_util.current())}")
551+
538552

539553
# -------------------- tracing --------------------
540554

@@ -1622,7 +1636,10 @@ def get_sharding(sharding, ndim):
16221636
from jax._src.sharding_impls import NamedSharding # type: ignore
16231637

16241638
if sharding is not None:
1625-
assert len(sharding.spec) == ndim
1639+
if len(sharding.spec) != ndim:
1640+
raise ValueError(
1641+
"Length of sharding.spec must be equal to aval's ndim. Got"
1642+
f" sharding.spec {sharding.spec} and aval.ndim {ndim}")
16261643
return _maybe_modify_sharding(sharding)
16271644

16281645
context_mesh = mesh_lib.get_abstract_mesh()
@@ -2518,17 +2535,18 @@ def write(v: Var, a: AbstractValue) -> None:
25182535
in_avals = [x.aval for x in in_atoms] # use in_atoms for dyn shapes
25192536

25202537
# Compute the type of the primitive application.
2521-
if prim in custom_typechecks:
2522-
out_type, eqn_effects = custom_typechecks[prim](
2523-
ctx_factory, *in_atoms, **eqn.params)
2524-
elif prim.call_primitive:
2525-
out_type, eqn_effects = _check_call(ctx_factory, prim, in_atoms,
2538+
with eqn.ctx.manager:
2539+
if prim in custom_typechecks:
2540+
out_type, eqn_effects = custom_typechecks[prim](
2541+
ctx_factory, *in_atoms, **eqn.params)
2542+
elif prim.call_primitive:
2543+
out_type, eqn_effects = _check_call(ctx_factory, prim, in_atoms,
2544+
eqn.params)
2545+
elif prim.map_primitive:
2546+
out_type, eqn_effects = _check_map(ctx_factory, prim, in_avals,
25262547
eqn.params)
2527-
elif prim.map_primitive:
2528-
out_type, eqn_effects = _check_map(ctx_factory, prim, in_avals,
2529-
eqn.params)
2530-
else:
2531-
out_type, eqn_effects = check_eqn(prim, in_avals, eqn.params)
2548+
else:
2549+
out_type, eqn_effects = check_eqn(prim, in_avals, eqn.params)
25322550

25332551
# Check the computed effect type matches the eqn's annotation, and is
25342552
# included in the jaxpr's annotation.

Diff for: jax/_src/dispatch.py

-14
Original file line numberDiff line numberDiff line change
@@ -522,8 +522,6 @@ def _batched_device_put_impl(
522522
device_put_p.def_impl(_batched_device_put_impl)
523523

524524
def _device_put_abstract_eval(*xs, devices, srcs, copy_semantics):
525-
if config.sharding_in_types.value:
526-
return [x.update(sharding=s) for x, s in zip(xs, devices)]
527525
return xs
528526
device_put_p.def_abstract_eval(_device_put_abstract_eval)
529527

@@ -566,12 +564,6 @@ def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics):
566564
# TODO(yashkatariya): Maybe we should add the custom calls anyways if it's
567565
# being used inside jit? Atleast for now, this preserves the old behavior.
568566
if ctx.module_context.all_default_mem_kind:
569-
if config.sharding_in_types.value:
570-
return [
571-
mlir.wrap_with_sharding_op(
572-
ctx, x, a, a.sharding._to_xla_hlo_sharding(a.ndim).to_proto())
573-
for x, a in zip(xs, ctx.avals_out)
574-
]
575567
return xs
576568
def lower(x, device, aval, out_aval):
577569
if (isinstance(device, (Sharding, TransferToMemoryKind)) and
@@ -597,12 +589,6 @@ def lower(x, device, aval, out_aval):
597589

598590

599591
def _common_device_put_lowering(ctx, *xs, devices, srcs, copy_semantics):
600-
if config.sharding_in_types.value:
601-
return [
602-
mlir.wrap_with_sharding_op(
603-
ctx, x, a, a.sharding._to_xla_hlo_sharding(a.ndim).to_proto())
604-
for x, a in zip(xs, ctx.avals_out)
605-
]
606592
return xs
607593
mlir.register_lowering(device_put_p, _common_device_put_lowering)
608594

Diff for: jax/_src/interpreters/pxla.py

+2
Original file line numberDiff line numberDiff line change
@@ -2162,6 +2162,8 @@ def _abstract_to_concrete_mesh(abstract_mesh):
21622162

21632163
out = []
21642164
for s, a in zip(shardings, avals):
2165+
# Remove the `UnconstrainedSingleton` logic after UNCONSTRAINED is supported
2166+
# in out_shardings at top level jit.
21652167
if (isinstance(s, UnspecifiedValue) and a.sharding is not None and
21662168
all(not isinstance(s, UnconstrainedSingleton) for s in a.sharding.spec)):
21672169
out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),

Diff for: jax/_src/lax/parallel.py

+1
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,7 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
687687
raise ValueError(f"axis_index_groups can only be used with reductions over "
688688
f"named axes, but got: {axes}")
689689
if config.sharding_in_types.value:
690+
core.check_avals_context_mesh(args, 'all_reduce')
690691
out_avals = [
691692
ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype,
692693
sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes))

Diff for: jax/_src/lax/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
5353
weak_type = weak_type_rule(*avals, **kwargs)
5454
least_specialized = type(max(avals, key=_get_array_abstraction_level))
5555
if least_specialized is core.ShapedArray:
56+
core.check_avals_context_mesh(avals, prim.name)
5657
return core.ShapedArray(
5758
shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
5859
weak_type=weak_type,
@@ -78,6 +79,7 @@ def standard_multi_result_abstract_eval(
7879
if least_specialized is core.ShapedArray:
7980
out_shapes = shape_rule(*avals, **kwargs)
8081
out_dtypes = dtype_rule(*avals, **kwargs)
82+
core.check_avals_context_mesh(avals, prim.name)
8183
out_shardings = (sharding_rule(*avals, **kwargs)
8284
if config.sharding_in_types.value else
8385
[None] * len(out_shapes))

Diff for: jax/_src/mesh.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,9 @@ def local_devices(self):
350350
def abstract_mesh(self):
351351
return AbstractMesh(self.shape_tuple, axis_types=self.axis_types)
352352

353+
def with_axis_types(self, new_axis_types) -> Mesh:
354+
return Mesh(self.devices, self.axis_names, axis_types=new_axis_types)
355+
353356

354357
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
355358

@@ -396,8 +399,9 @@ def __eq__(self, other):
396399
self._axis_types_tuple == other._axis_types_tuple)
397400

398401
def __repr__(self):
402+
mesh_repr = ", ".join(f"'{n}': {v}" for n, v in self.shape_tuple)
399403
atr = f", axis_types={self.axis_types}"
400-
return f"AbstractMesh({self.shape_tuple}{atr})"
404+
return f"AbstractMesh({mesh_repr}{atr})"
401405

402406
@property
403407
def axis_names(self):
@@ -427,6 +431,9 @@ def _internal_device_list(self):
427431
def empty(self):
428432
return self.size == 0
429433

434+
def with_axis_types(self, new_axis_types) -> AbstractMesh:
435+
return AbstractMesh(self.shape_tuple, axis_types=new_axis_types)
436+
430437
@functools.cached_property
431438
def _are_all_axes_collective(self) -> bool:
432439
return all(t == AxisTypes.Collective for t in self.axis_types.keys())

Diff for: jax/_src/pjit.py

+75-7
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@
6161
from jax._src.lib.mlir.dialects import func as func_dialect
6262
from jax._src.lib import jax_jit
6363
from jax._src.lib import xla_client as xc
64-
from jax._src import sharding
6564
from jax._src.mesh import AbstractMesh
65+
from jax._src.sharding import Sharding
6666
from jax._src.sharding_impls import (
6767
NamedSharding, GSPMDSharding,
6868
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
@@ -73,7 +73,7 @@
7373
from jax._src.tree_util import (
7474
tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, tree_leaves,
7575
treedef_children, broadcast_prefix, all_leaves, prefix_errors, keystr,
76-
PyTreeDef, none_leaf_registry as none_lr)
76+
PyTreeDef, none_leaf_registry as none_lr, tree_map)
7777
from jax._src.util import (
7878
HashableFunction, safe_map, safe_zip, wraps,
7979
distributed_debug_log, split_list, weakref_lru_cache,
@@ -1027,7 +1027,7 @@ def hashable_pytree(pytree):
10271027
def _create_sharding_for_array(mesh, x, name, api_name):
10281028
if x is None and (mesh is None or mesh.empty):
10291029
return UNSPECIFIED
1030-
if isinstance(x, (AUTO, UnspecifiedValue, sharding.Sharding)):
1030+
if isinstance(x, (AUTO, UnspecifiedValue, Sharding)):
10311031
return x
10321032
if mesh is None:
10331033
msg = ('jax.jit only supports `Sharding`s being passed to'
@@ -1339,7 +1339,7 @@ def _check_and_canonicalize_out_shardings(
13391339
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
13401340
out_layouts_leaves, out_tree, out_avals, debug_info, device_or_backend_set):
13411341
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
1342-
if isinstance(orig_out_shardings, (UnspecifiedValue, sharding.Sharding)):
1342+
if isinstance(orig_out_shardings, (UnspecifiedValue, Sharding)):
13431343
out_shardings_flat = (orig_out_shardings,) * len(out_avals)
13441344
else:
13451345
out_shardings_flat = flatten_axis_resources(
@@ -1571,7 +1571,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
15711571
else:
15721572
resolved_in_shardings.append(arg_s)
15731573
else:
1574-
assert isinstance(arg_s, sharding.Sharding)
1574+
assert isinstance(arg_s, Sharding)
15751575
if dispatch.is_single_device_sharding(arg_s):
15761576
resolved_in_shardings.append(UNSPECIFIED)
15771577
else:
@@ -1903,7 +1903,7 @@ def _pjit_typecheck(ctx_factory, *in_atoms, jaxpr, **params):
19031903
core.custom_typechecks[pjit_p] = _pjit_typecheck
19041904

19051905

1906-
def _pjit_abstract_eval(*args, jaxpr, **_):
1906+
def _pjit_abstract_eval(*args, jaxpr, out_shardings, **_):
19071907
return jaxpr.out_avals, jaxpr.effects
19081908
pjit_p.def_effectful_abstract_eval(_pjit_abstract_eval)
19091909

@@ -2016,7 +2016,7 @@ def _pjit_batcher(axis_data, vals_in, dims_in,
20162016
batching.ragged_prop_rules[pjit_p] = batching.ragged_mask_no_op_rule
20172017

20182018
def _pjit_batcher_for_sharding(
2019-
s: sharding.Sharding | UnspecifiedValue,
2019+
s: Sharding | UnspecifiedValue,
20202020
dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int):
20212021
if isinstance(s, UnspecifiedValue):
20222022
return s
@@ -2673,6 +2673,74 @@ def _sharding_constraint_batcher(
26732673
batching.fancy_primitive_batchers[sharding_constraint_p] = _sharding_constraint_batcher
26742674
batching.skippable_batchers[sharding_constraint_p] = lambda _: ()
26752675

2676+
# -------------------- sharding_cast ---------------------------
2677+
2678+
def _check_mesh_shape_same(src_sharding, dst_sharding, aval):
2679+
if src_sharding.mesh.shape_tuple != dst_sharding.mesh.shape_tuple:
2680+
raise ValueError(
2681+
f'Mesh shape of the input {src_sharding.mesh.shape_tuple} does not'
2682+
' match the mesh shape of the target sharding'
2683+
f' {dst_sharding.mesh.shape_tuple} for shape {aval.str_short()}')
2684+
2685+
def sharding_cast(xs, shardings):
2686+
if isinstance(shardings, NamedSharding):
2687+
return tree_map(lambda x: sharding_cast_p.bind(
2688+
x, src_sharding=x.sharding, dst_sharding=shardings), xs)
2689+
2690+
x_flat, treedef = tree_flatten(xs)
2691+
shardings_flat = flatten_axes("sharding_cast shardings", treedef, shardings)
2692+
out_flat = [sharding_cast_p.bind(x, src_sharding=x.sharding, dst_sharding=s)
2693+
for x, s in safe_zip(x_flat, shardings_flat)]
2694+
return tree_unflatten(treedef, out_flat)
2695+
2696+
sharding_cast_p = core.Primitive('sharding_cast')
2697+
def _sharding_cast_abstract_eval(aval, src_sharding, dst_sharding):
2698+
_check_mesh_shape_same(src_sharding, dst_sharding, aval)
2699+
return aval.update(sharding=dst_sharding)
2700+
sharding_cast_p.def_abstract_eval(_sharding_cast_abstract_eval)
2701+
2702+
def _sharding_cast_impl(x, src_sharding, dst_sharding):
2703+
aval = shaped_abstractify(x)
2704+
_check_mesh_shape_same(x.sharding, dst_sharding, aval)
2705+
new_mesh = x.sharding.mesh.with_axis_types(dst_sharding.mesh.axis_types)
2706+
concrete_dst_sharding = NamedSharding(new_mesh, dst_sharding.spec)
2707+
# TODO(yashkatariya): Replace this with `dispatch.apply_primitive(...)`
2708+
return api.jit(_identity_fn, out_shardings=concrete_dst_sharding)(x)
2709+
sharding_cast_p.def_impl(_sharding_cast_impl)
2710+
2711+
def _sharding_cast_transpose_rule(ct, _, src_sharding, dst_sharding):
2712+
return [sharding_cast_p.bind(ct, src_sharding=dst_sharding,
2713+
dst_sharding=src_sharding)]
2714+
ad.deflinear2(sharding_cast_p, _sharding_cast_transpose_rule)
2715+
2716+
def _sharding_cast_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding):
2717+
aval, = ctx.avals_in
2718+
aval_out, = ctx.avals_out
2719+
proto = dst_sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
2720+
return [mlir.lower_sharding_under_shit(ctx, x_node, aval_out, proto)]
2721+
mlir.register_lowering(sharding_cast_p, _sharding_cast_hlo_lowering)
2722+
2723+
# TODO(yashkatariya): Comment this in after vmap ShiT tests are added.
2724+
# def _sharding_cast_batcher(axis_data, vals_in, dims_in, src_sharding,
2725+
# dst_sharding):
2726+
# if axis_data.spmd_name is not None:
2727+
# used = {n for ns in dst_sharding.spec
2728+
# for n in (ns if isinstance(ns, tuple) else (ns,))}
2729+
# if set(axis_data.spmd_name) & used:
2730+
# raise ValueError(
2731+
# f'vmap spmd_axis_name {axis_data.spmd_name} cannot '
2732+
# f'appear in sharding_cast spec, but got spec {dst_sharding.spec}')
2733+
# x, = vals_in
2734+
# d, = dims_in
2735+
2736+
# val = None if axis_data.spmd_name is None else axis_data.spmd_name
2737+
# new_spec = PartitionSpec(*util.tuple_insert(dst_sharding.spec, d, val))
2738+
# vmapped_dst_sharding = NamedSharding(dst_sharding.mesh, new_spec)
2739+
# y = sharding_cast_p.bind(x, src_sharding=src_sharding,
2740+
# dst_sharding=vmapped_dst_sharding)
2741+
# return y, d
2742+
# batching.fancy_primitive_batchers[sharding_cast_p] = _sharding_cast_batcher
2743+
# batching.skippable_batchers[sharding_cast_p] = lambda _: ()
26762744

26772745
# -------------------- helpers --------------------
26782746

Diff for: jax/_src/sharding_impls.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,18 @@ def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
6969

7070
@util.cache(max_size=128, trace_context_in_key=False)
7171
def _check_axis_type_consistency(mesh, parsed_pspec):
72-
if mesh.axis_types is None:
73-
return
7472
for p in parsed_pspec:
7573
if p is not None:
7674
if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p):
7775
raise ValueError(
7876
'AxisTypes should be the same in a tuple subset of PartitionSpec:'
7977
f' {parsed_pspec.get_partition_spec()}. Got subset {p} with axis'
8078
f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})')
79+
if mesh_lib.AxisTypes.Auto not in mesh.axis_types and None in parsed_pspec:
80+
raise ValueError(
81+
f'PartitionSpec {parsed_pspec.get_partition_spec()} cannot contain'
82+
' `P.UNCONSTRAINED` when no mesh axis_types are `Auto`. Got mesh'
83+
f' axis_types: {mesh.axis_types}')
8184

8285

8386
def hashed_index(x) -> int:
@@ -271,11 +274,15 @@ def __init__(
271274
self._parsed_pspec = preprocess(self.mesh, self.spec, _parsed_pspec)
272275

273276
def __repr__(self):
274-
mesh_repr = ", ".join(f"'{k}': {v}" for k, v in self.mesh.shape.items())
275277
mem = '' if self.memory_kind is None else f', memory_kind={self.memory_kind}'
276278
ldi = ('' if self._logical_device_ids is None else
277279
f', logical_device_ids={self._logical_device_ids}')
278-
return f'NamedSharding(mesh=Mesh({mesh_repr}), spec={self.spec}{mem}{ldi})'
280+
if isinstance(self.mesh, mesh_lib.AbstractMesh):
281+
mesh_repr = f"{self.mesh}"
282+
else:
283+
nv_str = ", ".join(f"'{n}': {v}" for n, v in self.mesh.shape.items())
284+
mesh_repr = f"Mesh({nv_str})"
285+
return f'NamedSharding(mesh={mesh_repr}, spec={self.spec}{mem}{ldi})'
279286

280287
def __reduce__(self):
281288
return (type(self), (self.mesh, self.spec),
@@ -381,6 +388,9 @@ def with_spec(self, spec: PartitionSpec | Sequence[Any]) -> NamedSharding:
381388
spec = PartitionSpec(*spec)
382389
return NamedSharding(self.mesh, spec, memory_kind=self.memory_kind)
383390

391+
def with_mesh(self, new_mesh: mesh_lib.Mesh) -> NamedSharding:
392+
return NamedSharding(new_mesh, self.spec, memory_kind=self.memory_kind)
393+
384394
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
385395
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
386396

Diff for: jax/experimental/jax2tf/tests/primitives_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def test_primitive_coverage(self):
172172
continue
173173
if p.name == "sharding_constraint":
174174
continue
175+
if p.name == "sharding_cast":
176+
continue
175177
# TODO: Remove once tensorflow is 2.10.0 everywhere.
176178
if p.name == "optimization_barrier":
177179
continue

0 commit comments

Comments
 (0)