diff --git a/aiter/ops/triton/quant/fused_fp8_quant.py b/aiter/ops/triton/quant/fused_fp8_quant.py index f0583a86b3..52b83d3125 100644 --- a/aiter/ops/triton/quant/fused_fp8_quant.py +++ b/aiter/ops/triton/quant/fused_fp8_quant.py @@ -497,25 +497,46 @@ def fused_flatten_fp8_group_quant( x: torch.Tensor, group_size, dtype_quant=fp8_dtype, + transpose_scale: bool = False, ): """ Flatten the last two dimension of x and perform FP8 per-token group quantization along the last dimension Key parameters: - x: Matrix X with shape (M, N1, N2). + - transpose_scale: If True, return scale with shape (M, cdiv(N1*N2, group_size)) + but stored in column-major (transposed) memory layout. + Equivalent to: scale.transpose(0, 1).contiguous().view(*scale.shape) + Mirrors the same flag on fused_rms_fp8_group_quant. Returns: - out: The output matrix with shape (M, N1 * N2). - out_block_scales: The output matrix with shape (M, cdiv((N1 * N2), group_size)). + When transpose_scale=True, has column-major memory layout. """ M, N1, N2 = x.shape BLOCK_SIZE_N2 = max(triton.next_power_of_2(N2), group_size) N = N1 * N2 + num_bs_cols = triton.cdiv(N, group_size) out = torch.empty((M, N), dtype=dtype_quant, device=x.device) - out_block_scales = torch.empty( - (M, triton.cdiv(N, group_size)), dtype=torch.float32, device=x.device - ) + + if transpose_scale: + # Allocate as (num_bs_cols, M) so the row-major storage of this tensor + # is equivalent to a column-major (M, num_bs_cols) layout. We then + # tell the inner kernel to write with swapped strides, and view back + # to (M, num_bs_cols) at the end. + out_block_scales = torch.empty( + (num_bs_cols, M), dtype=torch.float32, device=x.device + ) + out_bs_row_stride = out_block_scales.stride(1) # = 1 + out_bs_col_stride = out_block_scales.stride(0) # = M + else: + out_block_scales = torch.empty( + (M, num_bs_cols), dtype=torch.float32, device=x.device + ) + out_bs_row_stride = out_block_scales.stride(0) # = num_bs_cols + out_bs_col_stride = out_block_scales.stride(1) # = 1 DTYPE_MAX = ( torch.finfo(out.dtype).max @@ -532,7 +553,8 @@ def fused_flatten_fp8_group_quant( out_block_scales, *x.stride(), *out.stride(), - *out_block_scales.stride(), + out_bs_row_stride, + out_bs_col_stride, N2, BLOCK_SIZE_N2=BLOCK_SIZE_N2, QUANT_BLOCK_SIZE=group_size, @@ -540,6 +562,13 @@ def fused_flatten_fp8_group_quant( DTYPE_MIN=-DTYPE_MAX, ) + if transpose_scale: + # Reinterpret the (num_bs_cols, M) row-major buffer back as + # (M, num_bs_cols) — same shape as default path, but data is now + # in column-major layout (consumers like CK bpreshuffle GEMM + # expect this layout when called with the new attribute marker). + out_block_scales = out_block_scales.view(M, num_bs_cols) + return out, out_block_scales diff --git a/op_tests/triton_tests/quant/test_fused_fp8_quant.py b/op_tests/triton_tests/quant/test_fused_fp8_quant.py index 96d4f31fd0..1f69715f43 100644 --- a/op_tests/triton_tests/quant/test_fused_fp8_quant.py +++ b/op_tests/triton_tests/quant/test_fused_fp8_quant.py @@ -365,6 +365,66 @@ def test_fused_flatten_fp8_group_quant(M: int, N1: int, N2: int, dtype): torch.testing.assert_close(y_upcast_torch, y_upcast_triton, atol=0.1, rtol=0.1) +@pytest.mark.parametrize("M", [1, 32, 256]) +@pytest.mark.parametrize("N1, N2", [(16, 128)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_flatten_fp8_group_quant_transpose_scale( + M: int, N1: int, N2: int, dtype +): + """Test that transpose_scale parameter returns scale with transposed memory layout.""" + torch.manual_seed(0) + group_size = 128 + dtype_quant = aiter.dtypes.fp8 + x = torch.randn((N1, M, N2), dtype=dtype, device="cuda") / 10 + x = x.transpose(0, 1) + + # Call with transpose_scale=False (original behavior) + y_q_orig, y_s_orig = fused_flatten_fp8_group_quant( + x, + group_size=group_size, + dtype_quant=dtype_quant, + transpose_scale=False, + ) + + # Call with transpose_scale=True + y_q_transposed, y_s_transposed = fused_flatten_fp8_group_quant( + x, + group_size=group_size, + dtype_quant=dtype_quant, + transpose_scale=True, + ) + + num_bs_cols = (N1 * N2 + group_size - 1) // group_size + + # Verify that both outputs have the same shape + assert y_s_orig.shape == ( + M, + num_bs_cols, + ), f"Expected shape (M, num_bs_cols), got {y_s_orig.shape}" + assert y_s_transposed.shape == ( + M, + num_bs_cols, + ), f"Expected shape (M, num_bs_cols), got {y_s_transposed.shape}" + + # Verify that transpose_scale=True version is equivalent to .transpose().contiguous().view() + y_s_expected = y_s_orig.transpose(0, 1).contiguous().view(*y_s_orig.shape) + + # Verify that both have the same shape and strides (row-major after view) + assert ( + y_s_orig.stride() == y_s_transposed.stride() + ), "Both should have row-major strides" + assert ( + y_s_orig.is_contiguous() and y_s_transposed.is_contiguous() + ), "Both should be contiguous" + + # Verify numerical correctness - values should match the transpose().contiguous().view() pattern + torch.testing.assert_close(y_s_transposed, y_s_expected, atol=1e-6, rtol=1e-6) + + # Verify that the quantized fp8 tensor is identical + # For fp8 tensors, use exact bitwise comparison + torch.testing.assert_close(y_q_transposed, y_q_orig, atol=0, rtol=0) + + def run_torch_reduce_act_mul_fp8_group_quant( x, x2, activation, dtype, dtype_quant, group_size=128 ):