Skip to content

Commit

Permalink
Add a warning if the user calls os.fork().
Browse files Browse the repository at this point in the history
Fixes #18852
  • Loading branch information
hawkinsp committed Dec 15, 2023
1 parent 1559d64 commit 680399f
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@
)


# Warn the user if they call fork(), because it's not going to go well for them.
def _at_fork():
warnings.warn(
"os.fork() is incompatible with multithreaded code, "
"and JAX is multithreaded, so this will likely lead to a deadlock.",
RuntimeWarning, stacklevel=2)

os.register_at_fork(before=_at_fork)


# Backends


Expand Down

0 comments on commit 680399f

Please sign in to comment.