Skip to content

Commit

Permalink
Make an AxisData struct that bundles axis name, size, and spmd name.
Browse files Browse the repository at this point in the history
This is just a small cleanup as prep work for stackless. It means fewer
arguments to batching functions and less room for argument-order mistakes.

PiperOrigin-RevId: 676833615
  • Loading branch information
dougalm authored and Google-ML-Automation committed Sep 20, 2024
1 parent a533635 commit 48ca887
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 205 deletions.
9 changes: 4 additions & 5 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,19 +701,18 @@ def transposed(*args_flat):
transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts)
return transposed_jaxpr, cell.in_cts_zero # pytype: disable=attribute-error

def remat_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *,
jaxpr, **params):
def remat_vmap(axis_data, main_type, args, dims, *, jaxpr, **params):
assert not jaxpr.constvars
jaxpr_batched_, out_batched = batching.batch_jaxpr_axes(
pe.close_jaxpr(jaxpr), axis_size, dims,
pe.close_jaxpr(jaxpr), axis_data, dims,
[batching.zero_if_mapped] * len(jaxpr.outvars),
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
main_type=main_type)
jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts
if consts:
jaxpr_batched = pe.convert_constvars_jaxpr(jaxpr_batched)
out_dims = [0 if b else None for b in out_batched]
return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
batching.axis_primitive_batchers[remat_p] = partial(remat_vmap, None)
batching.axis_primitive_batchers[remat_p] = remat_vmap
batching.spmd_axis_primitive_batchers[remat_p] = remat_vmap

# TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,10 +983,10 @@ def vmap_f(*args, **kwargs):
axis_size_ = (axis_size if axis_size is not None else
_mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap"))
try:
axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name)
out_flat = batching.batch(
flat_fun, axis_name, axis_size_, in_axes_flat,
lambda: flatten_axes("vmap out_axes", out_tree(), out_axes),
spmd_axis_name=spmd_axis_name
flat_fun, axis_data, in_axes_flat,
lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
).call_wrapped(*args_flat)
except batching.SpecMatchError as e:
out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/custom_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def maybe_bdim_at_front(x, bdim):
# `f` is pytree-flattened
def vmap_unrestricted(f: lu.WrappedFun, *args, in_axes, axis_name, axis_size):
f, out_axes = batching.batch_subtrace(f)
f = batching._batch_outer(f, axis_name, axis_size, in_axes,
batching.BatchTrace, None)
axis_data = batching.AxisData(axis_name, axis_size, None)
f = batching._batch_outer(f, axis_data, in_axes, batching.BatchTrace)
outs = f.call_wrapped(*args)
return outs, out_axes()

Expand Down
26 changes: 10 additions & 16 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,35 +921,31 @@ def _custom_vjp_call_jaxpr_jvp(
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp

def _custom_vjp_call_jaxpr_vmap(
spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *,
axis_data, main_type, args, in_dims, *,
fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)]

in_batched = [d is not not_mapped for d in in_dims]
_, args_batched = split_list(in_batched, [num_consts])
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
fun_jaxpr, axis_size, in_batched, False, axis_name, spmd_axis_name,
main_type)
fun_jaxpr, axis_data, in_batched, False, main_type)
out_dims1 = [0 if b else not_mapped for b in out_batched]
out_dims2 = []

@pe._memoize
def batched_fwd_jaxpr_thunk(*zeros):
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
fwd_jaxpr, axis_size, args_batched, False, axis_name, spmd_axis_name,
main_type)
fwd_jaxpr, axis_data, args_batched, False, main_type)
out_dims2.append([0 if b else not_mapped for b in out_batched])
return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts

fwd_args_batched = [0 if b else not_mapped for b in args_batched]
fwd_out_dims = lambda: out_dims2[0]
batched_bwd = batching.batch_custom_vjp_bwd(
bwd, axis_name, axis_size, fwd_out_dims, fwd_args_batched, main_type,
spmd_axis_name)
bwd, axis_data, fwd_out_dims, fwd_args_batched, main_type)

batched_outs = custom_vjp_call_jaxpr_p.bind(
*args, fun_jaxpr=batched_fun_jaxpr,
Expand All @@ -959,8 +955,8 @@ def batched_fwd_jaxpr_thunk(*zeros):
return batched_outs, out_dims
batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = \
_custom_vjp_call_jaxpr_vmap
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial(
_custom_vjp_call_jaxpr_vmap, None)
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = \
_custom_vjp_call_jaxpr_vmap

xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p)

Expand Down Expand Up @@ -1532,7 +1528,7 @@ def _remat_opt_abstract_eval(*args, fwd_jaxpr: core.ClosedJaxpr, **_):
return fwd_jaxpr.out_avals, fwd_jaxpr.effects

def _remat_opt_vmap(
spmd_axis_name, axis_size, axis_name, main_type, args, in_dims,
axis_data, main_type, args, in_dims,
*,
num_consts: int,
num_res: int,
Expand All @@ -1544,8 +1540,7 @@ def _remat_opt_vmap(

in_batched = [d is not not_mapped for d in in_dims]
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
fwd_jaxpr, axis_size, in_batched, False,
axis_name, spmd_axis_name, main_type)
fwd_jaxpr, axis_data, in_batched, False, main_type)
extra_consts = batched_fwd_jaxpr.consts
batched_fwd_jaxpr = pe.close_jaxpr(
pe.convert_constvars_jaxpr(batched_fwd_jaxpr.jaxpr))
Expand All @@ -1557,8 +1552,7 @@ def _remat_opt_vmap(
def batched_fun_jaxpr_thunk():
fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk())
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
fun_jaxpr, axis_size, prim_batched, False, axis_name, spmd_axis_name,
main_type)
fun_jaxpr, axis_data, prim_batched, False, main_type)
return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts

batched_outs = remat_opt_p.bind(*extra_consts, *args,
Expand Down Expand Up @@ -1667,7 +1661,7 @@ def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn):
mlir.register_lowering(remat_opt_p, mlir.lower_fun(
_remat_opt_impl, multiple_results=True))
batching.spmd_axis_primitive_batchers[remat_opt_p] = _remat_opt_vmap
batching.axis_primitive_batchers[remat_opt_p] = partial(_remat_opt_vmap, None)
batching.axis_primitive_batchers[remat_opt_p] = _remat_opt_vmap
ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp
ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose
pe.dce_rules[remat_opt_p] = _remat_opt_dce
Loading

0 comments on commit 48ca887

Please sign in to comment.