diff --git a/jax/_src/core.py b/jax/_src/core.py index 6f96dc760cc0..d9c9306d854c 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -955,6 +955,7 @@ def __eq__(self, other): @dataclass(frozen=True) class AxisEnv: axis_sizes : dict[AxisName, int] + spmd_axis_names : set[AxisName] def axis_size(self, axis_name): if axis_name not in self.axis_sizes: @@ -971,20 +972,24 @@ def axis_names(self): def pop_pure(self, axis_name): new_sizes = self.axis_sizes.copy() new_sizes.pop(axis_name) - return AxisEnv(new_sizes) + return AxisEnv(new_sizes, self.spmd_axis_names) def extend_pure(self, name_size_pairs): new_sizes = self.axis_sizes.copy() new_sizes.update((name, size) for name, size in name_size_pairs if name is not no_axis_name) - return AxisEnv(new_sizes) + return AxisEnv(new_sizes, self.spmd_axis_names) + + def add_spmd_axis_names(self, axis_names): + new_spmd_axis_names = self.spmd_axis_names | set(axis_names) + return AxisEnv(self.axis_sizes, new_spmd_axis_names) def as_hashable_key(self): return tuple((name, size) for (name, size) in self.axis_sizes.items() if name is not no_axis_name) eval_trace = EvalTrace() -top_axis_env = AxisEnv({}) +top_axis_env = AxisEnv({}, set()) class TracingContext(threading.local): trace: Trace | None @@ -1045,6 +1050,16 @@ def extend_axis_env_nd(name_size_pairs : Iterable[tuple[AxisName, int]]): finally: trace_ctx.set_axis_env(prev) +@contextmanager +def add_spmd_axis_names(axis_names: AxisName | None): + prev = trace_ctx.axis_env + try: + if axis_names is not None: + trace_ctx.set_axis_env(prev.add_spmd_axis_names(axis_names)) + yield + finally: + trace_ctx.set_axis_env(prev) + def get_axis_env(): return trace_ctx.axis_env diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 590e60383b90..0adb582a7993 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -596,9 +596,10 @@ def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals): idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0, source_info_util.current())) in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims) - with core.set_current_trace(trace): - with core.extend_axis_env_nd([(axis_data.name, axis_data.size)]): - outs = yield in_tracers, {} + with (core.set_current_trace(trace), + core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), + core.add_spmd_axis_names(axis_data.spmd_name)): + outs = yield in_tracers, {} out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests out_vals = map(partial(from_elt, trace, axis_data.size), range(len(outs)), @@ -795,9 +796,10 @@ def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals): _, in_axes = resolve_ragged_axes(in_vals, in_axes) in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val for val, dim in zip(in_vals, in_axes)] - with core.set_current_trace(trace): - with core.extend_axis_env_nd([(axis_data.name, axis_data.size)]): - outs = yield in_tracers, {} + with (core.set_current_trace(trace), + core.extend_axis_env_nd([(axis_data.name, axis_data.size)]), + core.add_spmd_axis_names(axis_data.spmd_name)): + outs = yield in_tracers, {} out_vals, out_axes = unzip2(map(trace.to_batch_info, outs)) new_out_axes = indirectify_ragged_axes_against_inputs_outputs( out_axes, in_vals, out_vals) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index c67b4f68cc9b..7ddd3805b5d0 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1506,7 +1506,7 @@ def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]: # We use a filtered-down version of unmentioned to avoid defensive-psum over # more chips than required in the transpose-no-check-rep case. name_set = {n for ns in names.values() for n in ns} - return [n for n in mesh.axis_names if n not in name_set] + return [n for n in _all_mesh_names_except_spmd(mesh) if n not in name_set] def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names, @@ -1652,10 +1652,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged, # TODO(mattjj): remove this mechanism when we revise mesh scopes def _all_mesh_names_except_spmd(mesh: Mesh, trace=None) -> tuple[AxisName, ...]: - trace = core.unsafe_get_current_trace() if trace is None else trace - stack = core.unsafe_get_trace_stack(trace) - batch_traces = [t for t in stack if isinstance(t, batching.BatchTrace)] - spmd_names = {n for trace in batch_traces for n in trace.axis_data.spmd_name } + spmd_names = core.get_axis_env().spmd_axis_names return tuple(name for name in mesh.axis_names if name not in spmd_names) # DCE