diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 85fb25f625..686b6875bd 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -4203,7 +4203,7 @@ def gemm_fp8_nt_groupwise( a_scale: torch.Tensor if the backend is ``cutlass``: - Column-major scale tensor for a, shape ``(m, k // block_size)`` if scale_major_mode is ``K`` + Row-major scale tensor for a, shape ``(m, k // block_size)`` if scale_major_mode is ``K`` or shape ``(k // block_size, m)`` if scale_major_mode is ``MN`` if the backend is ``trtllm``: scale_major_mode should be None, the scale tensor should be (m, k // block_size), @@ -4211,7 +4211,7 @@ def gemm_fp8_nt_groupwise( b_scale: torch.Tensor if the backend is ``cutlass``: - Row-major scale tensor for b, shape ``(n // block_size, k // block_size)`` if scale_major_k is ``K`` + Row-major scale tensor for b, shape ``(n // block_size, k // block_size)`` if scale_major_mode is ``K`` or shape ``(k // block_size, n // block_size)`` if scale_major_mode is ``MN`` if the backend is ``trtllm``: scale_major_mode should be None, the scale tensor should be (k // block_size, n // block_size), @@ -4573,7 +4573,7 @@ def group_gemm_fp8_nt_groupwise( Column-major input tensor shape ``(batch_size, n, k)``, data type is ``torch.float8_e4m3fn`` or ``torch.float8_e5m2``. a_scale: torch.Tensor - Column-major scale tensor for a, shape ``(cum_m, k // block_size)`` if scale_major_mode is ``K`` + Row-major scale tensor for a, shape ``(cum_m, k // block_size)`` if scale_major_mode is ``K`` or shape ``(k // block_size, cum_m)`` if scale_major_mode is ``MN``, data type is ``torch.float32``. b_scale: torch.Tensor @@ -4783,7 +4783,7 @@ def group_gemm_mxfp8_mxfp4_nt_groupwise( Column-major input tensor, shape ``(batch_size, n, k // 2)``, data type is ``torch.uint8``. a_scale: torch.Tensor - Column-major scale tensor for a, shape ``(cum_m_padded, k // 32)``, data type is ``torch.uint8``. + Row-major scale tensor for a, shape ``(cum_m_padded, k // 32)``, data type is ``torch.uint8``. b_scale: torch.Tensor Row-major scale tensor for b, shape ``(batch_size, n_padded, k // 32)``, data type is ``torch.uint8``.