-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[FEAT] [ROCm] Add AITER int8 scaled gemm kernel #15433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
492c6db
86f994a
895d6ba
a5a25a3
6e2832d
caf94ee
a26b31c
4d231f4
ab52481
9d81390
9754921
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -4,13 +4,20 @@ | |||||||||
|
|
||||||||||
| import torch | ||||||||||
|
|
||||||||||
| import vllm.envs as envs | ||||||||||
| from vllm import _custom_ops as ops | ||||||||||
| from vllm.platforms import current_platform | ||||||||||
|
|
||||||||||
| from .cutlass import CutlassScaledMMLinearKernel | ||||||||||
| from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def is_rocm_aiter_gemm_w8a8_scaled_mm_enabled() -> bool: | ||||||||||
| return current_platform.is_rocm() \ | ||||||||||
| and envs.VLLM_ROCM_USE_AITER_LINEAR \ | ||||||||||
| and envs.VLLM_ROCM_USE_AITER | ||||||||||
|
|
||||||||||
|
|
||||||||||
| class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): | ||||||||||
|
|
||||||||||
| @classmethod | ||||||||||
|
|
@@ -20,25 +27,20 @@ def get_min_capability(cls) -> int: | |||||||||
| @classmethod | ||||||||||
| def can_implement( | ||||||||||
| cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: | ||||||||||
| if current_platform.is_cpu(): | ||||||||||
| if current_platform.is_cpu() or not current_platform.is_rocm(): | ||||||||||
|
||||||||||
| if current_platform.is_cpu() or not current_platform.is_rocm(): | |
| if not current_platform.is_rocm(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. I have removed the check for CPU.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this more accurate here?
| per_channel_tensor_scale_a = (x_s.numel() == m) | |
| per_channel_tensor_scale_b = (w_s.numel() == n) | |
| per_token_scale_a = (x_s.numel() == m) | |
| per_channel_scale_b = (w_s.numel() == n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you are right. I have made the amendments. Thank you so much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious for future work: does this kernel support fp8?
Also, can you add a comment why w_q needs to be transposed here? I assume because it's using the Cutlass prepare weights which are transposed so here we restore it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ROCm/aiter does not support FP8 at this moment.
I have added the comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method should just be inlined to the sole callsite (unless I'm missing another use)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved.