Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions tests/kernels/quantization/test_rocm_skinny_gemms.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@
SEEDS = [0]


def pad_weights_fp8(weight):
num_pad = 256 // weight.element_size()
import torch.nn.functional as F

return F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]


@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITKRC)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
Expand Down Expand Up @@ -191,18 +198,21 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("padded", [False, True])
@pytest.mark.skipif(
not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8",
)
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed, padded):
torch.manual_seed(seed)

A = torch.rand(n, k, device="cuda") - 0.5
B = torch.rand(m, k, device="cuda") - 0.5

A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
if padded:
B = pad_weights_fp8(B)

ref_out = torch._scaled_mm(
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b
Expand All @@ -222,11 +232,12 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed):
@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("padded", [False, True])
@pytest.mark.skipif(
not (current_platform.is_rocm() and current_platform.supports_fp8()),
reason="only test for rocm fp8",
)
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):
def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed, padded):
torch.manual_seed(seed)

xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas
Expand All @@ -236,6 +247,8 @@ def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed):

A, scale_a = ref_dynamic_per_tensor_fp8_quant(A)
B, scale_b = ref_dynamic_per_tensor_fp8_quant(B)
if padded:
B = pad_weights_fp8(B)

ref_out = torch._scaled_mm(
A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS
Expand Down
16 changes: 0 additions & 16 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,35 +2027,20 @@ def selective_scan_fwd(
)


# NOTE: The wvSplitK kernel (and all of the kernels in skinny_gemms.cu)
# are unable to properly handle non-contiguous
# tensors. It might be a good TODO(rasmith) to augment these kernels
# to be able to handle non-contiguous kernels for better performance.
def rocm_enforce_contiguous_skinny_gemm_inputs(
a: torch.Tensor, b: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
a = a.contiguous() # no-op if already contiguous, else clone
b = b.contiguous() # no-op if already contiguous, else clone
return a, b


# ROCm skinny gemms
def LLMM1(a: torch.Tensor, b: torch.Tensor, rows_per_block: int) -> torch.Tensor:
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
return torch.ops._rocm_C.LLMM1(a, b, rows_per_block)


def wvSplitK(
a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None
) -> torch.Tensor:
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count)


def wvSplitKrc(
a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None
) -> torch.Tensor:
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
return torch.ops._rocm_C.wvSplitKrc(a, b, bias, cu_count)


Expand All @@ -2068,7 +2053,6 @@ def wvSplitKQ(
cu_count: int,
bias: torch.Tensor = None,
) -> torch.Tensor:
a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b)
out = torch.empty((b.shape[0], a.shape[0]), dtype=out_dtype, device=b.device)
torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count)
return out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl(
A.shape[0] == 1
and B.shape[1] % 16 == 0
and ((bias is None) or (bias.dtype == out_dtype))
and A.is_contiguous()
):
output = ops.wvSplitKQ(
B.t(),
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def rocm_unquantized_gemm_impl(
and n <= 128
and k > 512
and math.ceil(k / 512) * math.ceil(m / 16) < get_cu_count()
and x.is_contiguous()
)
# k == 2880 and (m == 640 or m == 128))
)
Expand All @@ -179,6 +180,7 @@ def rocm_unquantized_gemm_impl(
and on_gfx9()
and x.dtype in [torch.float16, torch.bfloat16]
and k % 8 == 0
and x.is_contiguous()
)

if use_skinny is not True:
Expand Down