Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 77 additions & 23 deletions flashinfer/gemm/gemm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,27 +241,77 @@ def forward(
k_dim = a.shape[2]
batch_size = a.shape[0]

# ScaleGranularityK must equal TileK (128)
if k_dim < 128:
raise ValueError(
f"SM120/SM121 CUTLASS blockwise scaling requires k >= 128, got k={k_dim}. "
)

scale_gran_m = 1
scale_gran_n = 128
scale_gran_k = 128

# round up to the next multiple
def _pad_to_multiple(x, multiple):
return ((x + multiple - 1) // multiple) * multiple

# SM120 CUTLASS blockwise scaling requires:
# - N % 128 == 0 (ScaleGranularityN)
# - K % 128 == 0 (TileK)
# If not aligned, we pad and then slice the result
n_padded = _pad_to_multiple(n_dim, scale_gran_n)
k_padded = _pad_to_multiple(k_dim, scale_gran_k)
needs_n_padding = n_padded != n_dim
needs_k_padding = k_padded != k_dim

if not needs_k_padding and not needs_n_padding:
# No padding needed
a_padded = a
b_col_major_padded = b_col_major
else:
# Padding needed
if a.dim() == 2:
a_padded = a
if needs_k_padding:
a_padded = torch.nn.functional.pad(
a_padded.contiguous(), (0, k_padded - k_dim)
)
b_col_major_padded = torch.zeros(
(n_padded, k_padded),
dtype=b_col_major.dtype,
device=b_col_major.device,
)
b_col_major_padded[:n_dim, :k_dim].copy_(b_col_major)
else:
a_padded = a
if needs_k_padding:
a_padded = torch.nn.functional.pad(
a_padded.contiguous(), (0, k_padded - k_dim)
)

b_underlying_padded = torch.zeros(
(batch_size, n_padded, k_padded),
dtype=b_col_major.dtype,
device=b_col_major.device,
)
b_col_major_padded = b_underlying_padded.transpose(-2, -1)
b_col_major_padded[:, :k_dim, :n_dim].copy_(b_col_major)
Comment on lines +267 to +292
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's some code duplication in how a_padded is handled for 2D and 3D cases. You can hoist the padding logic for a out of the if a.dim() == 2: block to avoid repetition and improve maintainability.

                    a_padded = a
                    if needs_k_padding:
                        a_padded = torch.nn.functional.pad(
                            a_padded.contiguous(), (0, k_padded - k_dim)
                        )

                    if a.dim() == 2:
                        b_col_major_padded = torch.zeros(
                            (n_padded, k_padded),
                            dtype=b_col_major.dtype,
                            device=b_col_major.device,
                        )
                        b_col_major_padded[:n_dim, :k_dim].copy_(b_col_major)
                    else:
                        b_underlying_padded = torch.zeros(
                            (batch_size, n_padded, k_padded),
                            dtype=b_col_major.dtype,
                            device=b_col_major.device,
                        )
                        b_col_major_padded = b_underlying_padded.transpose(-2, -1)
                        b_col_major_padded[:, :k_dim, :n_dim].copy_(b_col_major)


# Create padded output if needed
if needs_n_padding:
if a.dim() == 2:
out_padded = torch.empty(
(m_dim, n_padded), device=out.device, dtype=out.dtype
)
else:
out_padded = torch.empty(
(batch_size, m_dim, n_padded),
device=out.device,
dtype=out.dtype,
)
else:
out_padded = out

# For scalar scales, create compatible shapes for SM120
# SM120 requires scale tensors with specific shapes based on granularity
# Scale shape should be [m/scale_gran_m, k/scale_gran_k] for A
# and [n/scale_gran_n, k/scale_gran_k] for B
if scale_a.numel() == 1:
scale_m_count = (
batch_size * m_dim + scale_gran_m - 1
) // scale_gran_m
scale_k_count = (
k_dim + scale_gran_k - 1
) // scale_gran_k # k dimension
scale_k_count = (k_padded + scale_gran_k - 1) // scale_gran_k
scale_a_expanded = (
scale_a.view(1, 1)
.expand(scale_m_count, scale_k_count)
Expand All @@ -271,13 +321,10 @@ def forward(
scale_a_expanded = scale_a

if scale_b.numel() == 1:
# Calculate the expected scale dimensions
scale_n_count = (
batch_size * n_dim + scale_gran_n - 1
batch_size * n_padded + scale_gran_n - 1
) // scale_gran_n
scale_k_count = (
k_dim + scale_gran_k - 1
) // scale_gran_k # k dimension
scale_k_count = (k_padded + scale_gran_k - 1) // scale_gran_k
scale_b_expanded = (
scale_b.view(1, 1)
.expand(scale_n_count, scale_k_count)
Expand All @@ -289,16 +336,24 @@ def forward(
# Call SM120 gemm_fp8_nt_groupwise (now handles both 2D and 3D)
module.gemm_fp8_nt_groupwise(
workspace_buffer,
a,
b_col_major,
a_padded,
b_col_major_padded,
scale_a_expanded,
scale_b_expanded,
out,
out_padded,
scale_gran_m, # scale_granularity_m
scale_gran_n, # scale_granularity_n
scale_gran_k, # scale_granularity_k (adjusted for small k)
"MN", # scale_major_mode
)

# Slice the result if we padded
if needs_n_padding:
if a.dim() == 2:
out.copy_(out_padded[:, :n_dim])
else:
out.copy_(out_padded[:, :, :n_dim])

return out

return CutlassFp8GemmRunner()
Expand Down Expand Up @@ -2300,9 +2355,8 @@ def _heuristic_func_bmm_fp8(
if is_sm_supported:
heuristic_backends.append("cutlass_sm10x")
elif is_sm120_supported:
k_dim = A.shape[-1] if A.dim() == 2 else A.shape[2]
if k_dim >= 128:
heuristic_backends.append("cutlass_sm12x")
# supports all K values through padding
heuristic_backends.append("cutlass_sm12x")
if "cublas" in suitable_backends:
heuristic_backends.append("cublas")
if CUDNN_AVAILABLE and "cudnn" in suitable_backends:
Expand Down
14 changes: 3 additions & 11 deletions tests/gemm/test_bmm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,16 @@


@pytest.mark.parametrize("b", [1, 16])
@pytest.mark.parametrize("m", [48, 128])
@pytest.mark.parametrize("n", [80, 64])
@pytest.mark.parametrize("k", [64, 256])
@pytest.mark.parametrize("m", [1, 48, 128])
@pytest.mark.parametrize("n", [64, 80, 10304])
@pytest.mark.parametrize("k", [64, 256, 2688])
@pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("backend", ["cudnn", "cublas", "cutlass", "auto"])
@pytest.mark.parametrize("auto_tuning", [True, False])
def test_bmm_fp8(b, m, n, k, input_dtype, mat2_dtype, res_dtype, backend, auto_tuning):
compute_capability = get_compute_capability(torch.device("cuda"))
if compute_capability[0] == 12 and backend in [
"cutlass",
"auto",
]:
# TODO(yongwwww): enable all test cases for SM120/121 CUTLASS bmm_fp8 backend
pytest.xfail(
"Not all test cases for CUTLASS bmm_fp8 on SM120/121 are passing at this moment"
)
if backend == "cutlass" and compute_capability[0] not in [10, 11, 12]:
pytest.skip(
"bmm_fp8 with cutlass backend is only supported on SM100, SM110, and SM120/121 GPUs."
Expand Down