diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 55b9ef2c7183..48ef52c8b552 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -114,10 +114,7 @@ def _at_fork(): "and JAX is multithreaded, so this will likely lead to a deadlock.", RuntimeWarning, stacklevel=2) -# os.register_at_fork only exists on Unix. -if hasattr(os, "register_at_fork"): - os.register_at_fork(before=_at_fork) - +_at_fork_handler_installed = False # Backends @@ -824,12 +821,19 @@ def backends() -> dict[str, xla_client.Client]: global _backends global _backend_errors global _default_backend + global _at_fork_handler_installed _discover_and_register_pjrt_plugins() with _backend_lock: if _backends: return _backends + + # os.register_at_fork only exists on Unix. + if not _at_fork_handler_installed and hasattr(os, "register_at_fork"): + os.register_at_fork(before=_at_fork) + _at_fork_handler_installed = True + if jax_platforms := config.jax_platforms.value: platforms = [] # Allow platform aliases in the list of platforms.