-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a warning if the user calls os.fork(). #18989
Conversation
# Warn the user if they call fork(), because it's not going to go well for them. | ||
def _at_fork(): | ||
warnings.warn( | ||
"os.fork() was called. os.fork() is incompatible with multithreaded code, " | ||
"and JAX is multithreaded, so this will likely lead to a deadlock.", | ||
RuntimeWarning, stacklevel=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm seeing this warning even when not directly using JAX. Could this warning be more localized somehow?
Minimal repro: https://colab.research.google.com/gist/ScottTodd/476cfa0dd6620511ace15ce321b03a4e/colab-fork-jax-warning.ipynb
import os
import transformers
pid = os.fork()
RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
pid = os.fork()
Specific callstacks into os.fork() and this code from my project: https://gist.github.com/ScottTodd/981f1c8696887e0f9aef7fa57adbe84b
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Threading and os.fork()
do not mix. You're playing with fire combining os.fork()
and pytorch, in my opinion.
It's possible we could delay adding this check until you actually initialize jax
, I suppose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the main confusing part of this is that even when not using JAX (e.g., using PyTorch and it calls fork itself) one gets a warning about JAX. Delaying would at least avoid that happening if not using JAX at all, but seems would still trigger if one used both these libraries together.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax
is a pretty heavy import if you're not using it. Maybe the best fix is to ensure that transformers
doesn't import jax
unless it needs to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#20734 defers installing the warning until JAX initializes its internals. If you simply import but do not use jax
, this should silence the warning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fast responses! This warning did help us uncover a few issues across various packages, at least :)
Fixes #18852