From 1bfd09717662f3d4f4514acf01f0d2a74fb9e59b Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Tue, 26 May 2026 19:39:18 +0000 Subject: [PATCH 1/4] optimize cutlass fp8 Signed-off-by: yewentao256 --- vllm/model_executor/kernels/linear/scaled_mm/cutlass.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py index 9e65edb851e3..412b664f340e 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py @@ -312,6 +312,15 @@ def apply_block_scaled_mm( ) -> torch.Tensor: out_dtype = self.config.out_dtype if self.is_hopper: + # avoid padding when M is already 4-aligned. + if A.shape[0] % 4 == 0: + return ops.cutlass_scaled_mm( + A, + B.T, + out_dtype=out_dtype, + scale_a=As, + scale_b=Bs.T, + ) return torch.ops.vllm.padded_cutlass( A, B, From 6db67c12492d900db3874b1d8889d54a99e585b1 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Wed, 27 May 2026 22:00:44 +0000 Subject: [PATCH 2/4] update Signed-off-by: yewentao256 --- .../kernels/linear/scaled_mm/cutlass.py | 27 +++++++------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py index 412b664f340e..66567852d5f3 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py @@ -311,16 +311,8 @@ def apply_block_scaled_mm( Bs: torch.Tensor, ) -> torch.Tensor: out_dtype = self.config.out_dtype - if self.is_hopper: - # avoid padding when M is already 4-aligned. - if A.shape[0] % 4 == 0: - return ops.cutlass_scaled_mm( - A, - B.T, - out_dtype=out_dtype, - scale_a=As, - scale_b=Bs.T, - ) + # hopper requires padding only when M is not 4-aligned. + if self.is_hopper and A.shape[0] % 4 != 0: return torch.ops.vllm.padded_cutlass( A, B, @@ -329,14 +321,13 @@ def apply_block_scaled_mm( list(self.weight_group_shape), out_dtype, ) - else: - return ops.cutlass_scaled_mm( - A, - B.T, - out_dtype=out_dtype, - scale_a=As, - scale_b=Bs.T, - ) + return ops.cutlass_scaled_mm( + A, + B.T, + out_dtype=out_dtype, + scale_a=As, + scale_b=Bs.T, + ) def cutlass_scaled_mm( From 18ac0c69061ef912df6fa7211e83d13d847bc73e Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Thu, 28 May 2026 15:23:09 +0000 Subject: [PATCH 3/4] fix unit test Signed-off-by: yewentao256 --- .../kernels/linear/scaled_mm/cutlass.py | 54 +++++++++++++++---- 1 file changed, 44 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py index 66567852d5f3..7dd091b2d292 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py @@ -311,16 +311,50 @@ def apply_block_scaled_mm( Bs: torch.Tensor, ) -> torch.Tensor: out_dtype = self.config.out_dtype - # hopper requires padding only when M is not 4-aligned. - if self.is_hopper and A.shape[0] % 4 != 0: - return torch.ops.vllm.padded_cutlass( - A, - B, - As, - Bs, - list(self.weight_group_shape), - out_dtype, - ) + if self.is_hopper: + + def run_padded( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + ) -> torch.Tensor: + return torch.ops.vllm.padded_cutlass( + A, + B, + As, + Bs, + list(self.weight_group_shape), + out_dtype, + ) + + def run_direct( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + ) -> torch.Tensor: + return ops.cutlass_scaled_mm( + A, + B.T, + out_dtype=out_dtype, + scale_a=As, + scale_b=Bs.T, + ) + + if torch.compiler.is_compiling(): + # vLLM compile drops shape guards, so keep the M-alignment + # decision inside the graph + return torch.cond( + A.shape[0] % 4 != 0, + run_padded, + run_direct, + (A, B, As, Bs), + ) + + if A.shape[0] % 4 != 0: + return run_padded(A, B, As, Bs) + return ops.cutlass_scaled_mm( A, B.T, From 83c094942de469d7aa8d68f04424529c7d0f5ff2 Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 29 May 2026 21:08:42 +0000 Subject: [PATCH 4/4] fix ci Signed-off-by: yewentao256 --- .../kernels/linear/scaled_mm/cutlass.py | 98 +++++++++++-------- 1 file changed, 56 insertions(+), 42 deletions(-) diff --git a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py index 7dd091b2d292..b52d2c5b1012 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py @@ -312,48 +312,14 @@ def apply_block_scaled_mm( ) -> torch.Tensor: out_dtype = self.config.out_dtype if self.is_hopper: - - def run_padded( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - ) -> torch.Tensor: - return torch.ops.vllm.padded_cutlass( - A, - B, - As, - Bs, - list(self.weight_group_shape), - out_dtype, - ) - - def run_direct( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - ) -> torch.Tensor: - return ops.cutlass_scaled_mm( - A, - B.T, - out_dtype=out_dtype, - scale_a=As, - scale_b=Bs.T, - ) - - if torch.compiler.is_compiling(): - # vLLM compile drops shape guards, so keep the M-alignment - # decision inside the graph - return torch.cond( - A.shape[0] % 4 != 0, - run_padded, - run_direct, - (A, B, As, Bs), - ) - - if A.shape[0] % 4 != 0: - return run_padded(A, B, As, Bs) + return torch.ops.vllm.dynamic_padded_cutlass( + A, + B, + As, + Bs, + list(self.weight_group_shape), + out_dtype, + ) return ops.cutlass_scaled_mm( A, @@ -431,8 +397,56 @@ def _padded_cutlass_fake( ) +def _dynamic_padded_cutlass( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + block_size: list[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + def run_padded( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + return _padded_cutlass( + qx, weight, x_scale, weight_scale, block_size, output_dtype + ) + + def run_direct( + qx: torch.Tensor, + weight: torch.Tensor, + x_scale: torch.Tensor, + weight_scale: torch.Tensor, + ) -> torch.Tensor: + return cutlass_scaled_mm( + qx, weight, x_scale, weight_scale, block_size, output_dtype + ) + + if torch.compiler.is_compiling(): + return torch.cond( + qx.shape[0] % 4 != 0, + run_padded, + run_direct, + (qx, weight, x_scale, weight_scale), + ) + + if qx.shape[0] % 4 != 0: + return run_padded(qx, weight, x_scale, weight_scale) + + return run_direct(qx, weight, x_scale, weight_scale) + + direct_register_custom_op( "padded_cutlass", _padded_cutlass, fake_impl=_padded_cutlass_fake, ) + +direct_register_custom_op( + "dynamic_padded_cutlass", + _dynamic_padded_cutlass, + fake_impl=_padded_cutlass_fake, +)