Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions aiter/ops/triton/_triton_kernels/fused_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,112 @@ def _fp8_quant_op(
return x, scale_out


@triton.jit
def _fused_rms_fp8_per_tensor_static_quant_kernel(
inp1_ptr,
weight1_ptr,
inp2_ptr,
weight2_ptr,
res1_ptr,
out1_fp8_ptr,
out2_ptr,
out_res1_ptr,
out1_ptr,
scale_ptr,
eps1,
eps2,
n_rows,
inp1_n_cols,
inp2_n_cols,
inp1_row_stride,
inp2_row_stride,
inp1_col_stride,
inp2_col_stride,
res1_row_stride,
res1_col_stride,
out1_fp8_row_stride,
out1_fp8_col_stride,
out2_row_stride,
out2_col_stride,
out_res1_row_stride,
out_res1_col_stride,
out1_row_stride,
out1_col_stride,
BLOCK_SIZE_N: tl.constexpr,
DTYPE_MAX: tl.constexpr,
DTYPE_MIN: tl.constexpr,
HAVE_SECOND_INPUT: tl.constexpr,
FIRST_INPUT_RES: tl.constexpr,
FIRST_INPUT_OUT: tl.constexpr,
):
m_pid = tl.program_id(0)
n_offs = tl.arange(0, BLOCK_SIZE_N)

mask1 = n_offs < inp1_n_cols
inp1 = tl.load(
inp1_ptr + m_pid * inp1_row_stride + n_offs * inp1_col_stride,
mask=mask1,
other=0.0,
cache_modifier=".cg",
).to(tl.float32)

if FIRST_INPUT_RES:
res1 = tl.load(
res1_ptr + m_pid * res1_row_stride + n_offs * res1_col_stride,
mask=mask1,
other=0.0,
cache_modifier=".cg",
).to(tl.float32)
inp1 = inp1 + res1

w1 = tl.load(weight1_ptr + n_offs, mask=mask1, other=0.0).to(tl.float32)
norm1 = _rmsmorm_op(inp1, w1, inp1_n_cols, eps1)

if FIRST_INPUT_OUT:
mask1 = n_offs < inp1_n_cols
tl.store(
out1_ptr + m_pid * out1_row_stride + n_offs * out1_col_stride,
norm1,
mask=mask1,
)

# apply quantization
scale = tl.load(scale_ptr).to(tl.float32)
scale_recip = 1.0 / scale
out1_fp8 = tl.clamp(norm1 * scale_recip, DTYPE_MIN, DTYPE_MAX)

# store the results
tl.store(
out1_fp8_ptr + m_pid * out1_fp8_row_stride + n_offs * out1_fp8_col_stride,
out1_fp8.to(out1_fp8_ptr.dtype.element_ty),
mask=mask1,
)

if HAVE_SECOND_INPUT:
mask2 = n_offs < inp2_n_cols
inp2 = tl.load(
inp2_ptr + m_pid * inp2_row_stride + n_offs * inp2_col_stride,
mask=mask2,
other=0.0,
cache_modifier=".cg",
).to(tl.float32)
w2 = tl.load(weight2_ptr + n_offs, mask=mask2, other=0.0).to(tl.float32)
norm2 = _rmsmorm_op(inp2, w2, inp2_n_cols, eps2)
tl.store(
out2_ptr + m_pid * out2_row_stride + n_offs * out2_col_stride,
norm2,
mask=mask2,
)

if FIRST_INPUT_RES:
inp1 = inp1.to(out_res1_ptr.dtype.element_ty)
tl.store(
out_res1_ptr + m_pid * out_res1_row_stride + n_offs * out_res1_col_stride,
inp1,
mask=mask1,
)


@triton.jit
def _fused_rms_fp8_group_quant_kernel(
inp1_ptr,
Expand Down
136 changes: 136 additions & 0 deletions aiter/ops/triton/fused_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import triton
import aiter
from aiter.ops.triton._triton_kernels.fused_fp8_quant import (
_fused_rms_fp8_per_tensor_static_quant_kernel,
_fused_rms_fp8_group_quant_kernel,
_fused_flatten_fp8_group_quant_kernel,
_fused_reduce_act_mul_fp8_group_quant,
Expand All @@ -18,6 +19,141 @@
fp8_dtype = aiter.dtypes.fp8


def fused_rms_fp8_per_tensor_static_quant(
inp1,
inp1_weight,
inp1_epsilon,
inp1_scale,
inp2=None,
inp2_weight=None,
inp2_epsilon=None,
dtype_quant=fp8_dtype,
res1=None,
output_unquantized_inp1=False,
):
"""
This op contains several steps:
1. if res1 is not None, inp1 = inp1 + res1, and store inp1 to out_res1
2. perform RMS norm along the last dimenion for inp1
3. if inp2 is not None, perform RMS norm along the last dimenion for inp2
4. perform fp8 quantization for inp1 only

Key parameters:
- x: Matrix X with shape (M, N1, N2).

Returns:
- out1_fp8: The output matrix with shape (M, N1).
- out1_s: The output matrix with shape (1,).
- out1: The output matrix with shape (M, N1).
- out2: The output matrix with shape (M, N2).
- out_res1: The output matrix with shape (M, N1).
- out1: The output matrix with shape (M, N1).
"""
M, N1 = inp1.shape
BLOCK_SIZE_N = triton.next_power_of_2(N1)
if inp2 is not None:
M2, N2 = inp2.shape
BLOCK_SIZE_N = triton.next_power_of_2(N2)
assert (
M == M2
), "The leading dimension should be identical between inp1 and inp2"
else:
N2 = 0
out1_fp8 = torch.empty((M, N1), dtype=dtype_quant, device=inp1.device)

out2 = None
out2_row_stride = 0
out2_col_stride = 0
inp2_row_stride = 0
inp2_col_stride = 0
if inp2 is not None:
out2 = torch.empty((M, N2), dtype=inp1.dtype, device=inp1.device)
inp2_row_stride = inp2.stride(0)
inp2_col_stride = inp2.stride(1)
out2_row_stride = out2.stride(0)
out2_col_stride = out2.stride(1)

out1 = None
out1_row_stride = 0
out1_col_stride = 0
if output_unquantized_inp1:
out1 = torch.empty((M, N1), dtype=inp1.dtype, device=inp1.device)
out1_row_stride = out1.stride(0)
out1_col_stride = out2.stride(1)

out_res1 = None
res1_row_stride = 0
res1_col_stride = 0
out_res1_row_stride = 0
out_res1_col_stride = 0
if res1 is not None:
Mr, Nr = res1.shape
assert (
M == Mr and N1 == Nr
), "The shape should be identical between inp1 and res1"
out_res1 = torch.empty((M, N1), dtype=inp1.dtype, device=inp1.device)
res1_row_stride = res1.stride(0)
res1_col_stride = res1.stride(1)
out_res1_row_stride = out_res1.stride(0)
out_res1_col_stride = out_res1.stride(1)

if BLOCK_SIZE_N <= 512:
num_warps = 1
elif BLOCK_SIZE_N <= 2048:
num_warps = 4
elif BLOCK_SIZE_N <= 4096:
num_warps = 8
else:
num_warps = 16

DTYPE_MAX = (
torch.finfo(out1_fp8.dtype).max
if torch.is_floating_point(out1_fp8)
else torch.iinfo(out1_fp8.dtype).max
)

_fused_rms_fp8_per_tensor_static_quant_kernel[(M,)](
inp1,
inp1_weight,
inp2,
inp2_weight,
res1,
out1_fp8,
out2,
out_res1,
out1,
inp1_scale,
inp1_epsilon,
inp2_epsilon,
M,
N1,
N2,
inp1.stride(0),
inp2_row_stride,
inp1.stride(1),
inp2_col_stride,
res1_row_stride,
res1_col_stride,
out1_fp8.stride(0),
out1_fp8.stride(1),
out2_row_stride,
out2_col_stride,
out_res1_row_stride,
out_res1_col_stride,
out1_row_stride,
out1_col_stride,
BLOCK_SIZE_N=BLOCK_SIZE_N,
DTYPE_MAX=DTYPE_MAX,
DTYPE_MIN=-DTYPE_MAX,
HAVE_SECOND_INPUT=(inp2 is not None),
FIRST_INPUT_RES=(res1 is not None),
FIRST_INPUT_OUT=output_unquantized_inp1,
num_warps=num_warps,
)

return out1_fp8, out1, out2, out_res1


def fused_rms_fp8_group_quant(
inp1,
inp1_weight,
Expand Down
56 changes: 56 additions & 0 deletions op_tests/triton_tests/test_fused_fp8_quant.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import pytest
from aiter.ops.triton.fused_fp8_quant import (
fused_rms_fp8_per_tensor_static_quant,
fused_rms_fp8_group_quant,
fused_flatten_fp8_group_quant,
fused_reduce_act_mul_fp8_group_quant,
Expand Down Expand Up @@ -37,6 +38,13 @@ def per_token_fp8_group_quant(x, dtype_quant, group_size=128):
return x_quant, x_scale


def per_tensor_fp8_static_quant(x, dtype_quant, x_scale):
DTYPE_MAX = torch.finfo(dtype_quant).max
scale_recip = 1.0 / x_scale
x_quant = torch.clamp(x * scale_recip, -DTYPE_MAX, DTYPE_MAX).to(dtype_quant)
return x_quant


def upcast(x, s, dtype, group_size=128):
x_N = x.shape[1]
x = x.reshape(-1, x_N // group_size, group_size).to(torch.float32) * s.reshape(
Expand Down Expand Up @@ -65,6 +73,54 @@ def generate_fused_rms_quant_data(M, N1, N2, dtype=torch.bfloat16):
return x1, w1, x2, w2, res1


def run_torch_rms_fp8_per_tensor_static_quant(
x1, w1, eps1, x2, w2, eps2, res1, dtype_quant, x1_scale
):
s = x1 + res1
y1 = rmsnorm(s, w1, eps1)
y2 = rmsnorm(x2, w2, eps2)
y1_q = per_tensor_fp8_static_quant(y1, dtype_quant, x1_scale)
return y1_q, y1.to(x1.dtype), y2.to(x1.dtype), s.to(x1.dtype)


@pytest.mark.parametrize("M", [1, 32, 256])
@pytest.mark.parametrize("N1, N2", [(128, 128), (128, 7168), (7168, 7168)])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_rms_fp8_per_tensor_static_quant(M: int, N1: int, N2: int, dtype):
dtype_quant = aiter.dtypes.fp8
scale = torch.randn(1, dtype=torch.float32, device="cuda")
x1, w1, x2, w2, res1 = generate_fused_rms_quant_data(M, N1, N2, dtype)

y1_q_torch, y1_torch, y2_torch, y1_res_torch = (
run_torch_rms_fp8_per_tensor_static_quant(
x1, w1, 1e-6, x2, w2, 1e-6, res1, dtype_quant, scale
)
)

y1_q_triton, y1_triton, y2_triton, y1_res_triton = (
fused_rms_fp8_per_tensor_static_quant(
x1,
w1,
1e-6,
scale,
inp2=x2,
inp2_weight=w2,
inp2_epsilon=1e-6,
dtype_quant=dtype_quant,
res1=res1,
output_unquantized_inp1=True,
)
)

torch.testing.assert_close(y1_torch, y1_triton, atol=0.1, rtol=0.1)
torch.testing.assert_close(y2_torch, y2_triton, atol=0.1, rtol=0.1)
torch.testing.assert_close(y1_res_torch, y1_res_triton, atol=0.1, rtol=0.1)

y1_upcast_torch = y1_q_torch.to(torch.float32) * scale
y1_upcast_triton = y1_q_triton.to(torch.float32) * scale
torch.testing.assert_close(y1_upcast_torch, y1_upcast_triton, atol=0.1, rtol=0.1)


@pytest.mark.parametrize("M", [1, 32, 256])
@pytest.mark.parametrize("N1, N2", [(128, 128), (128, 7168), (7168, 7168)])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
Expand Down