Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile.rocm_base
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="6487649"
ARG AITER_BRANCH="916bf3c"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"

FROM ${BASE_IMAGE} AS base
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,55 @@
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op

from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig


def rocm_aiter_gemm_w8a8_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: Optional[torch.Tensor] = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:

from aiter import gemm_a8w8_CK

# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
Comment on lines +28 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment should focus on the expected input format for gemm_a8w8_CK only, as the CutlassScaledMMLinearKernel detail is no longer relevant here.

Suggested change
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects:
# a: [M, K]
# b: [N, K]

return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)


def rocm_aiter_gemm_w8a8_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: Optional[torch.Tensor] = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:

m = A.shape[0]
n = B.shape[0]
Y = torch.empty(m, n, dtype=output_dtype, device=A.device)
return Y


if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_gemm_w8a8",
op_func=rocm_aiter_gemm_w8a8_impl,
mutates_args=[],
fake_impl=rocm_aiter_gemm_w8a8_fake,
dispatch_key=current_platform.dispatch_key,
)


class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):

@classmethod
Expand Down Expand Up @@ -111,10 +155,9 @@ def apply_weights(self,
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` " +
"does not support AITER block scaled GEMM.")

from aiter import gemm_a8w8_CK

# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype)
return torch.ops.vllm.rocm_aiter_gemm_w8a8(x_q, w_q.t(), x_s, w_s,
bias, out_dtype)
Comment on lines +162 to +163
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the comment to reflect the use of torch.ops.vllm.rocm_aiter_gemm_w8a8 instead of the direct gemm_a8w8_CK call, explaining the weight tensor transposition for the new custom op.

Suggested change
return torch.ops.vllm.rocm_aiter_gemm_w8a8(x_q, w_q.t(), x_s, w_s,
bias, out_dtype)
# The AITER GEMM kernel expects the weight tensor to be in [N, K] format.
# `CutlassScaledMMLinearKernel` prepares the weight `w_q` in [K, N] format,
# so we transpose it before passing it to the kernel.

Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def rocm_aiter_gemm_w8a8_blockscale_impl(
) -> torch.Tensor:
import aiter as rocm_aiter

return rocm_aiter.gemm_a8w8_blockscale_CK(A, B, As, Bs, dtype=output_dtype)
return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider renaming rocm_aiter to aiter to align with the import alias used in the function.

Suggested change
return rocm_aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
return aiter.gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)



def rocm_aiter_gemm_w8a8_blockscale_fake(
Expand Down