Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions aiter/ops/triton/fused_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def fused_rms_fp8_group_quant(
dtype_quant=fp8_dtype,
res1=None,
output_unquantized_inp1=False,
transpose_scale=False,
):
"""
This op contains several steps:
Expand All @@ -39,10 +40,14 @@ def fused_rms_fp8_group_quant(

Key parameters:
- x: Matrix X with shape (M, N1, N2).
- transpose_scale: If True, return scale with shape (M, cdiv(N1, group_size)) but stored in
column-major (transposed) memory layout. Equivalent to:
scale.transpose(0, 1).contiguous().view(*scale.shape)

Returns:
- out1_fp8: The output matrix with shape (M, N1).
- out1_bs: The output matrix with shape (M, cdiv(N1, group_size)).
When transpose_scale=True, has column-major memory layout (transposed storage).
- out1: The output matrix with shape (M, N1).
- out2: The output matrix with shape (M, N2).
- out_res1: The output matrix with shape (M, N1).
Expand All @@ -60,11 +65,20 @@ def fused_rms_fp8_group_quant(
else:
N2 = 0
out1_fp8 = torch.empty((M, N1), dtype=dtype_quant, device=inp1.device)
out1_bs = torch.empty(
(M, (N1 + group_size - 1) // group_size),
dtype=torch.float32,
device=inp1.device,
)
num_bs_cols = (N1 + group_size - 1) // group_size
if transpose_scale:
# Create with transposed shape for direct transposed storage
out1_bs = torch.empty(
(num_bs_cols, M),
dtype=torch.float32,
device=inp1.device,
)
else:
out1_bs = torch.empty(
(M, num_bs_cols),
dtype=torch.float32,
device=inp1.device,
)

out2 = None
out2_row_stride = 0
Expand Down Expand Up @@ -117,6 +131,15 @@ def fused_rms_fp8_group_quant(
if torch.is_floating_point(out1_fp8)
else torch.iinfo(out1_fp8.dtype).max
)

# When transpose_scale=True, swap the strides to write directly in transposed layout
if transpose_scale:
out1_bs_row_stride = out1_bs.stride(1)
out1_bs_col_stride = out1_bs.stride(0)
else:
out1_bs_row_stride = out1_bs.stride(0)
out1_bs_col_stride = out1_bs.stride(1)

_fused_rms_fp8_group_quant_kernel[(M,)](
inp1,
inp1_weight,
Expand All @@ -141,8 +164,8 @@ def fused_rms_fp8_group_quant(
res1_col_stride,
out1_fp8.stride(0),
out1_fp8.stride(1),
out1_bs.stride(0),
out1_bs.stride(1),
out1_bs_row_stride,
out1_bs_col_stride,
out2_row_stride,
out2_col_stride,
out_res1_row_stride,
Expand All @@ -158,6 +181,10 @@ def fused_rms_fp8_group_quant(
FIRST_INPUT_OUT=output_unquantized_inp1,
num_warps=num_warps,
)
# When transpose_scale=True, view the transposed buffer back to original shape
# This keeps shape (M, num_bs_cols) but with column-major memory layout
if transpose_scale:
out1_bs = out1_bs.view(M, num_bs_cols)

return (out1_fp8, out1_bs), out1, out2, out_res1

Expand Down
78 changes: 78 additions & 0 deletions op_tests/triton_tests/test_fused_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,84 @@ def test_fused_rms_fp8_group_quant(M: int, N1: int, N2: int, dtype):
torch.testing.assert_close(y1_upcast_torch, y1_upcast_triton, atol=0.1, rtol=0.1)


@pytest.mark.parametrize("M", [1, 32, 256])
@pytest.mark.parametrize("N1, N2", [(128, 128), (128, 7168), (7168, 7168)])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_rms_fp8_group_quant_transpose_scale(M: int, N1: int, N2: int, dtype):
"""Test that transpose_scale parameter returns scale with transposed memory layout."""
group_size = 128
dtype_quant = aiter.dtypes.fp8
x1, w1, x2, w2, res1 = generate_fused_rms_quant_data(M, N1, N2, dtype)

# Call with transpose_scale=False (original behavior)
(y1_q_orig, y1_s_orig), y1_orig, y2_orig, y1_res_orig = fused_rms_fp8_group_quant(
x1,
w1,
1e-6,
inp2=x2,
inp2_weight=w2,
inp2_epsilon=1e-6,
group_size=group_size,
dtype_quant=dtype_quant,
res1=res1,
output_unquantized_inp1=True,
transpose_scale=False,
)

# Call with transpose_scale=True
(
(y1_q_transposed, y1_s_transposed),
y1_transposed,
y2_transposed,
y1_res_transposed,
) = fused_rms_fp8_group_quant(
x1,
w1,
1e-6,
inp2=x2,
inp2_weight=w2,
inp2_epsilon=1e-6,
group_size=group_size,
dtype_quant=dtype_quant,
res1=res1,
output_unquantized_inp1=True,
transpose_scale=True,
)

num_bs_cols = (N1 + group_size - 1) // group_size

# Verify that both outputs have the same shape
assert y1_s_orig.shape == (
M,
num_bs_cols,
), f"Expected shape (M, num_bs_cols), got {y1_s_orig.shape}"
assert y1_s_transposed.shape == (
M,
num_bs_cols,
), f"Expected shape (M, num_bs_cols), got {y1_s_transposed.shape}"

# Verify that transpose_scale=True version is equivalent to .transpose().contiguous().view()
y1_s_expected = y1_s_orig.transpose(0, 1).contiguous().view(*y1_s_orig.shape)

# Verify that both have the same shape and strides (row-major)
assert (
y1_s_orig.stride() == y1_s_transposed.stride()
), "Both should have row-major strides"
assert (
y1_s_orig.is_contiguous() and y1_s_transposed.is_contiguous()
), "Both should be contiguous"

# Verify numerical correctness - values should match the transpose().contiguous().view() pattern
torch.testing.assert_close(y1_s_transposed, y1_s_expected, atol=1e-6, rtol=1e-6)

# Verify that other outputs are identical
# For fp8 tensors, use exact bitwise comparison
torch.testing.assert_close(y1_q_transposed, y1_q_orig, atol=0, rtol=0)
torch.testing.assert_close(y1_transposed, y1_orig, atol=0.1, rtol=0.1)
torch.testing.assert_close(y2_transposed, y2_orig, atol=0.1, rtol=0.1)
torch.testing.assert_close(y1_res_transposed, y1_res_orig, atol=0.1, rtol=0.1)


def run_torch_flatten_fp8_group_quant(x, dtype_quant, group_size):
y_q, y_s = per_token_fp8_group_quant(
x.reshape(x.shape[0], -1), dtype_quant, group_size
Expand Down