Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm seeing this warning even when not directly using JAX. Could this warning be more localized somehow?
Minimal repro: https://colab.research.google.com/gist/ScottTodd/476cfa0dd6620511ace15ce321b03a4e/colab-fork-jax-warning.ipynb
Specific callstacks into os.fork() and this code from my project: https://gist.github.com/ScottTodd/981f1c8696887e0f9aef7fa57adbe84b
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Threading and
os.fork()
do not mix. You're playing with fire combiningos.fork()
and pytorch, in my opinion.It's possible we could delay adding this check until you actually initialize
jax
, I suppose.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the main confusing part of this is that even when not using JAX (e.g., using PyTorch and it calls fork itself) one gets a warning about JAX. Delaying would at least avoid that happening if not using JAX at all, but seems would still trigger if one used both these libraries together.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax
is a pretty heavy import if you're not using it. Maybe the best fix is to ensure thattransformers
doesn't importjax
unless it needs to?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytorch/pytorch#123954
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#20734 defers installing the warning until JAX initializes its internals. If you simply import but do not use
jax
, this should silence the warning.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fast responses! This warning did help us uncover a few issues across various packages, at least :)