Skip to content

Commit

Permalink
Merge pull request #18989 from hawkinsp:fork
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 591286756
  • Loading branch information
jax authors committed Dec 15, 2023
2 parents c8b3567 + ec89e5e commit eb595af
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@
)


# Warn the user if they call fork(), because it's not going to go well for them.
def _at_fork():
warnings.warn(
"os.fork() was called. os.fork() is incompatible with multithreaded code, "
"and JAX is multithreaded, so this will likely lead to a deadlock.",
RuntimeWarning, stacklevel=2)

os.register_at_fork(before=_at_fork)


# Backends


Expand Down

0 comments on commit eb595af

Please sign in to comment.