Skip to content

Commit

Permalink
Improvements in the handling of tracing debug info
Browse files Browse the repository at this point in the history
DO_NOT_SUBMIT WIP
  • Loading branch information
gnecula committed Nov 11, 2024
1 parent 7491fdd commit c6d5b41
Show file tree
Hide file tree
Showing 14 changed files with 166 additions and 90 deletions.
4 changes: 3 additions & 1 deletion jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dtypes
Expand Down Expand Up @@ -404,7 +405,8 @@ def new_fun(*dyn_args, **kwargs):
# remat-specific errors.
@weakref_lru_cache
def _trace_to_jaxpr(fun, in_tree, in_avals):
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun, traced_for="checkpoint"),
in_tree)
debug = pe.debug_info(fun, in_tree, out_tree, True, "checkpoint")
try:
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
Expand Down
20 changes: 9 additions & 11 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, apply_flat_fun_nokwargs, check_callable, debug_info,
result_paths, flat_out_axes, debug_info_final, fun_sourceinfo)
flat_out_axes, fun_sourceinfo)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -1392,19 +1392,17 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str,
return global_axis_size


def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
def _prepare_pmap(fun: Callable,
in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, in_devices, backend_name,
axis_size, args, kwargs):
if in_devices is not None and len(in_devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")

src = fun_sourceinfo(fun)
signature = api_util.fun_signature(fun)
dbg = debug_info(fun, 'pmap', args, kwargs,
static_argnums=static_broadcasted_tuple)

dbg = debug_info('pmap', src, signature, args, kwargs,
static_broadcasted_tuple, ())

f = lu.wrap_init(fun)
f = lu.wrap_init(fun, debug_info=dbg)
if static_broadcasted_tuple:
if max(static_broadcasted_tuple) >= len(args):
raise ValueError(
Expand Down Expand Up @@ -1451,10 +1449,10 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
raise ValueError(msg) from None
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")

f, res_paths = result_paths(f)
# f, res_paths = result_paths(f)
f, out_axes_thunk = flat_out_axes(f, out_axes)
flat_fun, out_tree = flatten_fun(f, in_tree)
flat_fun = debug_info_final(flat_fun, dbg, res_paths)
# flat_fun = debug_info_final(flat_fun, dbg, res_paths)

is_explicit_global_axis_size = axis_size is not None
global_axis_size = _get_global_axis_size(local_axis_size, in_devices,
Expand Down Expand Up @@ -1957,7 +1955,7 @@ def vjp(
del reduce_axes
check_callable(fun)
return _vjp(
lu.wrap_init(fun), *primals, has_aux=has_aux)
lu.wrap_init(fun, traced_for="vjp"), *primals, has_aux=has_aux)

def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
"""Variant of vjp() that takes an lu.WrappedFun."""
Expand Down
93 changes: 69 additions & 24 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,51 @@ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
else:
return tuple(map(_ensure_str, x))

def _fill_debug_info_after_flatten(
fun: lu.WrappedFun, in_tree: PyTreeDef,
has_kwargs: bool, out_tree_thunk: Callable[[], PyTreeDef]) -> lu.WrappedFun:
"""Fills in the fun.debug_info."""
dbg = fun.debug_info
if dbg is None:
fun_src_info = fun_sourceinfo(fun.f)
dbg = TracingDebugInfo("unknown", fun_src_info, None, None)
elif dbg.func_src_info is None:
fun_src_info = fun_sourceinfo(fun.f)
else:
fun_src_info = dbg.func_src_info

if dbg.arg_names is None:
dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves)
dummy_args, dummy_kwargs = dummy_args if has_kwargs else (dummy_args, {})
fun_sig = fun_signature(fun.f)
arg_names = _arg_names(fun_sig, dummy_args, dummy_kwargs, (), ())
else:
arg_names = dbg.arg_names

def result_paths() -> tuple[str] | None:
out_tree = out_tree_thunk()
try:
num_leaves = out_tree.num_leaves
dummy_result = tree_unflatten(out_tree, [False] * num_leaves)
except:
return None
else:
return tuple(keystr(path) for path, _ in generate_key_paths(dummy_result))
result_paths_ = HashableFunction(result_paths, closure=())
dbg = TracingDebugInfo(dbg.traced_for, fun_src_info,
arg_names, result_paths_)
return lu.add_debug_info(fun, dbg)

@lu.transformation_with_aux
def flatten_fun(in_tree, *args_flat):
def flatten_fun_transformation(in_tree, *args_flat):
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
ans = yield py_args, py_kwargs
yield tree_flatten(ans)

def flatten_fun(f: lu.WrappedFun, in_tree: PyTreeDef) -> tuple[lu.WrappedFun, Callable]:
wrapped, out_tree_thunk = flatten_fun_transformation(f, in_tree)
return _fill_debug_info_after_flatten(wrapped, in_tree, True, out_tree_thunk), out_tree_thunk

def apply_flat_fun(fun, io_tree, *py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten((py_args, {}))
Expand All @@ -83,11 +122,15 @@ def apply_flat_fun(fun, io_tree, *py_args):
return tree_unflatten(out_tree, ans)

@lu.transformation_with_aux
def flatten_fun_nokwargs(in_tree, *args_flat):
def flatten_fun_nokwargs_transformation(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = yield py_args, {}
yield tree_flatten(ans)

def flatten_fun_nokwargs(f: lu.WrappedFun, in_tree: PyTreeDef) -> tuple[lu.WrappedFun, Callable]:
wrapped, out_tree_thunk = flatten_fun_nokwargs_transformation(f, in_tree)
return _fill_debug_info_after_flatten(wrapped, in_tree, False, out_tree_thunk), out_tree_thunk

def apply_flat_fun_nokwargs(fun, io_tree, py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten(py_args)
Expand All @@ -105,13 +148,13 @@ def flattened_fun_in_tree(
# with flatten_fun or flatten_fun_nokwargs, which could happen e.g. when
# core.eval_jaxpr encounters a call primitive (though at that point we're just
# round-tripping jaxprs and the user errors in question are impossible).
assert isinstance(flatten_fun, partial) and len(flatten_fun.args) == 1
assert (isinstance(flatten_fun_nokwargs, partial) and
len(flatten_fun_nokwargs.args) == 1)
flattens = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]}
assert isinstance(flatten_fun_transformation, partial) and len(flatten_fun_transformation.args) == 1
assert (isinstance(flatten_fun_nokwargs_transformation, partial) and
len(flatten_fun_nokwargs_transformation.args) == 1)
flattens = {flatten_fun_transformation.args[0], flatten_fun_nokwargs_transformation.args[0]}
try:
((in_tree,), out_tree_store, has_kwargs), = (
(args, store, f is flatten_fun.args[0])
(args, store, f is flatten_fun_transformation.args[0])
for (f, args), store in zip(fn.transforms, fn.stores) if f in flattens)
except ValueError:
return None
Expand Down Expand Up @@ -639,16 +682,25 @@ def api_hook(fun, tag: str):


def debug_info(
traced_for: str, src: str | None, fun_signature: inspect.Signature | None,
args: tuple[Any, ...], kwargs: dict[str, Any], static_argnums: tuple[int, ...],
static_argnames: tuple[str, ...]
fun: Callable,
traced_for: str,
args: tuple[Any, ...],
kwargs: dict[str, Any],
*,
fun_src_info: str | None = None,
fun_sig: inspect.Signature | None = None,
static_argnums: tuple[int, ...] = (),
static_argnames: tuple[str, ...] = ()
) -> TracingDebugInfo | None:
"""Try to build trace-time debug info for fun when applied to args/kwargs."""
arg_names = _arg_names(fun_signature, args, kwargs, static_argnums,
assert isinstance(fun, Callable)
if fun_src_info is None:
fun_src_info = fun_sourceinfo(fun)
if fun_sig is None:
fun_sig = fun_signature(fun)
arg_names = _arg_names(fun_sig, args, kwargs, static_argnums,
static_argnames)
if arg_names is None:
return None
return TracingDebugInfo(traced_for, src, arg_names, None)
return TracingDebugInfo(traced_for, fun_src_info, arg_names, None)

def fun_signature(fun: Callable) -> inspect.Signature | None:
try:
Expand Down Expand Up @@ -686,10 +738,12 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
@lu.transformation_with_aux
def result_paths(*args, **kwargs):
"linear_util transform to get output pytree paths of pre-flattened function."
# TODO(necula): remove this function
ans = yield args, kwargs
yield ans, [keystr(path) for path, _ in generate_key_paths(ans)]

def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,
def jaxpr_debug_info(jaxpr: core.Jaxpr,
trace_debug: TracingDebugInfo | None,
result_paths: tuple[str, ...] | None = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
Expand All @@ -703,15 +757,6 @@ def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,
trace_debug.arg_names, tuple(result_paths))
return jaxpr.replace(debug_info=debug_info)

def debug_info_final(f: lu.WrappedFun, dbg: TracingDebugInfo | None,
res_paths: Callable[[], tuple[str, ...]]) -> lu.WrappedFun:
"Attach trace-time debug info and result paths lazy thunk to an lu.WrappedFun"
if dbg is None: return f
assert dbg.result_paths is None
res_paths_ = HashableFunction(res_paths, closure=())
return lu.add_debug_info(f, dbg._replace(result_paths=res_paths_))


def hoist_obj_attrs(f, flat_args):
idxs, objs, flat_args_ = [], [], []
for i, x in enumerate(flat_args):
Expand Down
1 change: 1 addition & 0 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
from jax._src import linear_util as lu
from jax._src.linear_util import TracingDebugInfo

from jax._src import source_info_util
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/custom_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def transformation_with_aux(
return fun.wrap(gen, gen_static_args, out_store), out_thunk

flatten_fun_nokwargs = transformation_with_aux(
api_util.flatten_fun_nokwargs.args[0])
api_util.flatten_fun_nokwargs_transformation.args[0])


### api
Expand Down
6 changes: 5 additions & 1 deletion jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import numpy as np

from jax._src import ad_util
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dtypes
Expand Down Expand Up @@ -2051,7 +2052,10 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
as `avals_out`."""
def f_lowered(ctx: LoweringRuleContext, *args, **params):
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
wrapped_fun = lu.wrap_init(f, params)
fun_src_info = api_util.fun_sourceinfo(fun)
dbg = api_util.debug_info(fun, "lower", args, {},
fun_src_info=fun_src_info)
wrapped_fun = lu.wrap_init(f, params=params, debug_info=dbg)
manager = (contextlib.nullcontext() if ctx.jaxpr_eqn_ctx is None else
ctx.jaxpr_eqn_ctx.manager)

Expand Down
32 changes: 23 additions & 9 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,7 +1795,7 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]:


class DynamicJaxprTrace(core.Trace):
def __init__(self, frame):
def __init__(self, frame: JaxprStackFrame):
self.frame = frame

def invalidate(self):
Expand Down Expand Up @@ -2101,19 +2101,33 @@ class DebugInfo(NamedTuple):
out_tree: Callable[[], PyTreeDef] | None # lazy, not avail at trace time
has_kwargs: bool # whether in_tree corresponds to (args, kwargs) or args
traced_for: str # "jit", "scan", "make_jaxpr", etc
# TODO(necula): will replace here all DebugInfo
replacement_debug_info: core.TracingDebugInfo

def debug_info(fn: Callable, in_tree: PyTreeDef | None,
out_tree_thunk: Callable[[], PyTreeDef] | None,
has_kwargs: bool, traced_for: str) -> DebugInfo:
has_kwargs: bool, traced_for: str,
*,
replacement_debug_info: core.TracingDebugInfo | None = None) -> DebugInfo:
sig = api_util.fun_signature(fn)
src_info = fun_sourceinfo(fn)
return DebugInfo(src_info, sig, in_tree, out_tree_thunk, has_kwargs,
traced_for)

def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
fun_src_info = fun_sourceinfo(fn)
if replacement_debug_info is None:
assert in_tree is not None
dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves)
dummy_args, dummy_kwargs = dummy_args if has_kwargs else (dummy_args, {})
replacement_debug_info = api_util.debug_info(fn,
traced_for,
dummy_args,
dummy_kwargs,
fun_src_info=fun_src_info)
return DebugInfo(fun_src_info, sig, in_tree, out_tree_thunk, has_kwargs,
traced_for, replacement_debug_info)

def debug_info_final(fun: lu.WrappedFun, traced_for: str) -> DebugInfo:
"Make a DebugInfo from data available to final-style primitives like pmap."
in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False)
return debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)
in_tree, out_tree, has_kws = flattened_fun_in_tree(fun) or (None, None, False)
return debug_info(fun.f, in_tree, out_tree, has_kws, traced_for,
replacement_debug_info=fun.debug_info)

def arg_info_all(dbg: DebugInfo) -> list[tuple[str, KeyPath]] | None:
ba = None if dbg.in_tree is None else sig_info(dbg)
Expand Down
1 change: 1 addition & 0 deletions jax/_src/lax/control_flow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from functools import partial
from typing import Any

from jax._src import api_util
from jax._src import core
from jax._src import linear_util as lu
from jax._src.lax import lax
Expand Down
Loading

0 comments on commit c6d5b41

Please sign in to comment.