From 48ca8876498e602efa951ccca8aaa3c82f0dda73 Mon Sep 17 00:00:00 2001 From: Dougal Maclaurin Date: Fri, 20 Sep 2024 07:12:41 -0700 Subject: [PATCH] Make an AxisData struct that bundles axis name, size, and spmd name. This is just a small cleanup as prep work for stackless. It means fewer arguments to batching functions and less room for argument-order mistakes. PiperOrigin-RevId: 676833615 --- jax/_src/ad_checkpoint.py | 9 +- jax/_src/api.py | 6 +- jax/_src/custom_batching.py | 4 +- jax/_src/custom_derivatives.py | 26 ++-- jax/_src/interpreters/batching.py | 176 ++++++++-------------- jax/_src/lax/control_flow/conditionals.py | 16 +- jax/_src/lax/control_flow/for_loop.py | 13 +- jax/_src/lax/control_flow/loops.py | 39 ++--- jax/_src/lax/control_flow/solves.py | 22 +-- jax/_src/lax/parallel.py | 28 +++- jax/_src/pjit.py | 17 +-- jax/experimental/shard_map.py | 4 +- 12 files changed, 155 insertions(+), 205 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 39df07359c18..477bd1819403 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -701,19 +701,18 @@ def transposed(*args_flat): transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts) return transposed_jaxpr, cell.in_cts_zero # pytype: disable=attribute-error -def remat_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *, - jaxpr, **params): +def remat_vmap(axis_data, main_type, args, dims, *, jaxpr, **params): assert not jaxpr.constvars jaxpr_batched_, out_batched = batching.batch_jaxpr_axes( - pe.close_jaxpr(jaxpr), axis_size, dims, + pe.close_jaxpr(jaxpr), axis_data, dims, [batching.zero_if_mapped] * len(jaxpr.outvars), - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + main_type=main_type) jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts if consts: jaxpr_batched = pe.convert_constvars_jaxpr(jaxpr_batched) out_dims = [0 if b else None for b in out_batched] return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims -batching.axis_primitive_batchers[remat_p] = partial(remat_vmap, None) +batching.axis_primitive_batchers[remat_p] = remat_vmap batching.spmd_axis_primitive_batchers[remat_p] = remat_vmap # TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule diff --git a/jax/_src/api.py b/jax/_src/api.py index bd8a951954ac..d74d2b786f33 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -983,10 +983,10 @@ def vmap_f(*args, **kwargs): axis_size_ = (axis_size if axis_size is not None else _mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap")) try: + axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name) out_flat = batching.batch( - flat_fun, axis_name, axis_size_, in_axes_flat, - lambda: flatten_axes("vmap out_axes", out_tree(), out_axes), - spmd_axis_name=spmd_axis_name + flat_fun, axis_data, in_axes_flat, + lambda: flatten_axes("vmap out_axes", out_tree(), out_axes) ).call_wrapped(*args_flat) except batching.SpecMatchError as e: out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 35e7d33430bd..213cdba70f07 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -139,8 +139,8 @@ def maybe_bdim_at_front(x, bdim): # `f` is pytree-flattened def vmap_unrestricted(f: lu.WrappedFun, *args, in_axes, axis_name, axis_size): f, out_axes = batching.batch_subtrace(f) - f = batching._batch_outer(f, axis_name, axis_size, in_axes, - batching.BatchTrace, None) + axis_data = batching.AxisData(axis_name, axis_size, None) + f = batching._batch_outer(f, axis_data, in_axes, batching.BatchTrace) outs = f.call_wrapped(*args) return outs, out_axes() diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index f5ecdfcda286..d140ea6a8f1b 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -921,18 +921,16 @@ def _custom_vjp_call_jaxpr_jvp( ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp def _custom_vjp_call_jaxpr_vmap( - spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *, + axis_data, main_type, args, in_dims, *, fun_jaxpr: core.ClosedJaxpr, fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]], num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool): args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0 else x for x, d in zip(args, in_dims)] - in_batched = [d is not not_mapped for d in in_dims] _, args_batched = split_list(in_batched, [num_consts]) batched_fun_jaxpr, out_batched = batching.batch_jaxpr( - fun_jaxpr, axis_size, in_batched, False, axis_name, spmd_axis_name, - main_type) + fun_jaxpr, axis_data, in_batched, False, main_type) out_dims1 = [0 if b else not_mapped for b in out_batched] out_dims2 = [] @@ -940,16 +938,14 @@ def _custom_vjp_call_jaxpr_vmap( def batched_fwd_jaxpr_thunk(*zeros): fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( - fwd_jaxpr, axis_size, args_batched, False, axis_name, spmd_axis_name, - main_type) + fwd_jaxpr, axis_data, args_batched, False, main_type) out_dims2.append([0 if b else not_mapped for b in out_batched]) return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts fwd_args_batched = [0 if b else not_mapped for b in args_batched] fwd_out_dims = lambda: out_dims2[0] batched_bwd = batching.batch_custom_vjp_bwd( - bwd, axis_name, axis_size, fwd_out_dims, fwd_args_batched, main_type, - spmd_axis_name) + bwd, axis_data, fwd_out_dims, fwd_args_batched, main_type) batched_outs = custom_vjp_call_jaxpr_p.bind( *args, fun_jaxpr=batched_fun_jaxpr, @@ -959,8 +955,8 @@ def batched_fwd_jaxpr_thunk(*zeros): return batched_outs, out_dims batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = \ _custom_vjp_call_jaxpr_vmap -batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial( - _custom_vjp_call_jaxpr_vmap, None) +batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = \ + _custom_vjp_call_jaxpr_vmap xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p) @@ -1532,7 +1528,7 @@ def _remat_opt_abstract_eval(*args, fwd_jaxpr: core.ClosedJaxpr, **_): return fwd_jaxpr.out_avals, fwd_jaxpr.effects def _remat_opt_vmap( - spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, + axis_data, main_type, args, in_dims, *, num_consts: int, num_res: int, @@ -1544,8 +1540,7 @@ def _remat_opt_vmap( in_batched = [d is not not_mapped for d in in_dims] batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( - fwd_jaxpr, axis_size, in_batched, False, - axis_name, spmd_axis_name, main_type) + fwd_jaxpr, axis_data, in_batched, False, main_type) extra_consts = batched_fwd_jaxpr.consts batched_fwd_jaxpr = pe.close_jaxpr( pe.convert_constvars_jaxpr(batched_fwd_jaxpr.jaxpr)) @@ -1557,8 +1552,7 @@ def _remat_opt_vmap( def batched_fun_jaxpr_thunk(): fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk()) batched_fun_jaxpr, out_batched = batching.batch_jaxpr( - fun_jaxpr, axis_size, prim_batched, False, axis_name, spmd_axis_name, - main_type) + fun_jaxpr, axis_data, prim_batched, False, main_type) return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts batched_outs = remat_opt_p.bind(*extra_consts, *args, @@ -1667,7 +1661,7 @@ def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn): mlir.register_lowering(remat_opt_p, mlir.lower_fun( _remat_opt_impl, multiple_results=True)) batching.spmd_axis_primitive_batchers[remat_opt_p] = _remat_opt_vmap -batching.axis_primitive_batchers[remat_opt_p] = partial(_remat_opt_vmap, None) +batching.axis_primitive_batchers[remat_opt_p] = _remat_opt_vmap ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose pe.dce_rules[remat_opt_p] = _remat_opt_dce diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 27cde6d31d35..7117bb995b8e 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -274,7 +274,7 @@ def _cont(axis_size, elt, axis): return _jumble_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val) else: try: - return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val) + return matchaxis(trace.axis_data.name, axis_size, x_.batch_dim, spec, x_.val) except SpecMatchError: raise SpecMatchError(i, x_.batch_dim, spec) from None from_elt_handlers: dict[type, FromEltHandler] = {} @@ -377,12 +377,20 @@ def get_referent(self): else: # TODO(mattjj): could handle the RaggedAxis case? return self + +@dataclasses.dataclass(frozen=True) +class AxisData: + name : core.AxisName + size : core.AxisSize + spmd_name : Any + + class BatchTrace(Trace): - def __init__(self, *args, axis_name, spmd_axis_name = None): + def __init__(self, *args, axis_data): super().__init__(*args) - self.axis_name = axis_name - self.spmd_axis_name = spmd_axis_name + assert isinstance(axis_data, AxisData) + self.axis_data = axis_data def pure(self, val): return BatchTracer(self, val, not_mapped, source_info_util.current()) @@ -393,36 +401,20 @@ def lift(self, val): def sublift(self, val): return BatchTracer(self, val.val, val.batch_dim, source_info_util.current()) - def get_primitive_batcher(self, primitive, frame): + def get_primitive_batcher(self, primitive): if primitive in primitive_batchers: return primitive_batchers[primitive] - elif self.spmd_axis_name is not None and primitive in spmd_axis_primitive_batchers: + elif self.axis_data.spmd_name is not None and primitive in spmd_axis_primitive_batchers: return partial(spmd_axis_primitive_batchers[primitive], - self.spmd_axis_name, frame.size, frame.name, - frame.main_trace.trace_type) + self.axis_data, self.main.trace_type) elif primitive in axis_primitive_batchers: - return self.get_axis_primitive_batcher(primitive, frame) + return self.get_axis_primitive_batcher(primitive) msg = "Batching rule for '{}' not implemented" raise NotImplementedError(msg.format(primitive)) - def get_axis_primitive_batcher(self, primitive, frame): - return partial(axis_primitive_batchers[primitive], - frame.size, frame.name, frame.main_trace.trace_type) - - def get_frame(self, vals, dims) -> core.AxisEnvFrame: - if any(d is not not_mapped for d in dims): - sizes = (x.shape[d] if type(d) is int else d.size - for x, d in zip(vals, dims) if d is not not_mapped) - axis_size, = core.dedup_referents(sizes) - else: - axis_size = None # can't be inferred from data - if self.axis_name is core.no_axis_name: - assert axis_size is not None # must be inferable from data - return core.AxisEnvFrame(self.axis_name, axis_size, self.main) - frame = core.axis_frame(self.axis_name, self.main) - assert axis_size is None or axis_size == frame.size, (axis_size, frame.size) - assert frame.main_trace is self.main - return frame + def get_axis_primitive_batcher(self, primitive): + return partial(axis_primitive_batchers[primitive], self.axis_data, + self.main.trace_type) def process_primitive(self, primitive, tracers, params): if config.dynamic_shapes.value: @@ -431,14 +423,12 @@ def process_primitive(self, primitive, tracers, params): is_axis_primitive = primitive in axis_primitive_batchers used_names = core.used_axis_names(primitive, params) if is_axis_primitive and _main_trace_for_axis_names(self.main, used_names): - frame = self.get_frame(vals_in, dims_in) - batcher_primitive = self.get_axis_primitive_batcher(primitive, frame) + batcher_primitive = self.get_axis_primitive_batcher(primitive) val_out, dim_out = batcher_primitive(vals_in, dims_in, **params) elif all(bdim is not_mapped for bdim in dims_in): return primitive.bind(*vals_in, **params) else: - frame = self.get_frame(vals_in, dims_in) - batched_primitive = self.get_primitive_batcher(primitive, frame) + batched_primitive = self.get_primitive_batcher(primitive) val_out, dim_out = batched_primitive(vals_in, dims_in, **params) src = source_info_util.current() if primitive.multiple_results: @@ -452,13 +442,10 @@ def process_call(self, call_primitive, f, tracers, params): vals, dims = unzip2((t.val, t.batch_dim) for t in tracers) if all(bdim is not_mapped for bdim in dims): return call_primitive.bind(f, *vals, **params) - sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths) - for x, d in zip(vals, dims) if d is not not_mapped) - axis_size, = core.dedup_referents(sizes) segment_lens, dims = indirectify_ragged_axes(dims) f_, dims_out = batch_subtrace(f, self.main, tuple(dims)) f_ = _update_annotation( - f_, f.in_type, axis_size, self.axis_name, dims, segment_lens) + f_, f.in_type, self.axis_data.size, self.axis_data.name, dims, segment_lens) vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params) vals_out, dims_out = resolve_ragged_axes(vals_out, dims_out()) src = source_info_util.current() @@ -561,14 +548,10 @@ def todo(vals): def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, *, out_trees, symbolic_zeros): # pytype: disable=signature-mismatch in_vals, in_dims = unzip2((t.val, t.batch_dim) for t in tracers) - axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) - if d is not not_mapped} fwd_in_dims = [d for in_dim in in_dims for d in [in_dim, not_mapped]] fun, out_dims1 = batch_subtrace(fun, self.main, in_dims) fwd, out_dims2 = batch_subtrace(fwd, self.main, fwd_in_dims) - bwd = batch_custom_vjp_bwd(bwd, self.axis_name, axis_size, - out_dims2, in_dims, self.main.trace_type, - self.spmd_axis_name) + bwd = batch_custom_vjp_bwd(bwd, self.axis_data, out_dims2, in_dims, self.main.trace_type) out_vals = prim.bind(fun, fwd, bwd, *in_vals, out_trees=out_trees, symbolic_zeros=symbolic_zeros) fst, out_dims = lu.merge_linear_aux(out_dims1, out_dims2) @@ -590,9 +573,7 @@ def todo(vals): def post_process_custom_vjp_call_fwd(self, out_tracers, out_trees): vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info) for t in out_tracers) - axis_size, = {x.shape[d] for x, d in zip(vals, dims) if d is not not_mapped} main, trace_type = self.main, self.main.trace_type - axis_name = self.axis_name _, res_tree = out_trees() num_res = res_tree.num_leaves res_dims, primal_dims = split_list(dims, [num_res]) @@ -601,8 +582,7 @@ def todo(vals): trace = main.with_cur_sublevel() return map(partial(BatchTracer, trace), vals, primal_dims, primal_srcs) def bwd_transform(bwd): - return batch_custom_vjp_bwd(bwd, axis_name, axis_size, dims, (None,), - trace_type, self.spmd_axis_name) + return batch_custom_vjp_bwd(bwd, self.axis_data, dims, (None,), trace_type) return vals, todo, bwd_transform def _main_trace_for_axis_names(main_trace: core.MainTrace, @@ -615,36 +595,32 @@ def _main_trace_for_axis_names(main_trace: core.MainTrace, ### API for batching callables with vmappable inputs and outputs -def batch(fun: lu.WrappedFun, axis_name: AxisName, axis_size, +def batch(fun: lu.WrappedFun, axis_data: AxisData, in_dims, out_dim_dests, main_type: type[BatchTrace] = BatchTrace, - spmd_axis_name: tuple[AxisName, ...] | None = None ) -> lu.WrappedFun: # we split up _batch_inner and _batch_outer for the leak checker - f = _batch_inner(fun, axis_size, out_dim_dests) - return _batch_outer(f, axis_name, axis_size, in_dims, main_type, - spmd_axis_name) + f = _batch_inner(fun, axis_data, out_dim_dests) + return _batch_outer(f, axis_data, in_dims, main_type) @lu.transformation -def _batch_outer(axis_name, axis_size, in_dims, main_type, spmd_axis_name, - *in_vals): - with core.new_main( - main_type, axis_name=axis_name, spmd_axis_name=spmd_axis_name) as main: - with core.extend_axis_env(axis_name, axis_size, main): +def _batch_outer(axis_data, in_dims, main_type, *in_vals): + with core.new_main(main_type, axis_data=axis_data) as main: + with core.extend_axis_env(axis_data.name, axis_data.size, main): with source_info_util.transform_name_stack('vmap'): outs = yield (main, in_dims, *in_vals), {} del main yield outs @lu.transformation -def _batch_inner(axis_size, out_dim_dests, main, in_dims, *in_vals): +def _batch_inner(axis_data, out_dim_dests, main, in_dims, *in_vals): in_dims = in_dims() if callable(in_dims) else in_dims trace = main.with_cur_sublevel() - idx = memoize(lambda: BatchTracer(trace, make_iota(axis_size), 0, + idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0, source_info_util.current())) in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) outs = yield in_tracers, {} out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests - out_vals = map(partial(from_elt, trace, axis_size), range(len(outs)), + out_vals = map(partial(from_elt, trace, axis_data.size), range(len(outs)), outs, out_dim_dests) yield out_vals @@ -678,8 +654,9 @@ def _map_to_tile(*args_flat): outputs_flat = yield map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat), {} yield map(untile_axis, outputs_flat, out_axes_flat) + axis_data = AxisData(axis_name, tile_size, None) return _map_to_tile(batch( - f_flat, axis_name, tile_size, in_axes_flat, out_axes_flat, main_type=main_type)) + f_flat, axis_data, in_axes_flat, out_axes_flat, main_type=main_type)) ### API for batching functions with jaxpr type inputs and outputs @@ -765,10 +742,8 @@ def fetch(idx): # Can reuse same pattern for all dynamic shape stuff. def batch_jaxpr2( closed_jaxpr: core.ClosedJaxpr, - axis_size: core.AxisSize, + axis_data: AxisData, in_axes: tuple[int | NotMapped | RaggedAxis, ...], - axis_name: AxisName, - spmd_axis_name: AxisName, main_type: type[BatchTrace], ) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped | RaggedAxis, ...]]: # This is only ever used in pjit. The difference vs batch_jaxpr is that @@ -776,27 +751,23 @@ def batch_jaxpr2( # their batch axes are; whereas batch_jaxpr has to obey caller-imposed # consistency constraints, such as type-agreement across arms of a # `lax.cond`, or input-output agreement for the body of a `lax.scan`. - return _batch_jaxpr2(closed_jaxpr, axis_size, tuple(in_axes), axis_name, - spmd_axis_name, main_type) + return _batch_jaxpr2(closed_jaxpr, axis_data, tuple(in_axes), main_type) @weakref_lru_cache def _batch_jaxpr2( closed_jaxpr: core.ClosedJaxpr, - axis_size: core.AxisSize, + axis_data: AxisData, in_axes: tuple[int | NotMapped | RaggedAxis, ...], - axis_name: AxisName, - spmd_axis_name: AxisName, main_type: type[BatchTrace], ) -> tuple[core.ClosedJaxpr, tuple[int | NotMapped, ...]]: f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr)) - f, out_axes = _batch_jaxpr_inner(f, axis_size) - f = _batch_jaxpr_outer(f, axis_name, spmd_axis_name, axis_size, in_axes, - main_type) + f, out_axes = _batch_jaxpr_inner(f, axis_data) + f = _batch_jaxpr_outer(f, axis_data, in_axes, main_type) in_axes2, avals_in = unzip2([ handle_ragged(closed_jaxpr.in_avals, dim, aval) if isinstance(dim, RaggedAxis) else (dim, aval) for dim, aval in zip(in_axes, closed_jaxpr.in_avals)]) - avals_in2 = [core.unmapped_aval(axis_size, axis_name, b, aval) + avals_in2 = [core.unmapped_aval(axis_data.size, axis_data.name, b, aval) if b is not not_mapped else aval for aval, b in unsafe_zip(avals_in, in_axes2)] jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2) @@ -810,14 +781,11 @@ def handle_ragged(in_avals: list[core.AbstractValue], dim: RaggedAxis, new_aval = aval.update(shape=tuple(new_shape)) return dim.stacked_axis, new_aval -def batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name, - spmd_axis_name, main_type): +def batch_jaxpr(closed_jaxpr, axis_data, in_batched, instantiate, main_type): inst = tuple(instantiate) if isinstance(instantiate, list) else instantiate - return _batch_jaxpr(closed_jaxpr, axis_size, tuple(in_batched), inst, - axis_name, spmd_axis_name, main_type) + return _batch_jaxpr(closed_jaxpr, axis_data, tuple(in_batched), inst, main_type) -def _batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name, - spmd_axis_name, main_type): +def _batch_jaxpr(closed_jaxpr, axis_data, in_batched, instantiate, main_type): assert (isinstance(instantiate, bool) or isinstance(instantiate, (list, tuple)) and all(isinstance(b, bool) for b in instantiate)) @@ -825,30 +793,25 @@ def _batch_jaxpr(closed_jaxpr, axis_size, in_batched, instantiate, axis_name, instantiate = [instantiate] * len(closed_jaxpr.out_avals) in_axes = [0 if b else not_mapped for b in in_batched] out_axes_dest = [0 if inst else zero_if_mapped for inst in instantiate] - return batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, - axis_name, spmd_axis_name, main_type) + return batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest, main_type) -def batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, axis_name, - spmd_axis_name, main_type): - return _batch_jaxpr_axes(closed_jaxpr, axis_size, tuple(in_axes), - tuple(out_axes_dest), axis_name, spmd_axis_name, - main_type) +def batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest, main_type): + return _batch_jaxpr_axes(closed_jaxpr, axis_data, tuple(in_axes), + tuple(out_axes_dest), main_type) @weakref_lru_cache -def _batch_jaxpr_axes(closed_jaxpr, axis_size, in_axes, out_axes_dest, - axis_name, spmd_axis_name, main_type): +def _batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest, main_type): f = lu.wrap_init(core.jaxpr_as_fun(closed_jaxpr)) - f, out_axes = _batch_jaxpr_inner(f, axis_size) - f, out_batched = _match_axes_jaxpr(f, axis_size, out_axes_dest, out_axes) - f = _batch_jaxpr_outer(f, axis_name, spmd_axis_name, axis_size, in_axes, - main_type) - avals_in = [core.unmapped_aval(axis_size, axis_name, b, aval) if b is not not_mapped + f, out_axes = _batch_jaxpr_inner(f, axis_data) + f, out_batched = _match_axes_jaxpr(f, axis_data, out_axes_dest, out_axes) + f = _batch_jaxpr_outer(f, axis_data, in_axes, main_type) + avals_in = [core.unmapped_aval(axis_data.size, axis_data.name, b, aval) if b is not not_mapped else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)] jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in) return core.ClosedJaxpr(jaxpr_out, consts), out_batched() @lu.transformation_with_aux -def _batch_jaxpr_inner(axis_size, main, in_axes, *in_vals): +def _batch_jaxpr_inner(axis_data, main, in_axes, *in_vals): trace = main.with_cur_sublevel() _, in_axes = resolve_ragged_axes(in_vals, in_axes) in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val @@ -861,7 +824,7 @@ def _batch_jaxpr_inner(axis_size, main, in_axes, *in_vals): yield out_vals, new_out_axes @lu.transformation_with_aux -def _match_axes_jaxpr(axis_size, out_axes_dest, out_axes, main, in_axes, +def _match_axes_jaxpr(axis_data, out_axes_dest, out_axes, main, in_axes, *in_vals): trace = main.with_cur_sublevel() out_vals = yield (main, in_axes, *in_vals), {} @@ -872,22 +835,18 @@ def _match_axes_jaxpr(axis_size, out_axes_dest, out_axes, main, in_axes, if len(out_axes_dest) != len(out_axes): out_axis_dest, = out_axes_dest out_axes_dest = [out_axis_dest] * len(out_axes) - out_vals = map(partial(matchaxis, trace.axis_name, axis_size), + out_vals = map(partial(matchaxis, axis_data.name, axis_data.size), out_axes, out_axes_dest, out_vals) out_batched = [dst is not None for dst in out_axes_dest] yield out_vals, out_batched @lu.transformation -def _batch_jaxpr_outer(axis_name, spmd_axis_name, axis_size, in_dims, main_type, - *in_vals): - if axis_size is None: - axis_size, = {x.shape[d] for x, d in zip(in_vals, in_dims) if d is not not_mapped} +def _batch_jaxpr_outer(axis_data, in_dims, main_type, *in_vals): in_dims = in_dims() if callable(in_dims) else in_dims in_dims = [canonicalize_axis(ax, np.ndim(x)) if isinstance(ax, int) else ax for x, ax in unsafe_zip(in_vals, in_dims)] - with core.new_main(main_type, axis_name=axis_name, - spmd_axis_name=spmd_axis_name) as main: - with core.extend_axis_env(axis_name, axis_size, main): + with core.new_main(main_type, axis_data=axis_data) as main: + with core.extend_axis_env(axis_data.name, axis_data.size, main): out_vals = yield (main, in_dims, *in_vals), {} del main yield out_vals @@ -909,8 +868,7 @@ class ZeroIfMapped: pass @lu.transformation_with_aux def batch_custom_jvp_subtrace(main, in_dims, *in_vals): - size, = {x.shape[d] for x, d in zip(in_vals, in_dims * 2) - if d is not not_mapped} + size = main.payload['axis_data'].size trace = main.with_cur_sublevel() in_tracers = [val if dim is None else SymbolicZero(core.mapped_aval(size, dim, val.aval)) @@ -925,25 +883,23 @@ def batch_custom_jvp_subtrace(main, in_dims, *in_vals): out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2]) out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2]) out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds) - out_primals = map(partial(matchaxis, trace.axis_name, size), + out_primals = map(partial(matchaxis, trace.axis_data.name, size), out_primal_bds, out_dims, out_primals) - out_tangents = map(partial(matchaxis, trace.axis_name, size), + out_tangents = map(partial(matchaxis, trace.axis_data.name, size), out_tangent_bds, out_dims, out_tangents) yield out_primals + out_tangents, out_dims * 2 -def batch_custom_vjp_bwd(bwd, axis_name, axis_size, in_dims, out_dim_dests, - main_type, spmd_axis_name): +def batch_custom_vjp_bwd(bwd, axis_data, in_dims, out_dim_dests, main_type): def new_bwd(*args): in_dims_ = in_dims() if callable(in_dims) else in_dims - args = [SymbolicZero(core.mapped_aval(axis_size, dim, x.aval)) + args = [SymbolicZero(core.mapped_aval(axis_data.size, dim, x.aval)) if type(x) is SymbolicZero else x for x, dim in zip(args, in_dims_)] in_dims_ = [None if type(x) is SymbolicZero else d for x, d in zip(args, in_dims_)] bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd)) - bwd_ = _batch_outer(bwd_, axis_name, axis_size, in_dims_, main_type, - spmd_axis_name) - bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk, + bwd_ = _batch_outer(bwd_, axis_data, in_dims_, main_type) + bwd_ = _match_axes_and_sum(bwd_, axis_data.size, axis_data.name, out_dims_thunk, out_dim_dests) return bwd_.call_wrapped(*args) return new_bwd diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index d3065d0f96d7..e855afc75b16 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -352,8 +352,7 @@ def _bcast_select_n(pred, *cases): pred = lax.broadcast_in_dim(pred, np.shape(cases[0]), idx) return lax.select_n(pred, *cases) -def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, - dims, branches): +def _cond_batching_rule(axis_data, main_type, args, dims, branches): index, *ops = args index_dim, *op_dims = dims # TODO(sharadmv): clean this up by adding a specific blocklist @@ -375,15 +374,14 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, # optimizations to XLA. # TODO(mattjj,frostig): assumes branches are side-effect-free, revise! index, *ops = ( - batching.bdim_at_front(x, d, axis_size) for x, d in zip(args, dims)) + batching.bdim_at_front(x, d, axis_data.size) for x, d in zip(args, dims)) in_batched = [True] * len(branches[0].in_avals) out_batched = [True] * len(branches[0].out_avals) branches_batched = [ batching.batch_jaxpr( - jaxpr, axis_size, in_batched, out_batched, axis_name, spmd_axis_name, - main_type)[0] + jaxpr, axis_data, in_batched, out_batched, main_type)[0] for jaxpr in branches] branch_outs = [] @@ -401,13 +399,11 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, for b, x, d in zip(ops_bat, ops, op_dims)] branches_out_bat = [ - batching.batch_jaxpr(jaxpr, axis_size, ops_bat, False, axis_name, - spmd_axis_name, main_type)[1] + batching.batch_jaxpr(jaxpr, axis_data, ops_bat, False, main_type)[1] for jaxpr in branches] out_bat = [any(bat) for bat in zip(*branches_out_bat)] branches_batched = tuple( - batching.batch_jaxpr(jaxpr, axis_size, ops_bat, out_bat, axis_name, - spmd_axis_name, main_type)[0] + batching.batch_jaxpr(jaxpr, axis_data, ops_bat, out_bat, main_type)[0] for jaxpr in branches) out_dims = [0 if b else batching.not_mapped for b in out_bat] @@ -811,7 +807,7 @@ def cond_bind(*args, branches): ad.reducing_transposes[cond_p] = _cond_transpose pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval batching.spmd_axis_primitive_batchers[cond_p] = _cond_batching_rule -batching.axis_primitive_batchers[cond_p] = partial(_cond_batching_rule, None) +batching.axis_primitive_batchers[cond_p] = _cond_batching_rule xla.register_initial_style_primitive(cond_p) core.custom_typechecks[cond_p] = partial(_cond_typecheck, False) core.axis_substitution_rules[cond_p] = _cond_axis_substitution diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 21b522b3d8bb..2dbcb8ca9171 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -278,7 +278,7 @@ def _cached_for_jaxpr(jaxpr): discharged_jaxpr, body_consts = discharge_state(jaxpr, ()) return core.ClosedJaxpr(discharged_jaxpr, body_consts) -def _for_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *, +def _for_vmap(axis_data, main_type, args, dims, *, jaxpr, nsteps, reverse, which_linear, unroll): init_batched = [d is not batching.not_mapped for d in dims] closed_jaxpr = _cached_for_jaxpr(jaxpr) @@ -286,25 +286,24 @@ def _for_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *, for _ in range(len(batched)): _, out_batched = batching.batch_jaxpr( closed_jaxpr, - axis_size, [False] + batched, instantiate=batched, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + axis_data, [False] + batched, instantiate=batched, main_type=main_type) if out_batched == batched: break batched = map(operator.or_, batched, out_batched) else: raise Exception("Invalid fixpoint") - args = [batching.broadcast(x, axis_size, 0) if now_bat and not was_bat + args = [batching.broadcast(x, axis_data.size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat else x for x, d, was_bat, now_bat in zip(args, dims, init_batched, batched)] batched_jaxpr_, _ = batching.batch_jaxpr( - pe.close_jaxpr(jaxpr), axis_size, [False] + batched, [], - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + pe.close_jaxpr(jaxpr), axis_data, [False] + batched, [], + main_type=main_type) batched_jaxpr, () = batched_jaxpr_.jaxpr, batched_jaxpr_.consts # TODO consts out_flat = for_p.bind(*args, jaxpr=batched_jaxpr, nsteps=nsteps, reverse=reverse, which_linear=which_linear, unroll=unroll) return out_flat, [0 if b else batching.not_mapped for b in batched] -batching.axis_primitive_batchers[for_p] = functools.partial(_for_vmap, None) +batching.axis_primitive_batchers[for_p] = _for_vmap batching.spmd_axis_primitive_batchers[for_p] = _for_vmap def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear, diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 7a9596bf2c0d..2c39ee25da7d 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -885,7 +885,7 @@ def transposed(*res1_cbar_bbar_res2): b_ys_avals_stripped + res2_avals)) -def _scan_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, +def _scan_batching_rule(axis_data, main_type, args, dims, reverse, length, jaxpr, num_consts, num_carry, linear, unroll, _split_transpose): @@ -902,10 +902,8 @@ def _scan_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, for _ in range(1 + len(carry_batched)): batched = const_batched + carry_batched + xs_batched jaxpr_batched, batched_out = batching.batch_jaxpr( - jaxpr, axis_size, batched, + jaxpr, axis_data, batched, instantiate=carry_batched + [False] * num_ys, - axis_name=axis_name, - spmd_axis_name=spmd_axis_name, main_type=main_type) carry_batched_out, ys_batched = batched_out[:num_carry], batched_out[num_carry:] if carry_batched_out == carry_batched: @@ -919,7 +917,7 @@ def _scan_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry]) new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0 else x for x, d in zip(consts, consts_bdims)] - new_init = [batching.broadcast(x, axis_size, 0) if now_batched and not was_batched + new_init = [batching.broadcast(x, axis_data.size, 0) if now_batched and not was_batched else batching.moveaxis(x, d, 0) if now_batched else x for x, d, was_batched, now_batched in zip(init, init_bdims, init_batched, carry_batched)] @@ -1228,7 +1226,7 @@ def scan_bind(*args, **params): xla.register_initial_style_primitive(scan_p) mlir.register_lowering(scan_p, mlir.lower_fun(_scan_impl, multiple_results=True)) -batching.axis_primitive_batchers[scan_p] = partial(_scan_batching_rule, None) +batching.axis_primitive_batchers[scan_p] = _scan_batching_rule batching.spmd_axis_primitive_batchers[scan_p] = _scan_batching_rule core.custom_typechecks[scan_p] = partial(_scan_typecheck, False) pe.partial_eval_jaxpr_custom_rules[scan_p] = _scan_partial_eval_custom @@ -1382,7 +1380,7 @@ def _while_loop_abstract_eval(*avals, cond_jaxpr, body_jaxpr, body_nconsts, return _map(raise_to_shaped, body_jaxpr.out_avals), joined_effects -def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, +def _while_loop_batching_rule(axis_data, main_type, args, dims, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): from jax._src.callback import _IOEffect, _OrderedIOEffect @@ -1401,8 +1399,8 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, # reach a fixpoint. for _ in range(1 + len(carry_bat)): _, carry_bat_out = batching.batch_jaxpr( - body_jaxpr, axis_size, bconst_bat + carry_bat, instantiate=carry_bat, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + body_jaxpr, axis_data, bconst_bat + carry_bat, instantiate=carry_bat, + main_type=main_type) if carry_bat == carry_bat_out: break carry_bat = safe_map(operator.or_, carry_bat, carry_bat_out) @@ -1412,8 +1410,8 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, # Knowing how the carry is batched now, we can determine if the predicate is # batched. _, (pred_bat,) = batching.batch_jaxpr( - cond_jaxpr, axis_size, cconst_bat + carry_bat, instantiate=False, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + cond_jaxpr, axis_data, cconst_bat + carry_bat, instantiate=False, + main_type=main_type) if pred_bat: # If the predicate is batched, we have to batch *all* of the carry @@ -1424,13 +1422,9 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, carry_bat = [True] * len(carry_bat) carry_dims = [0] * len(carry_bat) body_jaxpr_batched, _ = batching.batch_jaxpr_axes( - body_jaxpr, axis_size, bconst_dims + carry_dims, - carry_dims, axis_name=axis_name, spmd_axis_name=spmd_axis_name, - main_type=main_type) + body_jaxpr, axis_data, bconst_dims + carry_dims, carry_dims, main_type=main_type) cond_jaxpr_batched, _ = batching.batch_jaxpr_axes( - cond_jaxpr, axis_size, cconst_dims + carry_dims, [0], - axis_name=axis_name, spmd_axis_name=spmd_axis_name, - main_type=main_type) + cond_jaxpr, axis_data, cconst_dims + carry_dims, [0], main_type=main_type) else: # If the predicate is not batched, we can look at the `cond_jaxpr`'s out # shape to determine the rank of the predicate. From this rank we pick the @@ -1440,13 +1434,12 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, cond_rank = len(cond_jaxpr.out_avals[0].shape) carry_dims = [cond_rank if b else None for b in carry_bat] body_jaxpr_batched, _ = batching.batch_jaxpr_axes( - body_jaxpr, axis_size, bconst_dims + carry_dims, carry_dims, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + body_jaxpr, axis_data, bconst_dims + carry_dims, carry_dims, + main_type=main_type) # Now we need to rebatch the `cond_jaxpr` according to the new dims of the # carry. cond_jaxpr_batched, _ = batching.batch_jaxpr_axes( - cond_jaxpr, axis_size, cconst_dims + carry_dims, (None,), - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + cond_jaxpr, axis_data, cconst_dims + carry_dims, (None,), main_type=main_type) # To prepare the `init` to the `while_p`, we broadcast values if they are # unbatched and need to have an out axis. If their current batch axis does not @@ -1455,7 +1448,7 @@ def _while_loop_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, new_init = [] for x, old_axis, new_axis in zip(init, init_dims, carry_dims): if old_axis is batching.not_mapped and new_axis is not batching.not_mapped: - new_init.append(batching.broadcast(x, axis_size, new_axis)) + new_init.append(batching.broadcast(x, axis_data.size, new_axis)) elif old_axis is batching.not_mapped and new_axis is batching.not_mapped: new_init.append(x) else: @@ -1899,7 +1892,7 @@ def new_cond(*consts_refs_carry): pe.custom_partial_eval_rules[while_p] = _while_partial_eval xla.register_initial_style_primitive(while_p) ad.primitive_transposes[while_p] = _while_transpose_error -batching.axis_primitive_batchers[while_p] = partial(_while_loop_batching_rule, None) +batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule batching.spmd_axis_primitive_batchers[while_p] = _while_loop_batching_rule pe.partial_eval_jaxpr_custom_rules[while_p] = _while_partial_eval_custom mlir.register_lowering(while_p, _while_lowering) diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 4e0f5086b121..45f4bf405294 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -376,7 +376,7 @@ def _linear_solve_transpose_rule(cotangent, *primals, const_lengths, jaxprs): return [None] * sum(const_lengths) + cotangent_b -def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, +def _linear_solve_batching_rule(axis_data, main_type, args, dims, const_lengths, jaxprs): orig_bat = [d is not batching.not_mapped for d in dims] @@ -397,15 +397,15 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, for i in range(1 + len(orig_b_bat) + len(solve.out_avals)): # Apply vecmat and solve -> new batched parts of x solve_jaxpr_batched, solve_x_bat = batching.batch_jaxpr( - solve, axis_size, solve_bat + b_bat, instantiate=x_bat, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + solve, axis_data, solve_bat + b_bat, instantiate=x_bat, + main_type=main_type) if vecmat is None: vecmat_jaxpr_batched = None x_bat_out = solve_x_bat else: vecmat_jaxpr_batched, vecmat_x_bat = batching.batch_jaxpr( - vecmat, axis_size, vecmat_bat + b_bat, instantiate=b_bat, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + vecmat, axis_data, vecmat_bat + b_bat, instantiate=b_bat, + main_type=main_type) # batch all aux data by default x_bat_out = _map(operator.or_, vecmat_x_bat + [True] * num_aux, solve_x_bat) # keep a slice of only the linear operator part of solve's avals @@ -413,15 +413,15 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, # Apply matvec and solve_t -> new batched parts of b matvec_jaxpr_batched, matvec_b_bat = batching.batch_jaxpr( - matvec, axis_size, matvec_bat + x_bat_noaux, instantiate=b_bat, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + matvec, axis_data, matvec_bat + x_bat_noaux, instantiate=b_bat, + main_type=main_type) if solve_t is None: solve_t_jaxpr_batched = None b_bat_out = _map(operator.or_, matvec_b_bat, orig_b_bat) else: solve_t_jaxpr_batched, solve_t_b_aux_bat = batching.batch_jaxpr( - solve_t, axis_size, solve_t_bat + x_bat_noaux, instantiate=x_bat_out, - axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type) + solve_t, axis_data, solve_t_bat + x_bat_noaux, instantiate=x_bat_out, + main_type=main_type) assert len(solve_t_b_aux_bat) == len(orig_b_bat) + num_aux solve_t_b_bat, _ = split_list(solve_t_b_aux_bat, [len(orig_b_bat)]) b_bat_out = _map(lambda m, s, o: m or s or o, matvec_b_bat, solve_t_b_bat, @@ -445,7 +445,7 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, ] # Broadcast out b if necessary new_b = [ - batching.broadcast(x, axis_size, 0) if now_bat and not was_bat else + batching.broadcast(x, axis_data.size, 0) if now_bat and not was_bat else batching.moveaxis(x, d, 0) if now_bat and d != 0 else x for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat) ] @@ -468,5 +468,5 @@ def _linear_solve_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl, multiple_results=True)) ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule -batching.axis_primitive_batchers[linear_solve_p] = partial(_linear_solve_batching_rule, None) +batching.axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule batching.spmd_axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index c9a07072ddc7..d6ea0e23accb 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -536,8 +536,10 @@ def _reduction_batcher(prim, vals_in, dims_in, *, axes, axis_index_groups): return vals_out, [d if d is batching.not_mapped else 0 for d in dims_in] def _batched_reduction_collective( - prim, if_unmapped, axis_size, frame_name, _, vals_in, dims_in, axes, + prim, if_unmapped, axis_data, _, vals_in, dims_in, axes, axis_index_groups): + axis_size = axis_data.size + frame_name = axis_data.name assert prim.multiple_results assert frame_name in axes # Note that we have a choice here. We can either unfuse the reduction into one @@ -765,7 +767,9 @@ def _ppermute_transpose_rule(t, x, perm, axis_name): inverse_perm = list(zip(dsts, srcs)) return [ppermute(t, axis_name=axis_name, perm=inverse_perm)] -def _ppermute_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, perm): +def _ppermute_batcher(axis_data, _, vals_in, dims_in, axis_name, perm): + axis_size = axis_data.size + frame_name = axis_data.name (v,), (d,) = vals_in, dims_in if not isinstance(axis_name, (tuple, list)): axis_name = (axis_name,) @@ -799,7 +803,9 @@ def _pbroadcast_transpose_rule(t, x, source, axis_name): tsum = psum(t, axis_name) return [lax.select(is_source, lax.full_like(t, tsum), lax.full_like(t, 0))] -def _pbroadcast_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, source): +def _pbroadcast_batcher(axis_data, _, vals_in, dims_in, axis_name, source): + axis_size = axis_data.size + frame_name = axis_data.name (v,), (d,) = vals_in, dims_in if not isinstance(axis_name, (tuple, list)): axis_name = (axis_name,) @@ -914,9 +920,11 @@ def _all_to_all_batcher(vals_in, dims_in, *, axis_name, split_axis, concat_axis, ) return result, d -def _all_to_all_batched_collective(axis_size, frame_name, _, vals_in, dims_in, +def _all_to_all_batched_collective(axis_data, _, vals_in, dims_in, axis_name, split_axis, concat_axis, axis_index_groups, tiled): + axis_size = axis_data.size + frame_name = axis_data.name if axis_index_groups is not None: raise NotImplementedError("Please open a feature request!") x, = vals_in @@ -1157,9 +1165,11 @@ def _all_gather_batcher(vals_in, dims_in, *, all_gather_dimension, axis_name, ax tiled=tiled) return result, d -def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in, +def _all_gather_batched_collective(axis_data, _, vals_in, dims_in, all_gather_dimension, axis_name, axis_index_groups, axis_size, tiled): + frame_size = axis_data.size + frame_name = axis_data.name if axis_index_groups is not None: raise NotImplementedError("axis_index_groups not supported in vmap") assert axis_size == frame_size, "axis size doesn't match" @@ -1289,9 +1299,11 @@ def _reduce_scatter_batcher(vals_in, dims_in, *, scatter_dimension, axis_name, tiled=tiled) return result, d -def _reduce_scatter_collective(frame_size, frame_name, _, vals_in, dims_in, +def _reduce_scatter_collective(axis_data, _, vals_in, dims_in, scatter_dimension, axis_name, axis_index_groups, axis_size, tiled): + frame_size = axis_data.size + frame_name = axis_data.name if axis_index_groups is not None: raise NotImplementedError("axis_index_groups not supported in vmap") assert axis_size == frame_size, "axis size doesn't match" @@ -1536,7 +1548,9 @@ def _pgather_batcher(vals_in, dims_in, *, axes): else: assert False # This shouldn't get called anyway -def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, axes): +def _pgather_collective_batcher(axis_data, _, vals_in, dims_in, *, axes): + axis_size = axis_data.size + frame_name = axis_data.name src, idx = vals_in dsrc, didx = dims_in if dsrc is batching.not_mapped: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index ac1318ed7810..59a1610745db 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1959,14 +1959,13 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, mlir.register_lowering(pjit_p, _pjit_lowering) -def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type, +def _pjit_batcher(axis_data, main_type, vals_in, dims_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline): segment_lens, dims_in = batching.indirectify_ragged_axes(dims_in) new_jaxpr, axes_out = batching.batch_jaxpr2( - jaxpr, axis_size, dims_in, axis_name=axis_name, - spmd_axis_name=spmd_axis_name, main_type=main_type) + jaxpr, axis_data, dims_in, main_type=main_type) if resource_env is not None: mesh = resource_env.physical_mesh @@ -1975,11 +1974,11 @@ def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type, # TODO(axch): prepend with Nones (?) to account for new segment_lens inputs in_shardings = tuple( - _pjit_batcher_for_sharding(i, axis_in, spmd_axis_name, mesh, aval.ndim) + _pjit_batcher_for_sharding(i, axis_in, axis_data.spmd_name, mesh, aval.ndim) if axis_in is not None else i for axis_in, i, aval in zip(dims_in, in_shardings, new_jaxpr.in_avals)) out_shardings = tuple( - _pjit_batcher_for_sharding(o, axis_out, spmd_axis_name, mesh, aval.ndim) + _pjit_batcher_for_sharding(o, axis_out, axis_data.spmd_name, mesh, aval.ndim) if axis_out is not None else o for axis_out, o, aval in zip(axes_out, out_shardings, new_jaxpr.out_avals)) # TODO(yashkatariya): Figure out layouts should change under vmap. @@ -2006,7 +2005,7 @@ def _pjit_batcher(spmd_axis_name, axis_size, axis_name, main_type, return vals_out, resolved_axes_out batching.spmd_axis_primitive_batchers[pjit_p] = _pjit_batcher -batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, None) +batching.axis_primitive_batchers[pjit_p] = _pjit_batcher def _pjit_batcher_for_sharding( s: sharding.Sharding | UnspecifiedValue, @@ -2558,8 +2557,9 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, layout, def _sharding_constraint_batcher( - spmd_axis_name, axis_size, axis_name, main_type, vals_in, + axis_data, main_type, vals_in, dims_in, sharding, layout, resource_env, unconstrained_dims): + spmd_axis_name = axis_data.spmd_name if spmd_axis_name is not None and isinstance(sharding, NamedSharding): used = {n for ns in sharding.spec for n in (ns if isinstance(ns, tuple) else (ns,))} @@ -2597,8 +2597,7 @@ def _sharding_constraint_batcher( unconstrained_dims=unconstrained_dims) return y, d batching.spmd_axis_primitive_batchers[sharding_constraint_p] = _sharding_constraint_batcher -batching.axis_primitive_batchers[sharding_constraint_p] = partial( - _sharding_constraint_batcher, None) +batching.axis_primitive_batchers[sharding_constraint_p] = _sharding_constraint_batcher # -------------------- helpers -------------------- diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 35d665943792..b245c925fe60 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1338,7 +1338,7 @@ def _shard_map_batch( fun, out_dims = batching.batch_subtrace(fun, trace.main, tuple(in_dims)) new_in_names = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] # type: ignore for ax in names} for names, d in zip(in_names, in_dims)] - spmd_axis_name = trace.spmd_axis_name + spmd_axis_name = trace.axis_data.spmd_name if spmd_axis_name is not None: used = {n for names in in_names for ns in names.values() for n in ns} if not config.disable_vmap_shmap_error.value and set(spmd_axis_name) & used: @@ -1367,7 +1367,7 @@ def _shard_map_batch_post_process(trace, out_tracers, mesh, in_names, def todo(vals): trace = m.with_cur_sublevel() return map(partial(batching.BatchTracer, trace), vals, dims, srcs) - out_names_transform = partial(_batch_out_names, trace.spmd_axis_name, dims) + out_names_transform = partial(_batch_out_names, trace.axis_data.spmd_name, dims) return vals, (todo, out_names_transform) batching.BatchTrace.post_process_shard_map = _shard_map_batch_post_process