diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 7977f6329531..cf63ed9d69a7 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -106,6 +106,16 @@ ) +# Warn the user if they call fork(), because it's not going to go well for them. +def _at_fork(): + warnings.warn( + "os.fork() is incompatible with multithreaded code, " + "and JAX is multithreaded, so this will likely lead to a deadlock.", + RuntimeWarning, stacklevel=2) + +os.register_at_fork(before=_at_fork) + + # Backends