Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a warning if the user calls os.fork(). #18989

Merged
merged 1 commit into from
Dec 15, 2023
Merged

Conversation

hawkinsp
Copy link
Collaborator

Fixes #18852

jax/_src/xla_bridge.py Outdated Show resolved Hide resolved
Comment on lines +109 to +114
# Warn the user if they call fork(), because it's not going to go well for them.
def _at_fork():
warnings.warn(
"os.fork() was called. os.fork() is incompatible with multithreaded code, "
"and JAX is multithreaded, so this will likely lead to a deadlock.",
RuntimeWarning, stacklevel=2)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing this warning even when not directly using JAX. Could this warning be more localized somehow?

Minimal repro: https://colab.research.google.com/gist/ScottTodd/476cfa0dd6620511ace15ce321b03a4e/colab-fork-jax-warning.ipynb

import os
import transformers
pid = os.fork()

RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  pid = os.fork()

Specific callstacks into os.fork() and this code from my project: https://gist.github.com/ScottTodd/981f1c8696887e0f9aef7fa57adbe84b

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Threading and os.fork() do not mix. You're playing with fire combining os.fork() and pytorch, in my opinion.

It's possible we could delay adding this check until you actually initialize jax, I suppose.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main confusing part of this is that even when not using JAX (e.g., using PyTorch and it calls fork itself) one gets a warning about JAX. Delaying would at least avoid that happening if not using JAX at all, but seems would still trigger if one used both these libraries together.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax is a pretty heavy import if you're not using it. Maybe the best fix is to ensure that transformers doesn't import jax unless it needs to?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#20734 defers installing the warning until JAX initializes its internals. If you simply import but do not use jax, this should silence the warning.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fast responses! This warning did help us uncover a few issues across various packages, at least :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

JAX hangs when allocating array of at least 2^19 bytes on CPU in subprocess
7 participants