Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions aiter/ops/triton/quant/fused_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,25 +497,46 @@ def fused_flatten_fp8_group_quant(
x: torch.Tensor,
group_size,
dtype_quant=fp8_dtype,
transpose_scale: bool = False,
):
"""
Flatten the last two dimension of x and perform FP8 per-token group quantization along the last dimension

Key parameters:
- x: Matrix X with shape (M, N1, N2).
- transpose_scale: If True, return scale with shape (M, cdiv(N1*N2, group_size))
but stored in column-major (transposed) memory layout.
Equivalent to: scale.transpose(0, 1).contiguous().view(*scale.shape)
Mirrors the same flag on fused_rms_fp8_group_quant.

Returns:
- out: The output matrix with shape (M, N1 * N2).
- out_block_scales: The output matrix with shape (M, cdiv((N1 * N2), group_size)).
When transpose_scale=True, has column-major memory layout.
"""
M, N1, N2 = x.shape

BLOCK_SIZE_N2 = max(triton.next_power_of_2(N2), group_size)
N = N1 * N2
num_bs_cols = triton.cdiv(N, group_size)
out = torch.empty((M, N), dtype=dtype_quant, device=x.device)
out_block_scales = torch.empty(
(M, triton.cdiv(N, group_size)), dtype=torch.float32, device=x.device
)

if transpose_scale:
# Allocate as (num_bs_cols, M) so the row-major storage of this tensor
# is equivalent to a column-major (M, num_bs_cols) layout. We then
# tell the inner kernel to write with swapped strides, and view back
# to (M, num_bs_cols) at the end.
out_block_scales = torch.empty(
(num_bs_cols, M), dtype=torch.float32, device=x.device
)
out_bs_row_stride = out_block_scales.stride(1) # = 1
out_bs_col_stride = out_block_scales.stride(0) # = M
else:
out_block_scales = torch.empty(
(M, num_bs_cols), dtype=torch.float32, device=x.device
)
out_bs_row_stride = out_block_scales.stride(0) # = num_bs_cols
out_bs_col_stride = out_block_scales.stride(1) # = 1

DTYPE_MAX = (
torch.finfo(out.dtype).max
Expand All @@ -532,14 +553,22 @@ def fused_flatten_fp8_group_quant(
out_block_scales,
*x.stride(),
*out.stride(),
*out_block_scales.stride(),
out_bs_row_stride,
out_bs_col_stride,
N2,
BLOCK_SIZE_N2=BLOCK_SIZE_N2,
QUANT_BLOCK_SIZE=group_size,
DTYPE_MAX=DTYPE_MAX,
DTYPE_MIN=-DTYPE_MAX,
)

if transpose_scale:
# Reinterpret the (num_bs_cols, M) row-major buffer back as
# (M, num_bs_cols) — same shape as default path, but data is now
# in column-major layout (consumers like CK bpreshuffle GEMM
# expect this layout when called with the new attribute marker).
out_block_scales = out_block_scales.view(M, num_bs_cols)

return out, out_block_scales


Expand Down
60 changes: 60 additions & 0 deletions op_tests/triton_tests/quant/test_fused_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,66 @@ def test_fused_flatten_fp8_group_quant(M: int, N1: int, N2: int, dtype):
torch.testing.assert_close(y_upcast_torch, y_upcast_triton, atol=0.1, rtol=0.1)


@pytest.mark.parametrize("M", [1, 32, 256])
@pytest.mark.parametrize("N1, N2", [(16, 128)])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_flatten_fp8_group_quant_transpose_scale(
M: int, N1: int, N2: int, dtype
):
"""Test that transpose_scale parameter returns scale with transposed memory layout."""
torch.manual_seed(0)
group_size = 128
dtype_quant = aiter.dtypes.fp8
x = torch.randn((N1, M, N2), dtype=dtype, device="cuda") / 10
x = x.transpose(0, 1)

# Call with transpose_scale=False (original behavior)
y_q_orig, y_s_orig = fused_flatten_fp8_group_quant(
x,
group_size=group_size,
dtype_quant=dtype_quant,
transpose_scale=False,
)

# Call with transpose_scale=True
y_q_transposed, y_s_transposed = fused_flatten_fp8_group_quant(
x,
group_size=group_size,
dtype_quant=dtype_quant,
transpose_scale=True,
)

num_bs_cols = (N1 * N2 + group_size - 1) // group_size

# Verify that both outputs have the same shape
assert y_s_orig.shape == (
M,
num_bs_cols,
), f"Expected shape (M, num_bs_cols), got {y_s_orig.shape}"
assert y_s_transposed.shape == (
M,
num_bs_cols,
), f"Expected shape (M, num_bs_cols), got {y_s_transposed.shape}"

# Verify that transpose_scale=True version is equivalent to .transpose().contiguous().view()
y_s_expected = y_s_orig.transpose(0, 1).contiguous().view(*y_s_orig.shape)

# Verify that both have the same shape and strides (row-major after view)
assert (
y_s_orig.stride() == y_s_transposed.stride()
), "Both should have row-major strides"
assert (
y_s_orig.is_contiguous() and y_s_transposed.is_contiguous()
), "Both should be contiguous"

# Verify numerical correctness - values should match the transpose().contiguous().view() pattern
torch.testing.assert_close(y_s_transposed, y_s_expected, atol=1e-6, rtol=1e-6)

# Verify that the quantized fp8 tensor is identical
# For fp8 tensors, use exact bitwise comparison
torch.testing.assert_close(y_q_transposed, y_q_orig, atol=0, rtol=0)


def run_torch_reduce_act_mul_fp8_group_quant(
x, x2, activation, dtype, dtype_quant, group_size=128
):
Expand Down
Loading