Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements in the handling of tracing debug info #24831

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dtypes
Expand Down Expand Up @@ -404,7 +405,8 @@ def new_fun(*dyn_args, **kwargs):
# remat-specific errors.
@weakref_lru_cache
def _trace_to_jaxpr(fun, in_tree, in_avals):
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun, traced_for="checkpoint"),
in_tree)
debug = pe.debug_info(fun, in_tree, out_tree, True, "checkpoint")
try:
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
Expand Down
20 changes: 9 additions & 11 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, apply_flat_fun_nokwargs, check_callable, debug_info,
result_paths, flat_out_axes, debug_info_final, fun_sourceinfo)
flat_out_axes, fun_sourceinfo)
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
Expand Down Expand Up @@ -1392,19 +1392,17 @@ def _get_global_axis_size(local_axis_size: int, in_devices, backend_name: str,
return global_axis_size


def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
def _prepare_pmap(fun: Callable,
in_axes, out_axes, static_broadcasted_tuple,
donate_tuple, in_devices, backend_name,
axis_size, args, kwargs):
if in_devices is not None and len(in_devices) == 0:
raise ValueError("'devices' argument to pmap must be non-empty, or None.")

src = fun_sourceinfo(fun)
signature = api_util.fun_signature(fun)
dbg = debug_info(fun, 'pmap', args, kwargs,
static_argnums=static_broadcasted_tuple)

dbg = debug_info('pmap', src, signature, args, kwargs,
static_broadcasted_tuple, ())

f = lu.wrap_init(fun)
f = lu.wrap_init(fun, debug_info=dbg)
if static_broadcasted_tuple:
if max(static_broadcasted_tuple) >= len(args):
raise ValueError(
Expand Down Expand Up @@ -1451,10 +1449,10 @@ def _prepare_pmap(fun, in_axes, out_axes, static_broadcasted_tuple,
raise ValueError(msg) from None
local_axis_size = _mapped_axis_size(fun, in_tree, args, in_axes_flat, "pmap")

f, res_paths = result_paths(f)
# f, res_paths = result_paths(f)
f, out_axes_thunk = flat_out_axes(f, out_axes)
flat_fun, out_tree = flatten_fun(f, in_tree)
flat_fun = debug_info_final(flat_fun, dbg, res_paths)
# flat_fun = debug_info_final(flat_fun, dbg, res_paths)

is_explicit_global_axis_size = axis_size is not None
global_axis_size = _get_global_axis_size(local_axis_size, in_devices,
Expand Down Expand Up @@ -1957,7 +1955,7 @@ def vjp(
del reduce_axes
check_callable(fun)
return _vjp(
lu.wrap_init(fun), *primals, has_aux=has_aux)
lu.wrap_init(fun, traced_for="vjp"), *primals, has_aux=has_aux)

def _vjp(fun: lu.WrappedFun, *primals, has_aux=False):
"""Variant of vjp() that takes an lu.WrappedFun."""
Expand Down
95 changes: 70 additions & 25 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,51 @@ def _ensure_str_tuple(x: str | Iterable[str]) -> tuple[str, ...]:
else:
return tuple(map(_ensure_str, x))

def _fill_debug_info_after_flatten(
fun: lu.WrappedFun, in_tree: PyTreeDef,
has_kwargs: bool, out_tree_thunk: Callable[[], PyTreeDef]) -> lu.WrappedFun:
"""Fills in the fun.debug_info."""
dbg = fun.debug_info
if dbg is None:
fun_src_info = fun_sourceinfo(fun.f)
dbg = TracingDebugInfo("unknown", fun_src_info, None, None)
elif dbg.func_src_info is None:
fun_src_info = fun_sourceinfo(fun.f)
else:
fun_src_info = dbg.func_src_info

if dbg.arg_names is None:
dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves)
dummy_args, dummy_kwargs = dummy_args if has_kwargs else (dummy_args, {})
fun_sig = fun_signature(fun.f)
arg_names = _arg_names(fun_sig, dummy_args, dummy_kwargs, (), ())
else:
arg_names = dbg.arg_names

def result_paths() -> tuple[str, ...] | None:
out_tree = out_tree_thunk()
try:
num_leaves = out_tree.num_leaves
dummy_result = tree_unflatten(out_tree, [False] * num_leaves)
except:
return None
else:
return tuple(keystr(path) for path, _ in generate_key_paths(dummy_result))
result_paths_ = HashableFunction(result_paths, closure=())
dbg = TracingDebugInfo(dbg.traced_for, fun_src_info,
arg_names, result_paths_)
return lu.add_debug_info(fun, dbg)

@lu.transformation_with_aux
def flatten_fun(in_tree, *args_flat):
def flatten_fun_transformation(in_tree, *args_flat):
py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
ans = yield py_args, py_kwargs
yield tree_flatten(ans)

def flatten_fun(f: lu.WrappedFun, in_tree: PyTreeDef) -> tuple[lu.WrappedFun, Callable]:
wrapped, out_tree_thunk = flatten_fun_transformation(f, in_tree)
return _fill_debug_info_after_flatten(wrapped, in_tree, True, out_tree_thunk), out_tree_thunk

def apply_flat_fun(fun, io_tree, *py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten((py_args, {}))
Expand All @@ -83,11 +122,15 @@ def apply_flat_fun(fun, io_tree, *py_args):
return tree_unflatten(out_tree, ans)

@lu.transformation_with_aux
def flatten_fun_nokwargs(in_tree, *args_flat):
def flatten_fun_nokwargs_transformation(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = yield py_args, {}
yield tree_flatten(ans)

def flatten_fun_nokwargs(f: lu.WrappedFun, in_tree: PyTreeDef) -> tuple[lu.WrappedFun, Callable]:
wrapped, out_tree_thunk = flatten_fun_nokwargs_transformation(f, in_tree)
return _fill_debug_info_after_flatten(wrapped, in_tree, False, out_tree_thunk), out_tree_thunk

def apply_flat_fun_nokwargs(fun, io_tree, py_args):
in_tree_expected, out_tree = io_tree
args, in_tree = tree_flatten(py_args)
Expand All @@ -105,13 +148,13 @@ def flattened_fun_in_tree(
# with flatten_fun or flatten_fun_nokwargs, which could happen e.g. when
# core.eval_jaxpr encounters a call primitive (though at that point we're just
# round-tripping jaxprs and the user errors in question are impossible).
assert isinstance(flatten_fun, partial) and len(flatten_fun.args) == 1
assert (isinstance(flatten_fun_nokwargs, partial) and
len(flatten_fun_nokwargs.args) == 1)
flattens = {flatten_fun.args[0], flatten_fun_nokwargs.args[0]}
assert isinstance(flatten_fun_transformation, partial) and len(flatten_fun_transformation.args) == 1
assert (isinstance(flatten_fun_nokwargs_transformation, partial) and
len(flatten_fun_nokwargs_transformation.args) == 1)
flattens = {flatten_fun_transformation.args[0], flatten_fun_nokwargs_transformation.args[0]}
try:
((in_tree,), out_tree_store, has_kwargs), = (
(args, store, f is flatten_fun.args[0])
(args, store, f is flatten_fun_transformation.args[0])
for (f, args), store in zip(fn.transforms, fn.stores) if f in flattens)
except ValueError:
return None
Expand Down Expand Up @@ -639,16 +682,25 @@ def api_hook(fun, tag: str):


def debug_info(
traced_for: str, src: str | None, fun_signature: inspect.Signature | None,
args: tuple[Any, ...], kwargs: dict[str, Any], static_argnums: tuple[int, ...],
static_argnames: tuple[str, ...]
fun: Callable,
traced_for: str,
args: tuple[Any, ...],
kwargs: dict[str, Any],
*,
fun_src_info: str | None = None,
fun_sig: inspect.Signature | None = None,
static_argnums: tuple[int, ...] = (),
static_argnames: tuple[str, ...] = ()
) -> TracingDebugInfo | None:
"""Try to build trace-time debug info for fun when applied to args/kwargs."""
arg_names = _arg_names(fun_signature, args, kwargs, static_argnums,
assert not isinstance(fun, lu.WrappedFun)
if fun_src_info is None:
fun_src_info = fun_sourceinfo(fun)
if fun_sig is None:
fun_sig = fun_signature(fun)
arg_names = _arg_names(fun_sig, args, kwargs, static_argnums,
static_argnames)
if arg_names is None:
return None
return TracingDebugInfo(traced_for, src, arg_names, None)
return TracingDebugInfo(traced_for, fun_src_info, arg_names, None)

def fun_signature(fun: Callable) -> inspect.Signature | None:
try:
Expand Down Expand Up @@ -686,10 +738,12 @@ def _arg_names(fn_signature, args, kwargs, static_argnums, static_argnames,
@lu.transformation_with_aux
def result_paths(*args, **kwargs):
"linear_util transform to get output pytree paths of pre-flattened function."
# TODO(necula): remove this function
ans = yield args, kwargs
yield ans, [keystr(path) for path, _ in generate_key_paths(ans)]

def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,
def jaxpr_debug_info(jaxpr: core.Jaxpr,
trace_debug: TracingDebugInfo | None,
result_paths: tuple[str, ...] | None = None,
) -> core.Jaxpr:
"""Add debug info to jaxpr, given trace-time debug info and result paths."""
Expand All @@ -700,18 +754,9 @@ def jaxpr_debug_info(jaxpr: core.Jaxpr, trace_debug: TracingDebugInfo | None,
result_paths = trace_debug.result_paths() # type: ignore
debug_info = core.JaxprDebugInfo(
trace_debug.traced_for, trace_debug.func_src_info,
trace_debug.arg_names, tuple(result_paths))
trace_debug.arg_names, tuple(result_paths)) # type: ignore
return jaxpr.replace(debug_info=debug_info)

def debug_info_final(f: lu.WrappedFun, dbg: TracingDebugInfo | None,
res_paths: Callable[[], tuple[str, ...]]) -> lu.WrappedFun:
"Attach trace-time debug info and result paths lazy thunk to an lu.WrappedFun"
if dbg is None: return f
assert dbg.result_paths is None
res_paths_ = HashableFunction(res_paths, closure=())
return lu.add_debug_info(f, dbg._replace(result_paths=res_paths_))


def hoist_obj_attrs(f, flat_args):
idxs, objs, flat_args_ = [], [], []
for i, x in enumerate(flat_args):
Expand Down
1 change: 1 addition & 0 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
from jax._src import linear_util as lu
from jax._src.linear_util import TracingDebugInfo

from jax._src import source_info_util
from jax._src.util import (safe_zip, safe_map, curry, tuple_insert,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/custom_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def transformation_with_aux(
return fun.wrap(gen, gen_static_args, out_store), out_thunk

flatten_fun_nokwargs = transformation_with_aux(
api_util.flatten_fun_nokwargs.args[0])
api_util.flatten_fun_nokwargs_transformation.args[0])


### api
Expand Down
6 changes: 5 additions & 1 deletion jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import numpy as np

from jax._src import ad_util
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import dtypes
Expand Down Expand Up @@ -2051,7 +2052,10 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
as `avals_out`."""
def f_lowered(ctx: LoweringRuleContext, *args, **params):
f = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),)
wrapped_fun = lu.wrap_init(f, params)
fun_src_info = api_util.fun_sourceinfo(fun)
dbg = api_util.debug_info(fun, "lower", args, {},
fun_src_info=fun_src_info)
wrapped_fun = lu.wrap_init(f, params=params, debug_info=dbg)
manager = (contextlib.nullcontext() if ctx.jaxpr_eqn_ctx is None else
ctx.jaxpr_eqn_ctx.manager)

Expand Down
32 changes: 23 additions & 9 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1795,7 +1795,7 @@ def vars_in_shape(aval: AbstractValue) -> Sequence[Var]:


class DynamicJaxprTrace(core.Trace):
def __init__(self, frame):
def __init__(self, frame: JaxprStackFrame):
self.frame = frame

def invalidate(self):
Expand Down Expand Up @@ -2101,19 +2101,33 @@ class DebugInfo(NamedTuple):
out_tree: Callable[[], PyTreeDef] | None # lazy, not avail at trace time
has_kwargs: bool # whether in_tree corresponds to (args, kwargs) or args
traced_for: str # "jit", "scan", "make_jaxpr", etc
# TODO(necula): will replace here all DebugInfo
replacement_debug_info: core.TracingDebugInfo

def debug_info(fn: Callable, in_tree: PyTreeDef | None,
out_tree_thunk: Callable[[], PyTreeDef] | None,
has_kwargs: bool, traced_for: str) -> DebugInfo:
has_kwargs: bool, traced_for: str,
*,
replacement_debug_info: core.TracingDebugInfo | None = None) -> DebugInfo:
sig = api_util.fun_signature(fn)
src_info = fun_sourceinfo(fn)
return DebugInfo(src_info, sig, in_tree, out_tree_thunk, has_kwargs,
traced_for)

def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
fun_src_info = fun_sourceinfo(fn)
if replacement_debug_info is None:
assert in_tree is not None
dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves)
dummy_args, dummy_kwargs = dummy_args if has_kwargs else (dummy_args, {})
replacement_debug_info = api_util.debug_info(fn,
traced_for,
dummy_args,
dummy_kwargs,
fun_src_info=fun_src_info)
return DebugInfo(fun_src_info, sig, in_tree, out_tree_thunk, has_kwargs,
traced_for, replacement_debug_info)

def debug_info_final(fun: lu.WrappedFun, traced_for: str) -> DebugInfo:
"Make a DebugInfo from data available to final-style primitives like pmap."
in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False)
return debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)
in_tree, out_tree, has_kws = flattened_fun_in_tree(fun) or (None, None, False)
return debug_info(fun.f, in_tree, out_tree, has_kws, traced_for,
replacement_debug_info=fun.debug_info)

def arg_info_all(dbg: DebugInfo) -> list[tuple[str, KeyPath]] | None:
ba = None if dbg.in_tree is None else sig_info(dbg)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def stage_parallel_callable(
fun_name=fun.__name__, event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_sharded_avals, consts, _ = pe.trace_to_jaxpr_dynamic(
fun, sharded_avals, pe.debug_info_final(fun, "pmap"))
jaxpr = api_util.jaxpr_debug_info(jaxpr, orig_fun.debug_info)
jaxpr = api_util.jaxpr_debug_info(jaxpr, fun.debug_info)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)

assert len(out_sharded_avals) == len(pci.out_axes), (
Expand Down
1 change: 1 addition & 0 deletions jax/_src/lax/control_flow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from functools import partial
from typing import Any

from jax._src import api_util
from jax._src import core
from jax._src import linear_util as lu
from jax._src.lax import lax
Expand Down
Loading
Loading