From ef220d18de652d5258f812a78cd3ed80571586c9 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Thu, 29 Jan 2026 20:54:28 +0000 Subject: [PATCH 1/2] Fixing the skinny gemm dispatch logic. Weights can be padded for it to work Signed-off-by: Gregory Shtrasberg --- vllm/_custom_ops.py | 16 ---------------- .../quantization/kernels/scaled_mm/rocm.py | 1 + vllm/model_executor/layers/utils.py | 2 ++ 3 files changed, 3 insertions(+), 16 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ebc8e83eb6f1..ea44beda5931 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2027,35 +2027,20 @@ def selective_scan_fwd( ) -# NOTE: The wvSplitK kernel (and all of the kernels in skinny_gemms.cu) -# are unable to properly handle non-contiguous -# tensors. It might be a good TODO(rasmith) to augment these kernels -# to be able to handle non-contiguous kernels for better performance. -def rocm_enforce_contiguous_skinny_gemm_inputs( - a: torch.Tensor, b: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - a = a.contiguous() # no-op if already contiguous, else clone - b = b.contiguous() # no-op if already contiguous, else clone - return a, b - - # ROCm skinny gemms def LLMM1(a: torch.Tensor, b: torch.Tensor, rows_per_block: int) -> torch.Tensor: - a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b) return torch.ops._rocm_C.LLMM1(a, b, rows_per_block) def wvSplitK( a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None ) -> torch.Tensor: - a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b) return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count) def wvSplitKrc( a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None ) -> torch.Tensor: - a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b) return torch.ops._rocm_C.wvSplitKrc(a, b, bias, cu_count) @@ -2068,7 +2053,6 @@ def wvSplitKQ( cu_count: int, bias: torch.Tensor = None, ) -> torch.Tensor: - a, b = rocm_enforce_contiguous_skinny_gemm_inputs(a, b) out = torch.empty((b.shape[0], a.shape[0]), dtype=out_dtype, device=b.device) torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count) return out diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py index e52015ba2d11..ee660812e9f9 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/rocm.py @@ -28,6 +28,7 @@ def rocm_per_tensor_float_w8a8_scaled_mm_impl( A.shape[0] == 1 and B.shape[1] % 16 == 0 and ((bias is None) or (bias.dtype == out_dtype)) + and A.is_contiguous() ): output = ops.wvSplitKQ( B.t(), diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index 2ec364213977..55490bfbbabb 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -151,6 +151,7 @@ def rocm_unquantized_gemm_impl( and n <= 128 and k > 512 and math.ceil(k / 512) * math.ceil(m / 16) < get_cu_count() + and x.is_contiguous() ) # k == 2880 and (m == 640 or m == 128)) ) @@ -165,6 +166,7 @@ def rocm_unquantized_gemm_impl( and on_gfx9() and x.dtype in [torch.float16, torch.bfloat16] and k % 8 == 0 + and x.is_contiguous() ) if use_skinny is not True: From c29b513ea2f1a93a9499dedb584abe8fc84ddcc6 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Thu, 29 Jan 2026 21:04:35 +0000 Subject: [PATCH 2/2] Added weight padding into fp8 skinny gemm tests Signed-off-by: Gregory Shtrasberg --- .../quantization/test_rocm_skinny_gemms.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 1f6464e210b3..1505604a6919 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -87,6 +87,13 @@ SEEDS = [0] +def pad_weights_fp8(weight): + num_pad = 256 // weight.element_size() + import torch.nn.functional as F + + return F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] + + @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITKRC) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @@ -191,11 +198,12 @@ def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed): @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("padded", [False, True]) @pytest.mark.skipif( not (current_platform.is_rocm() and current_platform.supports_fp8()), reason="only test for rocm fp8", ) -def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): +def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed, padded): torch.manual_seed(seed) A = torch.rand(n, k, device="cuda") - 0.5 @@ -203,6 +211,8 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) + if padded: + B = pad_weights_fp8(B) ref_out = torch._scaled_mm( A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b @@ -222,11 +232,12 @@ def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): @pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("padded", [False, True]) @pytest.mark.skipif( not (current_platform.is_rocm() and current_platform.supports_fp8()), reason="only test for rocm fp8", ) -def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed): +def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed, padded): torch.manual_seed(seed) xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas @@ -236,6 +247,8 @@ def test_rocm_wvsplitk_fp8_bias1D_kernel(n, k, m, dtype, seed): A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) + if padded: + B = pad_weights_fp8(B) ref_out = torch._scaled_mm( A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS