Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a warning if the user calls os.fork(). #18989

Merged
merged 1 commit into from
Dec 15, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@
)


# Warn the user if they call fork(), because it's not going to go well for them.
def _at_fork():
warnings.warn(
"os.fork() was called. os.fork() is incompatible with multithreaded code, "
"and JAX is multithreaded, so this will likely lead to a deadlock.",
RuntimeWarning, stacklevel=2)
Comment on lines +109 to +114

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing this warning even when not directly using JAX. Could this warning be more localized somehow?

Minimal repro: https://colab.research.google.com/gist/ScottTodd/476cfa0dd6620511ace15ce321b03a4e/colab-fork-jax-warning.ipynb

import os
import transformers
pid = os.fork()

RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  pid = os.fork()

Specific callstacks into os.fork() and this code from my project: https://gist.github.com/ScottTodd/981f1c8696887e0f9aef7fa57adbe84b

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Threading and os.fork() do not mix. You're playing with fire combining os.fork() and pytorch, in my opinion.

It's possible we could delay adding this check until you actually initialize jax, I suppose.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main confusing part of this is that even when not using JAX (e.g., using PyTorch and it calls fork itself) one gets a warning about JAX. Delaying would at least avoid that happening if not using JAX at all, but seems would still trigger if one used both these libraries together.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax is a pretty heavy import if you're not using it. Maybe the best fix is to ensure that transformers doesn't import jax unless it needs to?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#20734 defers installing the warning until JAX initializes its internals. If you simply import but do not use jax, this should silence the warning.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fast responses! This warning did help us uncover a few issues across various packages, at least :)


os.register_at_fork(before=_at_fork)


# Backends


Expand Down