diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 5bbc4df2ba..052104766d 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -241,27 +241,77 @@ def forward( k_dim = a.shape[2] batch_size = a.shape[0] - # ScaleGranularityK must equal TileK (128) - if k_dim < 128: - raise ValueError( - f"SM120/SM121 CUTLASS blockwise scaling requires k >= 128, got k={k_dim}. " - ) - scale_gran_m = 1 scale_gran_n = 128 scale_gran_k = 128 + # round up to the next multiple + def _pad_to_multiple(x, multiple): + return ((x + multiple - 1) // multiple) * multiple + + # SM120 CUTLASS blockwise scaling requires: + # - N % 128 == 0 (ScaleGranularityN) + # - K % 128 == 0 (TileK) + # If not aligned, we pad and then slice the result + n_padded = _pad_to_multiple(n_dim, scale_gran_n) + k_padded = _pad_to_multiple(k_dim, scale_gran_k) + needs_n_padding = n_padded != n_dim + needs_k_padding = k_padded != k_dim + + if not needs_k_padding and not needs_n_padding: + # No padding needed + a_padded = a + b_col_major_padded = b_col_major + else: + # Padding needed + if a.dim() == 2: + a_padded = a + if needs_k_padding: + a_padded = torch.nn.functional.pad( + a_padded.contiguous(), (0, k_padded - k_dim) + ) + b_col_major_padded = torch.zeros( + (n_padded, k_padded), + dtype=b_col_major.dtype, + device=b_col_major.device, + ) + b_col_major_padded[:n_dim, :k_dim].copy_(b_col_major) + else: + a_padded = a + if needs_k_padding: + a_padded = torch.nn.functional.pad( + a_padded.contiguous(), (0, k_padded - k_dim) + ) + + b_underlying_padded = torch.zeros( + (batch_size, n_padded, k_padded), + dtype=b_col_major.dtype, + device=b_col_major.device, + ) + b_col_major_padded = b_underlying_padded.transpose(-2, -1) + b_col_major_padded[:, :k_dim, :n_dim].copy_(b_col_major) + + # Create padded output if needed + if needs_n_padding: + if a.dim() == 2: + out_padded = torch.empty( + (m_dim, n_padded), device=out.device, dtype=out.dtype + ) + else: + out_padded = torch.empty( + (batch_size, m_dim, n_padded), + device=out.device, + dtype=out.dtype, + ) + else: + out_padded = out + # For scalar scales, create compatible shapes for SM120 - # SM120 requires scale tensors with specific shapes based on granularity - # Scale shape should be [m/scale_gran_m, k/scale_gran_k] for A - # and [n/scale_gran_n, k/scale_gran_k] for B if scale_a.numel() == 1: scale_m_count = ( batch_size * m_dim + scale_gran_m - 1 ) // scale_gran_m - scale_k_count = ( - k_dim + scale_gran_k - 1 - ) // scale_gran_k # k dimension + scale_k_count = (k_padded + scale_gran_k - 1) // scale_gran_k scale_a_expanded = ( scale_a.view(1, 1) .expand(scale_m_count, scale_k_count) @@ -271,13 +321,10 @@ def forward( scale_a_expanded = scale_a if scale_b.numel() == 1: - # Calculate the expected scale dimensions scale_n_count = ( - batch_size * n_dim + scale_gran_n - 1 + batch_size * n_padded + scale_gran_n - 1 ) // scale_gran_n - scale_k_count = ( - k_dim + scale_gran_k - 1 - ) // scale_gran_k # k dimension + scale_k_count = (k_padded + scale_gran_k - 1) // scale_gran_k scale_b_expanded = ( scale_b.view(1, 1) .expand(scale_n_count, scale_k_count) @@ -289,16 +336,24 @@ def forward( # Call SM120 gemm_fp8_nt_groupwise (now handles both 2D and 3D) module.gemm_fp8_nt_groupwise( workspace_buffer, - a, - b_col_major, + a_padded, + b_col_major_padded, scale_a_expanded, scale_b_expanded, - out, + out_padded, scale_gran_m, # scale_granularity_m scale_gran_n, # scale_granularity_n scale_gran_k, # scale_granularity_k (adjusted for small k) "MN", # scale_major_mode ) + + # Slice the result if we padded + if needs_n_padding: + if a.dim() == 2: + out.copy_(out_padded[:, :n_dim]) + else: + out.copy_(out_padded[:, :, :n_dim]) + return out return CutlassFp8GemmRunner() @@ -2300,9 +2355,8 @@ def _heuristic_func_bmm_fp8( if is_sm_supported: heuristic_backends.append("cutlass_sm10x") elif is_sm120_supported: - k_dim = A.shape[-1] if A.dim() == 2 else A.shape[2] - if k_dim >= 128: - heuristic_backends.append("cutlass_sm12x") + # supports all K values through padding + heuristic_backends.append("cutlass_sm12x") if "cublas" in suitable_backends: heuristic_backends.append("cublas") if CUDNN_AVAILABLE and "cudnn" in suitable_backends: diff --git a/tests/gemm/test_bmm_fp8.py b/tests/gemm/test_bmm_fp8.py index 326fd2ab9c..d76416b511 100644 --- a/tests/gemm/test_bmm_fp8.py +++ b/tests/gemm/test_bmm_fp8.py @@ -8,9 +8,9 @@ @pytest.mark.parametrize("b", [1, 16]) -@pytest.mark.parametrize("m", [48, 128]) -@pytest.mark.parametrize("n", [80, 64]) -@pytest.mark.parametrize("k", [64, 256]) +@pytest.mark.parametrize("m", [1, 48, 128]) +@pytest.mark.parametrize("n", [64, 80, 10304]) +@pytest.mark.parametrize("k", [64, 256, 2688]) @pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) @pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) @@ -18,14 +18,6 @@ @pytest.mark.parametrize("auto_tuning", [True, False]) def test_bmm_fp8(b, m, n, k, input_dtype, mat2_dtype, res_dtype, backend, auto_tuning): compute_capability = get_compute_capability(torch.device("cuda")) - if compute_capability[0] == 12 and backend in [ - "cutlass", - "auto", - ]: - # TODO(yongwwww): enable all test cases for SM120/121 CUTLASS bmm_fp8 backend - pytest.xfail( - "Not all test cases for CUTLASS bmm_fp8 on SM120/121 are passing at this moment" - ) if backend == "cutlass" and compute_capability[0] not in [10, 11, 12]: pytest.skip( "bmm_fp8 with cutlass backend is only supported on SM100, SM110, and SM120/121 GPUs."