Skip to content

Commit

Permalink
Merge pull request #20734 from hawkinsp:atfork
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625038840
  • Loading branch information
jax authors committed Apr 15, 2024
2 parents 5f22b12 + fc128b3 commit b9a853d
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,7 @@ def _at_fork():
"and JAX is multithreaded, so this will likely lead to a deadlock.",
RuntimeWarning, stacklevel=2)

# os.register_at_fork only exists on Unix.
if hasattr(os, "register_at_fork"):
os.register_at_fork(before=_at_fork)

_at_fork_handler_installed = False

# Backends

Expand Down Expand Up @@ -809,12 +806,19 @@ def backends() -> dict[str, xla_client.Client]:
global _backends
global _backend_errors
global _default_backend
global _at_fork_handler_installed

_discover_and_register_pjrt_plugins()

with _backend_lock:
if _backends:
return _backends

# os.register_at_fork only exists on Unix.
if not _at_fork_handler_installed and hasattr(os, "register_at_fork"):
os.register_at_fork(before=_at_fork)
_at_fork_handler_installed = True

if jax_platforms := config.jax_platforms.value:
platforms = []
# Allow platform aliases in the list of platforms.
Expand Down

0 comments on commit b9a853d

Please sign in to comment.