-
-
Notifications
You must be signed in to change notification settings - Fork 15.7k
[ROCm][CI] Fix ROCm attention backend validation for head sizes, block sizes, and compute capability checks #36292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dbe3dda
d5d3aac
074ddb6
c8c7324
69e80b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,12 @@ class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend): | |
| def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: | ||
| return [MultipleOf(16)] | ||
|
|
||
| @classmethod | ||
| def supports_block_size(cls, block_size: int | None) -> bool: | ||
| if block_size is None: | ||
| return True | ||
| return block_size % 16 == 0 | ||
|
|
||
|
Comment on lines
+32
to
+37
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gshtras This is also what we had before the two PRs landed correct?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you don't need this,
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some tests were failing before I added this.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without these for example,
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Repro commands:
|
||
| @classmethod | ||
| def supports_head_size(cls, head_size: int) -> bool: | ||
| return head_size >= 32 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -188,6 +188,12 @@ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: | |
| # uses our optimized kernel logic. | ||
| return [16, 32, 544] | ||
|
|
||
| @classmethod | ||
| def supports_block_size(cls, block_size: int | None) -> bool: | ||
| if block_size is None: | ||
| return True | ||
| return block_size in (16, 32, 544) | ||
|
|
||
|
Comment on lines
+191
to
+196
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gshtras Do we need to add more here?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't need this,
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Responding in the comment below on this. I believe that these definitions are necessary as there are failures if they are deleted.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR is going to fix the block size issue of Qwen3.5 #35923
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tjtanaa I see that both attempt to address this indeed. You need me to remove my patch here or shall we close the other one?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @AndreasKaratzas thanks!
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @JartX Missed that part. let me see if I can just merge your commit.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh you did that already 😅 |
||
| @classmethod | ||
| def get_supported_head_sizes(cls) -> list[int]: | ||
| return [32, 64, 80, 96, 128, 160, 192, 224, 256] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -273,6 +273,12 @@ class TritonAttentionBackend(AttentionBackend): | |
| def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: | ||
| return [MultipleOf(16)] | ||
|
|
||
| @classmethod | ||
| def supports_block_size(cls, block_size: int | None) -> bool: | ||
| if block_size is None: | ||
| return True | ||
| return block_size % 16 == 0 | ||
|
|
||
|
Comment on lines
+276
to
+281
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gshtras (Same question I had for ROCm AITER Unified Attn): This is also what we had before the two PRs landed correct?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without these for example,
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did we triage why?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a "block size not supported message". So the straight forward solution was to define what is the supported block size set. |
||
| forward_includes_kv_cache_update: bool = False | ||
|
|
||
| @staticmethod | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we do this or override
supports_head_sizeand always return True instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Rohan138 I think that having in each attention backend the proper supported head sizes is the way to go. Besides, regarding AITER, there may be some unsupported head sizes if I am not wrong. It's just that we have not precisely identified them, right?