Skip to content

Commit

Permalink
Add frontend attributes to Jax. This allows Jax users to annotate Jax…
Browse files Browse the repository at this point in the history
… code with frontend_attributes which can be traced down to the HLO level, to be used for numerical debugging purposes.

PiperOrigin-RevId: 671930431
  • Loading branch information
Google-ML-Automation authored and jax authors committed Sep 6, 2024
1 parent 671acef commit 02b7a76
Show file tree
Hide file tree
Showing 9 changed files with 434 additions and 9 deletions.
11 changes: 11 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ py_library_providing_imports_info(
":version",
":xla",
":xla_bridge",
":xla_metadata",
"//jax/_src/lib",
] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + py_deps("flatbuffers") + jax_extra_deps,
)
Expand Down Expand Up @@ -451,6 +452,7 @@ pytype_strict_library(
":tree_util",
":typing",
":util",
":xla_metadata",
"//jax/_src/lib",
] + py_deps("numpy"),
)
Expand Down Expand Up @@ -703,6 +705,7 @@ pytype_strict_library(
":state_types",
":tree_util",
":util",
":xla_metadata",
] + py_deps("numpy"),
)

Expand Down Expand Up @@ -768,6 +771,14 @@ pytype_strict_library(
deps = [":config"],
)

pytype_strict_library(
name = "xla_metadata",
srcs = ["_src/xla_metadata.py"],
deps = [
":config",
],
)

pytype_strict_library(
name = "layout",
srcs = ["_src/layout.py"],
Expand Down
9 changes: 7 additions & 2 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,16 +202,20 @@ def trace_context():
tls = jax_jit.thread_local_state()
axis_env_state = ()
mesh_context_manager = ()
xla_metadata_context_manager = ()
compute_on_context_manager = ()

context: Any = tls.extra_jit_context
if context and context.axis_env_state is not None:
axis_env_state = context.axis_env_state
if context and context.mesh_context_manager:
mesh_context_manager = context.mesh_context_manager
if context and context.xla_metadata_context_manager:
xla_metadata_context_manager = context.xla_metadata_context_manager
if context and context.compute_on_context_manager:
compute_on_context_manager = context.compute_on_context_manager
return (axis_env_state, mesh_context_manager, compute_on_context_manager,
enable_x64.value,
return (axis_env_state, mesh_context_manager, xla_metadata_context_manager,
compute_on_context_manager, enable_x64.value,
numpy_rank_promotion.value, default_matmul_precision.value,
dynamic_shapes.value, numpy_dtype_promotion.value,
default_device.value, random_seed_offset.value,
Expand Down Expand Up @@ -858,6 +862,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
axis_env_state: Hashable = ()
mesh_context_manager: Hashable = ()
compute_on_context_manager: Hashable = ()
xla_metadata_context_manager: Hashable = ()

# Values set by _StateContextManager context managers.
# CAUTION: these must be initialized to `None`! The state context manager
Expand Down
20 changes: 15 additions & 5 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from jax._src import traceback_util
from jax._src.typing import Array, DimSize, Shape
from jax._src import typing
from jax._src import xla_metadata as xla_metadata_lib

traceback_util.register_exclusion(__file__)

Expand Down Expand Up @@ -261,12 +262,15 @@ def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args):

class JaxprEqnContext:

def __init__(self, compute_type: str | None, threefry_partitionable: bool):
def __init__(self, compute_type: str | None, threefry_partitionable: bool,
xla_metadata=None):
self.compute_type = compute_type
self.threefry_partitionable = threefry_partitionable
self.xla_metadata = xla_metadata
self._managers = [
(compute_on.extend_compute_type, self.compute_type),
(config.threefry_partitionable.__call__, self.threefry_partitionable),
(xla_metadata_lib.set_xla_metadata, self.xla_metadata),
]

@property
Expand All @@ -278,8 +282,11 @@ def manager(self):
yield

def __repr__(self):
return (f"JaxprEqnContext(compute_type={self.compute_type},"
f"threefry_partitionable={self.threefry_partitionable})")
return (
f"JaxprEqnContext(compute_type={self.compute_type},"
f"threefry_partitionable={self.threefry_partitionable}),"
f"xla_metadata={self.xla_metadata}"
)


class JaxprEqn:
Expand Down Expand Up @@ -333,8 +340,11 @@ def replace(
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None,
ctx=None):
source_info = source_info or source_info_util.new_source_info()
ctx = ctx or JaxprEqnContext(compute_on.current_compute_type(),
config.threefry_partitionable.value)
ctx = ctx or JaxprEqnContext(
compute_on.current_compute_type(),
config.threefry_partitionable.value,
xla_metadata_lib.current_xla_metadata(),
)
if config.enable_checks.value:
assert all(isinstance(x, (Var, Literal)) for x in invars)
assert all(isinstance(v, Var) for v in outvars)
Expand Down
27 changes: 27 additions & 0 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1923,6 +1923,10 @@ def lower_per_platform(ctx: LoweringRuleContext,
lambda o: wrap_compute_type_in_place(ctx, o.owner),
filter(_is_not_block_argument, flatten_ir_values(output)),
)
map(
lambda o: wrap_xla_metadata_in_place(ctx, o.owner),
flatten_ir_values(output),
)
return output

assert len(platforms) > 1 and len(kept_rules) >= 2, (platforms, kept_rules)
Expand Down Expand Up @@ -1964,6 +1968,10 @@ def lower_per_platform(ctx: LoweringRuleContext,
lambda o: wrap_compute_type_in_place(ctx, o.owner),
filter(_is_not_block_argument, out_nodes),
)
map(
lambda o: wrap_xla_metadata_in_place(ctx, o.owner),
out_nodes,
)
if inner_ctx.tokens_out is not None:
assert len(ordered_effects) == len(inner_ctx.tokens_out)
out_nodes = [inner_ctx.tokens_out.get(eff)
Expand Down Expand Up @@ -2125,6 +2133,25 @@ def wrap_compute_type_in_place(ctx, op):
op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr)


def wrap_xla_metadata_in_place(ctx, op):
ctx_attributes = {}
existing_attributes = {}
if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.xla_metadata:
for k, v in ctx.jaxpr_eqn_ctx.xla_metadata.items():
ctx_attributes[k] = ir.StringAttr.get(str(v).lower())
if isinstance(op, ir.Operation):
# combine with existing mhlo.frontend_attributes
op_attributes_dict = {attr.name: attr.attr for attr in op.attributes}
for k, attributes in op_attributes_dict.items():
if k == "mhlo.frontend_attributes":
v_dict = {attr.name: attr.attr for attr in attributes}
for fa_key, fa_val in v_dict.items():
existing_attributes[fa_key] = fa_val
op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
ctx_attributes | existing_attributes
)


def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *,
broadcast_dimensions) -> ir.Value:
# broadcast_dimension[i] is the axis of the result where the axis i of
Expand Down
8 changes: 6 additions & 2 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from jax._src import profiler
from jax._src import source_info_util
from jax._src import compute_on
from jax._src import xla_metadata as xla_metadata_lib
from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs,
fun_sourceinfo)
from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval,
Expand Down Expand Up @@ -898,8 +899,11 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
assert ("donated_invars" in params and
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
out_avals = [core.raise_to_shaped(t.aval) for t in out_tracers]
ctx = ctx or JaxprEqnContext(compute_on.current_compute_type(),
config.threefry_partitionable.value)
ctx = ctx or JaxprEqnContext(
compute_on.current_compute_type(),
config.threefry_partitionable.value,
xla_metadata_lib.current_xla_metadata(),
)
return JaxprEqnRecipe(object(), tuple(in_tracers), map(ref, out_tracers),
out_avals, primitive, params, effects, source_info,
ctx)
Expand Down
55 changes: 55 additions & 0 deletions jax/_src/xla_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any
import threading
from contextlib import contextmanager

from jax._src import config


class _XlaMetadata(threading.local):
val: dict[Any, Any]

def __init__(self):
self.val = {}

thread_local_metadata = _XlaMetadata()

def current_xla_metadata():
return thread_local_metadata.val

@contextmanager
def set_xla_metadata(*args, **kwargs):
new_metadata = thread_local_metadata.val.copy()
if args:
new_metadata.update(args[0] if args[0] else {})
else:
new_metadata.update(**kwargs)
prev_metadata, thread_local_metadata.val = (
thread_local_metadata.val,
new_metadata,
)
config.update_thread_local_jit_state(
xla_metadata_context_manager=tuple(
(v, k) for k, v in sorted(new_metadata.items())))
try:
yield
finally:
thread_local_metadata.val = prev_metadata
config.update_thread_local_jit_state(
xla_metadata_context_manager=tuple(
(v, k) for k, v in sorted(prev_metadata.items())
)
)
17 changes: 17 additions & 0 deletions jax/experimental/xla_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the ific language governing permissions and
# limitations under the License.

from jax._src.xla_metadata import (
set_xla_metadata as set_xla_metadata,
)
6 changes: 6 additions & 0 deletions tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1505,6 +1505,12 @@ jax_test(
tags = ["multiaccelerator"],
)

jax_test(
name = "xla_metadata_test",
srcs = ["xla_metadata_test.py"],
deps = ["//jax:experimental"],
)

py_test(
name = "pretty_printer_test",
srcs = ["pretty_printer_test.py"],
Expand Down
Loading

0 comments on commit 02b7a76

Please sign in to comment.