Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Install fork() warning during backend initialization, rather than jax… #20734

Merged
merged 1 commit into from
Apr 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,7 @@ def _at_fork():
"and JAX is multithreaded, so this will likely lead to a deadlock.",
RuntimeWarning, stacklevel=2)

# os.register_at_fork only exists on Unix.
if hasattr(os, "register_at_fork"):
os.register_at_fork(before=_at_fork)

_at_fork_handler_installed = False

# Backends

Expand Down Expand Up @@ -824,12 +821,19 @@ def backends() -> dict[str, xla_client.Client]:
global _backends
global _backend_errors
global _default_backend
global _at_fork_handler_installed

_discover_and_register_pjrt_plugins()

with _backend_lock:
if _backends:
return _backends

# os.register_at_fork only exists on Unix.
if not _at_fork_handler_installed and hasattr(os, "register_at_fork"):
os.register_at_fork(before=_at_fork)
_at_fork_handler_installed = True

if jax_platforms := config.jax_platforms.value:
platforms = []
# Allow platform aliases in the list of platforms.
Expand Down