Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Put the set of current spmd axis names in the axis env instead of spelunking #24803

Merged
merged 1 commit into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,7 @@ def __eq__(self, other):
@dataclass(frozen=True)
class AxisEnv:
axis_sizes : dict[AxisName, int]
spmd_axis_names : set[AxisName]

def axis_size(self, axis_name):
if axis_name not in self.axis_sizes:
Expand All @@ -971,20 +972,24 @@ def axis_names(self):
def pop_pure(self, axis_name):
new_sizes = self.axis_sizes.copy()
new_sizes.pop(axis_name)
return AxisEnv(new_sizes)
return AxisEnv(new_sizes, self.spmd_axis_names)

def extend_pure(self, name_size_pairs):
new_sizes = self.axis_sizes.copy()
new_sizes.update((name, size) for name, size in name_size_pairs
if name is not no_axis_name)
return AxisEnv(new_sizes)
return AxisEnv(new_sizes, self.spmd_axis_names)

def add_spmd_axis_names(self, axis_names):
new_spmd_axis_names = self.spmd_axis_names | set(axis_names)
return AxisEnv(self.axis_sizes, new_spmd_axis_names)

def as_hashable_key(self):
return tuple((name, size) for (name, size) in self.axis_sizes.items()
if name is not no_axis_name)

eval_trace = EvalTrace()
top_axis_env = AxisEnv({})
top_axis_env = AxisEnv({}, set())

class TracingContext(threading.local):
trace: Trace | None
Expand Down Expand Up @@ -1045,6 +1050,16 @@ def extend_axis_env_nd(name_size_pairs : Iterable[tuple[AxisName, int]]):
finally:
trace_ctx.set_axis_env(prev)

@contextmanager
def add_spmd_axis_names(axis_names: AxisName | None):
prev = trace_ctx.axis_env
try:
if axis_names is not None:
trace_ctx.set_axis_env(prev.add_spmd_axis_names(axis_names))
yield
finally:
trace_ctx.set_axis_env(prev)

def get_axis_env():
return trace_ctx.axis_env

Expand Down
14 changes: 8 additions & 6 deletions jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,9 +596,10 @@ def _batch_inner(axis_data, out_dim_dests, tag, in_dims, *in_vals):
idx = memoize(lambda: BatchTracer(trace, make_iota(axis_data.size), 0,
source_info_util.current()))
in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims)
with core.set_current_trace(trace):
with core.extend_axis_env_nd([(axis_data.name, axis_data.size)]):
outs = yield in_tracers, {}
with (core.set_current_trace(trace),
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
core.add_spmd_axis_names(axis_data.spmd_name)):
outs = yield in_tracers, {}

out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
out_vals = map(partial(from_elt, trace, axis_data.size), range(len(outs)),
Expand Down Expand Up @@ -795,9 +796,10 @@ def _batch_jaxpr_inner(axis_data, tag, in_axes, *in_vals):
_, in_axes = resolve_ragged_axes(in_vals, in_axes)
in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
for val, dim in zip(in_vals, in_axes)]
with core.set_current_trace(trace):
with core.extend_axis_env_nd([(axis_data.name, axis_data.size)]):
outs = yield in_tracers, {}
with (core.set_current_trace(trace),
core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
core.add_spmd_axis_names(axis_data.spmd_name)):
outs = yield in_tracers, {}
out_vals, out_axes = unzip2(map(trace.to_batch_info, outs))
new_out_axes = indirectify_ragged_axes_against_inputs_outputs(
out_axes, in_vals, out_vals)
Expand Down
7 changes: 2 additions & 5 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1506,7 +1506,7 @@ def _unmentioned2(mesh: Mesh, names: AxisNames) -> list[AxisName]:
# We use a filtered-down version of unmentioned to avoid defensive-psum over
# more chips than required in the transpose-no-check-rep case.
name_set = {n for ns in names.values() for n in ns}
return [n for n in mesh.axis_names if n not in name_set]
return [n for n in _all_mesh_names_except_spmd(mesh) if n not in name_set]


def _shard_map_transpose(out_cts, *args, jaxpr, mesh, in_names, out_names,
Expand Down Expand Up @@ -1652,10 +1652,7 @@ def _pe_custom_params(unks_in, inst_in, kept_outs_known, kept_outs_staged,

# TODO(mattjj): remove this mechanism when we revise mesh scopes
def _all_mesh_names_except_spmd(mesh: Mesh, trace=None) -> tuple[AxisName, ...]:
trace = core.unsafe_get_current_trace() if trace is None else trace
stack = core.unsafe_get_trace_stack(trace)
batch_traces = [t for t in stack if isinstance(t, batching.BatchTrace)]
spmd_names = {n for trace in batch_traces for n in trace.axis_data.spmd_name }
spmd_names = core.get_axis_env().spmd_axis_names
return tuple(name for name in mesh.axis_names if name not in spmd_names)

# DCE
Expand Down
Loading