diff --git a/.github/scripts/split_tests.sh b/.github/scripts/split_tests.sh index 05e58b9626..b9819cc5d6 100755 --- a/.github/scripts/split_tests.sh +++ b/.github/scripts/split_tests.sh @@ -134,7 +134,7 @@ elif [[ "$TEST_TYPE" == "triton" ]]; then FILE_TIMES[op_tests/triton_tests/attention/test_mha.py]=1452 FILE_TIMES[op_tests/triton_tests/test_pa_decode_gluon.py]=718 FILE_TIMES[op_tests/triton_tests/attention/test_pa_decode.py]=635 - FILE_TIMES[op_tests/triton_tests/test_causal_conv1d.py]=634 + FILE_TIMES[op_tests/triton_tests/conv/test_causal_conv1d.py]=634 FILE_TIMES[op_tests/triton_tests/gemm/batched/test_batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant.py]=402 FILE_TIMES[op_tests/triton_tests/attention/test_flash_attn_kvcache.py]=357 FILE_TIMES[op_tests/triton_tests/attention/test_chunked_pa_prefill.py]=336 diff --git a/.gitignore b/.gitignore index 6bab99e0d0..b272a12f1c 100644 --- a/.gitignore +++ b/.gitignore @@ -79,4 +79,8 @@ scripts/results/*.csv # Jupyter notebook checkpoints and outputs .ipynb_checkpoints/ -**/.ipynb_checkpoints/ \ No newline at end of file +**/.ipynb_checkpoints/ + +# vim swaps +*.swo +*.swp \ No newline at end of file diff --git a/aiter/ops/triton/__init__.py b/aiter/ops/triton/__init__.py index bef4ccc506..c1ffd2b3ed 100644 --- a/aiter/ops/triton/__init__.py +++ b/aiter/ops/triton/__init__.py @@ -141,6 +141,9 @@ # Quant modules (quant/) "fused_fp8_quant": "quant.fused_fp8_quant", "fused_mxfp4_quant": "quant.fused_mxfp4_quant", + # Conv modules (conv/) + "causal_conv1d": "conv.causal_conv1d", + "causal_conv1d_update_single_token": "conv.causal_conv1d_update_single_token", } diff --git a/aiter/ops/triton/_triton_kernels/conv/__init__.py b/aiter/ops/triton/_triton_kernels/conv/__init__.py new file mode 100644 index 0000000000..7824504eb0 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/conv/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from .causal_conv1d import PAD_SLOT_ID + +__all__ = ["PAD_SLOT_ID"] diff --git a/aiter/ops/triton/_triton_kernels/causal_conv1d.py b/aiter/ops/triton/_triton_kernels/conv/causal_conv1d.py similarity index 100% rename from aiter/ops/triton/_triton_kernels/causal_conv1d.py rename to aiter/ops/triton/_triton_kernels/conv/causal_conv1d.py diff --git a/aiter/ops/triton/_triton_kernels/causal_conv1d_update_single_token.py b/aiter/ops/triton/_triton_kernels/conv/causal_conv1d_update_single_token.py similarity index 100% rename from aiter/ops/triton/_triton_kernels/causal_conv1d_update_single_token.py rename to aiter/ops/triton/_triton_kernels/conv/causal_conv1d_update_single_token.py diff --git a/aiter/ops/triton/_triton_kernels/quant/fused_fp8_quant.py b/aiter/ops/triton/_triton_kernels/quant/fused_fp8_quant.py index 5acc36b81f..9cc751f2e2 100644 --- a/aiter/ops/triton/_triton_kernels/quant/fused_fp8_quant.py +++ b/aiter/ops/triton/_triton_kernels/quant/fused_fp8_quant.py @@ -181,6 +181,9 @@ def _fused_rms_fp8_group_quant_kernel( out_res1_col_stride, out1_row_stride, out1_col_stride, + gate_ptr, + linear_bias_ptr, + stride_gate_row, BLOCK_SIZE_N: tl.constexpr, QUANT_BLOCK_SIZE: tl.constexpr, DTYPE_MAX: tl.constexpr, @@ -188,81 +191,205 @@ def _fused_rms_fp8_group_quant_kernel( HAVE_SECOND_INPUT: tl.constexpr, FIRST_INPUT_RES: tl.constexpr, FIRST_INPUT_OUT: tl.constexpr, + GATED_RMS_FP8: tl.constexpr, + RMS_TILE: tl.constexpr, + ROWS_PER_BLOCK: tl.constexpr, + GROUP_SIZE_GATED: tl.constexpr, + NUM_GROUPS_GATED: tl.constexpr, + BLOCK_G: tl.constexpr, + HAS_BIAS_GATED: tl.constexpr, + HAS_Z_GATED: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr, + USE_UE8M0: tl.constexpr, + FP8_MIN_SCALING_FACTOR: tl.constexpr, + ACTIVATION: tl.constexpr, ): - m_pid = tl.program_id(0) - n_offs = tl.arange(0, BLOCK_SIZE_N) - NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE + """RMSNorm + FP8 row/group quant (classic) or gated RMSNorm + FP8 (vLLM-style). + + When ``GATED_RMS_FP8`` is True, use grid ``(cdiv(M, ROWS_PER_BLOCK),)`` and batch + ``ROWS_PER_BLOCK`` rows per program (vLLM ``calc_rows_per_block`` heuristic on the host). + Extra pointer args are unused in the classic path but must refer to valid tensors. + """ + if GATED_RMS_FP8: + # --- Gated path (adapted from vLLM / ROCm gated RMSNorm FP8 kernel) --- + X = inp1_ptr + W = weight1_ptr + Bptr = linear_bias_ptr + Z = gate_ptr + Y_quant = out1_fp8_ptr + Scales = out1_bs_ptr + stride_x_row = inp1_row_stride + stride_z_row = stride_gate_row + stride_y_row = out1_fp8_row_stride + stride_s_row = out1_bs_row_stride + stride_s_g = out1_bs_col_stride + M = n_rows + N = inp1_n_cols + eps = eps1 + + row_start = tl.program_id(0) * ROWS_PER_BLOCK + rows = row_start + tl.arange(0, ROWS_PER_BLOCK) + row_mask_1d = rows < M + + sumsq = tl.zeros([ROWS_PER_BLOCK], dtype=tl.float32) + off_rms = 0 + while off_rms < N: + cols = tl.arange(0, RMS_TILE) + off_rms + col_mask = cols < N + mask_r = row_mask_1d[:, None] & col_mask[None, :] + row_offsets = rows[:, None] * stride_x_row + col_offsets = cols[None, :] + X_base = X + row_offsets + col_offsets + x_el = tl.load(X_base, mask=mask_r, other=0.0).to(tl.float32) + if HAS_Z_GATED and (not NORM_BEFORE_GATE): + Z_base = Z + rows[:, None] * stride_z_row + col_offsets + z_el = tl.load(Z_base, mask=mask_r, other=0.0).to(tl.float32) + if ACTIVATION == "swish": + x_el = x_el * (z_el * tl.sigmoid(z_el)) + elif ACTIVATION == "silu": + x_el = x_el * (z_el * tl.sigmoid(z_el)) + elif ACTIVATION == "sigmoid": + x_el = x_el * tl.sigmoid(z_el) + xbar_sq = tl.where(mask_r, x_el, 0.0) + sumsq = sumsq + tl.sum(xbar_sq * xbar_sq, axis=1) + off_rms += RMS_TILE + + var = sumsq / N + rstd = tl.math.rsqrt(var + eps) + + for g in range(NUM_GROUPS_GATED): + col0 = g * GROUP_SIZE_GATED + cols = tl.arange(0, BLOCK_G) + col0 + col_mask = (cols < N) & (cols < col0 + GROUP_SIZE_GATED) + mask_g = row_mask_1d[:, None] & col_mask[None, :] + row_offsets = rows[:, None] * stride_x_row + col_offsets = cols[None, :] + X_base = X + row_offsets + col_offsets + x_el = tl.load(X_base, mask=mask_g, other=0.0).to(tl.float32) + + if HAS_Z_GATED and (not NORM_BEFORE_GATE): + Z_base = Z + rows[:, None] * stride_z_row + col_offsets + z_el = tl.load(Z_base, mask=mask_g, other=0.0).to(tl.float32) + if ACTIVATION == "swish": + x_el = x_el * (z_el * tl.sigmoid(z_el)) + elif ACTIVATION == "silu": + x_el = x_el * (z_el * tl.sigmoid(z_el)) + elif ACTIVATION == "sigmoid": + x_el = x_el * tl.sigmoid(z_el) + + x_hat = x_el * rstd[:, None] + + w_mask = col_mask + w_el = tl.load(W + cols, mask=w_mask, other=0.0).to(tl.float32) + if HAS_BIAS_GATED: + b_el = tl.load(Bptr + cols, mask=w_mask, other=0.0).to(tl.float32) + y_el = x_hat * w_el[None, :] + b_el[None, :] + else: + y_el = x_hat * w_el[None, :] + + if HAS_Z_GATED and NORM_BEFORE_GATE: + Z_base = Z + rows[:, None] * stride_z_row + col_offsets + z_el = tl.load(Z_base, mask=mask_g, other=0.0).to(tl.float32) + if ACTIVATION == "swish": + y_el = y_el * (z_el * tl.sigmoid(z_el)) + elif ACTIVATION == "silu": + y_el = y_el * (z_el * tl.sigmoid(z_el)) + elif ACTIVATION == "sigmoid": + y_el = y_el * tl.sigmoid(z_el) + + abs_y = tl.where(mask_g, tl.abs(y_el), 0.0) + absmax = tl.max(abs_y, axis=1) + scales_raw = absmax / FP8_MAX + if USE_UE8M0: + scales_raw = tl.exp2(tl.ceil(tl.log2(scales_raw))) + scales = tl.maximum(scales_raw, FP8_MIN_SCALING_FACTOR) + + y_scaled = y_el / scales[:, None] + y_q = tl.maximum(tl.minimum(y_scaled, FP8_MAX), FP8_MIN) + + Y_base = Y_quant + rows[:, None] * stride_y_row + col_offsets + tl.store(Y_base, y_q.to(Y_quant.dtype.element_ty), mask=mask_g) + + S_ptr = Scales + rows * stride_s_row + g * stride_s_g + tl.store(S_ptr, scales, mask=row_mask_1d) + else: + m_pid = tl.program_id(0) + n_offs = tl.arange(0, BLOCK_SIZE_N) + NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // QUANT_BLOCK_SIZE - mask1 = n_offs < inp1_n_cols - inp1 = tl.load( - inp1_ptr + m_pid * inp1_row_stride + n_offs * inp1_col_stride, - mask=mask1, - other=0.0, - cache_modifier=".cg", - ).to(tl.float32) - if FIRST_INPUT_RES: - res1 = tl.load( - res1_ptr + m_pid * res1_row_stride + n_offs * res1_col_stride, + mask1 = n_offs < inp1_n_cols + inp1 = tl.load( + inp1_ptr + m_pid * inp1_row_stride + n_offs * inp1_col_stride, mask=mask1, other=0.0, cache_modifier=".cg", ).to(tl.float32) - inp1 = inp1 + res1 - - w1 = tl.load(weight1_ptr + n_offs, mask=mask1, other=0.0).to(tl.float32) + if FIRST_INPUT_RES: + res1 = tl.load( + res1_ptr + m_pid * res1_row_stride + n_offs * res1_col_stride, + mask=mask1, + other=0.0, + cache_modifier=".cg", + ).to(tl.float32) + inp1 = inp1 + res1 - norm1 = _rmsmorm_op(inp1, w1, inp1_n_cols, eps1) + w1 = tl.load(weight1_ptr + n_offs, mask=mask1, other=0.0).to(tl.float32) - if FIRST_INPUT_OUT: - mask1 = n_offs < inp1_n_cols - tl.store( - out1_ptr + m_pid * out1_row_stride + n_offs * out1_col_stride, - norm1, - mask=mask1, - ) + norm1 = _rmsmorm_op(inp1, w1, inp1_n_cols, eps1) - out1_fp8, out1_block_scales = _fp8_quant_op( - norm1, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN - ) - out1_fp8 = tl.ravel(out1_fp8) - out1_block_scales = tl.ravel(out1_block_scales) + if FIRST_INPUT_OUT: + mask1 = n_offs < inp1_n_cols + tl.store( + out1_ptr + m_pid * out1_row_stride + n_offs * out1_col_stride, + norm1, + mask=mask1, + ) - # store the results - tl.store( - out1_fp8_ptr + m_pid * out1_fp8_row_stride + n_offs * out1_fp8_col_stride, - out1_fp8.to(out1_fp8_ptr.dtype.element_ty), - mask=mask1, - ) - g_offs = tl.arange(0, NUM_QUANT_BLOCKS) - num_bs_cols = (inp1_n_cols + QUANT_BLOCK_SIZE - 1) // QUANT_BLOCK_SIZE - tl.store( - out1_bs_ptr + m_pid * out1_bs_row_stride + g_offs * out1_bs_col_stride, - out1_block_scales.to(out1_bs_ptr.dtype.element_ty), - mask=g_offs < num_bs_cols, - ) - if HAVE_SECOND_INPUT: - mask2 = n_offs < inp2_n_cols - inp2 = tl.load( - inp2_ptr + m_pid * inp2_row_stride + n_offs * inp2_col_stride, - mask=mask2, - other=0.0, - cache_modifier=".cg", - ).to(tl.float32) - w2 = tl.load(weight2_ptr + n_offs, mask=mask2, other=0.0).to(tl.float32) - norm2 = _rmsmorm_op(inp2, w2, inp2_n_cols, eps2) - tl.store( - out2_ptr + m_pid * out2_row_stride + n_offs * out2_col_stride, - norm2, - mask=mask2, + out1_fp8_t, out1_block_scales = _fp8_quant_op( + norm1, 1, BLOCK_SIZE_N, QUANT_BLOCK_SIZE, DTYPE_MAX, DTYPE_MIN ) + out1_fp8_t = tl.ravel(out1_fp8_t) + out1_block_scales = tl.ravel(out1_block_scales) - if FIRST_INPUT_RES: - inp1 = inp1.to(out_res1_ptr.dtype.element_ty) tl.store( - out_res1_ptr + m_pid * out_res1_row_stride + n_offs * out_res1_col_stride, - inp1, + out1_fp8_ptr + m_pid * out1_fp8_row_stride + n_offs * out1_fp8_col_stride, + out1_fp8_t.to(out1_fp8_ptr.dtype.element_ty), mask=mask1, ) + g_offs = tl.arange(0, NUM_QUANT_BLOCKS) + num_bs_cols = (inp1_n_cols + QUANT_BLOCK_SIZE - 1) // QUANT_BLOCK_SIZE + tl.store( + out1_bs_ptr + m_pid * out1_bs_row_stride + g_offs * out1_bs_col_stride, + out1_block_scales.to(out1_bs_ptr.dtype.element_ty), + mask=g_offs < num_bs_cols, + ) + if HAVE_SECOND_INPUT: + mask2 = n_offs < inp2_n_cols + inp2 = tl.load( + inp2_ptr + m_pid * inp2_row_stride + n_offs * inp2_col_stride, + mask=mask2, + other=0.0, + cache_modifier=".cg", + ).to(tl.float32) + w2 = tl.load(weight2_ptr + n_offs, mask=mask2, other=0.0).to(tl.float32) + norm2 = _rmsmorm_op(inp2, w2, inp2_n_cols, eps2) + tl.store( + out2_ptr + m_pid * out2_row_stride + n_offs * out2_col_stride, + norm2, + mask=mask2, + ) + + if FIRST_INPUT_RES: + inp1 = inp1.to(out_res1_ptr.dtype.element_ty) + tl.store( + out_res1_ptr + + m_pid * out_res1_row_stride + + n_offs * out_res1_col_stride, + inp1, + mask=mask1, + ) @triton.jit @@ -766,127 +893,3 @@ def _fused_silu_mul_fp8_per_tensor_static_quant_kernel( quant_fp8_out.to(out_fp8_ptr.dtype.element_ty), mask=mask, ) - - -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. - - -@triton.heuristics( - { - "HAS_BIAS": lambda args: args["B"] is not None, - "HAS_Z": lambda args: args["Z"] is not None, - } -) -@triton.jit -def _fused_rms_gated_fp8_group_quant_kernel( - X, - W, - B, - Z, - Y_quant, - Scales, - stride_x_row, - stride_z_row, - stride_y_row, - stride_s_row, - stride_s_g, - M, - N: tl.constexpr, - eps, - RMS_TILE: tl.constexpr, - ROWS_PER_BLOCK: tl.constexpr, - GROUP_SIZE: tl.constexpr, - NUM_GROUPS: tl.constexpr, - BLOCK_G: tl.constexpr, - HAS_BIAS: tl.constexpr, - HAS_Z: tl.constexpr, - NORM_BEFORE_GATE: tl.constexpr, - FP8_MIN: tl.constexpr, - FP8_MAX: tl.constexpr, - USE_UE8M0: tl.constexpr, - FP8_MIN_SCALING_FACTOR: tl.constexpr, - ACTIVATION: tl.constexpr, -): - row_start = tl.program_id(0) * ROWS_PER_BLOCK - rows = row_start + tl.arange(0, ROWS_PER_BLOCK) - row_mask_1d = rows < M - - # --- Full-row RMS: accumulate sum of squares in float32 --- - sumsq = tl.zeros([ROWS_PER_BLOCK], dtype=tl.float32) - off = 0 - while off < N: - cols = tl.arange(0, RMS_TILE) + off - col_mask = cols < N - mask = row_mask_1d[:, None] & col_mask[None, :] - row_offsets = rows[:, None] * stride_x_row - col_offsets = cols[None, :] - X_base = X + row_offsets + col_offsets - x = tl.load(X_base, mask=mask, other=0.0).to(tl.float32) - if HAS_Z and not NORM_BEFORE_GATE: - Z_base = Z + rows[:, None] * stride_z_row + col_offsets - z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32) - if ACTIVATION == "swish" or ACTIVATION == "silu": - x *= z * tl.sigmoid(z) - elif ACTIVATION == "sigmoid": - x *= tl.sigmoid(z) - xbar = tl.where(mask, x, 0.0) - sumsq += tl.sum(xbar * xbar, axis=1) - off += RMS_TILE - - var = sumsq / N - rstd = tl.rsqrt(var + eps) - - # --- Per-group: normalize (when NORM_BEFORE_GATE), linear, optional gate, FP8 --- - for g in range(NUM_GROUPS): - col0 = g * GROUP_SIZE - cols = tl.arange(0, BLOCK_G) + col0 - col_mask = cols < N - mask = row_mask_1d[:, None] & col_mask[None, :] - row_offsets = rows[:, None] * stride_x_row - col_offsets = cols[None, :] - X_base = X + row_offsets + col_offsets - x = tl.load(X_base, mask=mask, other=0.0).to(tl.float32) - - if HAS_Z and not NORM_BEFORE_GATE: - Z_base = Z + rows[:, None] * stride_z_row + col_offsets - z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32) - if ACTIVATION == "swish" or ACTIVATION == "silu": - x *= z * tl.sigmoid(z) - elif ACTIVATION == "sigmoid": - x *= tl.sigmoid(z) - - x_hat = x * rstd[:, None] - - w_mask = cols < N - w = tl.load(W + cols, mask=w_mask, other=0.0).to(tl.float32) - if HAS_BIAS: - b = tl.load(B + cols, mask=w_mask, other=0.0).to(tl.float32) - y = x_hat * w[None, :] + b[None, :] - else: - y = x_hat * w[None, :] - - if HAS_Z and NORM_BEFORE_GATE: - Z_base = Z + rows[:, None] * stride_z_row + col_offsets - z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32) - if ACTIVATION == "swish" or ACTIVATION == "silu": - y *= z * tl.sigmoid(z) - elif ACTIVATION == "sigmoid": - y *= tl.sigmoid(z) - - abs_y = tl.where(mask, tl.abs(y), 0.0) - absmax = tl.max(abs_y, axis=1) - scales_raw = absmax / FP8_MAX - if USE_UE8M0: - scales_raw = tl.exp2(tl.ceil(tl.log2(scales_raw))) - scales = tl.maximum(scales_raw, FP8_MIN_SCALING_FACTOR) - - y_scaled = y / scales[:, None] - y_quant = tl.maximum(tl.minimum(y_scaled, FP8_MAX), FP8_MIN) - - Y_base = Y_quant + rows[:, None] * stride_y_row + col_offsets - tl.store(Y_base, y_quant.to(Y_quant.dtype.element_ty), mask=mask) - - S_ptr = Scales + rows * stride_s_row + g * stride_s_g - tl.store(S_ptr, scales, mask=row_mask_1d) diff --git a/aiter/ops/triton/conv/__init__.py b/aiter/ops/triton/conv/__init__.py new file mode 100644 index 0000000000..39852166ea --- /dev/null +++ b/aiter/ops/triton/conv/__init__.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +from .causal_conv1d import ( + causal_conv1d_fn, + causal_conv1d_update, + PAD_SLOT_ID, +) +from .causal_conv1d_update_single_token import ( + causal_conv1d_update_single_token, + fused_reshape_causal_conv1d_update_single_token, +) + +__all__ = [ + "PAD_SLOT_ID", + "causal_conv1d_fn", + "causal_conv1d_update", + "causal_conv1d_update_single_token", + "fused_reshape_causal_conv1d_update_single_token", +] diff --git a/aiter/ops/triton/causal_conv1d.py b/aiter/ops/triton/conv/causal_conv1d.py similarity index 99% rename from aiter/ops/triton/causal_conv1d.py rename to aiter/ops/triton/conv/causal_conv1d.py index 79973a6c38..161d71724c 100644 --- a/aiter/ops/triton/causal_conv1d.py +++ b/aiter/ops/triton/conv/causal_conv1d.py @@ -1,6 +1,6 @@ import torch import triton -from aiter.ops.triton._triton_kernels.causal_conv1d import ( +from aiter.ops.triton._triton_kernels.conv.causal_conv1d import ( _causal_conv1d_fwd_kernel, _causal_conv1d_update_kernel, PAD_SLOT_ID, diff --git a/aiter/ops/triton/causal_conv1d_update_single_token.py b/aiter/ops/triton/conv/causal_conv1d_update_single_token.py similarity index 98% rename from aiter/ops/triton/causal_conv1d_update_single_token.py rename to aiter/ops/triton/conv/causal_conv1d_update_single_token.py index 07ac393568..7d9f9b0950 100644 --- a/aiter/ops/triton/causal_conv1d_update_single_token.py +++ b/aiter/ops/triton/conv/causal_conv1d_update_single_token.py @@ -14,8 +14,8 @@ import torch import triton -from aiter.ops.triton._triton_kernels.causal_conv1d import PAD_SLOT_ID -from aiter.ops.triton._triton_kernels.causal_conv1d_update_single_token import ( +from aiter.ops.triton._triton_kernels.conv.causal_conv1d import PAD_SLOT_ID +from aiter.ops.triton._triton_kernels.conv.causal_conv1d_update_single_token import ( _causal_conv1d_update_single_token_kernel, _reshape_causal_conv1d_update_single_token_kernel, ) diff --git a/aiter/ops/triton/quant/fused_fp8_quant.py b/aiter/ops/triton/quant/fused_fp8_quant.py index f0583a86b3..42ace1dc6c 100644 --- a/aiter/ops/triton/quant/fused_fp8_quant.py +++ b/aiter/ops/triton/quant/fused_fp8_quant.py @@ -3,16 +3,15 @@ import torch import triton import aiter +from aiter.ops.triton.utils.types import get_fp8_e4m3_dtype from aiter.ops.triton._triton_kernels.quant.fused_fp8_quant import ( _fused_rms_fp8_per_tensor_static_quant_kernel, _fused_rms_fp8_group_quant_kernel, - _fused_rms_gated_fp8_group_quant_kernel, _fused_flatten_fp8_group_quant_kernel, _fused_reduce_act_mul_fp8_group_quant, _fused_reduce_rms_fp8_group_quant_kernel, _fused_silu_mul_fp8_per_tensor_static_quant_kernel, ) -from aiter.ops.triton.utils.types import get_fp8_e4m3_dtype from aiter.ops.triton._triton_kernels.activation import ( _get_activation_from_str, ) @@ -315,6 +314,9 @@ def fused_rms_fp8_group_quant( out_res1_col_stride, out1_row_stride, out1_col_stride, + inp1, + inp1_weight, + inp1.stride(0), BLOCK_SIZE_N=BLOCK_SIZE_N, QUANT_BLOCK_SIZE=group_size, DTYPE_MAX=DTYPE_MAX, @@ -322,6 +324,20 @@ def fused_rms_fp8_group_quant( HAVE_SECOND_INPUT=(inp2 is not None), FIRST_INPUT_RES=(res1 is not None), FIRST_INPUT_OUT=output_unquantized_inp1, + GATED_RMS_FP8=False, + RMS_TILE=512, + ROWS_PER_BLOCK=1, + GROUP_SIZE_GATED=1, + NUM_GROUPS_GATED=1, + BLOCK_G=1, + HAS_BIAS_GATED=False, + HAS_Z_GATED=False, + NORM_BEFORE_GATE=False, + FP8_MIN=-DTYPE_MAX, + FP8_MAX=DTYPE_MAX, + USE_UE8M0=False, + FP8_MIN_SCALING_FACTOR=1.0, + ACTIVATION="silu", num_warps=num_warps, ) # When transpose_scale=True, view the transposed buffer back to original shape @@ -334,7 +350,8 @@ def fused_rms_fp8_group_quant( def get_fp8_min_max_bounds(fp8_dtype: torch.dtype) -> tuple[float, float]: """Match vLLM ``quant_utils.get_fp8_min_max`` for ``fp8_dtype`` (incl. ROCm fnuz ±224).""" - if fp8_dtype == torch.float8_e4m3fnuz: + fnuz = getattr(torch, "float8_e4m3fnuz", None) + if fnuz is not None and fp8_dtype == fnuz: return -224.0, 224.0 finfo = torch.finfo(fp8_dtype) return float(finfo.min), float(finfo.max) @@ -342,20 +359,20 @@ def get_fp8_min_max_bounds(fp8_dtype: torch.dtype) -> tuple[float, float]: @cache def _num_compute_units(device_id: int = 0) -> int: - """Match vLLM ``vllm.utils.platform_utils.num_compute_units`` (``current_platform.num_compute_units``).""" - return torch.cuda.get_device_properties(device_id).multi_processor_count + """Approximate vLLM ``num_compute_units`` for heuristic tuning.""" + return int(torch.cuda.get_device_properties(device_id).multi_processor_count) def calc_rows_per_block(M: int, device: torch.device) -> int: - """Same heuristic as vLLM ``input_quant_fp8.calc_rows_per_block``.""" + """Heuristic from vLLM ``input_quant_fp8.calc_rows_per_block`` (gated RMSNorm+FP8 launch).""" if device.type != "cuda": raise ValueError( - "fused_rms_gated_fp8_group_quant targets AMD ROCm (HIP); expected a CUDA/HIP device." + "calc_rows_per_block targets CUDA/HIP; expected a CUDA/HIP device." ) device_id = ( device.index if device.index is not None else torch.cuda.current_device() ) - sm_count = max(int(_num_compute_units(device_id)), 1) + sm_count = max(_num_compute_units(device_id), 1) rows_per_block = triton.next_power_of_2(triton.cdiv(M, 2 * sm_count)) return min(int(rows_per_block), 4) @@ -376,42 +393,10 @@ def fused_rms_gated_fp8_group_quant( fp8_min_scaling_factor: float | None = None, group_size: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Fused RMSNorm (with optional bias), optional multiplicative gate from ``z``, - and FP8 quantization (same contract as vLLM ``_rmsnorm_quantize_group_native`` for - ``group_size == N``). - - Comparison with ``fused_rms_fp8_group_quant``: - Use ``fused_rms_fp8_group_quant`` when you need optional **two-stream** RMSNorm - (``inp1`` / optional ``inp2`` with separate weights and epsilons), optional - **residual** fused into ``inp1`` (``res1``), FP8 group quantization on the **first** - normalized stream only, the richer return tuple (quantized FP8, block scales, - optional unquantized ``inp1``, second RMS output, residual output), and optional - ``transpose_scale`` layout for scales. - - Use **this** function for **single** hidden ``x``, one RMS **weight** (and optional - **bias**), plus ``z`` for **elementwise multiplicative gating** (SiLU / sigmoid-style - activations on ``z``) matching ``x``'s shape; optional ``norm_before_gate`` ordering; - vLLM-aligned FP8 bounds / optional UE8M0 / ``group_size`` (``None`` = one scale per - row, else per-column-group scales). Returns only ``(x_quant_fp8, scales)``. Suited to - gated RMSNorm input quantization (e.g. SwiGLU-style / vLLM - ``_rmsnorm_quantize_group_native`` contracts), not the two-stream + residual pattern - above. - - ``x`` and ``z`` must be 2D contiguous with identical shape ``(M, N)``. - Returns ``(x_quant_fp8, scales)`` where ``scales`` is ``(M,)`` float32 if - ``group_size`` is ``None`` (one scale per row), or ``(M, N // group_size)`` float32 - when ``group_size`` divides ``N`` (one scale per row per column group). - - ``fp8_min`` / ``fp8_max`` / ``fp8_min_scaling_factor`` default from ``out_dtype`` (or - ``get_fp8_e4m3_dtype()``) using the same rules as vLLM ``get_fp8_min_max`` and - ``1.0 / (_FP8_MAX * 512)``. Pass them explicitly when you want to pin values (e.g. from - vLLM's ``get_fp8_min_max()`` at model init). - - Raises: - ValueError: if ``group_size`` is not ``None`` and ``group_size > N``, - ``group_size <= 0``, or ``N`` is not divisible by ``group_size``. - """ + """Gated RMSNorm + FP8 quant; launches ``_fused_rms_fp8_group_quant_kernel`` with ``GATED_RMS_FP8=True``. + + Uses ``calc_rows_per_block`` and grid ``(cdiv(M, rows_per_block),)`` like the legacy gated-only kernel, + independent of the non-gated path (which stays at grid ``(M,)``).""" assert x.is_contiguous() and z.is_contiguous() assert x.shape == z.shape, "x and z must have the same shape" fp8_dtype = out_dtype if out_dtype is not None else get_fp8_e4m3_dtype() @@ -449,7 +434,6 @@ def fused_rms_gated_fp8_group_quant( rms_tile = min(512, triton.next_power_of_2(N)) block_g = triton.next_power_of_2(effective_gs) - rows_per_block = calc_rows_per_block(M, x.device) num_warps = min(max(block_g // 256, 1), 8) x_quant = torch.empty(M, N, dtype=fp8_dtype, device=x.device) @@ -461,34 +445,71 @@ def fused_rms_gated_fp8_group_quant( scales = torch.empty(M, num_groups, dtype=torch.float32, device=x.device) stride_s_row, stride_s_g = (int(scales.stride(0)), int(scales.stride(1))) + bias_ptr = bias if bias is not None else weight + + dummy = torch.empty(1, dtype=x.dtype, device=x.device) + + rows_per_block = calc_rows_per_block(M, x.device) grid = (triton.cdiv(M, rows_per_block),) - _fused_rms_gated_fp8_group_quant_kernel[grid]( + BLOCK_SIZE_PAD = max(triton.next_power_of_2(N), effective_gs) + + _fused_rms_fp8_group_quant_kernel[grid]( x, weight, - bias, - z, + dummy, + dummy, + dummy, x_quant, scales, + dummy, + dummy, + dummy, + eps, + 0.0, + M, + N, + 0, x.stride(0), - z.stride(0), + 1, + x.stride(1), + 1, + 1, + 1, x_quant.stride(0), + x_quant.stride(1), stride_s_row, stride_s_g, - M, - N, - eps, + 1, + 1, + 1, + 1, + 1, + 1, + z, + bias_ptr, + z.stride(0), + BLOCK_SIZE_N=BLOCK_SIZE_PAD, + QUANT_BLOCK_SIZE=effective_gs, + DTYPE_MAX=fp8_max, + DTYPE_MIN=-fp8_max, + HAVE_SECOND_INPUT=False, + FIRST_INPUT_RES=False, + FIRST_INPUT_OUT=False, + GATED_RMS_FP8=True, RMS_TILE=rms_tile, ROWS_PER_BLOCK=rows_per_block, - GROUP_SIZE=effective_gs, - NUM_GROUPS=num_groups, + GROUP_SIZE_GATED=effective_gs, + NUM_GROUPS_GATED=num_groups, BLOCK_G=block_g, + HAS_BIAS_GATED=(bias is not None), + HAS_Z_GATED=True, NORM_BEFORE_GATE=norm_before_gate, FP8_MIN=fp8_min, FP8_MAX=fp8_max, USE_UE8M0=use_ue8m0, FP8_MIN_SCALING_FACTOR=fp8_min_scaling_factor, - num_warps=num_warps, ACTIVATION=activation, + num_warps=num_warps, ) return x_quant, scales diff --git a/op_tests/triton_tests/test_causal_conv1d.py b/op_tests/triton_tests/conv/test_causal_conv1d.py similarity index 98% rename from op_tests/triton_tests/test_causal_conv1d.py rename to op_tests/triton_tests/conv/test_causal_conv1d.py index 4d940520e8..bd07dd6682 100644 --- a/op_tests/triton_tests/test_causal_conv1d.py +++ b/op_tests/triton_tests/conv/test_causal_conv1d.py @@ -5,8 +5,11 @@ from einops import rearrange import numpy as np -from aiter.ops.triton.causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from aiter.ops.triton._triton_kernels.causal_conv1d import PAD_SLOT_ID +from aiter.ops.triton.causal_conv1d import ( + PAD_SLOT_ID, + causal_conv1d_fn, + causal_conv1d_update, +) def seed_everything(seed: int = 0) -> None: diff --git a/op_tests/triton_tests/test_causal_conv1d_update_single_token.py b/op_tests/triton_tests/conv/test_causal_conv1d_update_single_token.py similarity index 99% rename from op_tests/triton_tests/test_causal_conv1d_update_single_token.py rename to op_tests/triton_tests/conv/test_causal_conv1d_update_single_token.py index ad59e58427..2f151110a6 100644 --- a/op_tests/triton_tests/test_causal_conv1d_update_single_token.py +++ b/op_tests/triton_tests/conv/test_causal_conv1d_update_single_token.py @@ -19,7 +19,7 @@ import torch import triton -from aiter.ops.triton._triton_kernels.causal_conv1d import PAD_SLOT_ID +from aiter.ops.triton.causal_conv1d import PAD_SLOT_ID from aiter.ops.triton.causal_conv1d_update_single_token import ( causal_conv1d_update_single_token, fused_reshape_causal_conv1d_update_single_token, diff --git a/op_tests/triton_tests/quant/test_fused_rms_gated_fp8_group_quant.py b/op_tests/triton_tests/quant/test_fused_rms_gated_fp8_group_quant.py index 0653eccc01..b9d053e853 100644 --- a/op_tests/triton_tests/quant/test_fused_rms_gated_fp8_group_quant.py +++ b/op_tests/triton_tests/quant/test_fused_rms_gated_fp8_group_quant.py @@ -7,7 +7,7 @@ import pytest import torch -from aiter.ops.triton.quant.fused_fp8_quant import ( +from aiter.ops.triton.quant import ( fused_rms_gated_fp8_group_quant, get_fp8_min_max_bounds, )