From fc128b374d6e8784da5fac901d3c7881dd50e073 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 12 Apr 2024 13:44:04 -0400 Subject: [PATCH] Install fork() warning during backend initialization, rather than jax import. This avoids warning people making an incidental import of JAX. --- jax/_src/xla_bridge.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 55b9ef2c7183..48ef52c8b552 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -114,10 +114,7 @@ def _at_fork(): "and JAX is multithreaded, so this will likely lead to a deadlock.", RuntimeWarning, stacklevel=2) -# os.register_at_fork only exists on Unix. -if hasattr(os, "register_at_fork"): - os.register_at_fork(before=_at_fork) - +_at_fork_handler_installed = False # Backends @@ -824,12 +821,19 @@ def backends() -> dict[str, xla_client.Client]: global _backends global _backend_errors global _default_backend + global _at_fork_handler_installed _discover_and_register_pjrt_plugins() with _backend_lock: if _backends: return _backends + + # os.register_at_fork only exists on Unix. + if not _at_fork_handler_installed and hasattr(os, "register_at_fork"): + os.register_at_fork(before=_at_fork) + _at_fork_handler_installed = True + if jax_platforms := config.jax_platforms.value: platforms = [] # Allow platform aliases in the list of platforms.