diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 7977f6329531..d02ec8e648cf 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -106,6 +106,16 @@ ) +# Warn the user if they call fork(), because it's not going to go well for them. +def _at_fork(): + warnings.warn( + "os.fork() was called. os.fork() is incompatible with multithreaded code, " + "and JAX is multithreaded, so this will likely lead to a deadlock.", + RuntimeWarning, stacklevel=2) + +os.register_at_fork(before=_at_fork) + + # Backends