From ec89e5e4c54583fd16a5bc4d6af5e5b5af8f063e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 14 Dec 2023 18:02:13 -0500 Subject: [PATCH] Add a warning if the user calls os.fork(). Fixes https://github.com/google/jax/issues/18852 --- jax/_src/xla_bridge.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 7977f6329531..d02ec8e648cf 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -106,6 +106,16 @@ ) +# Warn the user if they call fork(), because it's not going to go well for them. +def _at_fork(): + warnings.warn( + "os.fork() was called. os.fork() is incompatible with multithreaded code, " + "and JAX is multithreaded, so this will likely lead to a deadlock.", + RuntimeWarning, stacklevel=2) + +os.register_at_fork(before=_at_fork) + + # Backends