Skip to content

Commit

Permalink
Stop using generators for linear_util transformations.
Browse files Browse the repository at this point in the history
They lead to confusing code, nasty bugs, and unhelpful (but terse!) stack traces.
  • Loading branch information
dougalm committed Nov 13, 2024
1 parent ed9fdbb commit 438c261
Show file tree
Hide file tree
Showing 14 changed files with 218 additions and 218 deletions.
48 changes: 26 additions & 22 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
return tuple(map(_ensure_str, x))

@lu.transformation_with_aux
def flatten_fun(in_tree, *args_flat):
def flatten_fun(f, store, in_tree, *args_flat):
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
ans = yield py_args, py_kwargs
yield tree_flatten(ans)
ans = f(*py_args, **py_kwargs)
ans, out_tree = tree_flatten(ans)
store.store(out_tree)
return ans

def apply_flat_fun(fun, io_tree, *py_args):
in_tree_expected, out_tree = io_tree
Expand All @@ -83,10 +85,12 @@ def apply_flat_fun(fun, io_tree, *py_args):
return tree_unflatten(out_tree, ans)

@lu.transformation_with_aux
def flatten_fun_nokwargs(in_tree, *args_flat):
def flatten_fun_nokwargs(f, store, in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = yield py_args, {}
yield tree_flatten(ans)
ans = f(*py_args)
ans, out_tree = tree_flatten(ans)
store.store(out_tree)
return ans

def apply_flat_fun_nokwargs(fun, io_tree, py_args):
in_tree_expected, out_tree = io_tree
Expand Down Expand Up @@ -119,16 +123,17 @@ def flattened_fun_in_tree(
return in_tree, lambda: out_tree_store.val, has_kwargs

@lu.transformation_with_aux
def flatten_fun_nokwargs2(in_tree, *args_flat):
def flatten_fun_nokwargs2(f, store, in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
pair = yield py_args, {}
pair = f(*py_args)
if not isinstance(pair, (list, tuple)) or len(pair) != 2:
raise TypeError("expected function with aux output to return a two-element "
f"tuple, but got type {type(pair)} with value {pair!r}")
ans, aux = pair
ans_flat, ans_tree = tree_flatten(ans)
aux_flat, aux_tree = tree_flatten(aux)
yield (ans_flat, aux_flat), (ans_tree, aux_tree)
store.store((ans_tree, aux_tree))
return ans_flat, aux_flat

class _HashableWithStrictTypeEquality:
"""Box object used when comparing static arguments as a jit key.
Expand Down Expand Up @@ -278,17 +283,15 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: tuple[int, ...],
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args

@lu.transformation
def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
def _argnums_partial(f, dyn_argnums, fixed_args, *dyn_args, **kwargs):
sentinel = object()
args = [sentinel] * (len(fixed_args) + len(dyn_args))
for i, arg in zip(dyn_argnums, dyn_args):
args[i] = arg
fixed_args_ = iter(fixed_args)
args = [next(fixed_args_).val if x is sentinel else x for x in args]
assert next(fixed_args_, sentinel) is sentinel
ans = yield args, kwargs
yield ans

return f(*args, **kwargs)

def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
kwargs: dict[str, Any]):
Expand All @@ -312,10 +315,9 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: tuple[str, ...],
return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs

@lu.transformation
def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
def _argnames_partial(f, fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs)
ans = yield args, kwargs
yield ans
return f(*args, **kwargs)


@lru_cache(maxsize=4096)
Expand Down Expand Up @@ -436,8 +438,8 @@ def flat_out_axes(
return f, HashableFunction(out_axes, closure=(tuple(leaves), treedef))

@lu.transformation_with_aux
def _flat_out_axes(leaves, treedef, *args, **kwargs):
ans = yield args, kwargs
def _flat_out_axes(f, store, leaves, treedef, *args, **kwargs):
ans = f(*args, **kwargs)
spec = tree_unflatten(treedef, leaves)
try:
spec_flat = tuple(broadcast_prefix(spec, ans, is_leaf=lambda x: x is None))
Expand All @@ -449,7 +451,8 @@ def _flat_out_axes(leaves, treedef, *args, **kwargs):
"that the `out_axes` argument to `pmap` is a pytree prefix of the "
"pmapped function's output.")
raise ValueError(msg) from None
yield ans, spec_flat
store.store(spec_flat)
return ans

def check_callable(fun):
# In Python 3.10+, the only thing stopping us from supporting staticmethods
Expand Down Expand Up @@ -684,10 +687,11 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
for path, l in generate_key_paths(x) if l is not static)

@lu.transformation_with_aux
def result_paths(*args, **kwargs):
def result_paths(f, store, *args, **kwargs):
"linear_util transform to get output pytree paths of pre-flattened function."
ans = yield args, kwargs
yield ans, [keystr(path) for path, _ in generate_key_paths(ans)]
ans = f(*args, **kwargs)
store.store([keystr(path) for path, _ in generate_key_paths(ans)])
return ans

def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,
result_paths: tuple[str, ...] | None = None,
Expand Down
15 changes: 9 additions & 6 deletions jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,10 +331,11 @@ def update_error(error, pred, code, metadata, payload, effect_type):
## Checkify transformation for plumbing functional error values.

@lu.transformation_with_aux
def _flatten_and_get_error_metadata_thunk(*invals):
error, out = yield invals, {}
def _flatten_and_get_error_metadata_thunk(f, store, *invals):
error, out = f(*invals)
out_vals, out_tree = jtu.tree_flatten((error, out))
yield out_vals, (out_tree, set(error._pred.keys()))
store.store((out_tree, set(error._pred.keys())))
return out_vals

def default_checkify_rule(primitive: core.Primitive, error: Error,
enabled_errors, *invals: core.Value,
Expand Down Expand Up @@ -439,9 +440,11 @@ def checkify_jaxpr_flat_hashable(jaxpr, hashable_consts, enabled_errors,
return checkify_jaxpr_flat(jaxpr, consts, enabled_errors, err_tree, *args)

@lu.transformation_with_aux
def flatten_fun_output(*args):
ans = yield args, {}
yield tree_flatten(ans)
def flatten_fun_output(f, store, *args):
ans = f(*args)
ans, out_tree = tree_flatten(ans)
store.store(out_tree)
return ans


def _reduce_any_error(error: Error):
Expand Down
35 changes: 19 additions & 16 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,13 @@ def _zeros_like_pytree(x):

# like the api_util.py function, but also grabs output avals for error checking
@lu.transformation_with_aux
def _flatten_fun_nokwargs(in_tree, *args_flat):
def _flatten_fun_nokwargs(f, store, in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = yield py_args, {}
ans = f(*py_args)
ans_flat, ans_tree = tree_flatten(ans)
ans_avals = [core.get_aval(x) for x in ans_flat]
yield ans_flat, (ans_tree, ans_avals)
store.store((ans_tree, ans_avals))
return ans_flat


### JVPs
Expand Down Expand Up @@ -267,17 +268,17 @@ def _add_args(f, extra_args):
return _add_args_(f, tuple(Unhashable(arg) for arg in extra_args))

@lu.transformation
def _add_args_(extra_args, *args, **kwargs):
def _add_args_(f, extra_args, *args, **kwargs):
extra_args = tuple(arg.val for arg in extra_args)
all_args = (extra_args + args)
yield (yield all_args, kwargs)
return f(*all_args, **kwargs)

@partial(lu.transformation_with_aux, use_eq_store=True)
def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
def _flatten_jvp(f, store, primal_name, jvp_name, in_tree, maybe_out_type, *args):
primals_in, tangents_in = split_list(args, [len(args) // 2])
py_primals = tree_unflatten(in_tree, primals_in)
py_tangents = tree_unflatten(in_tree, tangents_in)
pair_out = yield (py_primals, py_tangents), {}
pair_out = f(py_primals, py_tangents)
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
msg = (f"Custom JVP rule {jvp_name} for function {primal_name} "
"must produce a pair (list or tuple of length two) representing "
Expand Down Expand Up @@ -348,7 +349,8 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
if av_et != av_t)

raise TypeError(msg.format('\n'.join(disagreements)))
yield primals_out + tangents_out, (out_tree, primal_avals)
store.store((out_tree, primal_avals))
return primals_out + tangents_out

class CustomJVPCallPrimitive(core.Primitive):
multiple_results = True
Expand Down Expand Up @@ -653,14 +655,14 @@ def _check_for_tracers(x):
raise UnexpectedTracerError(msg)

@partial(lu.transformation_with_aux, use_eq_store=True)
def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
def _flatten_fwd(f, store, symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
*args):
if symbolic_zeros:
args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])]
else:
args = args[::2]
py_args = tree_unflatten(in_tree, args)
pair_out = yield py_args, {}
pair_out = f(*py_args)
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
"must produce a pair (list or tuple of length two) where the first "
Expand Down Expand Up @@ -710,16 +712,17 @@ def _flatten_fwd(symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
"shapes/dtypes of:\n"
f""" {str(ty_tree_).replace("'", "")}""")
raise TypeError(m)
yield (*res, *primals_out), (out_tree, res_tree)
store.store((out_tree, res_tree))
return (*res, *primals_out)

@lu.transformation
def _flatten_bwd(in_tree, in_avals, out_trees, *args):
def _flatten_bwd(f, in_tree, in_avals, out_trees, *args):
out_tree, res_tree = out_trees()
assert len(args) == res_tree.num_leaves + out_tree.num_leaves
res, cts_out = split_list(args, [res_tree.num_leaves])
py_res = tree_unflatten(res_tree, res)
py_cts_out = tree_unflatten(out_tree, cts_out)
py_cts_in = yield (py_res, py_cts_out), {}
py_cts_in = f(py_res, py_cts_out)
if isinstance(py_cts_in, list) and len(py_cts_in) == len(treedef_children(in_tree)):
py_cts_in = tuple(py_cts_in)
# For each None in py_cts_in, indicating an argument for which the rule
Expand Down Expand Up @@ -775,7 +778,7 @@ def append(x, d):
f"to an input of shape/dtype {a.str_short()}.")
raise ValueError(msg)
results.append(ct)
yield results
return results

# TODO(mattjj): remove both these exceptions to cotangent compatibility check
def _temporary_dtype_exception(a, a_) -> bool:
Expand Down Expand Up @@ -1426,10 +1429,10 @@ def fun_jaxpr_thunk():
return wrapped_fwd

@lu.transformation
def _fix_fwd_args(*args):
def _fix_fwd_args(f, *args):
args = [(x, True) for x in args]
args = [x for pair in args for x in pair]
yield (yield args, {})
return f(*args)

def _remat_opt_impl(
*args,
Expand Down
47 changes: 26 additions & 21 deletions jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,41 +69,42 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True,
return jvpfun(fun, instantiate, transform_stack), aux

@lu.transformation
def jvpfun(instantiate, transform_stack, primals, tangents):
def jvpfun(f, instantiate, transform_stack, primals, tangents):
tag = core.TraceTag()
tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero)
and dtype(t) == float0 else t for t in tangents]
ctx = (source_info_util.transform_name_stack('jvp') if transform_stack
else contextlib.nullcontext())
with ctx:
out_primals, out_tangents = yield (tag, primals, tangents), {}
out_primals, out_tangents = f(tag, primals, tangents)
if type(instantiate) is bool:
instantiate = [instantiate] * len(out_tangents)
out_tangents = [instantiate_zeros(t) if inst else t for t, inst
in zip(out_tangents, instantiate)]
yield out_primals, out_tangents
return out_primals, out_tangents

@lu.transformation
def jvp_subtrace(tag, primals, tangents):
def jvp_subtrace(f, tag, primals, tangents):
with core.take_current_trace() as parent_trace:
trace = JVPTrace(parent_trace, tag)
in_tracers = [maybe_jvp_tracer(trace, x, t)
for x, t in zip(primals, tangents)]
with core.set_current_trace(trace):
ans = yield in_tracers, {}
ans = f(*in_tracers)
out = unzip2(map(trace.to_primal_tangent_pair, ans))
yield out
return out

@lu.transformation_with_aux
def jvp_subtrace_aux(tag, primals, tangents):
def jvp_subtrace_aux(f, store, tag, primals, tangents):
with core.take_current_trace() as parent_trace:
trace = JVPTrace(parent_trace, tag)
with core.set_current_trace(trace):
ans, aux = yield map(partial(maybe_jvp_tracer, trace), primals, tangents), {}
ans, aux = f(*(map(partial(maybe_jvp_tracer, trace), primals, tangents)))
out_primals, out_tangents = unzip2(map(trace.to_primal_tangent_pair, ans))
aux_primals = [x.primal if isinstance(x, JVPTracer) and x._trace.tag is tag
else x for x in aux]
yield (out_primals, out_tangents), aux_primals
store.store(aux_primals)
return out_primals, out_tangents

def linearize(traceable, *primals, **kwargs):
has_aux = kwargs.pop('has_aux', False)
Expand Down Expand Up @@ -263,9 +264,10 @@ def get_primitive_transpose(p):
"not implemented".format(p)) from err

@lu.transformation_with_aux
def nonzero_tangent_outputs(*args, **kwargs):
results = (_, tangents_out) = yield args, kwargs
yield results, [type(r) is not Zero for r in tangents_out]
def nonzero_tangent_outputs(f, store, *args, **kwargs):
results = (_, tangents_out) = f(*args, **kwargs)
store.store([type(r) is not Zero for r in tangents_out])
return results


class JVPTrace(Trace):
Expand Down Expand Up @@ -544,14 +546,15 @@ def instantiate_zeros(tangent):
return zeros_like_aval(tangent.aval) if type(tangent) is Zero else tangent

@lu.transformation_with_aux
def traceable(in_tree, *primals_and_tangents):
def traceable(f, store, in_tree, *primals_and_tangents):
primals, tangents = tree_unflatten(in_tree, primals_and_tangents)
tangents = [Zero.from_primal_value(p) if t is None else t
for p, t in zip(primals, tangents)]
primals_out, tangents_out = yield (primals, tangents), {}
primals_out, tangents_out = f(primals, tangents)
tangents_out = [None if type(t) is Zero else t for t in tangents_out]
out_flat, out_tree = tree_flatten((primals_out, tangents_out))
yield out_flat, out_tree
store.store(out_tree)
return out_flat


def call_transpose(primitive, params, call_jaxpr, args, ct, _):
Expand Down Expand Up @@ -589,9 +592,10 @@ def _closed_call_transpose(params, jaxpr, args, ct, cts_in_avals):


@lu.transformation_with_aux
def nonzero_outputs(*args, **kwargs):
results = yield args, kwargs
yield results, [type(r) is not Zero for r in results]
def nonzero_outputs(f, store, *args, **kwargs):
results = f(*args, **kwargs)
store.store([type(r) is not Zero for r in results])
return results

def map_transpose(primitive, params, call_jaxpr, args, ct, _):
all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
Expand Down Expand Up @@ -656,16 +660,17 @@ def _jvp_jaxpr(jaxpr, nonzeros, instantiate):
return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()

@lu.transformation_with_aux
def f_jvp_traceable(nonzeros, *primals_and_nztangents):
def f_jvp_traceable(f, store, nonzeros, *primals_and_nztangents):
num_primals = len(nonzeros)
primals = list(primals_and_nztangents[:num_primals])
nonzero_tangents = iter(primals_and_nztangents[num_primals:])
tangents = [next(nonzero_tangents) if nz else Zero.from_primal_value(p)
for p, nz in zip(primals, nonzeros)]
primals_out, tangents_out = yield (primals, tangents), {}
primals_out, tangents_out = f(primals, tangents)
out_nonzeros = [type(t) is not Zero for t in tangents_out]
nonzero_tangents_out = [t for t in tangents_out if type(t) is not Zero]
yield list(primals_out) + nonzero_tangents_out, out_nonzeros
store.store(out_nonzeros)
return list(primals_out) + nonzero_tangents_out

def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out):
new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
Expand Down
Loading

0 comments on commit 438c261

Please sign in to comment.