-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[V1] [ROCm] [AITER] Upgrade AITER to commit 916bf3c and bugfix APIs
#20880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -8,11 +8,55 @@ | |||||||||||
| import vllm.envs as envs | ||||||||||||
| from vllm import _custom_ops as ops | ||||||||||||
| from vllm.platforms import current_platform | ||||||||||||
| from vllm.utils import direct_register_custom_op | ||||||||||||
|
|
||||||||||||
| from .cutlass import CutlassScaledMMLinearKernel | ||||||||||||
| from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def rocm_aiter_gemm_w8a8_impl( | ||||||||||||
| A: torch.Tensor, | ||||||||||||
| B: torch.Tensor, | ||||||||||||
| As: torch.Tensor, | ||||||||||||
| Bs: torch.Tensor, | ||||||||||||
| bias: Optional[torch.Tensor] = None, | ||||||||||||
| output_dtype: torch.dtype = torch.float16, | ||||||||||||
| ) -> torch.Tensor: | ||||||||||||
|
|
||||||||||||
| from aiter import gemm_a8w8_CK | ||||||||||||
|
|
||||||||||||
| # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects | ||||||||||||
| # a to be [M, K] | ||||||||||||
| # b to be [N, K] | ||||||||||||
| # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format | ||||||||||||
| return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def rocm_aiter_gemm_w8a8_fake( | ||||||||||||
| A: torch.Tensor, | ||||||||||||
| B: torch.Tensor, | ||||||||||||
| As: torch.Tensor, | ||||||||||||
| Bs: torch.Tensor, | ||||||||||||
| bias: Optional[torch.Tensor] = None, | ||||||||||||
| output_dtype: torch.dtype = torch.float16, | ||||||||||||
| ) -> torch.Tensor: | ||||||||||||
|
|
||||||||||||
| m = A.shape[0] | ||||||||||||
| n = B.shape[0] | ||||||||||||
| Y = torch.empty(m, n, dtype=output_dtype, device=A.device) | ||||||||||||
| return Y | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| if current_platform.is_rocm(): | ||||||||||||
| direct_register_custom_op( | ||||||||||||
| op_name="rocm_aiter_gemm_w8a8", | ||||||||||||
| op_func=rocm_aiter_gemm_w8a8_impl, | ||||||||||||
| mutates_args=[], | ||||||||||||
| fake_impl=rocm_aiter_gemm_w8a8_fake, | ||||||||||||
| dispatch_key=current_platform.dispatch_key, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): | ||||||||||||
|
|
||||||||||||
| @classmethod | ||||||||||||
|
|
@@ -111,10 +155,9 @@ def apply_weights(self, | |||||||||||
| " w8a8 scaled gemm. `AiterScaledMMLinearKernel` " + | ||||||||||||
| "does not support AITER block scaled GEMM.") | ||||||||||||
|
|
||||||||||||
| from aiter import gemm_a8w8_CK | ||||||||||||
|
|
||||||||||||
| # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects | ||||||||||||
| # a to be [M, K] | ||||||||||||
| # b to be [N, K] | ||||||||||||
| # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format | ||||||||||||
| return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype) | ||||||||||||
| return torch.ops.vllm.rocm_aiter_gemm_w8a8(x_q, w_q.t(), x_s, w_s, | ||||||||||||
| bias, out_dtype) | ||||||||||||
|
Comment on lines
+162
to
+163
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the comment to reflect the use of
Suggested change
|
||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,7 +55,7 @@ def rocm_aiter_gemm_w8a8_blockscale_impl( | |
| ) -> torch.Tensor: | ||
| import aiter as rocm_aiter | ||
|
|
||
| return rocm_aiter.gemm_a8w8_blockscale_CK(A, B, As, Bs, dtype=output_dtype) | ||
| return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
|
|
||
| def rocm_aiter_gemm_w8a8_blockscale_fake( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment should focus on the expected input format for
gemm_a8w8_CKonly, as theCutlassScaledMMLinearKerneldetail is no longer relevant here.