From 4410a6192fd056a47c724580485838268ca3ad87 Mon Sep 17 00:00:00 2001 From: Blake Ledden Date: Sat, 21 Feb 2026 09:57:36 -0800 Subject: [PATCH 1/2] docs: Fix incorrect column-major scale layout description in FP8 GEMM docstrings The a_scale parameter docstrings in gemm_fp8_nt_groupwise, group_gemm_fp8_nt_groupwise, and group_gemm_mxfp8_mxfp4_nt_groupwise incorrectly described the scale tensor as "Column-major". The kernel actually expects standard contiguous (row-major) tensors, consistent with what quantize_fp8 produces and the test suite passes. Changed "Column-major" to "Row-major" in all three a_scale descriptions to match the b_scale docs, which already correctly say "Row-major". Fixes #2147 Signed-off-by: Blake Ledden --- flashinfer/gemm/gemm_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 85fb25f625..e5433823cc 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -4203,7 +4203,7 @@ def gemm_fp8_nt_groupwise( a_scale: torch.Tensor if the backend is ``cutlass``: - Column-major scale tensor for a, shape ``(m, k // block_size)`` if scale_major_mode is ``K`` + Row-major scale tensor for a, shape ``(m, k // block_size)`` if scale_major_mode is ``K`` or shape ``(k // block_size, m)`` if scale_major_mode is ``MN`` if the backend is ``trtllm``: scale_major_mode should be None, the scale tensor should be (m, k // block_size), @@ -4573,7 +4573,7 @@ def group_gemm_fp8_nt_groupwise( Column-major input tensor shape ``(batch_size, n, k)``, data type is ``torch.float8_e4m3fn`` or ``torch.float8_e5m2``. a_scale: torch.Tensor - Column-major scale tensor for a, shape ``(cum_m, k // block_size)`` if scale_major_mode is ``K`` + Row-major scale tensor for a, shape ``(cum_m, k // block_size)`` if scale_major_mode is ``K`` or shape ``(k // block_size, cum_m)`` if scale_major_mode is ``MN``, data type is ``torch.float32``. b_scale: torch.Tensor @@ -4783,7 +4783,7 @@ def group_gemm_mxfp8_mxfp4_nt_groupwise( Column-major input tensor, shape ``(batch_size, n, k // 2)``, data type is ``torch.uint8``. a_scale: torch.Tensor - Column-major scale tensor for a, shape ``(cum_m_padded, k // 32)``, data type is ``torch.uint8``. + Row-major scale tensor for a, shape ``(cum_m_padded, k // 32)``, data type is ``torch.uint8``. b_scale: torch.Tensor Row-major scale tensor for b, shape ``(batch_size, n_padded, k // 32)``, data type is ``torch.uint8``. From 2b84047f13ee7ee55f1b80218de1bd42b93b87f5 Mon Sep 17 00:00:00 2001 From: Blake Ledden Date: Sun, 22 Feb 2026 15:41:49 -0800 Subject: [PATCH 2/2] fix: also fix adjacent scale_major_k typo in b_scale docstring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per CodeRabbit feedback, the adjacent b_scale docstring had scale_major_k instead of scale_major_mode — fixing while I'm already editing these docstrings. Co-Authored-By: Claude Opus 4.6 --- flashinfer/gemm/gemm_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index e5433823cc..686b6875bd 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -4211,7 +4211,7 @@ def gemm_fp8_nt_groupwise( b_scale: torch.Tensor if the backend is ``cutlass``: - Row-major scale tensor for b, shape ``(n // block_size, k // block_size)`` if scale_major_k is ``K`` + Row-major scale tensor for b, shape ``(n // block_size, k // block_size)`` if scale_major_mode is ``K`` or shape ``(k // block_size, n // block_size)`` if scale_major_mode is ``MN`` if the backend is ``trtllm``: scale_major_mode should be None, the scale tensor should be (k // block_size, n // block_size),