Skip to content

Commit

Permalink
Put the set of current spmd axis names in the axis env instead of spe…
Browse files Browse the repository at this point in the history
…lunking

through the trace stack to find it.

PiperOrigin-RevId: 694603535
  • Loading branch information
dougalm authored and Google-ML-Automation committed Nov 8, 2024
1 parent c8f5b2b commit 4c69973
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 14 deletions.
21 changes: 18 additions & 3 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,7 @@ def __eq__(self, other):
@dataclass(frozen=True)
class AxisEnv:
axis_sizes : dict[AxisName, int]
spmd_axis_names : set[AxisName]

def axis_size(self, axis_name):
if axis_name not in self.axis_sizes:
Expand All @@ -971,20 +972,24 @@ def axis_names(self):
def pop_pure(self, axis_name):
new_sizes = self.axis_sizes.copy()
new_sizes.pop(axis_name)
return AxisEnv(new_sizes)
return AxisEnv(new_sizes, self.spmd_axis_names)

def extend_pure(self, name_size_pairs):
new_sizes = self.axis_sizes.copy()
new_sizes.update((name, size) for name, size in name_size_pairs
if name is not no_axis_name)
return AxisEnv(new_sizes)
return AxisEnv(new_sizes, self.spmd_axis_names)

def add_spmd_axis_names(self, axis_names):
new_spmd_axis_names = self.spmd_axis_names | set(axis_names)
return AxisEnv(self.axis_sizes, new_spmd_axis_names)

def as_hashable_key(self):
return tuple((name, size) for (name, size) in self.axis_sizes.items()
if name is not no_axis_name)

eval_trace = EvalTrace()
top_axis_env = AxisEnv({})
top_axis_env = AxisEnv({}, set())

class TracingContext(threading.local):
trace: Trace | None
Expand Down Expand Up @@ -1045,6 +1050,16 @@ def extend_axis_env_nd(name_size_pairs : Iterable[tuple[AxisName, int]]):
finally:
trace_ctx.set_axis_env(prev)

@contextmanager
def add_spmd_axis_names(axis_names: AxisName | None):
prev = trace_ctx.axis_env
try:
if axis_names is not None:
trace_ctx.set_axis_env(prev.add_spmd_axis_names(axis_names))
yield
finally:
trace_ctx.set_axis_env(prev)

def get_axis_env():
return trace_ctx.axis_env

Expand Down
14 changes: 8 additions & 6 deletions jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,9 +596,10 @@ def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals):
idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0,
source_info_util.current()))
in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims)
with core.set_current_trace(trace):
with core.extend_axis_env_nd([(axis_data.name, axis_data.size)]):
outs = yield in_tracers, {}
with (core.set_current_trace(trace),
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
core.add_spmd_axis_names(axis_data.spmd_name)):
outs = yield in_tracers, {}

out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
out_vals = map(partial(from_elt, trace, axis_data.size), range(len(outs)),
Expand Down Expand Up @@ -795,9 +796,10 @@ def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals):
_, in_axes = resolve_ragged_axes(in_vals, in_axes)
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
for val, dim in zip(in_vals, in_axes)]
with core.set_current_trace(trace):
with core.extend_axis_env_nd([(axis_data.name, axis_data.size)]):
outs = yield in_tracers, {}
with (core.set_current_trace(trace),
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
core.add_spmd_axis_names(axis_data.spmd_name)):
outs = yield in_tracers, {}
out_vals, out_axes = unzip2(map(trace.to_batch_info, outs))
new_out_axes = indirectify_ragged_axes_against_inputs_outputs(
out_axes, in_vals, out_vals)
Expand Down
7 changes: 2 additions & 5 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1506,7 +1506,7 @@ def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]:
# We use a filtered-down version of unmentioned to avoid defensive-psum over
# more chips than required in the transpose-no-check-rep case.
name_set = {n for ns in names.values() for n in ns}
return [n for n in mesh.axis_names if n not in name_set]
return [n for n in _all_mesh_names_except_spmd(mesh) if n not in name_set]


def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
Expand Down Expand Up @@ -1652,10 +1652,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,

# TODO(mattjj): remove this mechanism when we revise mesh scopes
def _all_mesh_names_except_spmd(mesh: Mesh, trace=None) -> tuple[AxisName, ...]:
trace = core.unsafe_get_current_trace() if trace is None else trace
stack = core.unsafe_get_trace_stack(trace)
batch_traces = [t for t in stack if isinstance(t, batching.BatchTrace)]
spmd_names = {n for trace in batch_traces for n in trace.axis_data.spmd_name }
spmd_names = core.get_axis_env().spmd_axis_names
return tuple(name for name in mesh.axis_names if name not in spmd_names)

# DCE
Expand Down

0 comments on commit 4c69973

Please sign in to comment.