-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[qwen2-vl] fix FA2 inference #39121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[qwen2-vl] fix FA2 inference #39121
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
4301faa
fix FA2
zucchini-nlp 2cb7f21
update is causal flag and remove mask for FA2
zucchini-nlp 7823ffc
update for FA2 with varlen path
zucchini-nlp 5640fe1
Merge remote-tracking branch 'upstream/main' into qwen2-fix
zucchini-nlp ebb46c3
how the tests were passing with different devices?
zucchini-nlp f2248b7
add comment and ref to the PR
zucchini-nlp c567fe6
move mask preparation to base pretrained model
zucchini-nlp 3a3afe3
seq len is the first dim, not second
zucchini-nlp 0e1d686
fix copies to fix GLM4V
zucchini-nlp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit unsure about this since we would allow
cu_seqandmax_seqonly but on most models we also have RoPE so it's breaking those models silently if we eff up not passingposition_ids(due to RoPE positions being bound toposition_idsas well). We should imo add at least a warning on only varlen kwargs to give some discretion here.On another note, what do the integration tests use? Are they still working as expected 👀 seems a bit sus
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The integration tests always use
position_idswhich is the firstis_fa2_with_position_ids, I just copied it from existing code and moved up here. The second case is added for Qwen only, afaik no other model passes pre-computedcu_lensform attention layersIn qwen we don't need any position ids, because they are 3D and won't help at all in inferring
cu_lensUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant the general integration tests like e.g.
transformers/tests/models/qwen2_vl/test_modeling_qwen2_vl.py
Line 507 in d53518c
For why, I'm concerned about the second case on qwen is future model additions and general usability, not the validity of qwen. For developers,
is_fa2_with_varlen_kwargsindicates that this suffices for varlen - before, we (unintentionally) checked for the existence of (correct flattened) position ids that RoPE models need when using varlen. Maybe #35941 helps for reference on what I mean.Imo, it would help to add at least comments that for varlen most models need correct flattened position ids (from e.g. a collator), especially RoPE models which make up the majority of newer models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, am I right that users might be passing only
cu_lenswithout correctposition_ids? I believe that would be users' responsibility to take care that RoPE is applied correctly, but I will add a comment in code explaining it, sureIn slow integration tests we don't pass
position_ids, not that of I know. For most LLMs the fa2 path integration tests fallback to inferringcu_lensfrom the mask, and in Qwen theposition_idsare constructed on-the-fly during forward call. The model has a requirement for adding rope deltas on top of 3D positions and I don't think users would be doing all that manuallyUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Hmm, am I right that users might be passing only cu_lens without correct position_ids?" - Yes, not only users but possibly us as well because it's something that's harder to figure out when done wrong imo :D