Skip to content

Commit

Permalink
Skip broadcast_one_to_all for single-process JAX execution
Browse files Browse the repository at this point in the history
There is no need to actually perform exchange across processes if there's only one process.

PiperOrigin-RevId: 646309814
  • Loading branch information
junwhanahn authored and jax authors committed Jun 25, 2024
1 parent 94c5d0d commit 817eb7a
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions jax/experimental/multihost_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ def broadcast_one_to_all(in_tree: Any, is_source: bool | None = None) -> Any:
A pytree matching in_tree where the leaves now all contain the data from the
first host.
"""
if jax.process_count() == 1:
# Note: This may return results that are different from the multi-host case
# below since it does not force-convert inputs to numpy arrays. We don't do
# such conversion here (and the API contract does not promise such a
# requirement) because doing so could be expensive for single-controller
# runtimes with lots of addressable devices.
return in_tree

if is_source is None:
is_source = jax.process_index() == 0

Expand Down

0 comments on commit 817eb7a

Please sign in to comment.