-
-
Notifications
You must be signed in to change notification settings - Fork 16.6k
[ROCm][Deepseek] dsv3.2 further optimization #41217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
36bd605
4df6432
999ca3a
3a6d3f5
f77e120
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -396,6 +396,7 @@ class AiterMLAHelper: | |
| """ | ||
|
|
||
| _AITER_MIN_MLA_HEADS: Final = 16 | ||
| _AITER_UNSUPPORTED_HEADS = [32] | ||
|
|
||
| @staticmethod | ||
| def check_num_heads_validity(num_heads: int): | ||
|
|
@@ -419,6 +420,9 @@ def get_actual_mla_num_heads(num_heads: int) -> int: | |
|
|
||
| @staticmethod | ||
| def get_mla_padded_q(num_heads: int, q: torch.Tensor) -> torch.Tensor: | ||
| assert num_heads not in AiterMLAHelper._AITER_UNSUPPORTED_HEADS, ( | ||
| f"unsupported head_num: {num_heads}" | ||
| ) | ||
|
Comment on lines
+423
to
+425
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The assertion |
||
| return ( | ||
| q | ||
| if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it just head size 32 that has issue?
So can I understand as
head size 16 is fine; head size 32 is NOT supported;
head size 64 is fine; head size 128 is fine as well, etc .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like so, I found when head_size equals to 32, the aiter will automatically pick the, as the symbol name suggested, 16head implementation and cause the illegal memory access. I'm not sure if other shape have the same issue, I only tested tp8 and tp4 case