Skip to content

Commit e20a483

Browse files
hyeontaekGoogle-ML-Automation
authored andcommitted
[JAX] Add end-to-end execution support in colocated Python API
This change adds a capability to run colocated Python function calls through `PyLoadedExecutable`. This capability is not yet used for McJAX, but is tested with a prototype of a colocated Python backend. The overall behavior remains the same for McJAX (running the user code inline when colocated Python is called); the new logic will be used once we introduce a colocated Python backend for McJAX. Key highlights: * Colocated Python is compiled into `PyLoadedExeutable` and uses the JAX C++ dispatch path. * `CustomCallProgram` for a colocated Python compilation nows includes specialization (input/output specs, devices). This information allows a colocated Python backend to transform input/outputs and validate PyTree/dtype/shape/sharding. * `out_specs_fn` now receives `jax.ShapeDTypeStruct`s instead of concrete values. * Deserialization of devices now prefers the default backend. This improves the compatibility with an environment using both multi-platform backend as well as the standard "cpu" backend at the same time. * Several bugs have been fixed (e.g., correctly using `{}` for kwargs). PiperOrigin-RevId: 703172997
1 parent 3f5f3e1 commit e20a483

File tree

4 files changed

+109
-34
lines changed

4 files changed

+109
-34
lines changed

Diff for: jax/experimental/colocated_python/func.py

+59-18
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import jax
2525
from jax._src import api
2626
from jax._src import tree_util
27+
from jax._src.interpreters import pxla
2728
from jax._src.lib import xla_client as xc
2829
from jax._src.traceback_util import api_boundary
2930
from jax._src.util import wraps
@@ -137,23 +138,54 @@ def _infer_devices_from_args(args: Sequence[Any]) -> xc.DeviceList | None:
137138
def _compile_to_executable(
138139
name: str,
139140
fun: Callable[..., Any],
141+
in_specs_treedef: tree_util.PyTreeDef,
140142
in_specs_leaves: tuple[api.ShapeDtypeStruct, ...],
143+
out_specs_treedef: tree_util.PyTreeDef,
141144
out_specs_leaves: tuple[api.ShapeDtypeStruct, ...],
142145
devices: xc.DeviceList,
143146
) -> Callable[..., Any]:
144147
"""Compiles a Python function into a runtime executable."""
145-
pickled_function = _serialize(fun)
148+
fun_and_specialization = (
149+
fun,
150+
in_specs_treedef,
151+
in_specs_leaves,
152+
out_specs_treedef,
153+
out_specs_leaves,
154+
devices,
155+
)
156+
pickled_function = _serialize(fun_and_specialization)
146157
program = ifrt_programs.make_colocated_python_program(
147158
name, pickled_function, devices, in_specs_leaves, out_specs_leaves
148159
)
149-
# TODO(hyeontaek): Compile the program and use the executable.
150-
del program
160+
ifrt_client = devices[0].client
161+
out_sdss = tuple(
162+
jax.core.ShapedArray(sds.shape, sds.dtype) for sds in out_specs_leaves
163+
)
164+
out_shardings = tuple(sds.sharding for sds in out_specs_leaves)
165+
try:
166+
compile_options = ifrt_programs.make_colocated_python_compile_options()
167+
loaded_executable = ifrt_client.compile_ifrt_program(
168+
program, compile_options
169+
)
170+
out_handlers = pxla.global_avals_to_results_handler(
171+
out_sdss, out_shardings, committed=True
172+
).handlers
173+
174+
def call(*args, **kwargs):
175+
args_leaves = tree_util.tree_leaves((args, kwargs))
176+
execute_result = loaded_executable.execute_sharded(
177+
args_leaves, with_tokens=False
178+
)
179+
results = execute_result.consume_with_handlers(out_handlers)
180+
return tree_util.tree_unflatten(out_specs_treedef, results)
151181

152-
del name
153-
del in_specs_leaves
154-
del out_specs_leaves
155-
del devices
156-
return fun
182+
return call
183+
except jax.errors.JaxRuntimeError as e:
184+
# TODO(hyeontaek): Implement colocated Python support in McJAX and remove
185+
# this fallback path.
186+
if "PjRtCompiler requires an HloProgram" in str(e):
187+
return fun
188+
raise
157189

158190

159191
def _make_output_specs_and_push_result_fun(
@@ -170,20 +202,22 @@ def _make_output_specs_and_push_result_fun(
170202

171203
def lowered_fun(*args, **kwargs) -> Sequence[jax.Array]:
172204
result = info.fun(*args, **kwargs)
173-
out_leaves, out_treedef = tree_util.tree_flatten(result)
174-
out_spec_leaves = tuple(_get_spec(x) for x in out_leaves)
175-
func_backend.SINGLETON_RESULT_STORE.push(uid, out_leaves)
205+
result_leaves, out_treedef = tree_util.tree_flatten(result)
206+
out_spec_leaves = tuple(_get_spec(x) for x in result_leaves)
207+
func_backend.SINGLETON_RESULT_STORE.push(uid, result_leaves)
176208
return _serialize_specs(out_treedef, out_spec_leaves, devices)
177209

178-
out_specs_leaves, _ = tree_util.tree_flatten(
210+
out_specs_leaves, out_specs_treedef = tree_util.tree_flatten(
179211
_make_specs_for_serialized_specs(specialization.devices),
180212
)
181213
name = getattr(info.fun, "__name__", "unknown")
182214
name = f"{name}_output_specs_and_push_result"
183215
return _compile_to_executable(
184216
name=name,
185217
fun=lowered_fun,
218+
in_specs_treedef=specialization.in_specs_treedef,
186219
in_specs_leaves=specialization.in_specs_leaves,
220+
out_specs_treedef=out_specs_treedef,
187221
out_specs_leaves=tuple(out_specs_leaves),
188222
devices=specialization.devices,
189223
)
@@ -200,21 +234,23 @@ def _make_pop_result_fun(
200234
out_specs_treedef = specialization.out_specs_treedef
201235

202236
def lowered_fun() -> Any:
203-
flat_result = func_backend.SINGLETON_RESULT_STORE.pop(uid)
204-
return tree_util.tree_unflatten(out_specs_treedef, flat_result)
237+
result_leaves = func_backend.SINGLETON_RESULT_STORE.pop(uid)
238+
return tree_util.tree_unflatten(out_specs_treedef, result_leaves)
205239

206-
in_specs, _ = tree_util.tree_flatten((
240+
in_specs_leaves, in_specs_treedef = tree_util.tree_flatten((
207241
# args
208242
(),
209243
# kwargs
210-
(),
244+
{},
211245
))
212246
name = getattr(info.fun, "__name__", "unknown")
213247
name = f"{name}_pop_result"
214248
return _compile_to_executable(
215249
name=name,
216250
fun=lowered_fun,
217-
in_specs_leaves=tuple(in_specs),
251+
in_specs_treedef=in_specs_treedef,
252+
in_specs_leaves=tuple(in_specs_leaves),
253+
out_specs_treedef=specialization.out_specs_treedef,
218254
out_specs_leaves=specialization.out_specs_leaves,
219255
devices=specialization.devices,
220256
)
@@ -234,7 +270,9 @@ def _make_async_execution_fun(
234270
return _compile_to_executable(
235271
name=name,
236272
fun=info.fun,
273+
in_specs_treedef=specialization.in_specs_treedef,
237274
in_specs_leaves=specialization.in_specs_leaves,
275+
out_specs_treedef=specialization.out_specs_treedef,
238276
out_specs_leaves=specialization.out_specs_leaves,
239277
devices=specialization.devices,
240278
)
@@ -283,7 +321,10 @@ def specialized_func(*args, **kwargs) -> Any:
283321
return _make_pop_result_fun(info, specialization, uid)()
284322
else:
285323
# Compute out_specs using out_specs_fn and inputs.
286-
out_specs = specialization.out_specs_fn(*args, **kwargs)
324+
args_specs, kwargs_specs = tree_util.tree_map(
325+
_get_spec, (args, kwargs)
326+
)
327+
out_specs = specialization.out_specs_fn(*args_specs, **kwargs_specs)
287328
# Type checking is ignored to silence mypy error: Incompatible types
288329
# in assignment (expression has type "list[Any]", variable has type
289330
# "tuple[ShapeDtypeStruct, ...]") [assignment]

Diff for: jax/experimental/colocated_python/serialization.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,22 @@ def _get_cpu_device_map() -> dict[int, jax.Device]:
5151
# associated with colocated_python. When deserializing on the colocated_python
5252
# executor, it should be the CPU backend visible to the user function running
5353
# under colocated_python.
54-
for backed in xb.backends().values():
55-
for d in backed._get_all_devices(): # pylint: disable=protected-access
54+
55+
# Look for CPU devices in the default backend.
56+
for d in xb.local_devices()[0].client._get_all_devices(): # pylint: disable=protected-access
57+
if d.device_kind == "cpu":
58+
if d.id in cpu_device_map:
59+
raise ValueError(
60+
f"Multiple CPU devices with id {d.id} found:"
61+
f" {cpu_device_map[d.id]} and {d}"
62+
)
63+
cpu_device_map[d.id] = d
64+
if cpu_device_map:
65+
return cpu_device_map
66+
67+
# Fall back to searching CPU devices in all backends.
68+
for backend in xb.backends().values():
69+
for d in backend._get_all_devices(): # pylint: disable=protected-access
5670
if d.device_kind == "cpu":
5771
if d.id in cpu_device_map:
5872
raise ValueError(
@@ -87,7 +101,7 @@ def make_device_list(device_ids: Sequence[int]) -> DeviceList:
87101
devices = np.vectorize(lambda device_id: cpu_device_map[device_id])(
88102
device_ids
89103
)
90-
return DeviceList(devices)
104+
return DeviceList(tuple(devices))
91105

92106
device_ids = [d.id for d in device_list]
93107
return make_device_list, (device_ids,)

Diff for: tests/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -1562,6 +1562,7 @@ exports_files(
15621562
"api_test.py",
15631563
"array_test.py",
15641564
"cache_key_test.py",
1565+
"colocated_python_test.py",
15651566
"compilation_cache_test.py",
15661567
"memories_test.py",
15671568
"pmap_test.py",

Diff for: tests/colocated_python_test.py

+32-13
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,20 @@ def _colocated_cpu_devices(
3434
devices: Sequence[jax.Device],
3535
) -> Sequence[jax.Device]:
3636
"""Returns CPU devices colocated with the given devices."""
37-
# TODO(hyeontaek): Use `colocated_python.colocated_cpu_devices(devices)` once
38-
# PjRt-IFRT prepares CPU devices by its own.
39-
cpu_backend_devices = jax.local_devices(backend="cpu")
40-
device_index_map = {device.id: i for i, device in enumerate(jax.devices())}
37+
try:
38+
return colocated_python.colocated_cpu_devices(devices)
39+
except (ValueError, AttributeError):
40+
# PjRt-IFRT prepares CPU devices by its own.
41+
# TODO(hyeontaek): Remove this fallback path once PjRt-IFRT prepares CPU
42+
# devices by its own.
43+
cpu_backend_devices = jax.local_devices(backend="cpu")
44+
device_index_map = {device.id: i for i, device in enumerate(jax.devices())}
45+
46+
available_devices = devices[: min(len(cpu_backend_devices), len(devices))]
47+
return [
48+
cpu_backend_devices[device_index_map[d.id]] for d in available_devices
49+
]
4150

42-
available_devices = devices[:min(len(cpu_backend_devices), len(devices))]
43-
return [
44-
cpu_backend_devices[device_index_map[d.id]] for d in available_devices
45-
]
4651

4752
@contextlib.contextmanager
4853
def _count_colocated_python_specialization_cache_miss() -> list[int]:
@@ -79,20 +84,20 @@ class ColocatedPythonTest(jtu.JaxTestCase):
7984

8085
def setUp(self):
8186
super().setUp()
82-
if xla_extension_version < 298:
83-
self.skipTest("Requires xla_extension_version >= 298")
87+
if xla_extension_version < 300:
88+
self.skipTest("Requires xla_extension_version >= 300")
8489

8590
def testMakeColocatedPythonProgram(self):
8691
def add_one(x):
8792
return x + 1
8893

8994
cpu_devices = _colocated_cpu_devices(jax.local_devices())
9095
sharding = jax.sharding.SingleDeviceSharding(cpu_devices[0])
91-
aval = jax.ShapeDtypeStruct((), jnp.int32, sharding=sharding)
96+
sds = jax.ShapeDtypeStruct((), jnp.int32, sharding=sharding)
9297

9398
pickled_function = serialization._serialize(add_one)
9499
program = ifrt_programs.make_colocated_python_program(
95-
"add_one", pickled_function, [cpu_devices[0]], [aval], [aval]
100+
"add_one", pickled_function, [cpu_devices[0]], [sds], [sds]
96101
)
97102
del program
98103

@@ -107,10 +112,12 @@ def add_one(x):
107112

108113
with _count_colocated_python_specialization_cache_miss() as count:
109114
out = add_one(x)
115+
out = jax.device_get(out)
110116
self.assertEqual(out, np.array(2))
111117
self.assertEqual(count[0], 1)
112118

113119
out = add_one(x)
120+
out = jax.device_get(out)
114121
self.assertEqual(out, np.array(2))
115122
self.assertEqual(count[0], 1)
116123

@@ -125,10 +132,12 @@ def add_one(x):
125132

126133
with _count_colocated_python_specialization_cache_miss() as count:
127134
out = add_one(x)
135+
out = jax.device_get(out)
128136
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
129137
self.assertEqual(count[0], 1)
130138

131139
out = add_one(x)
140+
out = jax.device_get(out)
132141
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
133142
self.assertEqual(count[0], 1)
134143

@@ -154,10 +163,12 @@ def make_zero():
154163
with _count_colocated_python_specialization_cache_miss() as count:
155164
make_zero = make_zero.specialize(devices=cpu_devices[:1])
156165
out = make_zero()
166+
out = jax.device_get(out)
157167
self.assertEqual(out, np.array(0))
158168
self.assertEqual(count[0], 1)
159169

160170
out = make_zero()
171+
out = jax.device_get(out)
161172
self.assertEqual(out, np.array(0))
162173
self.assertEqual(count[0], 1)
163174

@@ -172,10 +183,12 @@ def add_one(x):
172183

173184
with _count_colocated_python_specialization_cache_miss() as count:
174185
out = add_one(x)
186+
out = jax.device_get(out)
175187
self.assertEqual(out, np.array(2))
176188
self.assertEqual(count[0], 1)
177189

178190
out = add_one(x)
191+
out = jax.device_get(out)
179192
self.assertEqual(out, np.array(2))
180193
self.assertEqual(count[0], 1)
181194

@@ -184,10 +197,12 @@ def add_one(x):
184197
x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0]))
185198

186199
out = add_one(x)
200+
out = jax.device_get(out)
187201
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
188202
self.assertEqual(count[0], 2)
189203

190204
out = add_one(x)
205+
out = jax.device_get(out)
191206
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
192207
self.assertEqual(count[0], 2)
193208

@@ -203,22 +218,26 @@ def add_one(x):
203218
with _count_colocated_python_specialization_cache_miss() as count:
204219
add_one = add_one.specialize(out_specs_fn=lambda x: x)
205220
out = add_one(x)
221+
out = jax.device_get(out)
206222
self.assertEqual(out, np.array(2))
207223
self.assertEqual(count[0], 1)
208224

209225
out = add_one(x)
226+
out = jax.device_get(out)
210227
self.assertEqual(out, np.array(2))
211228
self.assertEqual(count[0], 1)
212229

213230
# Different input tree structure and dtype/shape.
214-
x = [np.array(1), (np.array(2), {"v": jnp.array(3)})]
231+
x = [np.array(1), (np.array(2), {"v": np.array(3)})]
215232
x = jax.device_put(x, jax.sharding.SingleDeviceSharding(cpu_devices[0]))
216233

217234
out = add_one(x)
235+
out = jax.device_get(out)
218236
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
219237
self.assertEqual(count[0], 2)
220238

221239
out = add_one(x)
240+
out = jax.device_get(out)
222241
self.assertEqual(out, [np.array(2), (np.array(3), {"v": np.array(4)})])
223242
self.assertEqual(count[0], 2)
224243

0 commit comments

Comments
 (0)