Skip to content

Commit

Permalink
Change to internal dead code elimination. Now the functions in `dce_r…
Browse files Browse the repository at this point in the history
…ules` are responsible for checking if the equation has no used outputs or effects, and behaving appropriately in that case (which usually means eliminating said equation).

PiperOrigin-RevId: 695538316
  • Loading branch information
james-martens authored and Google-ML-Automation committed Nov 12, 2024
1 parent 3a5ac48 commit 57ce6c4
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 16 deletions.
2 changes: 2 additions & 0 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,8 @@ def remat_vmap(axis_data, args, dims, *, jaxpr, **params):
# TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule
def remat_dce(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], core.JaxprEqn | None]:
if not any(used_outputs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
new_params = dict(eqn.params, jaxpr=new_jaxpr)
if (not any(used_inputs) and not any(used_outputs) and
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,8 @@ def _remat_opt_transpose(
"remat optimization for custom_vjp does not support higher-order AD")

def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn):
if not any(used_outs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None
used_res, used_prims = split_list(used_outs, [eqn.params["num_res"]])
outvars = [v for used, v in zip(used_outs, eqn.outvars) if used]
if any(used_res):
Expand Down
33 changes: 19 additions & 14 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def const_out_axes_thunk():
staged_out_axes, _ = partition_list(out_knowns, out_axes)
staged_in_axes = (0,) * len(res) + (None,) * len(env) + (*unk_in_axes,)

# Create the input tracers for the staged-out (unkonwn-value) call.
# Create the input tracers for the staged-out (unknown-value) call.
const_tracers = map(self.new_instantiated_const, res)
env_tracers = map(self.to_jaxpr_tracer, env)
unknown_arg_tracers = [t for t in tracers if not t.is_known()]
Expand Down Expand Up @@ -1382,6 +1382,11 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
return new_jaxpr, used_consts, used_inputs


def has_effects(eqn: JaxprEqn) -> bool:
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)}
return bool(effs)


@weakref_lru_cache
def _dce_jaxpr(jaxpr: Jaxpr, used_outputs: tuple[bool, ...],
instantiate: tuple[bool, ...]
Expand All @@ -1395,21 +1400,14 @@ def write(x: Atom, b: bool) -> None:
if type(x) is Var:
env[x] = read(x) or b

def has_effects(eqn: JaxprEqn) -> bool:
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)}
return bool(effs)

new_eqns = []
map(write, jaxpr.outvars, used_outputs)
for eqn in jaxpr.eqns[::-1]:
used_outs = map(read, eqn.outvars)
if not any(used_outs) and not has_effects(eqn):
used_ins = [False] * len(eqn.invars)
else:
rule = dce_rules.get(eqn.primitive, _default_dce_rule)
used_ins, new_eqn = rule(used_outs, eqn)
if new_eqn is not None:
new_eqns.append(new_eqn)
rule = dce_rules.get(eqn.primitive, _default_dce_rule)
used_ins, new_eqn = rule(used_outs, eqn)
if new_eqn is not None:
new_eqns.append(new_eqn)
map(write, eqn.invars, used_ins)
used_inputs = map(read, jaxpr.invars)
used_inputs = map(op.or_, instantiate, used_inputs)
Expand All @@ -1433,14 +1431,18 @@ def has_effects(eqn: JaxprEqn) -> bool:

def _default_dce_rule(
used_outs: list[bool], eqn: JaxprEqn
) -> tuple[list[bool], JaxprEqn]:
) -> tuple[list[bool], JaxprEqn | None]:
if not any(used_outs) and not has_effects(eqn):
return [False] * len(eqn.invars), None
return [True] * len(eqn.invars), eqn

dce_rules: dict[Primitive, DCERule] = {}


def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn
) -> tuple[list[bool], JaxprEqn | None]:
if not any(used_outputs) and not has_effects(eqn):
return [False] * len(eqn.invars), None
new_jaxpr, used_inputs = dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
new_params = dict(eqn.params, call_jaxpr=new_jaxpr)
update_params = call_param_updaters.get(eqn.primitive)
Expand All @@ -1454,6 +1456,7 @@ def dce_jaxpr_call_rule(used_outputs: list[bool], eqn: JaxprEqn
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_jaxpr.effects, eqn.source_info, eqn.ctx)
return used_inputs, new_eqn

dce_rules[core.call_p] = dce_jaxpr_call_rule


Expand All @@ -1465,8 +1468,10 @@ def _cached_closed_call_dce(jaxpr_, used_outputs: tuple[bool, ...]
return core.ClosedJaxpr(new_jaxpr, consts), used_inputs

def dce_jaxpr_closed_call_rule(used_outputs: list[bool], eqn: JaxprEqn
) -> tuple[list[bool], JaxprEqn]:
) -> tuple[list[bool], JaxprEqn | None]:
# TODO(mattjj): de-duplicate with above rule?
if not any(used_outputs) and not has_effects(eqn):
return [False] * len(eqn.invars), None
jaxpr_ = eqn.params['call_jaxpr']
closed_jaxpr, used_inputs = _cached_closed_call_dce(jaxpr_, tuple(used_outputs))
new_params = dict(eqn.params, call_jaxpr=closed_jaxpr)
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,8 @@ def _pmap_partial_eval_custom_res_maker(params_known, aval):

def _pmap_dce_rule(used_outputs, eqn):
# just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes
if not any(used_outputs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None
axis_name = eqn.params["axis_name"]
with core.extend_axis_env_nd([(axis_name, eqn.params["global_axis_size"])]):
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
Expand Down
6 changes: 5 additions & 1 deletion jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,11 @@ def _ordered_unique(xs):
return list(d.keys())

def _cond_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn,
) -> tuple[list[bool], core.JaxprEqn]:
) -> tuple[list[bool], core.JaxprEqn | None]:

if not any(used_outputs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None

closed_branches = eqn.params['branches']
branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches]

Expand Down
4 changes: 3 additions & 1 deletion jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,9 @@ def _scan_padding_rule(in_avals, out_avals, *args, jaxpr, **params):
return scan_p.bind(*args, jaxpr=_cached_scan_pad_jaxpr(jaxpr), **params)

def _scan_dce_rule(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], core.JaxprEqn]:
) -> tuple[list[bool], core.JaxprEqn | None]:
if not any(used_outputs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None
jaxpr = eqn.params['jaxpr']
num_consts, num_carry = eqn.params['num_consts'], eqn.params['num_carry']
num_xs = len(jaxpr.in_avals) - num_consts - num_carry
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2326,6 +2326,10 @@ def _dce_jaxpr_pjit(

def dce_jaxpr_pjit_rule(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], core.JaxprEqn | None]:

if not any(used_outputs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None

dced_jaxpr, used_inputs = _dce_jaxpr_pjit(
eqn.params['jaxpr'], tuple(used_outputs))

Expand Down
2 changes: 2 additions & 0 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,6 +1660,8 @@ def _all_mesh_names_except_spmd(mesh: Mesh, trace=None) -> tuple[AxisName, ...]:
# TODO(mattjj): de-duplicate with pe.dce_jaxpr_call_rule, and/or _pmap_dce_rule?
def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn
) -> tuple[list[bool], core.JaxprEqn | None]:
if not any(used_outputs) and not pe.has_effects(eqn):
return [False] * len(eqn.invars), None
mesh = eqn.params["mesh"]
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['jaxpr'], used_outputs)
Expand Down
1 change: 1 addition & 0 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
debug_info_final as debug_info_final,
def_trivial_padding as def_trivial_padding,
forwarding_rules as forwarding_rules,
has_effects as has_effects,
infer_lambda_input_type as infer_lambda_input_type,
instantiate_const_at as instantiate_const_at,
make_jaxpr_effects as make_jaxpr_effects,
Expand Down

0 comments on commit 57ce6c4

Please sign in to comment.