Skip to content

Commit b52bcc1

Browse files
yashk2810jax authors
authored and
jax authors
committed
Reverts 3c07c10
PiperOrigin-RevId: 590700623
1 parent e888806 commit b52bcc1

File tree

3 files changed

+20
-15
lines changed

3 files changed

+20
-15
lines changed

Diff for: CHANGELOG.md

-10
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,6 @@ Remember to align the itemized text with the first line of an item within a list
88

99
## jax 0.4.22
1010

11-
* Changes
12-
* JAX lowering to StableHLO does not depend on physical devices anymore.
13-
If your primitive wraps custom_paritioning or JAX callbacks in the lowering
14-
rule i.e. function passed to `rule` parameter of `mlir.register_lowering` then add your
15-
primitive to `jax._src.dispatch.prim_requires_devices_during_lowering` set.
16-
This is needed because custom_partitioning and JAX callbacks need physical
17-
devices to create `Sharding`s during lowering.
18-
This is a temporary state until we can create `Sharding`s without physical
19-
devices.
20-
2111
* Deprecations
2212
* The `device_buffer` and `device_buffers` properties of JAX arrays are deprecated.
2313
Explicit buffers have been replaced by the more flexible array sharding interface,

Diff for: jax/_src/interpreters/pxla.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ class WeakRefList(list):
102102
MeshDimAssignment = Union[ShardedAxis, Replicated]
103103
ShardingSpec = sharding_specs.ShardingSpec
104104

105+
# TODO(yashkatariya): Remove this flag when nvidia's use cases are fixed.
106+
_JAX_REQUIRE_DEVICES_DURING_LOWERING = config.DEFINE_bool(
107+
"jax_require_devices_during_lowering",
108+
True,
109+
help="Forces physical devices to be passed during lowering to stablehlo.")
110+
105111
### util
106112

107113
def identity(x): return x
@@ -1971,13 +1977,17 @@ def lower_sharding_computation(
19711977
semantic_in_shardings = SemanticallyEqualShardings(in_shardings) # type: ignore
19721978
semantic_out_shardings = SemanticallyEqualShardings(out_shardings) # type: ignore
19731979
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)
1980+
materialized_da = (
1981+
tuple(da_object)
1982+
if prim_requires_devices or _JAX_REQUIRE_DEVICES_DURING_LOWERING.value
1983+
else None)
19741984

19751985
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
19761986
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
19771987
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
19781988
semantic_out_shardings, in_layouts, out_layouts, len(da_object),
1979-
tuple(da_object) if prim_requires_devices else None, donated_invars,
1980-
name_stack, all_default_mem_kind, lowering_parameters=lowering_parameters)
1989+
materialized_da, donated_invars, name_stack, all_default_mem_kind,
1990+
lowering_parameters=lowering_parameters)
19811991

19821992
# backend and device_assignment is passed through to MeshExecutable because
19831993
# if keep_unused=False and all in_shardings are pruned, then there is no way

Diff for: tests/pjit_test.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -3802,9 +3802,14 @@ def g(a):
38023802
b = jax.device_put(out_a, NamedSharding(mesh2, P('y')))
38033803
f(b) # lowering cache *hit*
38043804

3805-
with jtu.count_jit_and_pmap_compiles() as count:
3806-
g(np.arange(8))
3807-
self.assertEqual(count[0], 1)
3805+
prev_value = pxla._JAX_REQUIRE_DEVICES_DURING_LOWERING.value
3806+
try:
3807+
jax.config.update('jax_require_devices_during_lowering', False)
3808+
with jtu.count_jit_and_pmap_compiles() as count:
3809+
g(np.arange(8))
3810+
self.assertEqual(count[0], 1)
3811+
finally:
3812+
jax.config.update('jax_require_devices_during_lowering', prev_value)
38083813

38093814
def test_lowering_cache_miss_different_devices_and_sharding(self):
38103815
if jax.device_count() < 4:

0 commit comments

Comments
 (0)