-
Notifications
You must be signed in to change notification settings - Fork 31.6k
[feat] LlavaNext add feature size check to avoid CUDA Runtime Error #33608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4ea3e0e
610bc75
353c610
53d407d
fc07640
9ac6689
edab0fc
bea8910
028965d
244693f
82058f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -518,6 +518,12 @@ def forward( | |
|
|
||
| # TODO: @raushan retain only the new behavior after v4.47 | ||
| else: | ||
| n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() | ||
| n_image_features = image_features.shape[1] | ||
| if n_image_tokens != n_image_features: | ||
| raise ValueError( | ||
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" | ||
| ) | ||
|
Comment on lines
+521
to
+526
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know why we are adding this here as the processor is supposed to check this for non legacy path!
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it is supposed. There was only one edge case with llava-next which uses pad/unpad technique and since we used tensors in modeling, there were minor numerical inconsistencies Right now it should work, but in general imo it's a good idea to help users pinpoint what went wrong in their code
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not in the forward pass IMO, we are adding extra processing, .sum and .item() as seen above, which are run for every single forward pass. biggest issue for me is duplicated work! |
||
| special_image_mask = ( | ||
| (input_ids == self.config.image_token_index) | ||
| .unsqueeze(-1) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mmmm in general I don't mind, as this should help our users, but the .item() might break compile compatibility (well only full graph).
@McPatate that's where and when we would need to see how much we are losing from this small change ! 🤗 (FYI @LysandreJik )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can alwayss wrap these in
is_torchdynamo_compiling, same was a s we wrap all warnings/logging now in generation code. So we ask users to make sure the code works w/o compilation, to see all warning etc, and then compile the code which will not show the exact reason why/where this CUDA-side error was triggeredThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay that makes sense. Just 🥶 to more checks, but this one is most probably cached should be alright
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is these
is_compilingare unrelated to normal users ~-> expose them to unrelated codesThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i see what you mean. Yes, the processing should maybe check this, but we cannot perform any checks before getting image hidden states. My main idea was to bring the same check we had earlier in
merge_inputsmethod, so that after moving to the new logic we still can trace down bugs related to shape mismatch easily, or let users track that downAlso we won't do the sum() and item() every forward, for generation it is only for prefill stage after which we'll have image states in the cache. But anyway, if you think this is too many checks (given we now support old and new logic in VLMs for a few minor releases), I am okay with not adding it. I don't see it as a major blocker or anything, more like a nice addition for users :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay let's add it then 🤗