diff --git a/aiter/ops/triton/fused_fp8_quant.py b/aiter/ops/triton/fused_fp8_quant.py index 39e1c58777..afbc5023ef 100644 --- a/aiter/ops/triton/fused_fp8_quant.py +++ b/aiter/ops/triton/fused_fp8_quant.py @@ -29,6 +29,7 @@ def fused_rms_fp8_group_quant( dtype_quant=fp8_dtype, res1=None, output_unquantized_inp1=False, + transpose_scale=False, ): """ This op contains several steps: @@ -39,10 +40,14 @@ def fused_rms_fp8_group_quant( Key parameters: - x: Matrix X with shape (M, N1, N2). + - transpose_scale: If True, return scale with shape (M, cdiv(N1, group_size)) but stored in + column-major (transposed) memory layout. Equivalent to: + scale.transpose(0, 1).contiguous().view(*scale.shape) Returns: - out1_fp8: The output matrix with shape (M, N1). - out1_bs: The output matrix with shape (M, cdiv(N1, group_size)). + When transpose_scale=True, has column-major memory layout (transposed storage). - out1: The output matrix with shape (M, N1). - out2: The output matrix with shape (M, N2). - out_res1: The output matrix with shape (M, N1). @@ -60,11 +65,20 @@ def fused_rms_fp8_group_quant( else: N2 = 0 out1_fp8 = torch.empty((M, N1), dtype=dtype_quant, device=inp1.device) - out1_bs = torch.empty( - (M, (N1 + group_size - 1) // group_size), - dtype=torch.float32, - device=inp1.device, - ) + num_bs_cols = (N1 + group_size - 1) // group_size + if transpose_scale: + # Create with transposed shape for direct transposed storage + out1_bs = torch.empty( + (num_bs_cols, M), + dtype=torch.float32, + device=inp1.device, + ) + else: + out1_bs = torch.empty( + (M, num_bs_cols), + dtype=torch.float32, + device=inp1.device, + ) out2 = None out2_row_stride = 0 @@ -117,6 +131,15 @@ def fused_rms_fp8_group_quant( if torch.is_floating_point(out1_fp8) else torch.iinfo(out1_fp8.dtype).max ) + + # When transpose_scale=True, swap the strides to write directly in transposed layout + if transpose_scale: + out1_bs_row_stride = out1_bs.stride(1) + out1_bs_col_stride = out1_bs.stride(0) + else: + out1_bs_row_stride = out1_bs.stride(0) + out1_bs_col_stride = out1_bs.stride(1) + _fused_rms_fp8_group_quant_kernel[(M,)]( inp1, inp1_weight, @@ -141,8 +164,8 @@ def fused_rms_fp8_group_quant( res1_col_stride, out1_fp8.stride(0), out1_fp8.stride(1), - out1_bs.stride(0), - out1_bs.stride(1), + out1_bs_row_stride, + out1_bs_col_stride, out2_row_stride, out2_col_stride, out_res1_row_stride, @@ -158,6 +181,10 @@ def fused_rms_fp8_group_quant( FIRST_INPUT_OUT=output_unquantized_inp1, num_warps=num_warps, ) + # When transpose_scale=True, view the transposed buffer back to original shape + # This keeps shape (M, num_bs_cols) but with column-major memory layout + if transpose_scale: + out1_bs = out1_bs.view(M, num_bs_cols) return (out1_fp8, out1_bs), out1, out2, out_res1 diff --git a/op_tests/triton_tests/test_fused_fp8_quant.py b/op_tests/triton_tests/test_fused_fp8_quant.py index 9621756f4f..52ca1133c0 100644 --- a/op_tests/triton_tests/test_fused_fp8_quant.py +++ b/op_tests/triton_tests/test_fused_fp8_quant.py @@ -107,6 +107,84 @@ def test_fused_rms_fp8_group_quant(M: int, N1: int, N2: int, dtype): torch.testing.assert_close(y1_upcast_torch, y1_upcast_triton, atol=0.1, rtol=0.1) +@pytest.mark.parametrize("M", [1, 32, 256]) +@pytest.mark.parametrize("N1, N2", [(128, 128), (128, 7168), (7168, 7168)]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_fused_rms_fp8_group_quant_transpose_scale(M: int, N1: int, N2: int, dtype): + """Test that transpose_scale parameter returns scale with transposed memory layout.""" + group_size = 128 + dtype_quant = aiter.dtypes.fp8 + x1, w1, x2, w2, res1 = generate_fused_rms_quant_data(M, N1, N2, dtype) + + # Call with transpose_scale=False (original behavior) + (y1_q_orig, y1_s_orig), y1_orig, y2_orig, y1_res_orig = fused_rms_fp8_group_quant( + x1, + w1, + 1e-6, + inp2=x2, + inp2_weight=w2, + inp2_epsilon=1e-6, + group_size=group_size, + dtype_quant=dtype_quant, + res1=res1, + output_unquantized_inp1=True, + transpose_scale=False, + ) + + # Call with transpose_scale=True + ( + (y1_q_transposed, y1_s_transposed), + y1_transposed, + y2_transposed, + y1_res_transposed, + ) = fused_rms_fp8_group_quant( + x1, + w1, + 1e-6, + inp2=x2, + inp2_weight=w2, + inp2_epsilon=1e-6, + group_size=group_size, + dtype_quant=dtype_quant, + res1=res1, + output_unquantized_inp1=True, + transpose_scale=True, + ) + + num_bs_cols = (N1 + group_size - 1) // group_size + + # Verify that both outputs have the same shape + assert y1_s_orig.shape == ( + M, + num_bs_cols, + ), f"Expected shape (M, num_bs_cols), got {y1_s_orig.shape}" + assert y1_s_transposed.shape == ( + M, + num_bs_cols, + ), f"Expected shape (M, num_bs_cols), got {y1_s_transposed.shape}" + + # Verify that transpose_scale=True version is equivalent to .transpose().contiguous().view() + y1_s_expected = y1_s_orig.transpose(0, 1).contiguous().view(*y1_s_orig.shape) + + # Verify that both have the same shape and strides (row-major) + assert ( + y1_s_orig.stride() == y1_s_transposed.stride() + ), "Both should have row-major strides" + assert ( + y1_s_orig.is_contiguous() and y1_s_transposed.is_contiguous() + ), "Both should be contiguous" + + # Verify numerical correctness - values should match the transpose().contiguous().view() pattern + torch.testing.assert_close(y1_s_transposed, y1_s_expected, atol=1e-6, rtol=1e-6) + + # Verify that other outputs are identical + # For fp8 tensors, use exact bitwise comparison + torch.testing.assert_close(y1_q_transposed, y1_q_orig, atol=0, rtol=0) + torch.testing.assert_close(y1_transposed, y1_orig, atol=0.1, rtol=0.1) + torch.testing.assert_close(y2_transposed, y2_orig, atol=0.1, rtol=0.1) + torch.testing.assert_close(y1_res_transposed, y1_res_orig, atol=0.1, rtol=0.1) + + def run_torch_flatten_fp8_group_quant(x, dtype_quant, group_size): y_q, y_s = per_token_fp8_group_quant( x.reshape(x.shape[0], -1), dtype_quant, group_size