From 0040d23e6a5d7b1fc3c2cc1a1dc4a159a2e525e1 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Wed, 1 Apr 2026 18:35:50 -0700 Subject: [PATCH 01/17] native kernel --- csrc/flashinfer_rmsnorm_silu_binding.cu | 21 + csrc/rmsnorm_silu.cu | 115 + flashinfer/__init__.py | 1 + flashinfer/aot.py | 20 + flashinfer/jit/rmsnorm_silu.py | 342 +++ flashinfer/norm/__init__.py | 138 ++ .../flashinfer/norm/ln_fwd_silu_kernel.cuh | 431 ++++ include/flashinfer/norm/ln_silu_headers.cuh | 1919 +++++++++++++++++ .../norm/sm100_rms_norm_silu_knobs.h | 216 ++ tests/norm/test_fused_rmsnorm_silu.py | 493 +++++ 10 files changed, 3696 insertions(+) create mode 100644 csrc/flashinfer_rmsnorm_silu_binding.cu create mode 100644 csrc/rmsnorm_silu.cu create mode 100644 flashinfer/jit/rmsnorm_silu.py create mode 100644 include/flashinfer/norm/ln_fwd_silu_kernel.cuh create mode 100644 include/flashinfer/norm/ln_silu_headers.cuh create mode 100644 include/flashinfer/norm/sm100_rms_norm_silu_knobs.h create mode 100644 tests/norm/test_fused_rmsnorm_silu.py diff --git a/csrc/flashinfer_rmsnorm_silu_binding.cu b/csrc/flashinfer_rmsnorm_silu_binding.cu new file mode 100644 index 0000000000..4f3fea6c76 --- /dev/null +++ b/csrc/flashinfer_rmsnorm_silu_binding.cu @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tvm_ffi_utils.h" + +void rmsnorm_silu(TensorView output, TensorView input, TensorView weight, double eps, + TensorView workspace, int64_t sm_count); + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(rmsnorm_silu, rmsnorm_silu); diff --git a/csrc/rmsnorm_silu.cu b/csrc/rmsnorm_silu.cu new file mode 100644 index 0000000000..5629de80aa --- /dev/null +++ b/csrc/rmsnorm_silu.cu @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Include order matters: headers → config (defines Ktraits) → kernel (uses Ktraits) +#include +#include +#include + +#include "rmsnorm_silu_config.inc" +#include "tvm_ffi_utils.h" + +void rmsnorm_silu(TensorView output, TensorView input, TensorView weight, double eps, + TensorView workspace, int64_t sm_count) { + CHECK_LAST_DIM_CONTIGUOUS_INPUT(input); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(output); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight); + CHECK_DEVICE(input, weight); + CHECK_DIM(2, input); + CHECK_DIM(2, output); + CHECK_DIM(1, weight); + + int rows = input.size(0); + int cols = input.size(1); + TVM_FFI_ICHECK_EQ(cols, HIDDEN_SIZE) << "Input cols must match compiled HIDDEN_SIZE"; + TVM_FFI_ICHECK_EQ(output.size(0), rows); + + ffi::CUDADeviceGuard device_guard(input.device().device_id); + const cudaStream_t stream = get_stream(input.device()); + + // Grid dimensions (same logic as Sm100RmsNormSiluEngine::execute) + int ctas_per_col_max = (rows + WARPS_M - 1) / WARPS_M; + int ctas_per_col; + if (KERNEL_CFG == 2) { + ctas_per_col = ctas_per_col_max; + } else { + ctas_per_col = + std::min(static_cast(sm_count) * DESIRED_OCCUPANCY / CTAS_PER_ROW, ctas_per_col_max); + } + ctas_per_col = std::max(ctas_per_col, 1); + + dim3 grid(CTAS_PER_ROW * ctas_per_col); + dim3 block(WARPS_M * WARPS_N * 32); + + // Pack kernel params + PersistentLnFwdParams params{}; + params.rows = rows; + params.cols = cols; + params.ctas_per_col = ctas_per_col; + params.isRMSNorm = true; + params.noScale = false; + params.noBias = true; + params.isBatchFirst = true; + params.batchSize = 1; + params.seqLen = rows; + params.epsilon = static_cast(eps); + params.x = input.data_ptr(); + params.z = output.data_ptr(); + params.gamma = weight.data_ptr(); + + // Workspace layout (128-byte aligned segments) + char* ws_ptr = static_cast(workspace.data_ptr()); + + // [0] rs: rows * sizeof(float) + params.rs = ws_ptr; + int64_t off = static_cast(rows) * sizeof(float); + off = ((off + 127) / 128) * 128; + + // [aligned] fp8_scale: sizeof(float) + if (isFP8Out) { + params.fp8_out = true; + float* default_scale = reinterpret_cast(ws_ptr + off); + // Set scale = 1.0f via cudaMemcpyAsync from host + static const float one = 1.0f; + cudaMemcpyAsync(default_scale, &one, sizeof(float), cudaMemcpyHostToDevice, stream); + params.scale = default_scale; + } + off += sizeof(float); + off = ((off + 127) / 128) * 128; + + // [aligned] scale_row: rows * ceil(C/16) bytes (NVFP4 only) + if (isFP4Out) { + params.scale_row = ws_ptr + off; + off += static_cast(rows) * ((cols + 15) / 16); + off = ((off + 127) / 128) * 128; + } + + // [aligned] cooperative workspace + barriers (multi-CTA only) + if (CTAS_PER_ROW > 1) { + params.workspace = ws_ptr + off; + int64_t coop_ws_size = + static_cast(ctas_per_col) * WARPS_M * CTAS_PER_ROW * sizeof(float) * 2 * 2; + off += coop_ws_size; + off = ((off + 127) / 128) * 128; + + params.barrier = reinterpret_cast(ws_ptr + off); + cudaMemsetAsync(params.barrier, 0, 2 * ctas_per_col * sizeof(int32_t), stream); + } + + reduced_divisor divisor(rows); + + ln_fwd_kernel<<>>(params, divisor); +} diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index 079dcb0c23..8ced5c509a 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -114,6 +114,7 @@ from .norm import gemma_rmsnorm as gemma_rmsnorm from .norm import rmsnorm as rmsnorm from .norm import rmsnorm_quant as rmsnorm_quant +from .norm import fused_rmsnorm_silu as fused_rmsnorm_silu try: from .norm import rmsnorm_fp4quant as rmsnorm_fp4quant diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 9909befbbd..e83997876f 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -87,6 +87,13 @@ ) from .jit.mla import gen_mla_module from .jit.norm import gen_norm_module +from .jit.rmsnorm_silu import ( + gen_rmsnorm_silu_module, + select_knobs, + _estimate_ctas_per_row, + _SUPPORTED_C, + _SUPPORTED_TOKENS, +) from .jit.page import gen_page_module from .jit.quantization import gen_quantization_module from .jit.rope import gen_rope_module @@ -558,6 +565,19 @@ def gen_all_modules( gen_sampling_module(), gen_topk_module(), ] + # Fused RMSNorm+SiLU: pre-compile all LUT configs (SM100+ only) + if has_sm100: + for C in _SUPPORTED_C: + for tokens in _SUPPORTED_TOKENS: + for dtype in ["bf16", "fp8", "nvfp4"]: + knobs = select_knobs(C, tokens, dtype) + if knobs is None: + continue + wm, sc, kcfg, occ, bpl = knobs + cpr = _estimate_ctas_per_row(C, sc, kcfg, bpl) + jit_specs.append( + gen_rmsnorm_silu_module(C, dtype, wm, cpr, bpl, kcfg, occ) + ) # selective_state_update: one module per dtype combo per GPU arch _ssu_dtype_combos = [ # (state, input, weight, matrixA, stateIndex, state_scale_dtype) diff --git a/flashinfer/jit/rmsnorm_silu.py b/flashinfer/jit/rmsnorm_silu.py new file mode 100644 index 0000000000..292873b072 --- /dev/null +++ b/flashinfer/jit/rmsnorm_silu.py @@ -0,0 +1,342 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import shutil + +from . import env as jit_env +from .core import JitSpec, gen_jit_spec +from .utils import write_if_different + + +# Knob LUT ported from sm100_rms_norm_silu_knobs.h +# Format: (warps_m, split_cols, kernel_cfg, occupancy, bytes_per_ldg) +_KNOB_LUT = { + # C=64 + (64, 1560, "bf16"): (8, 0, 0, 2, 4), + (64, 6240, "bf16"): (32, 4, 0, 2, 4), + (64, 24960, "bf16"): (32, 4, 0, 2, 4), + (64, 99840, "bf16"): (8, 0, 1, 8, 4), + (64, 399360, "bf16"): (4, 0, 1, 16, 4), + (64, 1560, "fp8"): (8, 4, 0, 6, 4), + (64, 6240, "fp8"): (8, 0, 0, 3, 2), + (64, 24960, "fp8"): (8, 0, 0, 7, 4), + (64, 99840, "fp8"): (8, 0, 1, 6, 2), + (64, 399360, "fp8"): (32, 0, 1, 2, 2), + (64, 1560, "nvfp4"): (8, 0, 2, 1, 4), + (64, 6240, "nvfp4"): (8, 4, 0, 4, 4), + (64, 24960, "nvfp4"): (8, 0, 1, 6, 4), + (64, 99840, "nvfp4"): (32, 4, 1, 2, 4), + (64, 399360, "nvfp4"): (32, 4, 1, 2, 4), + # C=128 + (128, 1560, "bf16"): (8, 4, 0, 3, 4), + (128, 6240, "bf16"): (8, 0, 0, 3, 4), + (128, 24960, "bf16"): (8, 0, 0, 6, 4), + (128, 99840, "bf16"): (32, 4, 0, 2, 4), + (128, 399360, "bf16"): (8, 0, 0, 8, 4), + (128, 1560, "fp8"): (8, 0, 0, 3, 4), + (128, 6240, "fp8"): (8, 0, 0, 4, 8), + (128, 24960, "fp8"): (8, 0, 0, 8, 8), + (128, 99840, "fp8"): (32, 0, 0, 2, 8), + (128, 399360, "fp8"): (32, 0, 0, 2, 8), + (128, 1560, "nvfp4"): (8, 4, 0, 3, 8), + (128, 6240, "nvfp4"): (8, 0, 0, 5, 8), + (128, 24960, "nvfp4"): (8, 0, 1, 8, 8), + (128, 99840, "nvfp4"): (32, 0, 1, 2, 8), + (128, 399360, "nvfp4"): (32, 0, 1, 2, 8), + # C=160 + (160, 1560, "bf16"): (8, 0, 0, 4, 2), + (160, 6240, "bf16"): (8, 0, 0, 4, 2), + (160, 24960, "bf16"): (8, 4, 1, 6, 2), + (160, 99840, "bf16"): (32, 4, 1, 2, 2), + (160, 399360, "bf16"): (32, 4, 1, 2, 2), + (160, 1560, "fp8"): (8, 0, 0, 2, 2), + (160, 6240, "fp8"): (8, 0, 0, 4, 2), + (160, 24960, "fp8"): (8, 4, 0, 6, 2), + (160, 99840, "fp8"): (32, 4, 1, 2, 2), + (160, 399360, "fp8"): (32, 4, 1, 2, 2), + (160, 1560, "nvfp4"): (4, 4, 0, 4, 2), + (160, 6240, "nvfp4"): (8, 0, 1, 4, 2), + (160, 24960, "nvfp4"): (8, 4, 1, 8, 2), + (160, 99840, "nvfp4"): (32, 4, 0, 1, 2), + (160, 399360, "nvfp4"): (32, 0, 1, 2, 2), + # C=256 + (256, 1560, "bf16"): (8, 0, 0, 6, 16), + (256, 6240, "bf16"): (8, 0, 0, 4, 4), + (256, 24960, "bf16"): (8, 0, 0, 8, 16), + (256, 99840, "bf16"): (4, 4, 0, 16, 16), + (256, 399360, "bf16"): (4, 0, 0, 16, 16), + (256, 1560, "fp8"): (8, 4, 0, 2, 4), + (256, 6240, "fp8"): (8, 0, 0, 4, 4), + (256, 24960, "fp8"): (8, 4, 0, 8, 16), + (256, 99840, "fp8"): (4, 0, 0, 16, 16), + (256, 399360, "fp8"): (32, 0, 0, 2, 16), + (256, 1560, "nvfp4"): (8, 0, 2, 1, 16), + (256, 6240, "nvfp4"): (8, 0, 2, 1, 16), + (256, 24960, "nvfp4"): (8, 4, 1, 6, 16), + (256, 99840, "nvfp4"): (32, 0, 1, 1, 16), + (256, 399360, "nvfp4"): (32, 0, 1, 2, 16), + # C=320 + (320, 1560, "bf16"): (8, 4, 1, 4, 4), + (320, 6240, "bf16"): (8, 4, 0, 5, 4), + (320, 24960, "bf16"): (8, 0, 0, 5, 4), + (320, 99840, "bf16"): (4, 0, 1, 16, 4), + (320, 399360, "bf16"): (32, 4, 0, 2, 4), + (320, 1560, "fp8"): (8, 0, 0, 2, 4), + (320, 6240, "fp8"): (8, 0, 0, 5, 4), + (320, 24960, "fp8"): (8, 0, 0, 5, 4), + (320, 99840, "fp8"): (32, 0, 1, 2, 4), + (320, 399360, "fp8"): (32, 0, 1, 2, 4), + (320, 1560, "nvfp4"): (4, 4, 0, 9, 4), + (320, 6240, "nvfp4"): (4, 0, 0, 9, 4), + (320, 24960, "nvfp4"): (8, 0, 1, 8, 4), + (320, 99840, "nvfp4"): (32, 4, 1, 2, 4), + (320, 399360, "nvfp4"): (32, 4, 1, 2, 4), + # C=512 + (512, 1560, "bf16"): (8, 0, 0, 2, 16), + (512, 6240, "bf16"): (8, 0, 0, 5, 16), + (512, 24960, "bf16"): (4, 0, 0, 8, 16), + (512, 99840, "bf16"): (4, 0, 2, 1, 8), + (512, 399360, "bf16"): (4, 0, 2, 1, 4), + (512, 1560, "fp8"): (8, 0, 0, 2, 8), + (512, 6240, "fp8"): (8, 0, 0, 4, 8), + (512, 24960, "fp8"): (4, 0, 0, 9, 8), + (512, 99840, "fp8"): (32, 4, 1, 2, 8), + (512, 399360, "fp8"): (32, 4, 1, 2, 8), + (512, 1560, "nvfp4"): (4, 4, 0, 3, 16), + (512, 6240, "nvfp4"): (4, 0, 0, 9, 16), + (512, 24960, "nvfp4"): (4, 0, 2, 1, 16), + (512, 99840, "nvfp4"): (32, 4, 0, 1, 16), + (512, 399360, "nvfp4"): (32, 0, 0, 1, 16), + # C=640 + (640, 1560, "bf16"): (4, 0, 0, 4, 4), + (640, 6240, "bf16"): (4, 0, 0, 5, 4), + (640, 24960, "bf16"): (4, 0, 0, 5, 4), + (640, 99840, "bf16"): (4, 0, 2, 1, 8), + (640, 399360, "bf16"): (4, 0, 2, 1, 8), + (640, 1560, "fp8"): (4, 0, 0, 3, 8), + (640, 6240, "fp8"): (8, 0, 0, 4, 8), + (640, 24960, "fp8"): (8, 0, 0, 4, 8), + (640, 99840, "fp8"): (4, 4, 0, 9, 8), + (640, 399360, "fp8"): (32, 4, 1, 2, 8), + (640, 1560, "nvfp4"): (4, 4, 0, 5, 8), + (640, 6240, "nvfp4"): (4, 0, 1, 9, 8), + (640, 24960, "nvfp4"): (4, 0, 2, 1, 8), + (640, 99840, "nvfp4"): (32, 0, 1, 1, 8), + (640, 399360, "nvfp4"): (32, 4, 1, 1, 8), + # C=1024 + (1024, 1560, "bf16"): (4, 4, 0, 3, 16), + (1024, 6240, "bf16"): (4, 0, 0, 5, 16), + (1024, 24960, "bf16"): (4, 4, 1, 10, 16), + (1024, 99840, "bf16"): (8, 0, 2, 1, 16), + (1024, 399360, "bf16"): (8, 0, 2, 1, 16), + (1024, 1560, "fp8"): (4, 0, 0, 3, 4), + (1024, 6240, "fp8"): (4, 0, 0, 5, 8), + (1024, 24960, "fp8"): (1, 4, 0, 16, 8), + (1024, 99840, "fp8"): (4, 0, 1, 9, 8), + (1024, 399360, "fp8"): (32, 4, 1, 1, 8), + (1024, 1560, "nvfp4"): (4, 4, 0, 7, 16), + (1024, 6240, "nvfp4"): (4, 0, 2, 1, 16), + (1024, 24960, "nvfp4"): (4, 0, 2, 1, 16), + (1024, 99840, "nvfp4"): (32, 0, 1, 1, 16), + (1024, 399360, "nvfp4"): (32, 4, 1, 1, 16), +} + +_SUPPORTED_C = [64, 128, 160, 256, 320, 512, 640, 1024] +_SUPPORTED_TOKENS = [1560, 6240, 24960, 99840, 399360] + + +def _compute_default_knobs(C: int, dtype: str): + """Conservative fallback knobs for non-LUT sizes.""" + input_size = 2 # bf16 + warps_m = 32 if dtype == "nvfp4" else 1 + warps_n = 1 + cpr = 1 + + for bpl in [4, 8, 16, 2]: + num_elts = bpl // input_size + if num_elts <= 0 or C % num_elts != 0: + continue + vec_cols = C // num_elts + vec_cols_per_ldg = cpr * warps_n * 32 + if vec_cols_per_ldg <= 0 or vec_cols % vec_cols_per_ldg != 0: + continue + ldgs = vec_cols // vec_cols_per_ldg + if ldgs > 1024: + continue + return (warps_m, 0, 0, 1, bpl) + + return None + + +def select_knobs(C: int, num_tokens: int, dtype: str): + """Select knobs from LUT or fallback heuristic. Returns (warps_m, split_cols, kernel_cfg, occupancy, bytes_per_ldg).""" + key = (C, num_tokens, dtype) + if key in _KNOB_LUT: + return _KNOB_LUT[key] + return _compute_default_knobs(C, dtype) + + +def _estimate_ctas_per_row( + C: int, split_cols: int, kernel_cfg: int, bytes_per_ldg: int, warps_n: int = 1 +) -> int: + """Estimate CTAS_PER_ROW from knobs (matches cuDNN's estimate_ctas_per_row).""" + if split_cols != 4 or kernel_cfg == 2: + return 1 + input_size = 2 # bf16 + num_elts = bytes_per_ldg // input_size + elts_per_ldg = num_elts * warps_n * 32 + if elts_per_ldg <= 0 or C % elts_per_ldg != 0: + return 1 + ldgs_per_row = C // elts_per_ldg + ldgs_to_cause_register_spill = 64 // num_elts if num_elts > 0 else 1 + ctas_per_row = 1 + for ldgs in range(min(ldgs_per_row, ldgs_to_cause_register_spill - 1), 0, -1): + if ldgs_per_row % ldgs == 0: + ctas_per_row = ldgs_per_row // ldgs + break + return ctas_per_row + + +def _generate_config( + C: int, + output_dtype: str, + warps_m: int, + ctas_per_row: int, + bytes_per_ldg: int, + kernel_cfg: int, + occupancy: int, +) -> str: + """Generate the constexpr config .inc file content.""" + lines = [ + "// Auto-generated RmsNorm+SiLU kernel config. Do not edit.", + "", + "using ITYPE = nv_bfloat16;", + ] + + is_fp8 = output_dtype == "fp8" + is_nvfp4 = output_dtype == "nvfp4" + + if output_dtype == "bf16": + lines.append("using OTYPE = nv_bfloat16;") + lines.append("using NORM_OTYPE = nv_bfloat16;") + elif output_dtype == "fp8": + lines.append("using OTYPE = nv_fp8_e4m3;") + lines.append("using NORM_OTYPE = float;") + elif output_dtype == "nvfp4": + lines.append("using OTYPE = nv_fp4_e2m1;") + lines.append("using NORM_OTYPE = float;") + + lines += [ + "using WTYPE = nv_bfloat16;", + "using CTYPE = float;", + "", + f"constexpr int HIDDEN_SIZE = {C};", + "constexpr int BATCH_SIZE = 1;", + f"constexpr int CTAS_PER_ROW = {ctas_per_row};", + f"constexpr int WARPS_M = {warps_m};", + "constexpr int WARPS_N = 1;", + f"constexpr int BYTES_PER_LDG = {bytes_per_ldg};", + f"constexpr int KERNEL_CFG = {kernel_cfg};", + "constexpr bool isRMSNorm = true;", + "constexpr bool isAdaLN = false;", + "constexpr bool isBatchFirst = true;", + "constexpr bool hasGamma = true;", + "constexpr bool hasBeta = false;", + "constexpr bool isZeroCenteredGamma = false;", + "constexpr bool isZeroCenteredGammaCastBeforeAdd = false;", + ] + + use_smem_gamma = kernel_cfg == 1 + use_non_persistent = kernel_cfg == 2 + lines += [ + f"constexpr bool useSmemGamma = {'true' if use_smem_gamma else 'false'};", + f"constexpr bool GAMMA_ON_DEMAND = {'true' if (not use_smem_gamma and use_non_persistent) else 'false'};", + f"constexpr bool isFP8Out = {'true' if is_fp8 else 'false'};", + "constexpr bool hasScaleInv = false;", + "constexpr bool hasAmax = false;", + "#define LN_USE_CLUSTER 0", + "constexpr bool USE_CLUSTER = false;", + f"constexpr bool isBlockScaleOut = {'true' if is_nvfp4 else 'false'};", + f"constexpr bool isFP4Out = {'true' if is_nvfp4 else 'false'};", + f"constexpr bool isBlockScale_1D1X1X = {'true' if is_nvfp4 else 'false'};", + "constexpr bool isBlockScale_1D2X2X = false;", + "constexpr bool isBlockScale_1D2X2X_Transpose = false;", + "constexpr bool useBlockScaleColwiseKernel = false;", + f"constexpr int DESIRED_OCCUPANCY = {occupancy};", + "", + "using Ktraits = Kernel_traits;", + "", + "#define USE_STATIC_SMEM_VALUE ((int)sizeof(LnFwdShared))", + ] + + return "\n".join(lines) + "\n" + + +def _get_uri( + C: int, + output_dtype: str, + warps_m: int, + ctas_per_row: int, + bytes_per_ldg: int, + kernel_cfg: int, + occupancy: int, +) -> str: + return ( + f"rmsnorm_silu_C{C}_{output_dtype}" + f"_wm{warps_m}_cpr{ctas_per_row}_bpl{bytes_per_ldg}" + f"_cfg{kernel_cfg}_occ{occupancy}" + ) + + +def gen_rmsnorm_silu_module( + C: int, + output_dtype: str, + warps_m: int, + ctas_per_row: int, + bytes_per_ldg: int, + kernel_cfg: int, + occupancy: int, +) -> JitSpec: + uri = _get_uri( + C, output_dtype, warps_m, ctas_per_row, bytes_per_ldg, kernel_cfg, occupancy + ) + + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / uri + gen_directory.mkdir(parents=True, exist_ok=True) + + config_content = _generate_config( + C, output_dtype, warps_m, ctas_per_row, bytes_per_ldg, kernel_cfg, occupancy + ) + write_if_different(gen_directory / "rmsnorm_silu_config.inc", config_content) + + sources = [] + for fname in ["rmsnorm_silu.cu", "flashinfer_rmsnorm_silu_binding.cu"]: + dst = gen_directory / fname + shutil.copy(jit_env.FLASHINFER_CSRC_DIR / fname, dst) + sources.append(dst) + + return gen_jit_spec( + uri, + sources, + extra_cuda_cflags=[ + "-DENABLE_BF16", + "-DENABLE_FP8", + ], + extra_include_paths=[str(gen_directory)], + ) diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py index 89259177a9..01fdd34fc9 100644 --- a/flashinfer/norm/__init__.py +++ b/flashinfer/norm/__init__.py @@ -519,6 +519,143 @@ def _layernorm_fake( pass +# ============================================================ +# Fused RMSNorm + SiLU (ported from cuDNN frontend OSS engine) +# ============================================================ + +from ..jit.rmsnorm_silu import ( + gen_rmsnorm_silu_module, + select_knobs, + _estimate_ctas_per_row, +) + + +@functools.cache +def _get_rmsnorm_silu_sm_count(): + """Cache the SM count for the current device.""" + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + return props.multi_processor_count + + +@functools.cache +def _get_rmsnorm_silu_module( + C, output_dtype, warps_m, ctas_per_row, bytes_per_ldg, kernel_cfg, occupancy +): + return gen_rmsnorm_silu_module( + C, output_dtype, warps_m, ctas_per_row, bytes_per_ldg, kernel_cfg, occupancy + ).build_and_load() + + +def _compute_rmsnorm_silu_workspace_size( + rows, cols, output_dtype, warps_m, ctas_per_row, kernel_cfg, occupancy, sm_count +): + """Compute workspace size matching the engine's layout.""" + # rs + ws = rows * 4 # sizeof(float) + ws = ((ws + 127) // 128) * 128 + # fp8_scale + ws += 4 + ws = ((ws + 127) // 128) * 128 + # scale_row (NVFP4 only) + if output_dtype == "nvfp4": + ws += rows * ((cols + 15) // 16) + ws = ((ws + 127) // 128) * 128 + # cooperative workspace (multi-CTA) + if ctas_per_row > 1: + ctas_per_col_max = (rows + warps_m - 1) // warps_m + if kernel_cfg == 2: + ctas_per_col = ctas_per_col_max + else: + ctas_per_col = min(sm_count * occupancy // ctas_per_row, ctas_per_col_max) + ctas_per_col = max(ctas_per_col, 1) + ws += ctas_per_col * warps_m * ctas_per_row * 8 * 2 # sizeof(float2) * 2 + ws = ((ws + 127) // 128) * 128 + ws += 2 * ctas_per_col * 4 # sizeof(int32_t) + ws = ((ws + 127) // 128) * 128 + ws += 128 # final alignment padding + return ws + + +def _torch_dtype_to_str(dtype): + if dtype == torch.bfloat16: + return "bf16" + elif dtype == torch.float8_e4m3fn: + return "fp8" + elif hasattr(torch, "float4_e2m1fn_x2") and dtype == torch.float4_e2m1fn_x2: + return "nvfp4" + return "bf16" + + +@flashinfer_api +def fused_rmsnorm_silu( + input: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + r"""Fused RMSNorm + SiLU activation. + + ``out[i] = SiLU(RMSNorm(input[i], weight, eps))`` + + where ``SiLU(x) = x / (1 + exp(-x))`` + + This kernel is ported from the cuDNN frontend OSS Sm100RmsNormSiluEngine + and is optimized for WAN VAE decoder workloads on B200. + + Parameters + ---------- + input: torch.Tensor + Input tensor, shape ``(num_tokens, hidden_size)``, dtype ``bfloat16``. + weight: torch.Tensor + Scale (gamma) tensor, shape ``(hidden_size,)``, dtype ``bfloat16``. + eps: float + Epsilon for numerical stability. + out: Optional[torch.Tensor] + Output tensor. If None, allocated as same shape/dtype as input. + + Returns + ------- + output: torch.Tensor + Normalized + SiLU activated tensor, shape ``(num_tokens, hidden_size)``. + """ + if out is None: + out = torch.empty_like(input) + + num_tokens = input.size(0) + C = input.size(1) + output_dtype_str = _torch_dtype_to_str(out.dtype) + + knobs = select_knobs(C, num_tokens, output_dtype_str) + if knobs is None: + raise ValueError( + f"Unsupported problem size for fused_rmsnorm_silu: " + f"C={C}, num_tokens={num_tokens}, dtype={output_dtype_str}" + ) + + warps_m, split_cols, kernel_cfg, occupancy, bytes_per_ldg = knobs + ctas_per_row = _estimate_ctas_per_row(C, split_cols, kernel_cfg, bytes_per_ldg) + sm_count = _get_rmsnorm_silu_sm_count() + + module = _get_rmsnorm_silu_module( + C, output_dtype_str, warps_m, ctas_per_row, bytes_per_ldg, kernel_cfg, occupancy + ) + + ws_size = _compute_rmsnorm_silu_workspace_size( + num_tokens, + C, + output_dtype_str, + warps_m, + ctas_per_row, + kernel_cfg, + occupancy, + sm_count, + ) + workspace = torch.empty(ws_size, dtype=torch.uint8, device=input.device) + + module.rmsnorm_silu(out, input, weight, eps, workspace, sm_count) + return out + + # Public API exports __all__ = [ # JIT module generator (always available) @@ -531,4 +668,5 @@ def _layernorm_fake( "gemma_rmsnorm", "gemma_fused_add_rmsnorm", "layernorm", + "fused_rmsnorm_silu", ] diff --git a/include/flashinfer/norm/ln_fwd_silu_kernel.cuh b/include/flashinfer/norm/ln_fwd_silu_kernel.cuh new file mode 100644 index 0000000000..2b1030e327 --- /dev/null +++ b/include/flashinfer/norm/ln_fwd_silu_kernel.cuh @@ -0,0 +1,431 @@ +#pragma once +// Extracted from cudnn_frontend ln_fwd_silu_kernel.h for RmsNorm+SiLU kernel. +// Original: cudnn_frontend/include/.../generated/rms_norm_silu/sm100/ln_fwd_silu_kernel.h +// +// IMPORTANT: Include ln_silu_headers.cuh and the config .inc BEFORE this file. +// The config must define Ktraits, DESIRED_OCCUPANCY, and all constexpr flags. + +constexpr int mxfp8_block_size = 32; +constexpr int nvfp4_block_size = 16; + +template +struct LnFwdShared { + using Traits = _Traits; + static constexpr int32_t SMEM_STATS_ELEMENTS = + ((Traits::Stats::SMEM_BYTES > 0) ? Traits::Stats::SMEM_BYTES : 1); + static constexpr int32_t SMEM_BAR_ELEMENTS = + ((Traits::USE_CLUSTER && Traits::CTAS_PER_ROW > 1) + ? (Traits::WARPS_M + 1 + Traits::WARPS_M * Traits::CTAS_PER_ROW) + : 1); + static constexpr int32_t SMEM_MXFP8_ELEMENTS = + ((isBlockScale_1D2X2X && !useBlockScaleColwiseKernel) + ? ((mxfp8_block_size * Traits::NUM_ELTS + (Traits::NUM_ELTS - 1)) * + (mxfp8_block_size + 1)) + : 1); + static constexpr int32_t GAMMA_ELEMENTS = + ((Traits::hasGamma && Traits::USE_GAMMA_SMEM) + ? (Traits::BATCH_SIZE * Traits::LDGS * Traits::THREADS_PER_ROW * Traits::NUM_ELTS) + : 1); + static constexpr int32_t BETA_ELEMENTS = + ((Traits::hasBeta && Traits::USE_GAMMA_SMEM) + ? (Traits::BATCH_SIZE * Traits::LDGS * Traits::THREADS_PER_ROW * Traits::NUM_ELTS) + : 1); + + __align__(16) char smem_stats[SMEM_STATS_ELEMENTS]; + __align__(16) uint64_t smem_bar[SMEM_BAR_ELEMENTS]; + __align__(16) typename Traits::weight_t smem_gamma[GAMMA_ELEMENTS]; + __align__(16) typename Traits::weight_t smem_beta[BETA_ELEMENTS]; + __align__(16) float smem_mxfp8[SMEM_MXFP8_ELEMENTS]; +}; + +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA, DESIRED_OCCUPANCY) void ln_fwd_kernel( + PersistentLnFwdParams params, + reduced_divisor + divisor) { // divisor is div_batch if it is batch-first case, else it is div_seqLen + enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; + enum { WARPS_N = Ktraits::WARPS_N }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; + enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; + enum { VEC_COLS = Ktraits::VEC_COLS }; + enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::NUM_ELTS }; + enum { THREADS_PER_WARP = Ktraits::THREADS_PER_WARP }; + enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + enum { COLS = Ktraits::COLS }; + enum { COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG * Ktraits::NUM_ELTS }; + enum { COLS_PER_LDG_PER_CTA = COLS_PER_LDG / Ktraits::CTAS_PER_ROW }; + enum { VEC_COLS_PER_LDG_PER_CTA = VEC_COLS_PER_LDG / Ktraits::CTAS_PER_ROW }; + enum { USE_GAMMA_SMEM = Ktraits::USE_GAMMA_SMEM }; + enum { BATCH_SIZE = Ktraits::BATCH_SIZE }; + enum { isAdaLN = Ktraits::isAdaLN }; + enum { isBatchFirst = Ktraits::isBatchFirst }; + + using output_t = typename Ktraits::output_t; + using weight_t = typename Ktraits::weight_t; + using index_t = typename Ktraits::index_t; + using compute_t = typename Ktraits::compute_t; + using norm_output_t = typename Ktraits::norm_output_t; + using Ivec = typename Ktraits::Ivec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + using NormOvec = typename Ktraits::NormOvec; + + using Stats = typename Ktraits::Stats; + using stats_t = typename Stats::stats_t; + +#ifdef USE_STATIC_SMEM_VALUE + __shared__ __align__(16) char smem_base_[USE_STATIC_SMEM_VALUE]; +#else + extern __shared__ char smem_base_[]; +#endif + + LnFwdShared* shared = reinterpret_cast*>(smem_base_); + + uint64_t* smemBar = shared->smem_bar; +#if LN_USE_CLUSTER + if (CTAS_PER_ROW > 1) { +#if (__CUDA_ARCH__ >= 900) && (CUDART_VERSION >= 12080) + // Init the empty bars for each warp + if (threadIdx.x < WARPS_M) { + cuda::ptx::mbarrier_init(&smemBar[threadIdx.x], CTAS_PER_ROW * WARPS_N * THREADS_PER_WARP); + } + // Init the full bar (shared by the CTA) + if (threadIdx.x == 0) { + cuda::ptx::mbarrier_init(&smemBar[WARPS_M], 1); + cuda::ptx::fence_mbarrier_init(cuda::ptx::sem_release, cuda::ptx::scope_cluster); + } + cuda::ptx::barrier_cluster_arrive(cuda::ptx::sem_relaxed); + cuda::ptx::barrier_cluster_wait(); +#else + static_assert(true, "Cluster enabled on host side but not available on device"); +#endif // (__CUDA_ARCH__ >= 900) && (CUDART_VERSION >= 12080) + } +#endif // LN_USE_CLUSTER + + const index_t tidx = threadIdx.x; + const index_t bidn = blockIdx.x % CTAS_PER_ROW; + const index_t bidm = blockIdx.x / CTAS_PER_ROW; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / WARPS_N; + const index_t warp_n = warp % WARPS_N; + + const index_t r = bidm * ROWS_PER_CTA + warp_m; + + const index_t col_in_tile = warp_n * THREADS_PER_WARP + lane; + const index_t c = bidn * THREADS_PER_ROW + col_in_tile; + + Stats stats(params, bidm, bidn, warp_m, warp_n, tidx, lane, shared->smem_stats, smemBar); + + // Unused when USE_GAMMA_SMEM is true and will be optimized out + [[maybe_unused]] Wvec gamma_regs[BATCH_SIZE][LDGS]; + [[maybe_unused]] Wvec beta_regs[BATCH_SIZE][LDGS]; + weight_t *gamma_smem = nullptr, *beta_smem = nullptr; + if constexpr (USE_GAMMA_SMEM) { + static constexpr int32_t SMEM_BYTES_GAMMA = THREADS_PER_ROW * BATCH_SIZE * LDGS * sizeof(Wvec); + if constexpr (Ktraits::hasGamma) { + gamma_smem = shared->smem_gamma; + } + if constexpr (Ktraits::hasBeta) { + beta_smem = shared->smem_beta; + } + } + + // If we are mxfp8 output type, we need shared memory for amax calculations across warps + // 2d1x1x (not yet implemented) requires 1 (1 for a 32x32 block) + // 1d1x1x requires 0 (since it reduces over a row which can be done with warp reduce) + // 1d2x2x requires 32x(32+1)xNUM_ELTS, more details as follows: + constexpr int block_scale_size = isFP4Out ? nvfp4_block_size : mxfp8_block_size; + BlockScaleRowHelper rowwise_scale_helper{}; + BlockScaleColHelper colwise_scale_helper{shared->smem_mxfp8}; + compute_t* mu_ptr = static_cast(params.mu); + compute_t* rs_ptr = static_cast(params.rs); + + constexpr compute_t rn = 1.f / compute_t(Ktraits::COLS); + + index_t idx = c; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("griddepcontrol.wait;\n"); +#endif + + // Load gamma and beta into shared memory or registers +#pragma unroll + for (int b = 0; b < BATCH_SIZE; b++) { +// CL-14115: The unroll factor 128 for LDGS was chosen based on the compilation/perf results for +// APEX LN_fwd engines +#pragma unroll 128 + for (int it = 0; it < LDGS; it++) { + if constexpr (USE_GAMMA_SMEM) { + if (warp_m == 0) { + const index_t cur_gamma_smem_base_idx = (b * LDGS + it) * THREADS_PER_ROW * NUM_ELTS + + warp_n * THREADS_PER_WARP * NUM_ELTS + lane; + if constexpr (Ktraits::hasGamma) { + Wvec cur_gamma_vec; + cur_gamma_vec.load_from(params.gamma, idx); +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + const index_t cur_gamma_smem_idx = cur_gamma_smem_base_idx + jt * THREADS_PER_WARP; + gamma_smem[cur_gamma_smem_idx] = cur_gamma_vec.data.elt[jt]; + } + } + if constexpr (Ktraits::hasBeta) { + Wvec cur_beta_vec; + cur_beta_vec.load_from(params.beta, idx); +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + const index_t cur_beta_smem_idx = cur_gamma_smem_base_idx + jt * THREADS_PER_WARP; + beta_smem[cur_beta_smem_idx] = cur_beta_vec.data.elt[jt]; + } + } + } + } else if constexpr (!GAMMA_ON_DEMAND) { + if constexpr (Ktraits::hasGamma) { + gamma_regs[b][it].load_from(params.gamma, idx); + } + if constexpr (Ktraits::hasBeta) { + beta_regs[b][it].load_from(params.beta, idx); + } + } + idx += VEC_COLS_PER_LDG; + } + } + + if constexpr (USE_GAMMA_SMEM) { + __syncthreads(); + } + + // Initialize scale and bias for FP8 output + compute_t scale = 1.f; + if constexpr (isFP8Out) { + scale = __ldg(params.scale); + } + compute_t amax = 0; + + index_t remaining_rows = params.rows - bidm * ROWS_PER_CTA; + int row_increment_step = params.ctas_per_col * ROWS_PER_CTA; + int batch_idx = 0, remainder = 0; + int batch_increment_step = 0, step_remainder = 0; + if constexpr (isAdaLN) { + if constexpr (isBatchFirst) { + divisor.divmod(r, batch_idx, remainder); // row = batch_idx * seqLen + remainder; + // (remainder < seqLen) batch_idx = r/seqLen + divisor.divmod(row_increment_step, batch_increment_step, + step_remainder); // row_increment_step = + // batch_increment_step * seqLen + + // step_remainder (remainder < + // seqLen) + } else { + batch_idx = divisor.mod(r); // batch_idx = row % BATCH_SIZE; + batch_increment_step = divisor.mod( + row_increment_step); // batch_increment_step = row_increment_step % BATCH_SIZE; + } + } + + for (int row = r; row < params.rows; + row += row_increment_step, batch_idx += batch_increment_step, remainder += step_remainder) { + index_t idx = static_cast(row) * VEC_COLS + c; + + // Load x and convert to compute type per row per thread + compute_t xf[LDGS * NUM_ELTS]; +#pragma unroll 128 + for (int it = 0; it < LDGS; it++) { + Ivec x_it{}; + x_it.load_from(params.x, idx); +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + compute_t x_ij = compute_t(x_it.data.elt[jt]); + xf[it * NUM_ELTS + jt] = x_ij; + } + idx += VEC_COLS_PER_LDG; + } + + // Compute mean and variance per row per thread + // How many rows current CTA will handle for this iteration + int rows_per_cta = remaining_rows >= ROWS_PER_CTA ? ROWS_PER_CTA : remaining_rows; + stats_t s = stats.compute(xf, rn, rows_per_cta); + remaining_rows -= params.ctas_per_col * ROWS_PER_CTA; // for next iteration + compute_t mu = Get<0>::of(s); + compute_t m2 = Get<1>::of(s); + if constexpr (!Ktraits::isRMSNorm) { + if (bidn == 0 && warp_n == 0 && lane == 0) { + mu_ptr[row] = mu; + } + } + compute_t rs = rsqrtf(rn * m2 + params.epsilon); + + if (bidn == 0 && warp_n == 0 && lane == 0) { + rs_ptr[row] = rs; + } + + idx = row * VEC_COLS + c; + + if constexpr (isAdaLN) { + if constexpr (isBatchFirst) { + if (remainder >= params.seqLen) { + batch_idx += 1; + remainder -= params.seqLen; + } + } else { + if (batch_idx >= BATCH_SIZE) { + batch_idx -= BATCH_SIZE; + } + } + } + index_t gamma_idx = c + (batch_idx * LDGS) * VEC_COLS_PER_LDG; +#pragma unroll 128 + for (int it = 0; it < LDGS; it++) { + Cvec z_math; + [[maybe_unused]] Wvec g_wt; + [[maybe_unused]] Wvec b_wt; + if constexpr (GAMMA_ON_DEMAND && Ktraits::hasGamma && !USE_GAMMA_SMEM) { + g_wt.load_from(params.gamma, gamma_idx); + + if constexpr (Ktraits::hasBeta) { + b_wt.load_from(params.beta, gamma_idx); + } + } +#pragma unroll + // Compute output per ldg per row per thread + for (int jt = 0; jt < NUM_ELTS; jt++) { + compute_t y_ij = rs * (xf[it * NUM_ELTS + jt] - mu); + + if constexpr (Ktraits::hasGamma) { + weight_t g_ij_wt{}; + const int32_t cur_gamma_smem_base_idx = + (batch_idx * LDGS + it) * THREADS_PER_ROW * NUM_ELTS + + warp_n * THREADS_PER_WARP * NUM_ELTS + lane; + const int32_t cur_gamma_smem_idx = cur_gamma_smem_base_idx + jt * THREADS_PER_WARP; + if constexpr (USE_GAMMA_SMEM) { + g_ij_wt = gamma_smem[cur_gamma_smem_idx]; + } else if constexpr (GAMMA_ON_DEMAND) { + g_ij_wt = g_wt.data.elt[jt]; + } else { + g_ij_wt = gamma_regs[batch_idx][it].data.elt[jt]; + } + compute_t g_ij = static_cast(g_ij_wt); + if constexpr (isZeroCenteredGamma) { + if constexpr (isZeroCenteredGammaCastBeforeAdd) { + g_ij = static_cast(g_ij_wt) + static_cast(1.f); + } else { + g_ij = static_cast(g_ij_wt + static_cast(1.f)); + } + } + if constexpr (Ktraits::hasBeta) { + compute_t b_ij{}; + const int32_t cur_beta_smem_idx = cur_gamma_smem_base_idx + jt * THREADS_PER_WARP; + if constexpr (USE_GAMMA_SMEM) { + b_ij = beta_smem[cur_beta_smem_idx]; + } else if constexpr (GAMMA_ON_DEMAND) { + b_ij = static_cast(b_wt.data.elt[jt]); + } else { + b_ij = beta_regs[batch_idx][it].data.elt[jt]; + } + y_ij = g_ij * y_ij + b_ij; + } else { + y_ij = g_ij * y_ij; + } + } + + // SiLU activation: y = y * sigmoid(y) = y / (1 + exp(-y)) + // Applied after norm + gamma [+ beta], before FP8/block-scale quantization. + y_ij = __fdividef(y_ij, 1.0f + __expf(-y_ij)); + + if constexpr (isFP8Out) { + if (hasAmax) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(y_ij)); + } + y_ij *= scale; + } + z_math.data.elt[jt] = y_ij; + } // NUM_ELTS + + if constexpr (isBlockScaleOut) { + static_assert(!isBlockScaleOut || + (Ktraits::COLS % mxfp8_block_size == 0)); // ensure cols divisable by 32 + + index_t sf_row_idx = idx / mxfp8_block_size; + [[maybe_unused]] NormOvec z_intermediate; + if constexpr (std::is_same::value) { + rowwise_scale_helper.blockQuantizeStore(z_math, params.scale_row, sf_row_idx, params.z, + idx); + } else { + z_math.to(z_intermediate); + rowwise_scale_helper.blockQuantizeStore(z_intermediate, params.scale_row, sf_row_idx, + params.z, idx); + } + if constexpr (isBlockScale_1D2X2X) { + if constexpr (useBlockScaleColwiseKernel) { + // Store the temporary z_math values in workspace and launch a separate kernel to + // compute the colwise scaling results + if constexpr (std::is_same::value) { + z_math.store_to(params.z_math, idx); + } else { + z_intermediate.store_to(params.z_math, idx); + } + } else { + if constexpr (std::is_same::value) { + colwise_scale_helper.initTile(z_math, THREADS_PER_ROW * WARPS_M); + // static_assert(!std::is_same::value); + } else { + colwise_scale_helper.initTile(z_intermediate, THREADS_PER_ROW * WARPS_M); + } + index_t sf_col_row_idx = 0; + index_t sf_col_col_idx = 0; + index_t sf_col_row_width = 0; + index_t z_col_idx = 0; + index_t z_row_offset = row - row % mxfp8_block_size; + if constexpr (!isBlockScale_1D2X2X_Transpose) { + sf_col_row_idx = row / mxfp8_block_size; + sf_col_col_idx = it * VEC_COLS_PER_LDG + bidn * VEC_COLS_PER_LDG_PER_CTA + warp; + sf_col_row_width = VEC_COLS; + z_col_idx = (z_row_offset + lane) * VEC_COLS + sf_col_col_idx; + } else { + constexpr index_t group_size = mxfp8_block_size / NUM_ELTS; + sf_col_row_idx = it * COLS_PER_LDG + bidn * COLS_PER_LDG_PER_CTA + warp * NUM_ELTS; + sf_col_col_idx = row / mxfp8_block_size; + sf_col_row_width = params.rows / mxfp8_block_size; + z_col_idx = (sf_col_row_idx + lane / group_size) * params.rows / NUM_ELTS + + z_row_offset / NUM_ELTS + (lane % group_size); + } + colwise_scale_helper.blockQuantizeStore( + params.scale_col, sf_col_row_idx, sf_col_col_idx, sf_col_row_width, params.z_col, + z_col_idx, THREADS_PER_ROW * WARPS_M); + } + } + } else { + Ovec z; + z_math.to(z); + z.store_to(params.z, idx); + } + idx += VEC_COLS_PER_LDG; + gamma_idx += VEC_COLS_PER_LDG; + } // LDGS + } // grid stride loop + + // Write scale_inv before launch_dependents - consumer needs it to dequantize FP8 output + if constexpr (isFP8Out) { + if (hasScaleInv && blockIdx.x == 0 && threadIdx.x == 0) { + *reinterpret_cast(params.scale_inv) = __fdividef(1.f, scale); + } + } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + asm volatile("griddepcontrol.launch_dependents;\n"); +#endif + + // amax can be after launch_dependents - only needed for delayed scaling (next iteration) + if constexpr (isFP8Out) { + if constexpr (hasAmax) { + amax = reduce_max(amax, warp, threadIdx.x); + if (threadIdx.x == 0) { + atomicMaxFloat(reinterpret_cast(params.amax), amax); + } + } + } +} diff --git a/include/flashinfer/norm/ln_silu_headers.cuh b/include/flashinfer/norm/ln_silu_headers.cuh new file mode 100644 index 0000000000..2172776e04 --- /dev/null +++ b/include/flashinfer/norm/ln_silu_headers.cuh @@ -0,0 +1,1919 @@ +#pragma once +// Extracted from cudnn_frontend ln_headers.h for RmsNorm+SiLU kernel. +// Original: cudnn_frontend/include/.../generated/rms_norm_silu/sm100/ln_headers.h + +#pragma once +// Auto-generated lightweight LN kernel header. +// Replaces the 151K-line persistent_ln_headers_13.0.h with: +// - Standard CUDA TK #include directives (resolved via --include-path at NVRTC compile time) +// - ~2900 lines of cuDNN-authored LN-specific code extracted from the original +// +// CGA/cluster support (USE_CLUSTER) is disabled: +// - #include and cooperative_groups are NOT included +// - All if constexpr (USE_CLUSTER) branches changed to if constexpr(USE_CLUSTER) so dead code is +// not compiled +// +// Requires CUDA Toolkit headers available at the --include-path location. +// Generated by step3b_extract_ln_headers.py — do not edit manually. + +// ============================================================ +// Standard CUDA Toolkit headers (resolved via --include-path) +// ============================================================ + +#ifdef __CUDACC_RTC__ +// Minimal std stubs needed by NVRTC (no standard library available) +#if __cplusplus >= 201103L +namespace std { +template +struct enable_if {}; +template +struct enable_if { + typedef T type; +}; +#if __cplusplus >= 201402L +template +using enable_if_t = typename enable_if::type; +#endif + +struct true_type { + enum { value = true }; + operator bool() const { return true; } +}; +struct false_type { + enum { value = false }; + operator bool() const { return false; } +}; + +template +struct is_floating_point : false_type {}; +template <> +struct is_floating_point : true_type {}; +template <> +struct is_floating_point : true_type {}; + +template +struct is_integral : false_type {}; +template <> +struct is_integral : true_type {}; +template <> +struct is_integral : true_type {}; +template <> +struct is_integral : true_type {}; +template <> +struct is_integral : true_type {}; +template <> +struct is_integral : true_type {}; +template <> +struct is_integral : true_type {}; +template <> +struct is_integral : true_type {}; +template <> +struct is_integral : true_type {}; +template <> +struct is_integral : true_type {}; +template <> +struct is_integral : true_type {}; +template <> +struct is_integral : true_type {}; + +template +struct is_same : false_type {}; +template +struct is_same : true_type {}; +} // namespace std +#endif +#else +#include +#endif + +// CUDA numeric type headers +#include +#include +#include +#include + +// NOTE: #include and are intentionally omitted. +// They are only needed for CGA/cluster support (USE_CLUSTER=true), which this +// OSS engine does not use. All USE_CLUSTER code paths are guarded with +// "if constexpr" so dead branches are not compiled. + +// ============================================================ +// Fixed-width integer typedefs for NVRTC (no available) +// ============================================================ +constexpr int THREADS_PER_WARP = 32; + +#ifndef __CUDACC_RTC__ +#include +#else +typedef signed char int8_t; +typedef unsigned char uint8_t; +typedef signed short int16_t; +typedef unsigned short uint16_t; +typedef signed int int32_t; +typedef unsigned int uint32_t; +typedef signed long long int64_t; +typedef unsigned long long uint64_t; +#endif + +// ============================================================ +// cuDNN-authored LN kernel utilities +// Extracted from persistent_ln_headers_13.0.h +// ============================================================ + +// cuDNN type aliases for FP8 types (from after inlined cuda_fp8.h) +typedef __nv_fp8_e4m3 nv_fp8_e4m3; +typedef __nv_fp8x2_e4m3 nv_fp8x2_e4m3; +typedef __nv_fp8_e5m2 nv_fp8_e5m2; +typedef __nv_fp8x2_e5m2 nv_fp8x2_e5m2; + +// cuDNN type aliases for FP4 types (from after inlined cuda_fp4.h) +typedef __nv_fp4_e2m1 nv_fp4_e2m1; +typedef __nv_fp4x2_e2m1 nv_fp4x2_e2m1; +typedef __nv_fp4x4_e2m1 nv_fp4x4_e2m1; + +// Pointer typedefs for NVRTC (from original Jitify stubs) +typedef int64_t intptr_t; +typedef uint64_t uintptr_t; +typedef int64_t intmax_t; +typedef uint64_t uintmax_t; + +// Large vector types for 64-byte and 32-byte loads +struct uint16 { + uint4 u; + uint4 v; + uint4 s; + uint4 t; +}; + +struct uint8 { + uint4 u; + uint4 v; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//// Datatype helpers + +template +struct BytesToType {}; + +template <> +struct BytesToType<64> { + using Type = uint16; +}; + +template <> +struct BytesToType<32> { + using Type = uint8; +}; + +template <> +struct BytesToType<16> { + using Type = uint4; +}; + +template <> +struct BytesToType<8> { + using Type = uint64_t; +}; + +template <> +struct BytesToType<4> { + using Type = uint32_t; +}; + +template <> +struct BytesToType<2> { + using Type = uint16_t; +}; + +template <> +struct BytesToType<1> { + using Type = uint8_t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TypeToVec2 {}; + +template <> +struct TypeToVec2 { + using Type = float2; +}; + +template <> +struct TypeToVec2 { + using Type = half2; +}; + +template <> +struct TypeToVec2 { + using Type = nv_bfloat162; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Get { + template + static inline __device__ R of(const T& vec); +}; + +template <> +template +inline __device__ R Get<0>::of(const T& vec) { + return vec.x; +} + +template <> +template +inline __device__ R Get<1>::of(const T& vec) { + return vec.y; +} + +template <> +template +inline __device__ R Get<2>::of(const T& vec) { + return vec.z; +} + +template <> +template +inline __device__ R Get<3>::of(const T& vec) { + return vec.w; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Converter { + static inline __device__ Dst convert(const Src& from) { return Dst(from); } +}; + +template <> +struct Converter { + static inline __device__ half2 convert(const float2& x) { return __float22half2_rn(x); } +}; + +template <> +struct Converter { + static inline __device__ nv_bfloat162 convert(const float2& x) { +#if __CUDA_ARCH__ >= 800 + return __float22bfloat162_rn(x); +#else + union { + nv_bfloat162 raw; + nv_bfloat16 x; + nv_bfloat16 y; + } tmp; + tmp.x = __float2bfloat16_rn(x.x); + tmp.y = __float2bfloat16_rn(x.y); + return tmp.raw; +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Zeros { + static inline __device__ T get() { return T(0.f); } +}; + +template <> +struct Zeros { + static inline __device__ float2 get() { return make_float2(0.f, 0.f); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +__device__ __inline__ uint8_t float_to_e8m0(float val) { + // CL-15277: use rounding-up mode for float -> e8m0 conversion +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) && \ + (defined(__CUDA_ARCH_FEAT_SM100_ALL) || defined(__CUDA_ARCH_FEAT_SM101_ALL) || \ + defined(__CUDA_ARCH_FEAT_SM110_ALL) || defined(__CUDA_ARCH_FEAT_SM120_ALL)) + uint16_t out; + asm volatile( + "{\n" + "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" + "}" + : "=h"(out) + : "f"(val)); + return *reinterpret_cast(&out); +#else +#if CUDART_VERSION >= 12080 + // Use explicit round-up mode to match the SM_100+ hardware instruction + return __nv_cvt_float_to_e8m0(val, __NV_SATFINITE, cudaRoundPosInf); +#else + if (isnan(val) || isinf(val)) { + return 0xFF; + } + uint32_t val_u32 = __float_as_uint(val); + uint8_t exponent = (val_u32 >> 23) & 0xFF; + uint32_t mantissa = val_u32 & 0x7FFFFF; + if ((mantissa > 0) && (exponent != 0xFE)) { // exp can only be < 0xFE here + exponent++; + } + return exponent; +#endif +#endif +} + +__device__ __inline__ float e8m0_to_float(uint8_t val) { +#if CUDART_VERSION >= 12080 + __nv_fp8_e8m0 e8m0; + e8m0.__x = val; + return static_cast(e8m0); +#else + return __uint_as_float(static_cast(val) << 23); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//// Vectorization data structures +template +struct Vec { + using Elt_type = ELT_TYPE; + + enum { BYTES = NUM_ELT * sizeof(Elt_type) }; + using Vec_type = typename BytesToType::Type; + + using Alias_type = union { + Vec_type vec; + Elt_type elt[NUM_ELT]; + }; + + Alias_type data; + + template + inline __device__ void to(Vec& other) { +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + other.data.elt[it] = S(this->data.elt[it]); + } + } + + template + inline __device__ void assign(const Op& op) { +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + this->data.elt[it] = op(it); + } + } + + inline __device__ void load_from(const void* base_ptr, const size_t idx) { + this->data.vec = static_cast(base_ptr)[idx]; + } + + inline __device__ void store_to(void* base_ptr, const size_t idx) { + static_cast(base_ptr)[idx] = this->data.vec; + } + + inline __device__ void operator+=(const Vec& rhs) { +#pragma unroll + for (int it = 0; it < NUM_ELT; it++) { + this->data.elt[it] += rhs.data.elt[it]; + } + } +}; + +inline __device__ float2 operator+(const float2& a, const float2& b) { + return {a.x + b.x, a.y + b.y}; +} + +inline __device__ void operator+=(float2& a, const float2& b) { + a.x += b.x; + a.y += b.y; +} + +template +struct Sum { + inline __device__ Sum() {} + inline __device__ T operator()(const T& a, const T& b) { return a + b; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void __nv_ptx_builtin_ocg_fence_view_async_shared(void) { + asm volatile("fence.proxy.async.shared::cta;\n"); +} + +#ifdef __CUDACC_RTC__ +extern "C" { +__device__ uint32_t __nvvm_get_smem_pointer(void* ptr); +__device__ void __nv_ptx_builtin_ocg_write_async_shared_b32(uint32_t dstAddr, uint32_t mbarrierAddr, + uint32_t b0); +__device__ void __nv_ptx_builtin_ocg_write_async_shared_b64(uint32_t dstAddr, uint32_t mbarrierAddr, + uint64_t b0); +__device__ void __nv_ptx_builtin_ocg_write_async_shared_v2_b32(uint32_t dstAddr, + uint32_t mbarrierAddr, uint32_t b0, + uint32_t b1); +__device__ void __nv_ptx_builtin_ocg_write_async_shared_v4_b32(uint32_t dstAddr, + uint32_t mbarrierAddr, uint32_t b0, + uint32_t b1, uint32_t b2, + uint32_t b3); +} +#endif + +namespace utils { + +////////////////////////////////////////////////////////////////////////////////////////////////// +//// Shuffle syncronization helpers +template +inline __device__ T warp_shuffle_xor(const T& x, uint32_t idx) { + return __shfl_xor_sync(uint32_t(-1), x, idx); +} + +template <> +inline __device__ float2 warp_shuffle_xor(const float2& x, uint32_t idx) { + return {warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx)}; +} + +template +inline __device__ T warp_shuffle_down(const T& x, uint32_t idx) { + return __shfl_down_sync(uint32_t(-1), x, idx); +} + +template <> +inline __device__ float2 warp_shuffle_down(const float2& x, uint32_t idx) { + return {warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx)}; +} + +// This is a helper function to support __shfl_sync(float2) +template +__device__ inline static T shfl_sync_helper(unsigned mask, T var, int srcLane) { + return __shfl_sync(mask, var, srcLane); +} + +// specialize for float2 +template <> +__device__ inline float2 shfl_sync_helper(unsigned mask, float2 var, int srcLane) { + double ret = __shfl_sync(mask, *reinterpret_cast(&var), srcLane); // cast to double + return *reinterpret_cast(&ret); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// +__device__ __forceinline__ void namedBarrierSync(int name, int numThreads) { + asm volatile("bar.sync %0, %1;" : : "r"(name), "r"(numThreads) : "memory"); +} + +} // namespace utils + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//// Atomics +static inline __device__ void atomicMaxFloat(float* addr, const float value) { + atomicMax(reinterpret_cast(addr), __float_as_int(value)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//// Warp reduction functions +template +static inline __device__ float warp_reduce_max(const float m, unsigned int mask = 0xFFFFFFFF) { + static constexpr int WARP_SIZE = 32; + static_assert(GROUP_SIZE > 0 && GROUP_SIZE <= WARP_SIZE, + "group size must be less than or equal to warp size"); + + float max_val = m; +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) && defined(__CUDA_ARCH_FEAT_SM100_ALL) + // Only enable the credux instruction for reduction on the full warp, as we saw perf regression + // for GROUP_SIZE < 32, where we need to iterate over #(WARP_SIZE/GROUP_SIZE) groups in the warp + if constexpr (GROUP_SIZE == WARP_SIZE) { + // For sm_100a arch, we can use redux.sync.op{.abs.}{.NaN}.f32 instruction to reduce the min/max + // value + // https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-redux-sync + asm volatile("redux.sync.max.f32 %0, %1, %2;" : "=f"(max_val) : "f"(m), "r"(mask)); + return max_val; + } +#endif +#pragma unroll + for (int delta = GROUP_SIZE / 2; delta > 0; delta /= 2) { + const float other_m = __shfl_down_sync(mask, max_val, delta); + __builtin_assume(max_val >= 0); + __builtin_assume(other_m >= 0); + max_val = fmaxf(max_val, other_m); + } + return max_val; +} + +template +static inline __device__ compute_t reduce_max(const compute_t m, const int warpid) { + __shared__ float staging[num_warps]; + constexpr int warp_size = 32; + const float my_max = m; + const float my_warp_max = warp_reduce_max(my_max); + if (threadIdx.x % 32 == 0) { + staging[warpid] = my_warp_max; + } + __syncthreads(); + compute_t result = 0; + if (warpid == 0) { + const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0; + result = warp_reduce_max(my_max); + } + return result; +} + +template +static inline __device__ compute_t reduce_max(const compute_t m, const int warpid, int tidx) { + __shared__ float staging[num_warps]; + constexpr int warp_size = 32; + const float my_max = m; + const float my_warp_max = warp_reduce_max(my_max); + if (tidx % 32 == 0) { + staging[warpid] = my_warp_max; + } + utils::namedBarrierSync(1, num_warps * warp_size); + compute_t result = 0; + if (warpid == 0) { + const float my_max = tidx < num_warps ? staging[tidx] : 0; + result = warp_reduce_max(my_max); + } + return result; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//// Misc +constexpr __device__ int32_t find_largest_divisor(int32_t n, int32_t divisor_upper_bound) { + if (n <= divisor_upper_bound) { + return n; + } + for (int32_t i = divisor_upper_bound; i >= 1; --i) { + if (n % i == 0) { + return i; // found the largest divisor + } + } + return divisor_upper_bound; // fallback +} + +constexpr __device__ bool is_power_of_2(unsigned int n) { return (n > 0) && ((n & (n - 1)) == 0); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//// Cooperative groups related +//// !!!! Keep this block at the end of the file - start !!!! + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//// Runtime parameters + +struct PersistentLnParamsBase { + // For Multi-CTA, number of different CTA groups; otherwise same as gridDim.x. + int ctas_per_col = 0; + + // Input is interpreted as matrix and we normalize across columns. + int rows = 0; + int cols = 0; + int batchSize = 1; // For AdaLN + int seqLen = 1; // For AdaLN + + // Common data pointers for forward and backward passes. + void* x = nullptr; + void* mu = nullptr; + void* rs = nullptr; + void* gamma = nullptr; + + // Multi-CTA workspace in gmem. + void* workspace = nullptr; + + // Multi-CTA sync barriers in gmem. + int* barrier = nullptr; + + bool isRMSNorm = false; + bool noScale = false; + bool noBias = false; + bool isAdaLN = false; + bool isBatchFirst = true; +}; + +struct PersistentLnFwdParams : public PersistentLnParamsBase { + // Output of LN FWD. + void* z = nullptr; + + void* beta = nullptr; + float epsilon = 0.f; + + // FP8 support + bool fp8_out = false; + float* scale = nullptr; + float* scale_inv = nullptr; + float* amax = nullptr; + + // Reuse z for row scaled output (or block scaled output if we use 2d1x1x) + void* scale_row = nullptr; + void* scale_col = nullptr; // only used in 1d2x2x + void* z_col = nullptr; // only used in 1d2x2x + void* z_math = nullptr; // only used when enabling the 1d2x2x colwise kernel +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//// Inter-CTA synchronization +template +struct InterCTASync { + template + inline __device__ InterCTASync(Params& params, uint32_t bidm, uint32_t bidn, uint32_t tidx) + : phase_counter_(0), + b0_(params.barrier + bidm) // The barrier for this group of CTAs. + , + b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs. + , + tidx_(tidx) { + // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0! + } + + inline __device__ void spin_wait_(int* barrier, int step, int expected) { + asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); + for (int found = -1; found != expected;) { + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); + } + } + + inline __device__ void sync(uint32_t threads_per_cta) { + // ALL THREADS MUST ENTER! + + // We switch barrier every iteration. + int* barrier = phase_counter_ & 0x1 ? b1_ : b0_; + // We decrement every other iteration. + bool dec = phase_counter_ & 0x2; + int step = dec ? -1 : 1; + int expected = dec ? 0 : CTAS_PER_ROW; + // There are only 4 phases: up/down for b0/b1. + phase_counter_ = (phase_counter_ + 1) & 0x3; + + if (0 == tidx_) { + spin_wait_(barrier, step, expected); + } + + // CTA waits for thread 0 + utils::namedBarrierSync(0, threads_per_cta); + } + + int phase_counter_; + int* b0_; + int* b1_; + uint32_t tidx_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//// LayerNorm Reduction and stats computation +// NOTE: WHOLE_CTA means if we'll use one empty mbarrier for inter-CTA sync using cluster. +// In the TMA LN engines it should be "true" since we're syncing the whole CTA with one empty +// mbarrier, while in the APEX LN engines it should be "false" as there's one empty mbarrier per +// warp m (and therefore WARPS_M empty mbarriers in total) +template +struct Reducer : public Reducer { + using InterCTASync = InterCTASync; + using Base = Reducer; + using Type = typename Base::Type; + + enum { SMEM_BYTES = Base::SMEM_BYTES }; + + enum { WS_BARRIER_BYTES = 2 * sizeof(int) }; + enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) }; + + // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total) + enum { + WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES + }; + + template + inline __device__ Reducer(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t tidx, uint32_t lane, void* smem, + uint64_t* smem_bar) + : Base(params, bidm, bidn, warp_m, warp_n, tidx, lane, smem, smem_bar), + inter_cta_(params, bidm, bidn, tidx), + bidn_(bidn), // CTA id within the group. + warp_m_(warp_m), + w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW), + w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW), + smem_bar_(smem_bar), + parity_(0) { + // the first several elements are the barriers + smem_cga_ = reinterpret_cast(&smem_bar[WHOLE_CTA ? 2 : (WARPS_M + 1)]); + } + + template + inline __device__ T allreduce(T data, Op& op, int32_t warps_m) { + data = Base::reduce(data, op, warps_m); + + if constexpr (USE_CLUSTER) { +#if (__CUDA_ARCH__ >= 900) && (CUDART_VERSION >= 12080) + auto cluster = cooperative_groups::this_cluster(); + + // NOTE: right now only lane 0 has the valid value so that a intra-warp broadcast is required + // here. This issue is not in fwd (struct Stats) which does it earlier inside + // warp_chan_upd_dynamic(). + data = utils::shfl_sync_helper(uint32_t(-1), data, 0); + + // Broadcast local results to other CTAs inside the CGA + // Size of smem_cga_: [WARPS_M][CTAS_PER_ROW] + // Thread 0 sends to block 0, Thread 1 sends to block 1, etc... + if ((this->warp_n_ == 0) && (this->lane_ < CTAS_PER_ROW)) { + st_async_remote(&smem_cga_[warp_m_ * CTAS_PER_ROW + bidn_], data, this->lane_, + &smem_bar_[WHOLE_CTA ? 0 : WARPS_M]); + } + + // Leader thread arrives on local barrier to indicate expected tx count + // We can't use (threadIdx.x == 0) directly here because it won't work for TMA + if ((this->warp_n_ == 0) && (warp_m_ == 0) && (this->lane_ == 0)) { + // It's possible that not all warp_m are active here, therefore, we can't use WARPS_M + // directly + uint32_t expected_tx_count = sizeof(data) * CTAS_PER_ROW * warps_m; + utils::mbarrier_arrive_expect_tx_relaxed_cluster(&smem_bar_[WHOLE_CTA ? 0 : WARPS_M], + expected_tx_count); + } + + // Wait on local barrier + while (!utils::mbarrier_try_wait_parity_relaxed_cluster(&smem_bar_[WHOLE_CTA ? 0 : WARPS_M], + parity_)) { + } + utils::fence_acquire_smem(); // Ensure we can read local smem values that are released by + // barrier flip + + T total = Zeros::get(); + if (this->lane_ < CTAS_PER_ROW) { + total = smem_cga_[warp_m_ * CTAS_PER_ROW + this->lane_]; + } + utils::fence_release_smem(); // Ensure read of buffer remains ordered before + // smem_bar.arrive.relaxed + + // Signal barrier is empty. + // Each thread must arrive on all the barriers (like an arrive broadcast). + // NOTE: here we only let each warp m arrive on the remote barriers of the same warp m (i.e. + // the same row), because if we have one empty barrier for the entire CTA, there's a + // difficulty how to init the count of the barrier (since some warps could be disabled here). + // The downside is we'll need a __syncthreads() here, which doesn't have obvious perf impact + // though. + // TODO: we could probably have one full bar per row (same as the empty bar) to get rid of the + // CTA sync, while the downside is more smem (mbarriers) will be needed. + if (!WHOLE_CTA) { + utils::namedBarrierSync(0, warps_m * WARPS_N * 32); + } + for (int other_block_rank = 0; other_block_rank < CTAS_PER_ROW; ++other_block_rank) { + uint64_t* remote_bar_empty = + cluster.map_shared_rank(&smem_bar_[WHOLE_CTA ? 1 : warp_m_], other_block_rank); + utils::mbarrier_arrive_relaxed_cluster(remote_bar_empty); + } + + total = Reducer::allreduce_(total, op); + + // Wait for remote buffer to be writable again by waiting on local barrier: + while (!utils::mbarrier_try_wait_parity_relaxed_cluster(&smem_bar_[WHOLE_CTA ? 1 : warp_m_], + parity_)) { + } + utils::fence_acquire_smem(); // Order wrt st_async_remote + + // Flip parity bit for next iteration + parity_ ^= 1; + + return total; +#else + static_assert(true, "Cluster enabled on host side but not available on device"); +#endif + } else { + // We switch workspace every iteration. + T* workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; + + // Warp leaders 0 hold the CTA-local results. + if (this->warp_n_ == 0 && this->lane_ == 0) { + workspace[bidn_] = data; + } + + // We need to check for the case where we are using CTA > 1 and WARP_M > 1 and N == 1 + // Because in this case we cannot rely on the __syncThreads in template::compute (because we will use the specialized template ::compute which does not have a sync) + if (CTAS_PER_ROW > 1 && WARPS_M > 1 && WARPS_N == 1) { + utils::namedBarrierSync(0, warps_m * WARPS_N * 32); + } + + // Wait for all CTAS_PER_ROW CTAS in the group to have written their result. + inter_cta_.sync(warps_m * WARPS_N * 32); + + T total = Zeros::get(); + + // Every warp does the final reduction locally. + if constexpr (CTAS_PER_ROW <= THREADS_PER_WARP) { + if (this->lane_ < CTAS_PER_ROW) { + total = workspace[this->lane_]; + } + } else { + static constexpr int32_t LOOP_PER_THREAD = + (CTAS_PER_ROW + THREADS_PER_WARP - 1) / THREADS_PER_WARP; + + // Collect stats for the current lane from ctas (#bidn % THREADS_PER_WARP == #lane) and then + // do the intra-warp reduction + total = workspace[this->lane_]; +#pragma unroll + for (int32_t loop_in_thread = 1; loop_in_thread < LOOP_PER_THREAD; ++loop_in_thread) { + const int32_t lane_in_loop = this->lane_ + loop_in_thread * THREADS_PER_WARP; + if (lane_in_loop < CTAS_PER_ROW) { + total = op(total, workspace[lane_in_loop]); + } + } + } + + total = Reducer::allreduce_(total, op); + + return total; + } + } + + InterCTASync inter_cta_; + T* w0_; + T* w1_; + int bidn_; + int warp_m_; + + private: + T* smem_cga_; + uint64_t* smem_bar_; + int parity_; +}; + +template +struct Reducer { + using Type = T; + enum { SMEM_BYTES = 0 }; + enum { WORKSPACE_BYTES_PER_GROUP = 0 }; + enum { THREADS_PER_WARP = 32 }; + + template + inline __device__ Reducer(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t /*tidx*/, uint32_t lane, void* smem, + uint64_t* smem_bar) + : warp_n_(warp_n), lane_(lane) {} + + template + static inline __device__ T allreduce_(T data, Op& op) { +#pragma unroll + for (int it = 1; it < THREADS_PER_WARP; it *= 2) { + data = op(data, utils::warp_shuffle_xor(data, it)); + } + return data; + } + + template + inline __device__ T allreduce(T data, Op& op, int32_t warps_m) { + return allreduce_(data, op); + } + + protected: + template + inline __device__ T reduce(T data, Op& op, int32_t warps_m) { +// only lane 0 holds the result! +#pragma unroll + for (int it = THREADS_PER_WARP / 2; it > 0; it /= 2) { + data = op(data, utils::warp_shuffle_down(data, it)); + } + return data; + } + + public: + int warp_n_; + int lane_; +}; + +template +struct Reducer + : public Reducer { + using Base = Reducer; + using Type = T; + + enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; + enum { WORKSPACE_BYTES_PER_GROUP = 0 }; + enum { THREADS_PER_WARP = 32 }; + + template + inline __device__ Reducer(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t tidx, uint32_t lane, void* smem, + uint64_t* smem_bar) + : Base(params, bidm, bidn, warp_m, warp_n, tidx, lane, smem, smem_bar), use0_(true) { + smem0_ = &static_cast(smem)[warp_m * WARPS_N]; + smem1_ = smem0_ + WARPS_M * WARPS_N; + } + + template + inline __device__ T allreduce(T data, Op& op, int32_t warps_m) { + T* smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + data = Base::reduce(data, op, warps_m); + if (this->lane_ == 0) { + smem[this->warp_n_] = data; + } + // The number matters here. 0 will only use one barrier per block. "1" will use 2 according to + // ncu. We should use 0 unless it is absolutely necessary to use seperate barriers. + utils::namedBarrierSync(0, warps_m * WARPS_N * 32); + T out = Zeros::get(); +#pragma unroll + for (int it = 0; it < WARPS_N; it++) { + out = op(out, smem[it]); + } + return out; + } + + protected: + template + inline __device__ T reduce(T data, Op& op, int32_t warps_m) { + T* smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + // only intra-CTA group leader holds the result! + data = Base::reduce(data, op, warps_m); + if (this->lane_ == 0) { + smem[this->warp_n_] = data; + } + utils::namedBarrierSync(0, warps_m * WARPS_N * 32); + T out = Zeros::get(); + if (this->warp_n_ == 0 && this->lane_ == 0) { +#pragma unroll + for (int it = 0; it < WARPS_N; it++) { + out = op(out, smem[it]); + } + } + return out; + } + + private: + T* smem0_; + T* smem1_; + bool use0_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline __device__ void update_norm_stats(T& m_a, T& m2_a, T& n_a, T m_b, T m2_b, T n_b) { + const T n_ab = n_a + n_b; + + if constexpr (!isRMSNorm) { + const T rn_ab = (handleNan && (n_ab == Zeros::get())) ? Zeros::get() : (1.f / n_ab); + const T delta = m_a - m_b; + const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; + const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab; + n_a = n_ab; + m_a = m_ab; + m2_a = m2_ab; + } else { + const float m2_ab = m2_a + m2_b; + n_a = n_ab; + m_a = Zeros::get(); + m2_a = m2_ab; + } +} + +template +inline __device__ void warp_chan_upd_dynamic(T& m_a, T& m2_a, T& n_a, int num_active) { + // Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN + // otherwise) + int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); + +#pragma unroll + for (int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2) { + // Exchange + T n_b = utils::warp_shuffle_down(n_a, step); + T m_b = (!isRMSNorm ? utils::warp_shuffle_down(m_a, step) : Zeros::get()); + T m2_b = utils::warp_shuffle_down(m2_a, step); + update_norm_stats(m_a, m2_a, n_a, m_b, m2_b, n_b); + } + // Intra-warp broadcast (only lane 0 has valid stats). + m_a = __shfl_sync(uint32_t(-1), m_a, 0); + m2_a = __shfl_sync(uint32_t(-1), m2_a, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// NOTE: WHOLE_CTA means if we'll use one empty mbarrier for inter-CTA sync using cluster. +// In the TMA LN engines it should be "true" since we're syncing the whole CTA with one empty +// mbarrier, while in the APEX LN engines it should be "false" as there's one empty mbarrier per +// warp m (and therefore WARPS_M empty mbarriers in total) +template +struct Stats { + // This could be done generically with the Reducer. But then we would have to exchange 3 instead + // of 2 fields. + + using InterCTASync = InterCTASync; + using BlockStats = Stats; + using stats_t = typename BlockStats::stats_t; + + enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; + + template + inline __device__ Stats(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t tidx, uint32_t lane, void* smem, + uint64_t* smem_bar) + : inter_cta_(params, bidm, bidn, tidx), + block_stats_(params, bidm, bidn, warp_m, warp_n, tidx, lane, smem, smem_bar), + bidn_(bidn), // CTA id within the group. + w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW), + w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW), + warp_n_(warp_n), + warp_m_(warp_m), + lane_(lane), + smem_bar_(smem_bar), + parity_(0) { + // the first several elements are the barriers + smem_cga_ = reinterpret_cast(&smem_bar[WHOLE_CTA ? 2 : (WARPS_M + 1)]); + } + + template + inline __device__ stats_t compute(const T (&elts)[LDGS * NUM_ELTS], const T rn, int32_t warps_m) { + static constexpr uint32_t N = LDGS * NUM_ELTS; + constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP; + // TODO rn is not really needed here.. + constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA); + stats_t block_stats = + block_stats_.template compute(elts, block_rn, warps_m); + + if constexpr (USE_CLUSTER) { +#if (__CUDA_ARCH__ >= 900) && (CUDART_VERSION >= 12080) + auto cluster = cooperative_groups::this_cluster(); + + // Broadcast local results to other CTAs inside the CGA + // Size of smem_cga_: [WARPS_M][CTAS_PER_ROW] + // Thread 0 sends to block 0, Thread 1 sends to block 1, etc... + if ((warp_n_ == 0) && (lane_ < CTAS_PER_ROW)) { + st_async_remote(&smem_cga_[warp_m_ * CTAS_PER_ROW + bidn_], block_stats, lane_, + &smem_bar_[WHOLE_CTA ? 0 : WARPS_M]); + } + + // Leader thread arrives on local barrier to indicate expected tx count + // We can't use (threadIdx.x == 0) directly here because it won't work for TMA + if ((warp_n_ == 0) && (warp_m_ == 0) && (lane_ == 0)) { + // It's possible that not all warp_m are active here, therefore, we can't use WARPS_M + // directly + uint32_t expected_tx_count = sizeof(block_stats) * CTAS_PER_ROW * warps_m; + utils::mbarrier_arrive_expect_tx_relaxed_cluster(&smem_bar_[WHOLE_CTA ? 0 : WARPS_M], + expected_tx_count); + } + + // Wait on local barrier + while (!utils::mbarrier_try_wait_parity_relaxed_cluster(&smem_bar_[WHOLE_CTA ? 0 : WARPS_M], + parity_)) { + } + utils::fence_acquire_smem(); // Ensure we can read local smem values that are released by + // barrier flip + + T n = Zeros::get(); + T m = Zeros::get(); + T m2 = Zeros::get(); + if (lane_ < CTAS_PER_ROW) { + stats_t result = smem_cga_[warp_m_ * CTAS_PER_ROW + lane_]; + n = ELTS_PER_ROW_PER_CTA; + m = Get<0>::of(result); // should be 0 for RMSNorm + m2 = Get<1>::of(result); + } + utils::fence_release_smem(); // Ensure read of buffer remains ordered before + // smem_bar.arrive.relaxed + + // Signal barrier is empty. + // Each thread must arrive on all the barriers (like an arrive broadcast). + // NOTE: here we only let each warp m arrive on the remote barriers of the same warp m (i.e. + // the same row), because if we have one empty barrier for the entire CTA, there's a + // difficulty how to init the count of the barrier (since some warps could be disabled here). + // The downside is we'll need a __syncthreads() here, which doesn't have obvious perf impact + // though. + // TODO: we could probably have one full bar per row (same as the empty bar) to get rid of the + // CTA sync, while the downside is more smem (mbarriers) will be needed. + if (!WHOLE_CTA) { + utils::namedBarrierSync(0, warps_m * WARPS_N * 32); + } + for (int other_block_rank = 0; other_block_rank < CTAS_PER_ROW; ++other_block_rank) { + uint64_t* remote_bar_empty = + cluster.map_shared_rank(&smem_bar_[WHOLE_CTA ? 1 : warp_m_], other_block_rank); + utils::mbarrier_arrive_relaxed_cluster(remote_bar_empty); + } + + warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); + + // Wait for remote buffer to be writable again by waiting on local barrier: + while (!utils::mbarrier_try_wait_parity_relaxed_cluster(&smem_bar_[WHOLE_CTA ? 1 : warp_m_], + parity_)) { + } + utils::fence_acquire_smem(); // Order wrt st_async_remote + + // Flip parity bit for next iteration + parity_ ^= 1; + + return {m, m2}; +#else + static_assert(true, "Cluster enabled on host side but not available on device"); +#endif + } else { + // We switch workspace every iteration. + stats_t* workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; + + // Warp leaders 0 hold the CTA-local results. + if (warp_n_ == 0 && lane_ == 0) { + workspace[bidn_] = block_stats; + } + + // We need to check for the case where we are using CTA > 1 and WARP_M > 1 and N == 1 + // Because in this case we cannot rely on the __syncThreads in template::compute (because we will use the specialized template ::compute which does not have a sync) Bug 5221388: For WARPS_M > 1 and + // WARPS_N > 1, we also need the sync here, otherwise the stats received in `workspace[lane_]` + // may be NaN (as if some stats from the other CTAs are not initialized). Reason unknown yet + // and tracked in CL-16775. + if (CTAS_PER_ROW > 1 && WARPS_M > 1) { + utils::namedBarrierSync(0, warps_m * WARPS_N * 32); + } + + // Wait for all CTAS_PER_ROW CTAS in the group to have written their result. + inter_cta_.sync(warps_m * WARPS_N * 32); + + T n = Zeros::get(); + T m = Zeros::get(); + T m2 = Zeros::get(); + + // Every warp does the final reduction locally. + if constexpr (CTAS_PER_ROW <= THREADS_PER_WARP) { + if (lane_ < CTAS_PER_ROW) { + stats_t result = workspace[lane_]; + n = ELTS_PER_ROW_PER_CTA; + m = Get<0>::of(result); // should be 0 for RMSNorm + m2 = Get<1>::of(result); + } + warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); + } else { + static constexpr int32_t LOOP_PER_THREAD = + (CTAS_PER_ROW + THREADS_PER_WARP - 1) / THREADS_PER_WARP; + + stats_t curr = workspace[lane_]; + n = ELTS_PER_ROW_PER_CTA; + m = Get<0>::of(curr); // should be 0 for RMSNorm + m2 = Get<1>::of(curr); + + // Collect stats for the current lane from ctas (#bidn % THREADS_PER_WARP == #lane) and then + // do the intra-warp reduction +#pragma unroll + for (int32_t loop_in_thread = 1; loop_in_thread < LOOP_PER_THREAD; ++loop_in_thread) { + const int32_t lane_in_loop = lane_ + loop_in_thread * THREADS_PER_WARP; + if (lane_in_loop < CTAS_PER_ROW) { + curr = workspace[lane_in_loop]; + T m_b = Get<0>::of(curr); + T m2_b = Get<1>::of(curr); + update_norm_stats(m, m2, n, m_b, m2_b, ELTS_PER_ROW_PER_CTA); + } + } + warp_chan_upd_dynamic(m, m2, n, THREADS_PER_WARP); + } + + return {m, m2}; + } + } + + InterCTASync inter_cta_; + BlockStats block_stats_; + + stats_t* w0_; + stats_t* w1_; + int bidn_; + int warp_n_; + int warp_m_; + int lane_; + + private: + stats_t* smem_cga_; + uint64_t* smem_bar_; + int parity_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Stats { + using WarpStats = Stats; + using stats_t = typename WarpStats::stats_t; + + enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; + + template + inline __device__ Stats(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t tidx, uint32_t lane, void* smem, + uint64_t* smem_bar) + : warp_stats_(params, bidm, bidn, warp_m, warp_n, tidx, lane, smem, smem_bar), use0_(true) { + smem0_ = static_cast(smem) + warp_m * WARPS_N; + smem1_ = smem0_ + WARPS_M * WARPS_N; + } + + template + inline __device__ stats_t compute(const T (&elts)[LDGS * NUM_ELTS], const T rn, int32_t warps_m) { + stats_t* smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + // Compute warp local for all WARPS_N + static constexpr uint32_t N = LDGS * NUM_ELTS; + constexpr T warp_rn = 1.f / T(N * THREADS_PER_WARP); + stats_t warp_stats = + warp_stats_.template compute(elts, warp_rn, warps_m); + + // Each warp warp leader stores its stats + const auto warp_n = warp_stats_.reducer_.warp_n_; + const auto lane = warp_stats_.reducer_.lane_; + if (lane == 0) { + smem[warp_n] = warp_stats; + } + utils::namedBarrierSync(0, warps_m * WARPS_N * 32); + + T n = Zeros::get(); + T m = Zeros::get(); + T m2 = Zeros::get(); + + if (lane < WARPS_N) { + stats_t result = smem[lane]; + n = N * THREADS_PER_WARP; + m = Get<0>::of(result); // should be 0 for RMSNorm + m2 = Get<1>::of(result); + } + + warp_chan_upd_dynamic(m, m2, n, WARPS_N); + + return {m, m2}; + } + + WarpStats warp_stats_; + stats_t* smem0_; + stats_t* smem1_; + bool use0_; +}; + +template +struct Stats { + using stats_t = typename TypeToVec2::Type; + // The simple Warp reducer. + using Reducer = Reducer; + + enum { SMEM_BYTES = 0 }; + + template + inline __device__ Stats(Params& params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, + uint32_t warp_n, uint32_t tidx, uint32_t lane, void* smem, + uint64_t* smem_bar) + : reducer_(params, bidm, bidn, warp_m, warp_n, tidx, lane, smem, smem_bar) {} + + template + inline __device__ stats_t compute(const T (&elts)[LDGS * NUM_ELTS], const T rn, int32_t warps_m) { + static constexpr uint32_t N = LDGS * NUM_ELTS; + auto sum = Sum(); + + T m = Zeros::get(); + if (!isRMSNorm) { +// CL-14115: The unroll factor 128 for LDGS was chosen based on the compilation/perf results for +// APEX LN_fwd engines +#pragma unroll(128 * NUM_ELTS) + for (int it = 0; it < N; it++) { + m += elts[it]; + } + m = reducer_.allreduce(m, sum, warps_m) * rn; + } + + T m2 = Zeros::get(); +#pragma unroll(128 * NUM_ELTS) + for (int it = 0; it < N; it++) { + T diff = (elts[it] - m); + m2 = __fmaf_ieee_rn(diff, diff, m2); + } + m2 = reducer_.allreduce(m2, sum, warps_m); + + return {m, m2}; + } + + Reducer reducer_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//// Fast division helpers + +// Count leading zeros - start from most significant bit. +static int clz(int32_t x) { + for (int32_t i = 31; i >= 0; --i) + if (((1 << i) & x) != 0) { + return 31 - i; + } + return 32; +} + +static int32_t find_log_2(int32_t x, bool round_up = false) { + int32_t a = 31 - clz(x); + if (round_up) { + a += !(0 == (x & (x - 1))); + } + return a; +} + +static void find_divisor(int32_t denom, uint32_t& mul_coeff, uint32_t& shift_coeff) { + if (denom == 0) { + return; + } + if (denom == 1) { + // if dividing by 1, reduced math doesn't work because mul_coeff would + // need to be 2^32, which doesn't fit into unsigned int. the div() + // routine handles this special case separately. + mul_coeff = 0; + shift_coeff = 0; + return; + } + // To express the division N/D in terms of a multiplication, what we first + // imagine is simply N*(1/D). However, 1/D will always evaluate to 0 (for D>1), + // so we need another way. There's nothing that says we have to use exactly + // the fraction 1/D; instead it could be any X/Y that reduces to 1/D (i.e., + // Y=X*D), or at least to "close enough" to it. If we pick Y that is a power + // of two, then the N*(X/Y) can be N*X followed by a right-shift by some amount. + // The power of two we should pick should be at least 2^32, because in the + // div() routine we'll use umulhi(), which returns only the upper 32 bits -- + // this being equivalent to a right-shift by 32. But we might want a higher + // power of two for better accuracy depending on the magnitude of the denominator. + // Once we've picked Y, then X [our mul_coeff value] is simply Y/D, rounding up, + // and we save shift_coeff as whatever further shift we have to do beyond + // what the umulhi() implies. + uint32_t p = 31 + find_log_2(denom, true); + uint32_t m = ((1ull << p) + static_cast(denom - 1)) / static_cast(denom); + mul_coeff = m; + shift_coeff = p - 32; +} + +__device__ __forceinline__ uint32_t umulhi(uint32_t x, uint32_t y) { +#if defined(__CUDA_ARCH__) + return __umulhi(x, y); +#else + uint64_t z = static_cast(x) * y; + return static_cast(z >> 32); +#endif +} + +class reduced_divisor { + public: + reduced_divisor() {} + __forceinline__ reduced_divisor(int32_t _y) : y(_y) { + mul_coeff = 0U; + shift_coeff = 0U; + find_divisor(y, mul_coeff, shift_coeff); + } + __device__ __forceinline__ reduced_divisor(uint32_t _mul_coeff, uint32_t _shift_coeff, int32_t _y) + : mul_coeff(_mul_coeff), shift_coeff(_shift_coeff), y(_y) {} + __device__ __forceinline__ int32_t div(int32_t x) const { + // if dividing by 1, then find_divisor wouldn't have worked because + // mul_coeff would have had to be 2^32, which can't be represented, + // so we have to special case that one. + return (y != 1) ? umulhi(static_cast(x), mul_coeff) >> shift_coeff : x; + } + __device__ __forceinline__ int32_t mod(int32_t x) const { return x - (div(x) * y); } + __device__ __forceinline__ void divmod(int32_t x, int32_t& q, int32_t& mod) const { + q = div(x); + mod = x - (q * y); + } + __device__ __forceinline__ int32_t get() const { return y; } + + protected: + uint32_t mul_coeff; + uint32_t shift_coeff; + int32_t y; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Kernel_traits_base { + using weight_t = weight_t_; + using input_t = input_t_; + using output_t = output_t_; + using compute_t = compute_t_; + using index_t = index_t_; + + enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; + enum { THREADS_PER_CTA = THREADS_PER_CTA_ }; + enum { THREADS_PER_WARP = 32 }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template > +struct Kernel_traits : public Base { + using input_t = typename Base::input_t; + using weight_t = typename Base::weight_t; + using compute_t = typename Base::compute_t; + using output_t = typename Base::output_t; + using index_t = typename Base::index_t; + using norm_output_t = norm_output_t_; + enum { isRMSNorm = isRMSNorm_ }; + enum { isAdaLN = isAdaLN_ }; + enum { isBatchFirst = isBatchFirst_ }; + enum { hasGamma = hasGamma_ }; + enum { hasBeta = hasBeta_ }; + enum { CTAS_PER_ROW = CTAS_PER_ROW_ }; + enum { WARPS_M = WARPS_M_ }; + enum { WARPS_N = WARPS_N_ }; + enum { COLS = HIDDEN_SIZE_ }; + enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; + enum { BATCH_SIZE = BATCH_SIZE_ }; + enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; + enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) }; + enum { USE_GAMMA_SMEM = useGammaSmem_ }; + enum { USE_CLUSTER = useCluster_ }; + enum { WHOLE_CTA = wholeCTA_ }; + enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP }; + enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW }; + enum { ROWS_PER_WARP = ROWS_PER_WARP_ }; + enum { ROWS_PER_CTA = WARPS_M * ROWS_PER_WARP }; + + enum { BYTES_PER_ROW = COLS * sizeof(input_t) }; + enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; + // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed + enum { + SMEM_BYTES_WGRAD = ROWS_PER_CTA == 1 ? 0 : BATCH_SIZE * ROWS_PER_CTA * COLS * sizeof(compute_t) + }; + + using reduce_t = typename TypeToVec2::Type; + using RMSReducer = Reducer; + using Reducer = Reducer; + + enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; + enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; + + using Ivec = Vec; + using Ovec = Vec; + using Wvec = Vec; + using Cvec = Vec; + using NormOvec = Vec; + enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) }; + + // Assume that each thread can handle the same number of elements in the output and weights as in + // the input. The number of columns fetched per load from input: one per thread. + enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW }; + // The total number of vectorized loads/stores per hidden vector. + enum { VEC_COLS = COLS / ELTS_PER_LDG }; + // The number of loads per thread for the input. + enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG }; + + using Stats = Stats; + enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr float CUDNN_FLT_MIN = + 1.17549435082228750796873653722224568e-38F; // Minimum positive normalized float value +/** + * Base class to provide common block scaling utilities, private to `BlockScaleHelper` + */ +template +class BlockScaleHelperBase; + +/** + * Base class to provide common block scaling utilities, for nvfp4 + */ +template +class BlockScaleHelperBase { + public: + using scale_t = nv_fp8_e4m3; + + protected: + using compute_t = typename Cvec::Elt_type; + using output_t = typename Ovec::Elt_type; + + using index_t = uint32_t; + + static constexpr int32_t THREADS_PER_WARP = 32; + static constexpr int32_t BLOCK_SIZE = 16; + static constexpr int32_t NUM_ELTS = sizeof(Cvec) / sizeof(compute_t); + static constexpr int32_t GROUP_SIZE = BLOCK_SIZE / NUM_ELTS; + static constexpr int32_t SF_COUNT = THREADS_PER_WARP / GROUP_SIZE; + + static __device__ compute_t scale(const compute_t value) { + constexpr compute_t nvfp4_max = 6.0f; + compute_t result = value / nvfp4_max; + // Clamp to CUDNN_FLT_MIN to match reference implementation and avoid 0 scaling factors + return fmaxf(result, CUDNN_FLT_MIN); + } + + static __device__ scale_t fp32ToE4M3(float value) { return scale_t(value); } + + static __device__ Vec vecDivideAndCast(const Cvec& input, compute_t scale) { + // Construct the nv_fp4x2_e2m1 outputs. + // If NUM_ELTS == 1, nv_fp4(ele1,0) + // If NUM_ELTS % 2 == 0; f2_i=nv_fp4(ele_i, ele_i+1) --> out[i/2] + // Per cuda_fp4.h and cuda_fp8.h + // The data type of __nv_fp4x2_storage_t is __nv_fp8_storage_t + // The data type of __nv_fp8_storage_t is unsigned char: + // typedef unsigned char __nv_fp8_storage_t; + // The reason why we use unsigned char is that the __shfl_xor_sync don't support nv_fp4 data + // type + Vec output; +#pragma unroll + for (int i = 0; i < NUM_ELTS; i += 2) { + float2 f2 = make_float2((input.data.elt[i] / scale), + (i + 1 < NUM_ELTS ? (input.data.elt[i + 1] / scale) : 0)); + output.data.elt[i / 2] = static_cast(nv_fp4x2_e2m1(f2).__x); + } + return output; + } +}; + +/** + * Base class to provide common block scaling utilities, for mxfp8 + */ +template +class BlockScaleHelperBase { + public: + using lne8m0_t = uint8_t; + using scale_t = lne8m0_t; + + protected: + using compute_t = typename Cvec::Elt_type; + using output_t = typename Ovec::Elt_type; + using index_t = uint32_t; + + static constexpr int32_t THREADS_PER_WARP = 32; + static constexpr int32_t BLOCK_SIZE = 32; + static constexpr int32_t NUM_ELTS = sizeof(Cvec) / sizeof(compute_t); + static constexpr int32_t GROUP_SIZE = BLOCK_SIZE / NUM_ELTS; + + static __device__ compute_t scale(const compute_t value) { + static constexpr float FP8_MAX = std::is_same::value ? 57344.f : 448.f; + compute_t result = value / FP8_MAX; + // Clamp to CUDNN_FLT_MIN to match reference implementation and avoid 0 scaling factors + return fmaxf(result, CUDNN_FLT_MIN); + } + + template + static __device__ Ovec vecDivideAndCast(const IOvec& input, compute_t scale) { + Ovec output; +#pragma unroll + for (int32_t i = 0; i < NUM_ELTS; ++i) { + output.data.elt[i] = output_t((compute_t(input.data.elt[i]) / scale)); + } + return output; + }; +}; + +/** + * Helper class to define the block scale operations + */ +template +class BlockScaleRowHelper; + +template +class BlockScaleColHelper; + +/** + * Specialization for MXFP8 rowwise scaling + */ +template +class BlockScaleRowHelper : private BlockScaleHelperBase { + private: + using Base = BlockScaleHelperBase; + using Base::BLOCK_SIZE; + using Base::GROUP_SIZE; + using Base::NUM_ELTS; + using Base::THREADS_PER_WARP; + using typename Base::compute_t; + using typename Base::index_t; + using typename Base::scale_t; + + public: + __device__ BlockScaleRowHelper() = default; + + template + __device__ void blockQuantizeStore(const IOvec& z_math, void* sf_ptr, index_t sf_idx, void* z_ptr, + index_t z_idx) { + const index_t group_id = lane / GROUP_SIZE; + + // 1. compute scaling factor + compute_t sf = 0; +#pragma unroll + // compute the amax across NUM_ELTS elements + for (int jt = 0; jt < NUM_ELTS; jt++) { + sf = fmaxf(sf, fabsf(compute_t(z_math.data.elt[jt]))); + } + +// Currently disabled for rowwise block scaling as perf regression observed +#if 0 && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) && defined(__CUDA_ARCH_FEAT_SM100_ALL) +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + if (group_id == jt) { + asm volatile("redux.sync.max.f32 %0, %1, %2;" : "=f"(sf) : "f"(sf), "r"(__activemask())); + } + } +#else + // warp reduce (binary reduction) starting from GROUP_SIZE / 2 to aggregate + // a total of NUM_ELTS * GROUP_SIZE = BLOCK_SIZE number of z values in sf +#pragma unroll + for (int offset = GROUP_SIZE / 2; offset > 0; offset /= 2) { + sf = fmaxf(sf, __shfl_down_sync(0xffffffff, sf, offset)); + } + + // broadcast from the lane with the valid values to the other threads in the group + sf = __shfl_sync(0xffffffff, sf, group_id * GROUP_SIZE); +#endif + // scale the row amax to get the scaling factor + sf = Base::scale(sf); + + // 2. store scaling factor + // Warp orchestration to compute scale_row: + // e.g. for NUM_ELTS == 2: + // | W0SF0 W0SF1 | W0SF0 W0SF1 | ... | W0SF0 W0SF1 | W0SF0 W0SF1 |... + // | W1SF0 W1SF1 | W1SF0 W1SF1 | ... | W1SF0 W1SF1 | W1SF0 W1SF1 |... + // | ... | ... | ... | ... | ... |... + // | W31SF0 W31SF1 | W31SF0 W31SF1 | ... | W31SF0 W31SF1 | W31SF0 W31SF1 |... + // |---------------|---------------|-----|--------------------------|---------------|... + // |----- CTA0 ----|----- CTA1 ----| ... |--- CTA --|----- CTA0 ----|... + // | | + // |-------------------------------LDG0-----------------------------|-------LDG1-----... + Vec sf_out; + // only correct for group_id == 0, but it's ok because only lane == 0 is storing sf_out + sf_out.data.elt[0] = float_to_e8m0(sf); +#pragma unroll + for (int jt = 1; jt < NUM_ELTS; jt++) { + sf_out.data.elt[jt] = float_to_e8m0(__shfl_sync(0xffffffff, sf, jt * GROUP_SIZE)); + } + if (lane == 0) { + sf_out.store_to(sf_ptr, sf_idx); + } + // No need to check for 0x00 - scale() now clamps to FLT_MIN which converts to 0x01 + float scale = e8m0_to_float(sf_out.data.elt[0]); + Ovec z_row_scaled = Base::vecDivideAndCast(z_math, scale); + z_row_scaled.store_to(z_ptr, z_idx); + } + + private: + const index_t lane = threadIdx.x % THREADS_PER_WARP; +}; + +/** + * Specialization for MXFP8 colwise scaling in 1D2X2X + */ +template +class BlockScaleColHelper : private BlockScaleHelperBase { + private: + using Base = BlockScaleHelperBase; + using Base::BLOCK_SIZE; + using Base::GROUP_SIZE; + using Base::NUM_ELTS; + using Base::THREADS_PER_WARP; + using typename Base::compute_t; + using typename Base::index_t; + using typename Base::output_t; + using typename Base::scale_t; + + public: + __device__ BlockScaleColHelper(float* mxfp8_tile) : mxfp8_tile_(mxfp8_tile) {} + + template + __device__ void initTile(const IOvec& z_math, int num_threads) { + return initTile(z_math, warp, num_threads); + } + + template + __device__ void initTile(const IOvec& z_math, index_t row_in_tile, int num_threads) { + constexpr int entry_barrier_arbitrary = 5; + utils::namedBarrierSync(entry_barrier_arbitrary, num_threads); +#pragma unroll + for (index_t jt = 0; jt < NUM_ELTS; jt++) { + index_t idx_in_mxfp8_smem = + (jt * (BLOCK_SIZE + /*padding row*/ 1) + row_in_tile) * MXFP8_1D2X2X_SMEM_COLS + lane; + mxfp8_tile_[idx_in_mxfp8_smem] = float(z_math.data.elt[jt]); + } + } + + template + __device__ void blockQuantizeStore(void* sf_ptr, index_t sf_row_idx, index_t sf_col_idx, + index_t sf_row_width, void* z_ptr, index_t z_idx, + int num_threads) { + return blockQuantizeStore(sf_ptr, sf_row_idx, sf_col_idx, sf_row_width, z_ptr, z_idx, + num_threads, warp); + } + + template + __device__ void blockQuantizeStore(void* sf_ptr, index_t sf_row_idx, index_t sf_col_idx, + index_t sf_row_width, void* z_ptr, index_t z_idx, + int num_threads, index_t row_in_tile) { + // Ensure all temp outputs have been written to mxfp8 smem for the current ldg + constexpr int entry_barrier_arbitrary = 4; + utils::namedBarrierSync(entry_barrier_arbitrary, num_threads); + const index_t group_id = lane / GROUP_SIZE; + +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) && defined(__CUDA_ARCH_FEAT_SM100_ALL) + // For sm_100a arch, we can use redux.sync.op{.abs.}{.NaN}.f32 instruction to reduce the min/max + // value + // https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-redux-sync + compute_t amax_vals[NUM_ELTS]; +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + const index_t idx_in_mxfp8_smem = + jt * NEXT_COL_SMEM_OFFSET + lane * MXFP8_1D2X2X_SMEM_COLS + row_in_tile; + asm volatile("redux.sync.max.abs.f32 %0, %1, 0xffffffff;" + : "=f"(amax_vals[jt]) + : "f"(mxfp8_tile_[idx_in_mxfp8_smem])); + } + compute_t sf = amax_vals[group_id]; +#else + // 1. compute scaling factor + compute_t sf = 0; +#pragma unroll + // compute the amax across NUM_ELTS elements in the same col (in shared memory, adjacent rows) + for (int jt = 0; jt < NUM_ELTS * MXFP8_1D2X2X_SMEM_COLS; jt += MXFP8_1D2X2X_SMEM_COLS) { + const index_t idx_in_mxfp8_smem = ROW_SMEM_OFFSET + jt + row_in_tile; + sf = fmaxf(sf, fabsf(mxfp8_tile_[idx_in_mxfp8_smem])); + } + + // warp reduce (binary reduction) starting from GROUP_SIZE / 2 to aggregate + // a total of NUM_ELTS * GROUP_SIZE = BLOCK_SIZE number of z values in sf +#pragma unroll + for (int offset = GROUP_SIZE / 2; offset > 0; offset /= 2) { + sf = fmaxf(sf, __shfl_down_sync(0xffffffff, sf, offset)); + } + + // broadcast from the lane with the valid values to the other threads in the group + sf = __shfl_sync(0xffffffff, sf, group_id * GROUP_SIZE); +#endif + + // scale the col amax to get the scaling factor + sf = Base::scale(sf); + + // Warp orchestration to compute scale_col: + // e.g. for NUM_ELTS == 2: + // | W0SF0 W0SF1 W1SF0 W1SF1 ... W31SF0 W31SF1 | W0SF0 W0SF1 ... W31SF0 W31SF1 |... + // |----warp0---|---warp1--| ... |----warp31---|----warp0---|... |----warp31---|... + // |--------------------CTA0-------------------|---------------CTA1------------|... + // |------------------------------------------LDG0------------------------------... + Ovec z_col_scaled; + if constexpr (!Transpose) { + // 2. store scaling factor + // different columns require different sf, so all threads must exchange sf + Vec sf_out; +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + sf_out.data.elt[jt] = float_to_e8m0(__shfl_sync(0xffffffff, sf, jt * GROUP_SIZE)); + } + if (lane == 0) { + sf_out.store_to(sf_ptr, sf_row_idx * sf_row_width + sf_col_idx); + } + + // 3. scale and store output +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + // iterate over NUM_ELTS elements in a row (in shared memory, every other NUM_ELTS in a + // column) + index_t z_idx_in_mxfp8_smem = + jt * NEXT_COL_SMEM_OFFSET + lane * MXFP8_1D2X2X_SMEM_COLS + row_in_tile; + // No need to check for 0x00 - scale() now clamps to FLT_MIN which converts to 0x01 + float scale = e8m0_to_float(sf_out.data.elt[jt]); + z_col_scaled.data.elt[jt] = output_t(mxfp8_tile_[z_idx_in_mxfp8_smem] / scale); + } + } else { + // 2. store scaling factor + uint8_t sf_e8m0 = float_to_e8m0(sf); + // No need to check for 0x00 - scale() now clamps to FLT_MIN which converts to 0x01 + if (lane % GROUP_SIZE == 0) { + static_cast(sf_ptr)[(sf_row_idx + group_id) * sf_row_width + sf_col_idx] = + sf_e8m0; + } + + // 3. scale and store output + float scale = e8m0_to_float(sf_e8m0); + // No need to check for 0.0f - scale() now clamps to FLT_MIN +#pragma unroll + for (int jt = 0; jt < NUM_ELTS; jt++) { + // iterate over NUM_ELTS elements in a column (in shared memory, adjacent NUM_ELTS rows) + index_t z_idx_in_mxfp8_smem = ROW_SMEM_OFFSET + jt * MXFP8_1D2X2X_SMEM_COLS + row_in_tile; + z_col_scaled.data.elt[jt] = output_t(mxfp8_tile_[z_idx_in_mxfp8_smem] / scale); + } + } + z_col_scaled.store_to(z_ptr, z_idx); + } + + private: + static constexpr index_t MXFP8_1D2X2X_SMEM_ROWS = + BLOCK_SIZE * NUM_ELTS + /*padding rows between*/ (NUM_ELTS - 1); + static constexpr index_t MXFP8_1D2X2X_SMEM_COLS = BLOCK_SIZE + /*padding col*/ 1; + static constexpr index_t NEXT_COL_SMEM_OFFSET = + (BLOCK_SIZE + /*padding row*/ 1) * MXFP8_1D2X2X_SMEM_COLS; + // For MXFP8 1d2x2x, we store the 32x32xNUM_ELTS tile in shared memory to compute the col amax for + // scaling within the current CTA e.g. for NUM_ELTS == 2, the shared memory layout for 1d2x2x is + // as follows: + // + // col0 col2 col4 col6 ... col62 + // ↓ ↓ ↓ ↓ ↓ + // |-------------------------------------------| + // | T0V0 | T1V0 | T2V0 | T3V0 |...| T31V0 | 0 | <-- warp0 to store at jt=0 + // | T0V0 | T1V0 | T2V0 | T3V0 |...| T31V0 | 0 | <-- warp1 to store at jt=0 + // | .. | .. | .. | .. |...| ... | 0 | <-- ... + // | T0V0 | T1V0 | T2V0 | T3V0 |...| T31V0 | 0 | <-- warp31 to store at jt=0 + // |--------------------|------|---|-------|---| + // | 0 | 0 | 0 | 0 |...| 0 | 0 | <-- padding row between to avoid bank + // conflicts when col-scaling + // |--------------------|------|---|-------|---| + // | T0V1 | T1V1 | T2V1 | T3V1 |...| T31V1 | 0 | <-- warp0 to store at jt=1 + // | T0V1 | T1V1 | T2V1 | T3V1 |...| T31V1 | 0 | <-- warp1 to store at jt=1 + // | .. | .. | .. | .. |...| ... | 0 | <-- ... + // | T0V1 | T1V1 | T2V1 | T3V1 |...| T31V1 | 0 | <-- warp31 to store at jt=1 + // |-------------------------------------------| + // ↑ ↑ ↑ ↑ ↑ ↑ + // col1 col3 col5 col7 ... col63 padding col to avoid bank conflicts when computing + // col-scale factors + // + // ↑ ↑ ↑ ↑ ↑ + // warp0 warp1 warp2 warp3 ... warp31 + // to compute and apply the col-scale factors + float* mxfp8_tile_ = nullptr; + + const index_t lane = threadIdx.x % THREADS_PER_WARP; + const index_t warp = threadIdx.x / THREADS_PER_WARP; + + const index_t ROW_SMEM_OFFSET = + ((lane * NUM_ELTS) + /*padding rows*/ (lane / GROUP_SIZE)) * MXFP8_1D2X2X_SMEM_COLS; +}; + +/** + * Specialization for NVFP4 rowwise scaling + */ +template +class BlockScaleRowHelper : private BlockScaleHelperBase { + using Base = BlockScaleHelperBase; + using Base::BLOCK_SIZE; + using Base::GROUP_SIZE; + using Base::NUM_ELTS; + using Base::SF_COUNT; + using Base::THREADS_PER_WARP; + using typename Base::compute_t; + using typename Base::index_t; + using typename Base::output_t; + using typename Base::scale_t; + + public: + __device__ BlockScaleRowHelper() = default; + + __device__ void blockQuantizeStore(const Cvec& z_math, void* sf_ptr, index_t sf_idx, void* z_ptr, + index_t z_idx) { + // 1. compute scaling factor + compute_t sf = 0; +#pragma unroll + // compute the amax across NUM_ELTS elements + for (int jt = 0; jt < NUM_ELTS; jt++) { + sf = fmaxf(sf, fabsf(z_math.data.elt[jt])); + } + + // warp reduce (binary reduction) starting from GROUP_SIZE / 2 to aggregate + // a total of NUM_ELTS * GROUP_SIZE = BLOCK_SIZE number of z values in sf +#pragma unroll + for (int offset = GROUP_SIZE / 2; offset > 0; offset /= 2) { + sf = fmaxf(sf, __shfl_down_sync(0xffffffff, sf, offset)); + } + + // broadcast from the lane with the valid values to the other threads in the group + const index_t group_id = lane / GROUP_SIZE; + sf = __shfl_sync(0xffffffff, sf, group_id * GROUP_SIZE); + + // scale the row amax to get the scaling factor + sf = Base::scale(sf); + + // 2. store scaling factor + Vec sf_out; + // only correct for group_id == 0, but it's ok because only lane == 0 is storing sf_out + sf_out.data.elt[0] = Base::fp32ToE4M3(sf); +#pragma unroll + for (int jt = 1; jt < SF_COUNT; jt++) { + sf_out.data.elt[jt] = Base::fp32ToE4M3(__shfl_sync(0xffffffff, sf, jt * GROUP_SIZE)); + } + if (lane == 0) { + sf_out.store_to(sf_ptr, sf_idx); + } + + // TODO: can consider doing something like the following to remove the need for a shuffle loop + // template + // struct UintContainer { + // using type = std::conditional_t>>>; // void for unsupported sizes + // }; + // Base::vecDivideAndCast returns Vec, 2> + // 3. scale and store output + auto z_row_scaled = Base::vecDivideAndCast(z_math, sf); +#pragma unroll + // Consturct z_row_scaled by combining the adjacent thread's z_row_scaled + // Vec z_row_scaled + for (int i = 0; i < NUM_ELTS / 2; i++) { + uint8_t adjacent_val = __shfl_xor_sync(0xffffffff, z_row_scaled.data.elt[i], 1); + z_row_scaled.data.elt[i + NUM_ELTS / 2] = adjacent_val; + } + if (NUM_ELTS == 1) { + uint8_t adjacent_val = __shfl_xor_sync(0xffffffff, z_row_scaled.data.elt[0], 1); + z_row_scaled.data.elt[0] |= adjacent_val << 4U; + } + if (lane % 2 == 0) { + z_row_scaled.store_to(z_ptr, z_idx / 2); + } + } + + private: + const index_t lane = threadIdx.x % THREADS_PER_WARP; +}; diff --git a/include/flashinfer/norm/sm100_rms_norm_silu_knobs.h b/include/flashinfer/norm/sm100_rms_norm_silu_knobs.h new file mode 100644 index 0000000000..e89f83e72e --- /dev/null +++ b/include/flashinfer/norm/sm100_rms_norm_silu_knobs.h @@ -0,0 +1,216 @@ +#pragma once + +// Auto-generated knob selection logic for Sm100RmsNormSiluEngine. +// Generated from optimal knob sweep results on B200 (SM100). +// Knob mapping: knobTileRows {0:1,1:4,2:8,3:32}, knobLoadSize {0:2,1:4,2:8,3:16} + +#include + +namespace cudnn_frontend { +namespace experimental { + +enum class RmsNormSiluDtype : uint8_t { + BF16 = 0, + FP8 = 1, + NVFP4 = 2, +}; + +// Compact knob configuration per (C, tokens, dtype) entry. +// WARPS_N is always 1. All values are the ACTUAL kernel parameters +// (not knob indices). +struct RmsNormSiluKnobs { + uint8_t warps_m; // WARPS_M value: 1, 4, 8, or 32 + uint8_t split_cols; // knobSplitCols: 0 = no split, 4 = estimated CTAS_PER_ROW + uint8_t kernel_cfg; // knobKernelCfg: 0, 1, or 2 + uint8_t occupancy; // DESIRED_OCCUPANCY: 0-16 + uint8_t bytes_per_ldg; // BYTES_PER_LDG: 2, 4, 8, or 16 +}; + +static constexpr int kSupportedC[] = {64, 128, 160, 256, 320, 512, 640, 1024}; +static constexpr int kSupportedTokens[] = {1560, 6240, 24960, 99840, 399360}; +static constexpr int kNumC = 8; +static constexpr int kNumTokens = 5; +static constexpr int kNumDtypes = 3; + +// Knob LUT indexed as: knob_lut[c_idx][tokens_idx][dtype_idx] +// c_idx: 0=64, 1=128, 2=160, 3=256, 4=320, 5=512, 6=640, 7=1024 +// tokens_idx: 0=1560, 1=6240, 2=24960, 3=99840, 4=399360 +// dtype_idx: 0=bf16, 1=fp8, 2=nvfp4 +static constexpr RmsNormSiluKnobs knob_lut[kNumC][kNumTokens][kNumDtypes] = { + { + // C=64 + {{8, 0, 0, 2, 4}, {8, 4, 0, 6, 4}, {8, 0, 2, 1, 4}}, // tokens=1560 + {{32, 4, 0, 2, 4}, {8, 0, 0, 3, 2}, {8, 4, 0, 4, 4}}, // tokens=6240 + {{32, 4, 0, 2, 4}, {8, 0, 0, 7, 4}, {8, 0, 1, 6, 4}}, // tokens=24960 + {{8, 0, 1, 8, 4}, {8, 0, 1, 6, 2}, {32, 4, 1, 2, 4}}, // tokens=99840 + {{4, 0, 1, 16, 4}, {32, 0, 1, 2, 2}, {32, 4, 1, 2, 4}}, // tokens=399360 + }, + { + // C=128 + {{8, 4, 0, 3, 4}, {8, 0, 0, 3, 4}, {8, 4, 0, 3, 8}}, // tokens=1560 + {{8, 0, 0, 3, 4}, {8, 0, 0, 4, 8}, {8, 0, 0, 5, 8}}, // tokens=6240 + {{8, 0, 0, 6, 4}, {8, 0, 0, 8, 8}, {8, 0, 1, 8, 8}}, // tokens=24960 + {{32, 4, 0, 2, 4}, {32, 0, 0, 2, 8}, {32, 0, 1, 2, 8}}, // tokens=99840 + {{8, 0, 0, 8, 4}, {32, 0, 0, 2, 8}, {32, 0, 1, 2, 8}}, // tokens=399360 + }, + { + // C=160 + {{8, 0, 0, 4, 2}, {8, 0, 0, 2, 2}, {4, 4, 0, 4, 2}}, // tokens=1560 + {{8, 0, 0, 4, 2}, {8, 0, 0, 4, 2}, {8, 0, 1, 4, 2}}, // tokens=6240 + {{8, 4, 1, 6, 2}, {8, 4, 0, 6, 2}, {8, 4, 1, 8, 2}}, // tokens=24960 + {{32, 4, 1, 2, 2}, {32, 4, 1, 2, 2}, {32, 4, 0, 1, 2}}, // tokens=99840 + {{32, 4, 1, 2, 2}, {32, 4, 1, 2, 2}, {32, 0, 1, 2, 2}}, // tokens=399360 + }, + { + // C=256 + {{8, 0, 0, 6, 16}, {8, 4, 0, 2, 4}, {8, 0, 2, 1, 16}}, // tokens=1560 + {{8, 0, 0, 4, 4}, {8, 0, 0, 4, 4}, {8, 0, 2, 1, 16}}, // tokens=6240 + {{8, 0, 0, 8, 16}, {8, 4, 0, 8, 16}, {8, 4, 1, 6, 16}}, // tokens=24960 + {{4, 4, 0, 16, 16}, {4, 0, 0, 16, 16}, {32, 0, 1, 1, 16}}, // tokens=99840 + {{4, 0, 0, 16, 16}, {32, 0, 0, 2, 16}, {32, 0, 1, 2, 16}}, // tokens=399360 + }, + { + // C=320 + {{8, 4, 1, 4, 4}, {8, 0, 0, 2, 4}, {4, 4, 0, 9, 4}}, // tokens=1560 + {{8, 4, 0, 5, 4}, {8, 0, 0, 5, 4}, {4, 0, 0, 9, 4}}, // tokens=6240 + {{8, 0, 0, 5, 4}, {8, 0, 0, 5, 4}, {8, 0, 1, 8, 4}}, // tokens=24960 + {{4, 0, 1, 16, 4}, {32, 0, 1, 2, 4}, {32, 4, 1, 2, 4}}, // tokens=99840 + {{32, 4, 0, 2, 4}, {32, 0, 1, 2, 4}, {32, 4, 1, 2, 4}}, // tokens=399360 + }, + { + // C=512 + {{8, 0, 0, 2, 16}, {8, 0, 0, 2, 8}, {4, 4, 0, 3, 16}}, // tokens=1560 + {{8, 0, 0, 5, 16}, {8, 0, 0, 4, 8}, {4, 0, 0, 9, 16}}, // tokens=6240 + {{4, 0, 0, 8, 16}, {4, 0, 0, 9, 8}, {4, 0, 2, 1, 16}}, // tokens=24960 + {{4, 0, 2, 1, 8}, {32, 4, 1, 2, 8}, {32, 4, 0, 1, 16}}, // tokens=99840 + {{4, 0, 2, 1, 4}, {32, 4, 1, 2, 8}, {32, 0, 0, 1, 16}}, // tokens=399360 + }, + { + // C=640 + {{4, 0, 0, 4, 4}, {4, 0, 0, 3, 8}, {4, 4, 0, 5, 8}}, // tokens=1560 + {{4, 0, 0, 5, 4}, {8, 0, 0, 4, 8}, {4, 0, 1, 9, 8}}, // tokens=6240 + {{4, 0, 0, 5, 4}, {8, 0, 0, 4, 8}, {4, 0, 2, 1, 8}}, // tokens=24960 + {{4, 0, 2, 1, 8}, {4, 4, 0, 9, 8}, {32, 0, 1, 1, 8}}, // tokens=99840 + {{4, 0, 2, 1, 8}, {32, 4, 1, 2, 8}, {32, 4, 1, 1, 8}}, // tokens=399360 + }, + { + // C=1024 + {{4, 4, 0, 3, 16}, {4, 0, 0, 3, 4}, {4, 4, 0, 7, 16}}, // tokens=1560 + {{4, 0, 0, 5, 16}, {4, 0, 0, 5, 8}, {4, 0, 2, 1, 16}}, // tokens=6240 + {{4, 4, 1, 10, 16}, {1, 4, 0, 16, 8}, {4, 0, 2, 1, 16}}, // tokens=24960 + {{8, 0, 2, 1, 16}, {4, 0, 1, 9, 8}, {32, 0, 1, 1, 16}}, // tokens=99840 + {{8, 0, 2, 1, 16}, {32, 4, 1, 1, 8}, {32, 4, 1, 1, 16}}, // tokens=399360 + }, +}; + +// Compute conservative default knobs for arbitrary problem sizes not in the LUT. +// Uses safe defaults (WARPS_M=1, BPL=4, occupancy=1) and validates vectorization +// divisibility constraints before accepting a configuration. +// Returns true if a valid configuration was found, false otherwise. +inline bool compute_default_knobs(int C, int num_tokens, RmsNormSiluDtype dtype, + RmsNormSiluKnobs& out) { + // Conservative defaults: + // CTAS_PER_ROW = 1, WARPS_M = 1, WARPS_N = 1, BPL = 4, occupancy = 1, kernel_cfg = 0 + // For block-scale output (NVFP4): WARPS_M = 32 + + int input_size = 2; // bf16 input always + + // Start with conservative defaults + int warps_m = (dtype == RmsNormSiluDtype::NVFP4) ? 32 : 1; + int warps_n = 1; // always 1 for our engine + int bpl = 4; // default bytes per load + int cpr = 1; // no column splitting for fallback + int occ = 1; + int kcfg = 0; + + // Validation: C must be evenly divisible into vectorized loads. + // NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) + // VEC_COLS = C / NUM_ELTS + // VEC_COLS_PER_LDG = CTAS_PER_ROW * WARPS_N * 32 + // Require: C % NUM_ELTS == 0 AND VEC_COLS % VEC_COLS_PER_LDG == 0 + // Also: LDGS = VEC_COLS / VEC_COLS_PER_LDG <= 1024 (avoid register spill) + + auto validate = [&](int test_bpl, int test_wm) -> bool { + int num_elts = test_bpl / input_size; + if (num_elts <= 0 || C % num_elts != 0) return false; + int vec_cols = C / num_elts; + int vec_cols_per_ldg = cpr * warps_n * 32; + if (vec_cols_per_ldg <= 0 || vec_cols % vec_cols_per_ldg != 0) return false; + int ldgs = vec_cols / vec_cols_per_ldg; + if (ldgs > 1024) return false; // reject extreme LDGS to avoid register spilling + // Check WARPS_M constraint: if WARPS_M > 1, rows per CTA must divide evenly + if (test_wm > 1 && num_tokens % test_wm != 0) return false; + return true; + }; + + // Try default BPL=4, then cascade through {4, 8, 16, 2} + static constexpr int bpl_candidates[] = {4, 8, 16, 2}; + bool found = false; + for (int candidate : bpl_candidates) { + if (validate(candidate, warps_m)) { + bpl = candidate; + found = true; + break; + } + } + + // If WARPS_M=1 failed, try bumping to WARPS_M=4 for better row coverage + if (!found && warps_m == 1 && num_tokens % 4 == 0) { + warps_m = 4; + for (int candidate : bpl_candidates) { + if (validate(candidate, warps_m)) { + bpl = candidate; + found = true; + break; + } + } + } + + if (!found) return false; + + out.warps_m = static_cast(warps_m); + out.split_cols = 0; // no column splitting + out.kernel_cfg = static_cast(kcfg); + out.occupancy = static_cast(occ); + out.bytes_per_ldg = static_cast(bpl); + return true; +} + +// Look up knob configuration for a given (C, num_tokens, output_dtype, sm_version). +// Tier 1: exact LUT match for SM100 VAE problem sizes (optimal, sweep-tuned on B200). +// Tier 2: fallback heuristic for other archs or arbitrary sizes (functional, conservative). +// Returns nullptr only if the problem is fundamentally unsupported. +inline const RmsNormSiluKnobs* lookup_rms_norm_silu_knobs(int C, int num_tokens, + RmsNormSiluDtype dtype, + int sm_version = 100) { + // Tier 1: exact LUT match — only valid for SM100 (swept on B200) + if (sm_version >= 100) { + int c_idx = -1, t_idx = -1; + for (int i = 0; i < kNumC; ++i) { + if (kSupportedC[i] == C) { + c_idx = i; + break; + } + } + for (int i = 0; i < kNumTokens; ++i) { + if (kSupportedTokens[i] == num_tokens) { + t_idx = i; + break; + } + } + if (c_idx >= 0 && t_idx >= 0) { + return &knob_lut[c_idx][t_idx][static_cast(dtype)]; + } + } + + // Tier 2: fallback heuristic for non-SM100 archs or non-LUT problem sizes + static thread_local RmsNormSiluKnobs fallback; + if (compute_default_knobs(C, num_tokens, dtype, fallback)) { + return &fallback; + } + + return nullptr; // fundamentally unsupported (C not divisible by any valid config) +} + +} // namespace experimental +} // namespace cudnn_frontend diff --git a/tests/norm/test_fused_rmsnorm_silu.py b/tests/norm/test_fused_rmsnorm_silu.py new file mode 100644 index 0000000000..63e81ce1d9 --- /dev/null +++ b/tests/norm/test_fused_rmsnorm_silu.py @@ -0,0 +1,493 @@ +# Copyright (c) 2025 by FlashInfer team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Unit tests for Fused RMSNorm + SiLU kernel. +Tolerance and reference methodology matches the cuDNN frontend OSS test suite +(test_sm100_rms_norm_silu_graph_api.py). +""" + +import pytest +import torch +import torch.nn.functional as F + + +def get_cc(): + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + + +def rmsnorm_silu_reference(x, weight, eps, output_dtype=None): + """Reference: RMSNorm + SiLU. + + Matches the cuDNN test reference: compute entirely in float32. + If output_dtype is specified, cast the result to that dtype. + """ + rms = torch.sqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + eps) + x_norm = (x.float() / rms) * weight.float() + result = F.silu(x_norm) + if output_dtype is not None: + return result.to(output_dtype) + return result.to(x.dtype) + + +# FP4 E2M1 lookup table (4-bit value -> float) +_FP4_E2M1_TABLE = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, # positive + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, # negative +] + + +def _unpack_fp4_nibbles(packed_bytes, num_tokens, C): + """Unpack FP4 packed bytes into a [num_tokens, C] int tensor of 4-bit nibble values.""" + nibbles = torch.zeros(num_tokens, C, dtype=torch.int32, device=packed_bytes.device) + for col_byte in range(C // 2): + byte_val = packed_bytes[:, col_byte].int() + nibbles[:, col_byte * 2] = byte_val & 0x0F + nibbles[:, col_byte * 2 + 1] = (byte_val >> 4) & 0x0F + return nibbles + + +def _quantize_to_fp4_reference(values_f32, C): + """Quantize float32 values to FP4 E2M1 nibbles matching the kernel's algorithm. + + Matches cuDNN BlockScaleRowHelper: + 1. amax = max(|block of 16 elements|) + 2. scale = max(amax / 6.0, FLT_MIN) + 3. quantized = nv_fp4x2_e2m1(value / scale) + """ + BLOCK_SIZE = 16 + FP4_MAX = 6.0 + FLT_MIN = 1.17549435082228750796873653722224568e-38 + num_tokens = values_f32.shape[0] + num_blocks = C // BLOCK_SIZE + + fp4_positive = torch.tensor(_FP4_E2M1_TABLE[:8], dtype=torch.float32) + nibbles = torch.zeros(num_tokens, C, dtype=torch.int32, device=values_f32.device) + + for b in range(num_blocks): + col_start = b * BLOCK_SIZE + col_end = col_start + BLOCK_SIZE + block_vals = values_f32[:, col_start:col_end].cpu().float() + + amax = block_vals.abs().max(dim=1, keepdim=True).values + scale = torch.clamp(amax / FP4_MAX, min=FLT_MIN) + + scaled = block_vals / scale + magnitudes = scaled.abs() + signs = (scaled < 0).int() + + diffs = (magnitudes.unsqueeze(2) - fp4_positive.unsqueeze(0).unsqueeze(0)).abs() + mag_nibbles = diffs.argmin(dim=2) + + block_nibbles = mag_nibbles + signs * 8 + nibbles[:, col_start:col_end] = block_nibbles.to(values_f32.device) + + return nibbles + + +def dequantize_nvfp4(packed_bytes, scale_row_fp8, num_tokens, C): + """Dequantize NVFP4 1D1X1X output to float32.""" + BLOCK_SIZE = 16 + scale_f32 = scale_row_fp8.view(torch.float8_e4m3fn).float() + nibbles = _unpack_fp4_nibbles(packed_bytes, num_tokens, C) + + output = torch.zeros(num_tokens, C, dtype=torch.float32, device=packed_bytes.device) + for col in range(C): + block = col // BLOCK_SIZE + fp4_vals = torch.tensor( + [_FP4_E2M1_TABLE[v] for v in nibbles[:, col].cpu().tolist()], + dtype=torch.float32, + device=packed_bytes.device, + ) + output[:, col] = fp4_vals * scale_f32[:, block] + return output + + +@pytest.fixture(autouse=True) +def skip_if_not_sm100(): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + if get_cc() < 100: + pytest.skip("Fused RMSNorm+SiLU requires SM100+") + + +SUPPORTED_C = [64, 128, 160, 256, 320, 512, 640, 1024] +SUPPORTED_TOKENS = [1560, 6240, 24960, 99840, 399360] + +ALL_LUT_SHAPES = [(tokens, C) for C in SUPPORTED_C for tokens in SUPPORTED_TOKENS] + + +# ============================================================ +# bf16 output — atol=2e-2, rtol=2e-2, zero mismatches +# (matches cuDNN test _run_rmsnorm_silu_test) +# ============================================================ + + +@pytest.mark.parametrize( + "num_tokens,hidden_size", + ALL_LUT_SHAPES, + ids=[f"t{t}_C{c}" for t, c in ALL_LUT_SHAPES], +) +def test_lut_bf16(num_tokens, hidden_size): + """All 40 LUT shapes for bf16 output.""" + import flashinfer + + torch.manual_seed(42) + x = ( + torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") * 5.0 + + 5.0 + ) + weight = torch.rand(hidden_size, dtype=torch.bfloat16, device="cuda") * 1.5 + 0.5 + + out = flashinfer.fused_rmsnorm_silu(x, weight, eps=1e-6) + ref = rmsnorm_silu_reference(x, weight, eps=1e-6) + + mismatches = ~torch.isclose(out.float(), ref.float(), atol=2e-2, rtol=2e-2) + num_mismatches = mismatches.sum().item() + max_diff = (out.float() - ref.float()).abs().max().item() + assert num_mismatches == 0, ( + f"C={hidden_size}, tokens={num_tokens}: " + f"{num_mismatches}/{out.numel()} mismatches (max_diff={max_diff:.6e})" + ) + + +# ============================================================ +# FP8 output — atol=0.125, rtol=0.125, zero mismatches +# Reference in float32, then cast to FP8 (avoids bf16 double-rounding) +# (matches cuDNN test _run_fp8_rmsnorm_silu_test) +# ============================================================ + + +@pytest.mark.parametrize( + "num_tokens,hidden_size", + ALL_LUT_SHAPES, + ids=[f"t{t}_C{c}" for t, c in ALL_LUT_SHAPES], +) +def test_lut_fp8(num_tokens, hidden_size): + """All 40 LUT shapes for FP8 (E4M3) output.""" + import flashinfer + + torch.manual_seed(42) + x = ( + torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") * 5.0 + + 5.0 + ) + weight = torch.rand(hidden_size, dtype=torch.bfloat16, device="cuda") * 1.5 + 0.5 + out = torch.empty(num_tokens, hidden_size, dtype=torch.float8_e4m3fn, device="cuda") + + result = flashinfer.fused_rmsnorm_silu(x, weight, eps=1e-6, out=out) + + ref_f32 = rmsnorm_silu_reference(x, weight, eps=1e-6, output_dtype=torch.float32) + ref_fp8 = ref_f32.clamp(-448.0, 448.0).to(torch.float8_e4m3fn) + + z_float = result.float() + ref_float = ref_fp8.float() + mismatches = ~torch.isclose(z_float, ref_float, atol=0.125, rtol=0.125) + num_mismatches = mismatches.sum().item() + max_diff = (z_float - ref_float).abs().max().item() + assert num_mismatches == 0, ( + f"FP8 C={hidden_size}, tokens={num_tokens}: " + f"{num_mismatches}/{result.numel()} mismatches (max_diff={max_diff:.6e})" + ) + + +# ============================================================ +# NVFP4 output — nibble-level comparison, <=1 ULP allowed +# (matches cuDNN test _run_nvfp4_rmsnorm_silu_test) +# ============================================================ + +has_fp4_dtype = hasattr(torch, "float4_e2m1fn_x2") + + +@pytest.mark.skipif(not has_fp4_dtype, reason="torch.float4_e2m1fn_x2 not available") +@pytest.mark.parametrize( + "num_tokens,hidden_size", + ALL_LUT_SHAPES, + ids=[f"t{t}_C{c}" for t, c in ALL_LUT_SHAPES], +) +def test_lut_nvfp4(num_tokens, hidden_size): + """All 40 LUT shapes for NVFP4 (FP4_E2M1) 1D1X1X block-scale output.""" + import flashinfer + + torch.manual_seed(42) + C = hidden_size + x = torch.randn(num_tokens, C, dtype=torch.bfloat16, device="cuda") * 5.0 + 5.0 + weight = torch.rand(C, dtype=torch.bfloat16, device="cuda") * 1.5 + 0.5 + + # FP4 packs 2 values per byte + out = torch.empty(num_tokens, C // 2, dtype=torch.float4_e2m1fn_x2, device="cuda") + result = flashinfer.fused_rmsnorm_silu(x, weight, eps=1e-6, out=out) + + ref_f32 = rmsnorm_silu_reference(x, weight, eps=1e-6, output_dtype=torch.float32) + + # Unpack kernel output nibbles + z_packed = result.view(torch.uint8).reshape(num_tokens, C // 2) + kernel_nibbles = _unpack_fp4_nibbles(z_packed, num_tokens, C) + + # Quantize reference using the same FP4 algorithm + ref_nibbles = _quantize_to_fp4_reference(ref_f32, C) + + # Allow <=1 nibble index difference (1 FP4 ULP) + nibble_diff = (kernel_nibbles - ref_nibbles).abs() + mismatches = nibble_diff > 1 + num_mismatches = mismatches.sum().item() + max_nibble_diff = nibble_diff.max().item() + assert num_mismatches == 0, ( + f"NVFP4 C={C}, tokens={num_tokens}: " + f"{num_mismatches}/{num_tokens * C} nibbles differ by >{1} ULP " + f"(max_nibble_diff={max_nibble_diff})" + ) + + +# ============================================================ +# Random / non-LUT shapes (fallback heuristics) — bf16 +# ============================================================ + +RANDOM_SHAPES = [ + (1, 64), + (7, 128), + (32, 256), + (100, 512), + (1024, 1024), + (2048, 640), + (4096, 320), + (8192, 160), + (16384, 128), + (500, 64), + (3000, 256), + (50000, 512), + (200000, 1024), +] + + +@pytest.mark.parametrize( + "num_tokens,hidden_size", + RANDOM_SHAPES, + ids=[f"t{t}_C{c}" for t, c in RANDOM_SHAPES], +) +def test_fallback_knobs_bf16(num_tokens, hidden_size): + """Non-LUT shapes that use fallback default knobs.""" + import flashinfer + + torch.manual_seed(42) + x = ( + torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") * 5.0 + + 5.0 + ) + weight = torch.rand(hidden_size, dtype=torch.bfloat16, device="cuda") * 1.5 + 0.5 + + out = flashinfer.fused_rmsnorm_silu(x, weight, eps=1e-6) + ref = rmsnorm_silu_reference(x, weight, eps=1e-6) + + mismatches = ~torch.isclose(out.float(), ref.float(), atol=2e-2, rtol=2e-2) + num_mismatches = mismatches.sum().item() + max_diff = (out.float() - ref.float()).abs().max().item() + assert num_mismatches == 0, ( + f"C={hidden_size}, tokens={num_tokens}: " + f"{num_mismatches}/{out.numel()} mismatches (max_diff={max_diff:.6e})" + ) + + +# ============================================================ +# Random / non-LUT shapes — FP8 +# ============================================================ + +RANDOM_SHAPES_FP8 = [ + (32, 256), + (1024, 512), + (2048, 1024), + (4096, 128), + (8192, 640), +] + + +@pytest.mark.parametrize( + "num_tokens,hidden_size", + RANDOM_SHAPES_FP8, + ids=[f"t{t}_C{c}" for t, c in RANDOM_SHAPES_FP8], +) +def test_fallback_knobs_fp8(num_tokens, hidden_size): + """Non-LUT shapes with FP8 output using fallback knobs.""" + import flashinfer + + torch.manual_seed(42) + x = ( + torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") * 5.0 + + 5.0 + ) + weight = torch.rand(hidden_size, dtype=torch.bfloat16, device="cuda") * 1.5 + 0.5 + out = torch.empty(num_tokens, hidden_size, dtype=torch.float8_e4m3fn, device="cuda") + + result = flashinfer.fused_rmsnorm_silu(x, weight, eps=1e-6, out=out) + + ref_f32 = rmsnorm_silu_reference(x, weight, eps=1e-6, output_dtype=torch.float32) + ref_fp8 = ref_f32.clamp(-448.0, 448.0).to(torch.float8_e4m3fn) + + z_float = result.float() + ref_float = ref_fp8.float() + mismatches = ~torch.isclose(z_float, ref_float, atol=0.125, rtol=0.125) + num_mismatches = mismatches.sum().item() + max_diff = (z_float - ref_float).abs().max().item() + assert num_mismatches == 0, ( + f"FP8 C={hidden_size}, tokens={num_tokens}: " + f"{num_mismatches}/{result.numel()} mismatches (max_diff={max_diff:.6e})" + ) + + +# ============================================================ +# Random / non-LUT shapes — NVFP4 +# ============================================================ + +RANDOM_SHAPES_NVFP4 = [ + (32, 256), + (1024, 512), + (2048, 1024), + (4096, 128), +] + + +@pytest.mark.skipif(not has_fp4_dtype, reason="torch.float4_e2m1fn_x2 not available") +@pytest.mark.parametrize( + "num_tokens,hidden_size", + RANDOM_SHAPES_NVFP4, + ids=[f"t{t}_C{c}" for t, c in RANDOM_SHAPES_NVFP4], +) +def test_fallback_knobs_nvfp4(num_tokens, hidden_size): + """Non-LUT shapes with NVFP4 output using fallback knobs.""" + import flashinfer + + torch.manual_seed(42) + C = hidden_size + x = torch.randn(num_tokens, C, dtype=torch.bfloat16, device="cuda") * 5.0 + 5.0 + weight = torch.rand(C, dtype=torch.bfloat16, device="cuda") * 1.5 + 0.5 + out = torch.empty(num_tokens, C // 2, dtype=torch.float4_e2m1fn_x2, device="cuda") + + result = flashinfer.fused_rmsnorm_silu(x, weight, eps=1e-6, out=out) + + ref_f32 = rmsnorm_silu_reference(x, weight, eps=1e-6, output_dtype=torch.float32) + + z_packed = result.view(torch.uint8).reshape(num_tokens, C // 2) + kernel_nibbles = _unpack_fp4_nibbles(z_packed, num_tokens, C) + ref_nibbles = _quantize_to_fp4_reference(ref_f32, C) + + nibble_diff = (kernel_nibbles - ref_nibbles).abs() + mismatches = nibble_diff > 1 + num_mismatches = mismatches.sum().item() + assert num_mismatches == 0, ( + f"NVFP4 C={C}, tokens={num_tokens}: " + f"{num_mismatches}/{num_tokens * C} nibbles differ by >1 ULP" + ) + + +# ============================================================ +# Pre-allocated output +# ============================================================ + + +def test_preallocated_output_bf16(): + import flashinfer + + num_tokens, hidden_size = 1560, 1024 + torch.manual_seed(42) + x = ( + torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") * 5.0 + + 5.0 + ) + weight = torch.rand(hidden_size, dtype=torch.bfloat16, device="cuda") * 1.5 + 0.5 + out = torch.empty_like(x) + + result = flashinfer.fused_rmsnorm_silu(x, weight, eps=1e-6, out=out) + ref = rmsnorm_silu_reference(x, weight, eps=1e-6) + + assert result.data_ptr() == out.data_ptr() + mismatches = ~torch.isclose(out.float(), ref.float(), atol=2e-2, rtol=2e-2) + assert mismatches.sum().item() == 0 + + +def test_preallocated_output_fp8(): + import flashinfer + + num_tokens, hidden_size = 1560, 1024 + torch.manual_seed(42) + x = ( + torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") * 5.0 + + 5.0 + ) + weight = torch.rand(hidden_size, dtype=torch.bfloat16, device="cuda") * 1.5 + 0.5 + out = torch.empty(num_tokens, hidden_size, dtype=torch.float8_e4m3fn, device="cuda") + + result = flashinfer.fused_rmsnorm_silu(x, weight, eps=1e-6, out=out) + ref_f32 = rmsnorm_silu_reference(x, weight, eps=1e-6, output_dtype=torch.float32) + ref_fp8 = ref_f32.clamp(-448.0, 448.0).to(torch.float8_e4m3fn) + + assert result.data_ptr() == out.data_ptr() + mismatches = ~torch.isclose(result.float(), ref_fp8.float(), atol=0.125, rtol=0.125) + assert mismatches.sum().item() == 0 + + +# ============================================================ +# Numerical edge cases +# ============================================================ + + +def test_epsilon_sensitivity(): + import flashinfer + + num_tokens, hidden_size = 6240, 512 + torch.manual_seed(42) + x = ( + torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") * 5.0 + + 5.0 + ) + weight = torch.rand(hidden_size, dtype=torch.bfloat16, device="cuda") * 1.5 + 0.5 + + for eps in [1e-5, 1e-6, 1e-8]: + out = flashinfer.fused_rmsnorm_silu(x, weight, eps=eps) + ref = rmsnorm_silu_reference(x, weight, eps=eps) + mismatches = ~torch.isclose(out.float(), ref.float(), atol=2e-2, rtol=2e-2) + assert mismatches.sum().item() == 0, ( + f"eps={eps}: {mismatches.sum().item()} mismatches" + ) + + +def test_uniform_weight(): + import flashinfer + + num_tokens, hidden_size = 1560, 256 + torch.manual_seed(42) + x = ( + torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") * 5.0 + + 5.0 + ) + weight = torch.ones(hidden_size, dtype=torch.bfloat16, device="cuda") + + out = flashinfer.fused_rmsnorm_silu(x, weight, eps=1e-6) + ref = rmsnorm_silu_reference(x, weight, eps=1e-6) + + mismatches = ~torch.isclose(out.float(), ref.float(), atol=2e-2, rtol=2e-2) + assert mismatches.sum().item() == 0 From 0aede25a681a513f058985705104c500beebd9db Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Thu, 2 Apr 2026 11:54:58 -0700 Subject: [PATCH 02/17] remove cudnn references --- flashinfer/jit/rmsnorm_silu.py | 2 +- flashinfer/norm/__init__.py | 6 ++--- .../flashinfer/norm/ln_fwd_silu_kernel.cuh | 3 +-- include/flashinfer/norm/ln_silu_headers.cuh | 23 +++++++++---------- .../norm/sm100_rms_norm_silu_knobs.h | 8 +++---- tests/norm/test_fused_rmsnorm_silu.py | 13 +++++------ 6 files changed, 26 insertions(+), 29 deletions(-) diff --git a/flashinfer/jit/rmsnorm_silu.py b/flashinfer/jit/rmsnorm_silu.py index 292873b072..407baedd54 100644 --- a/flashinfer/jit/rmsnorm_silu.py +++ b/flashinfer/jit/rmsnorm_silu.py @@ -192,7 +192,7 @@ def select_knobs(C: int, num_tokens: int, dtype: str): def _estimate_ctas_per_row( C: int, split_cols: int, kernel_cfg: int, bytes_per_ldg: int, warps_n: int = 1 ) -> int: - """Estimate CTAS_PER_ROW from knobs (matches cuDNN's estimate_ctas_per_row).""" + """Estimate CTAS_PER_ROW from knobs.""" if split_cols != 4 or kernel_cfg == 2: return 1 input_size = 2 # bf16 diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py index 01fdd34fc9..c7df87f068 100644 --- a/flashinfer/norm/__init__.py +++ b/flashinfer/norm/__init__.py @@ -520,7 +520,7 @@ def _layernorm_fake( # ============================================================ -# Fused RMSNorm + SiLU (ported from cuDNN frontend OSS engine) +# Fused RMSNorm + SiLU kernel (SM100 optimized) # ============================================================ from ..jit.rmsnorm_silu import ( @@ -599,8 +599,8 @@ def fused_rmsnorm_silu( where ``SiLU(x) = x / (1 + exp(-x))`` - This kernel is ported from the cuDNN frontend OSS Sm100RmsNormSiluEngine - and is optimized for WAN VAE decoder workloads on B200. + Optimized for WAN VAE decoder workloads on SM100 (B200). + Uses sweep-tuned knobs for all standard VAE problem sizes. Parameters ---------- diff --git a/include/flashinfer/norm/ln_fwd_silu_kernel.cuh b/include/flashinfer/norm/ln_fwd_silu_kernel.cuh index 2b1030e327..c137d6f996 100644 --- a/include/flashinfer/norm/ln_fwd_silu_kernel.cuh +++ b/include/flashinfer/norm/ln_fwd_silu_kernel.cuh @@ -1,6 +1,5 @@ #pragma once -// Extracted from cudnn_frontend ln_fwd_silu_kernel.h for RmsNorm+SiLU kernel. -// Original: cudnn_frontend/include/.../generated/rms_norm_silu/sm100/ln_fwd_silu_kernel.h +// Fused RmsNorm+SiLU forward kernel (SM100 optimized). // // IMPORTANT: Include ln_silu_headers.cuh and the config .inc BEFORE this file. // The config must define Ktraits, DESIRED_OCCUPANCY, and all constexpr flags. diff --git a/include/flashinfer/norm/ln_silu_headers.cuh b/include/flashinfer/norm/ln_silu_headers.cuh index 2172776e04..e71b510e68 100644 --- a/include/flashinfer/norm/ln_silu_headers.cuh +++ b/include/flashinfer/norm/ln_silu_headers.cuh @@ -1,12 +1,12 @@ #pragma once -// Extracted from cudnn_frontend ln_headers.h for RmsNorm+SiLU kernel. -// Original: cudnn_frontend/include/.../generated/rms_norm_silu/sm100/ln_headers.h +// LayerNorm kernel headers for RmsNorm+SiLU (SM100 optimized). +// Kernel_traits, Reducer, Stats, PersistentLnFwdParams, reduced_divisor, etc. #pragma once // Auto-generated lightweight LN kernel header. // Replaces the 151K-line persistent_ln_headers_13.0.h with: // - Standard CUDA TK #include directives (resolved via --include-path at NVRTC compile time) -// - ~2900 lines of cuDNN-authored LN-specific code extracted from the original +// - ~2900 lines of LN-specific code extracted from the original // // CGA/cluster support (USE_CLUSTER) is disabled: // - #include and cooperative_groups are NOT included @@ -116,17 +116,16 @@ typedef unsigned long long uint64_t; #endif // ============================================================ -// cuDNN-authored LN kernel utilities -// Extracted from persistent_ln_headers_13.0.h +// LN kernel utilities // ============================================================ -// cuDNN type aliases for FP8 types (from after inlined cuda_fp8.h) +// Type aliases for FP8 types typedef __nv_fp8_e4m3 nv_fp8_e4m3; typedef __nv_fp8x2_e4m3 nv_fp8x2_e4m3; typedef __nv_fp8_e5m2 nv_fp8_e5m2; typedef __nv_fp8x2_e5m2 nv_fp8x2_e5m2; -// cuDNN type aliases for FP4 types (from after inlined cuda_fp4.h) +// Type aliases for FP4 types typedef __nv_fp4_e2m1 nv_fp4_e2m1; typedef __nv_fp4x2_e2m1 nv_fp4x2_e2m1; typedef __nv_fp4x4_e2m1 nv_fp4x4_e2m1; @@ -1458,7 +1457,7 @@ struct Kernel_traits : public Base { //////////////////////////////////////////////////////////////////////////////////////////////////// -constexpr float CUDNN_FLT_MIN = +constexpr float LN_FLT_MIN = 1.17549435082228750796873653722224568e-38F; // Minimum positive normalized float value /** * Base class to provide common block scaling utilities, private to `BlockScaleHelper` @@ -1489,8 +1488,8 @@ class BlockScaleHelperBase { static __device__ compute_t scale(const compute_t value) { constexpr compute_t nvfp4_max = 6.0f; compute_t result = value / nvfp4_max; - // Clamp to CUDNN_FLT_MIN to match reference implementation and avoid 0 scaling factors - return fmaxf(result, CUDNN_FLT_MIN); + // Clamp to LN_FLT_MIN to match reference implementation and avoid 0 scaling factors + return fmaxf(result, LN_FLT_MIN); } static __device__ scale_t fp32ToE4M3(float value) { return scale_t(value); } @@ -1538,8 +1537,8 @@ class BlockScaleHelperBase { static __device__ compute_t scale(const compute_t value) { static constexpr float FP8_MAX = std::is_same::value ? 57344.f : 448.f; compute_t result = value / FP8_MAX; - // Clamp to CUDNN_FLT_MIN to match reference implementation and avoid 0 scaling factors - return fmaxf(result, CUDNN_FLT_MIN); + // Clamp to LN_FLT_MIN to match reference implementation and avoid 0 scaling factors + return fmaxf(result, LN_FLT_MIN); } template diff --git a/include/flashinfer/norm/sm100_rms_norm_silu_knobs.h b/include/flashinfer/norm/sm100_rms_norm_silu_knobs.h index e89f83e72e..2ec1b2ef65 100644 --- a/include/flashinfer/norm/sm100_rms_norm_silu_knobs.h +++ b/include/flashinfer/norm/sm100_rms_norm_silu_knobs.h @@ -6,8 +6,8 @@ #include -namespace cudnn_frontend { -namespace experimental { +namespace flashinfer { +namespace norm { enum class RmsNormSiluDtype : uint8_t { BF16 = 0, @@ -212,5 +212,5 @@ inline const RmsNormSiluKnobs* lookup_rms_norm_silu_knobs(int C, int num_tokens, return nullptr; // fundamentally unsupported (C not divisible by any valid config) } -} // namespace experimental -} // namespace cudnn_frontend +} // namespace norm +} // namespace flashinfer diff --git a/tests/norm/test_fused_rmsnorm_silu.py b/tests/norm/test_fused_rmsnorm_silu.py index 63e81ce1d9..589c7ab20f 100644 --- a/tests/norm/test_fused_rmsnorm_silu.py +++ b/tests/norm/test_fused_rmsnorm_silu.py @@ -13,8 +13,7 @@ # limitations under the License. """ Unit tests for Fused RMSNorm + SiLU kernel. -Tolerance and reference methodology matches the cuDNN frontend OSS test suite -(test_sm100_rms_norm_silu_graph_api.py). +Tests cover bf16, FP8, and NVFP4 output for all 40 LUT shapes plus fallback knobs. """ import pytest @@ -30,7 +29,7 @@ def get_cc(): def rmsnorm_silu_reference(x, weight, eps, output_dtype=None): """Reference: RMSNorm + SiLU. - Matches the cuDNN test reference: compute entirely in float32. + Compute entirely in float32 for maximum reference accuracy. If output_dtype is specified, cast the result to that dtype. """ rms = torch.sqrt(torch.mean(x.float() ** 2, dim=-1, keepdim=True) + eps) @@ -75,7 +74,7 @@ def _unpack_fp4_nibbles(packed_bytes, num_tokens, C): def _quantize_to_fp4_reference(values_f32, C): """Quantize float32 values to FP4 E2M1 nibbles matching the kernel's algorithm. - Matches cuDNN BlockScaleRowHelper: + Matches the kernel's block-scale quantization: 1. amax = max(|block of 16 elements|) 2. scale = max(amax / 6.0, FLT_MIN) 3. quantized = nv_fp4x2_e2m1(value / scale) @@ -144,7 +143,7 @@ def skip_if_not_sm100(): # ============================================================ # bf16 output — atol=2e-2, rtol=2e-2, zero mismatches -# (matches cuDNN test _run_rmsnorm_silu_test) +# atol=2e-2, rtol=2e-2, zero mismatches required # ============================================================ @@ -179,7 +178,7 @@ def test_lut_bf16(num_tokens, hidden_size): # ============================================================ # FP8 output — atol=0.125, rtol=0.125, zero mismatches # Reference in float32, then cast to FP8 (avoids bf16 double-rounding) -# (matches cuDNN test _run_fp8_rmsnorm_silu_test) +# atol=0.125, rtol=0.125, zero mismatches; reference in float32 then cast to FP8 # ============================================================ @@ -218,7 +217,7 @@ def test_lut_fp8(num_tokens, hidden_size): # ============================================================ # NVFP4 output — nibble-level comparison, <=1 ULP allowed -# (matches cuDNN test _run_nvfp4_rmsnorm_silu_test) +# Nibble-level comparison, <=1 ULP allowed # ============================================================ has_fp4_dtype = hasattr(torch, "float4_e2m1fn_x2") From 1f425eb2ec6793f8118507be538e1d89112531f9 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Thu, 2 Apr 2026 15:05:59 -0700 Subject: [PATCH 03/17] support checks --- flashinfer/jit/rmsnorm_silu.py | 11 +++++-- flashinfer/norm/__init__.py | 58 ++++++++++++++++++++++++++++++++-- 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/flashinfer/jit/rmsnorm_silu.py b/flashinfer/jit/rmsnorm_silu.py index 407baedd54..ed54ed85be 100644 --- a/flashinfer/jit/rmsnorm_silu.py +++ b/flashinfer/jit/rmsnorm_silu.py @@ -181,10 +181,15 @@ def _compute_default_knobs(C: int, dtype: str): return None -def select_knobs(C: int, num_tokens: int, dtype: str): - """Select knobs from LUT or fallback heuristic. Returns (warps_m, split_cols, kernel_cfg, occupancy, bytes_per_ldg).""" +def select_knobs(C: int, num_tokens: int, dtype: str, sm_version: int = 100): + """Select knobs from LUT or fallback heuristic. + + For parity with the original integration: + - SM100+: use sweep-tuned LUT for known shapes. + - non-SM100 or non-LUT shapes: use conservative fallback heuristic. + """ key = (C, num_tokens, dtype) - if key in _KNOB_LUT: + if sm_version >= 100 and key in _KNOB_LUT: return _KNOB_LUT[key] return _compute_default_knobs(C, dtype) diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py index c7df87f068..1d62755a47 100644 --- a/flashinfer/norm/__init__.py +++ b/flashinfer/norm/__init__.py @@ -32,7 +32,12 @@ import torch from ..api_logging import flashinfer_api -from ..utils import device_support_pdl, register_custom_op, register_fake_op +from ..utils import ( + device_support_pdl, + get_compute_capability, + register_custom_op, + register_fake_op, +) # Always import gen_norm_module for JIT warmup and CUDA fallback from ..jit.norm import gen_norm_module @@ -583,7 +588,10 @@ def _torch_dtype_to_str(dtype): return "fp8" elif hasattr(torch, "float4_e2m1fn_x2") and dtype == torch.float4_e2m1fn_x2: return "nvfp4" - return "bf16" + raise ValueError( + "Unsupported output dtype for fused_rmsnorm_silu: " + f"{dtype}. Supported dtypes: bfloat16, float8_e4m3fn, float4_e2m1fn_x2" + ) @flashinfer_api @@ -618,14 +626,58 @@ def fused_rmsnorm_silu( output: torch.Tensor Normalized + SiLU activated tensor, shape ``(num_tokens, hidden_size)``. """ + if input.device.type != "cuda": + raise ValueError("fused_rmsnorm_silu requires CUDA tensors") + if input.dtype != torch.bfloat16: + raise ValueError(f"input must be torch.bfloat16, got {input.dtype}") + if weight.dtype != torch.bfloat16: + raise ValueError(f"weight must be torch.bfloat16, got {weight.dtype}") + if input.ndim != 2: + raise ValueError( + f"input must be 2D [num_tokens, hidden_size], got ndim={input.ndim}" + ) + if weight.ndim != 1: + raise ValueError(f"weight must be 1D [hidden_size], got ndim={weight.ndim}") + if weight.device != input.device: + raise ValueError("weight must be on the same device as input") + if out is None: out = torch.empty_like(input) + if out.device != input.device: + raise ValueError("out must be on the same device as input") num_tokens = input.size(0) C = input.size(1) + if weight.size(0) != C: + raise ValueError( + f"weight shape mismatch: expected [{C}], got {tuple(weight.shape)}" + ) output_dtype_str = _torch_dtype_to_str(out.dtype) - knobs = select_knobs(C, num_tokens, output_dtype_str) + if output_dtype_str in ("bf16", "fp8"): + if tuple(out.shape) != tuple(input.shape): + raise ValueError( + f"out shape mismatch for {output_dtype_str}: expected {tuple(input.shape)}, got {tuple(out.shape)}" + ) + elif output_dtype_str == "nvfp4": + expected_shape = (num_tokens, C // 2) + if C % 2 != 0: + raise ValueError(f"nvfp4 output requires even hidden size, got C={C}") + if tuple(out.shape) != expected_shape: + raise ValueError( + f"out shape mismatch for nvfp4: expected {expected_shape}, got {tuple(out.shape)}" + ) + + major, minor = get_compute_capability(input.device) + sm_version = major * 10 + minor + if sm_version < 80: + raise RuntimeError("fused_rmsnorm_silu requires SM80+") + if output_dtype_str == "fp8" and sm_version < 89: + raise RuntimeError("FP8 output requires SM89+ (Ada/Hopper)") + if output_dtype_str == "nvfp4" and sm_version < 100: + raise RuntimeError("NVFP4 output requires SM100+ (Blackwell)") + + knobs = select_knobs(C, num_tokens, output_dtype_str, sm_version) if knobs is None: raise ValueError( f"Unsupported problem size for fused_rmsnorm_silu: " From 5e192b09309077d4c4c94984c4e7e604adf35ce8 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Thu, 2 Apr 2026 15:44:50 -0700 Subject: [PATCH 04/17] Remove unused C++ knob LUT; Python LUT is the sole source of truth MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The C++ header sm100_rms_norm_silu_knobs.h was never included by any source file — all knob selection happens in Python at JIT compile time via flashinfer/jit/rmsnorm_silu.py. Keeping a duplicate 120-entry LUT in C++ was a maintenance burden with no benefit. AI-assisted. Made-with: Cursor --- flashinfer/jit/rmsnorm_silu.py | 2 +- .../norm/sm100_rms_norm_silu_knobs.h | 216 ------------------ 2 files changed, 1 insertion(+), 217 deletions(-) delete mode 100644 include/flashinfer/norm/sm100_rms_norm_silu_knobs.h diff --git a/flashinfer/jit/rmsnorm_silu.py b/flashinfer/jit/rmsnorm_silu.py index ed54ed85be..8225db670e 100644 --- a/flashinfer/jit/rmsnorm_silu.py +++ b/flashinfer/jit/rmsnorm_silu.py @@ -21,7 +21,7 @@ from .utils import write_if_different -# Knob LUT ported from sm100_rms_norm_silu_knobs.h +# Sweep-tuned knob LUT for SM100 (B200) VAE problem sizes. # Format: (warps_m, split_cols, kernel_cfg, occupancy, bytes_per_ldg) _KNOB_LUT = { # C=64 diff --git a/include/flashinfer/norm/sm100_rms_norm_silu_knobs.h b/include/flashinfer/norm/sm100_rms_norm_silu_knobs.h deleted file mode 100644 index 2ec1b2ef65..0000000000 --- a/include/flashinfer/norm/sm100_rms_norm_silu_knobs.h +++ /dev/null @@ -1,216 +0,0 @@ -#pragma once - -// Auto-generated knob selection logic for Sm100RmsNormSiluEngine. -// Generated from optimal knob sweep results on B200 (SM100). -// Knob mapping: knobTileRows {0:1,1:4,2:8,3:32}, knobLoadSize {0:2,1:4,2:8,3:16} - -#include - -namespace flashinfer { -namespace norm { - -enum class RmsNormSiluDtype : uint8_t { - BF16 = 0, - FP8 = 1, - NVFP4 = 2, -}; - -// Compact knob configuration per (C, tokens, dtype) entry. -// WARPS_N is always 1. All values are the ACTUAL kernel parameters -// (not knob indices). -struct RmsNormSiluKnobs { - uint8_t warps_m; // WARPS_M value: 1, 4, 8, or 32 - uint8_t split_cols; // knobSplitCols: 0 = no split, 4 = estimated CTAS_PER_ROW - uint8_t kernel_cfg; // knobKernelCfg: 0, 1, or 2 - uint8_t occupancy; // DESIRED_OCCUPANCY: 0-16 - uint8_t bytes_per_ldg; // BYTES_PER_LDG: 2, 4, 8, or 16 -}; - -static constexpr int kSupportedC[] = {64, 128, 160, 256, 320, 512, 640, 1024}; -static constexpr int kSupportedTokens[] = {1560, 6240, 24960, 99840, 399360}; -static constexpr int kNumC = 8; -static constexpr int kNumTokens = 5; -static constexpr int kNumDtypes = 3; - -// Knob LUT indexed as: knob_lut[c_idx][tokens_idx][dtype_idx] -// c_idx: 0=64, 1=128, 2=160, 3=256, 4=320, 5=512, 6=640, 7=1024 -// tokens_idx: 0=1560, 1=6240, 2=24960, 3=99840, 4=399360 -// dtype_idx: 0=bf16, 1=fp8, 2=nvfp4 -static constexpr RmsNormSiluKnobs knob_lut[kNumC][kNumTokens][kNumDtypes] = { - { - // C=64 - {{8, 0, 0, 2, 4}, {8, 4, 0, 6, 4}, {8, 0, 2, 1, 4}}, // tokens=1560 - {{32, 4, 0, 2, 4}, {8, 0, 0, 3, 2}, {8, 4, 0, 4, 4}}, // tokens=6240 - {{32, 4, 0, 2, 4}, {8, 0, 0, 7, 4}, {8, 0, 1, 6, 4}}, // tokens=24960 - {{8, 0, 1, 8, 4}, {8, 0, 1, 6, 2}, {32, 4, 1, 2, 4}}, // tokens=99840 - {{4, 0, 1, 16, 4}, {32, 0, 1, 2, 2}, {32, 4, 1, 2, 4}}, // tokens=399360 - }, - { - // C=128 - {{8, 4, 0, 3, 4}, {8, 0, 0, 3, 4}, {8, 4, 0, 3, 8}}, // tokens=1560 - {{8, 0, 0, 3, 4}, {8, 0, 0, 4, 8}, {8, 0, 0, 5, 8}}, // tokens=6240 - {{8, 0, 0, 6, 4}, {8, 0, 0, 8, 8}, {8, 0, 1, 8, 8}}, // tokens=24960 - {{32, 4, 0, 2, 4}, {32, 0, 0, 2, 8}, {32, 0, 1, 2, 8}}, // tokens=99840 - {{8, 0, 0, 8, 4}, {32, 0, 0, 2, 8}, {32, 0, 1, 2, 8}}, // tokens=399360 - }, - { - // C=160 - {{8, 0, 0, 4, 2}, {8, 0, 0, 2, 2}, {4, 4, 0, 4, 2}}, // tokens=1560 - {{8, 0, 0, 4, 2}, {8, 0, 0, 4, 2}, {8, 0, 1, 4, 2}}, // tokens=6240 - {{8, 4, 1, 6, 2}, {8, 4, 0, 6, 2}, {8, 4, 1, 8, 2}}, // tokens=24960 - {{32, 4, 1, 2, 2}, {32, 4, 1, 2, 2}, {32, 4, 0, 1, 2}}, // tokens=99840 - {{32, 4, 1, 2, 2}, {32, 4, 1, 2, 2}, {32, 0, 1, 2, 2}}, // tokens=399360 - }, - { - // C=256 - {{8, 0, 0, 6, 16}, {8, 4, 0, 2, 4}, {8, 0, 2, 1, 16}}, // tokens=1560 - {{8, 0, 0, 4, 4}, {8, 0, 0, 4, 4}, {8, 0, 2, 1, 16}}, // tokens=6240 - {{8, 0, 0, 8, 16}, {8, 4, 0, 8, 16}, {8, 4, 1, 6, 16}}, // tokens=24960 - {{4, 4, 0, 16, 16}, {4, 0, 0, 16, 16}, {32, 0, 1, 1, 16}}, // tokens=99840 - {{4, 0, 0, 16, 16}, {32, 0, 0, 2, 16}, {32, 0, 1, 2, 16}}, // tokens=399360 - }, - { - // C=320 - {{8, 4, 1, 4, 4}, {8, 0, 0, 2, 4}, {4, 4, 0, 9, 4}}, // tokens=1560 - {{8, 4, 0, 5, 4}, {8, 0, 0, 5, 4}, {4, 0, 0, 9, 4}}, // tokens=6240 - {{8, 0, 0, 5, 4}, {8, 0, 0, 5, 4}, {8, 0, 1, 8, 4}}, // tokens=24960 - {{4, 0, 1, 16, 4}, {32, 0, 1, 2, 4}, {32, 4, 1, 2, 4}}, // tokens=99840 - {{32, 4, 0, 2, 4}, {32, 0, 1, 2, 4}, {32, 4, 1, 2, 4}}, // tokens=399360 - }, - { - // C=512 - {{8, 0, 0, 2, 16}, {8, 0, 0, 2, 8}, {4, 4, 0, 3, 16}}, // tokens=1560 - {{8, 0, 0, 5, 16}, {8, 0, 0, 4, 8}, {4, 0, 0, 9, 16}}, // tokens=6240 - {{4, 0, 0, 8, 16}, {4, 0, 0, 9, 8}, {4, 0, 2, 1, 16}}, // tokens=24960 - {{4, 0, 2, 1, 8}, {32, 4, 1, 2, 8}, {32, 4, 0, 1, 16}}, // tokens=99840 - {{4, 0, 2, 1, 4}, {32, 4, 1, 2, 8}, {32, 0, 0, 1, 16}}, // tokens=399360 - }, - { - // C=640 - {{4, 0, 0, 4, 4}, {4, 0, 0, 3, 8}, {4, 4, 0, 5, 8}}, // tokens=1560 - {{4, 0, 0, 5, 4}, {8, 0, 0, 4, 8}, {4, 0, 1, 9, 8}}, // tokens=6240 - {{4, 0, 0, 5, 4}, {8, 0, 0, 4, 8}, {4, 0, 2, 1, 8}}, // tokens=24960 - {{4, 0, 2, 1, 8}, {4, 4, 0, 9, 8}, {32, 0, 1, 1, 8}}, // tokens=99840 - {{4, 0, 2, 1, 8}, {32, 4, 1, 2, 8}, {32, 4, 1, 1, 8}}, // tokens=399360 - }, - { - // C=1024 - {{4, 4, 0, 3, 16}, {4, 0, 0, 3, 4}, {4, 4, 0, 7, 16}}, // tokens=1560 - {{4, 0, 0, 5, 16}, {4, 0, 0, 5, 8}, {4, 0, 2, 1, 16}}, // tokens=6240 - {{4, 4, 1, 10, 16}, {1, 4, 0, 16, 8}, {4, 0, 2, 1, 16}}, // tokens=24960 - {{8, 0, 2, 1, 16}, {4, 0, 1, 9, 8}, {32, 0, 1, 1, 16}}, // tokens=99840 - {{8, 0, 2, 1, 16}, {32, 4, 1, 1, 8}, {32, 4, 1, 1, 16}}, // tokens=399360 - }, -}; - -// Compute conservative default knobs for arbitrary problem sizes not in the LUT. -// Uses safe defaults (WARPS_M=1, BPL=4, occupancy=1) and validates vectorization -// divisibility constraints before accepting a configuration. -// Returns true if a valid configuration was found, false otherwise. -inline bool compute_default_knobs(int C, int num_tokens, RmsNormSiluDtype dtype, - RmsNormSiluKnobs& out) { - // Conservative defaults: - // CTAS_PER_ROW = 1, WARPS_M = 1, WARPS_N = 1, BPL = 4, occupancy = 1, kernel_cfg = 0 - // For block-scale output (NVFP4): WARPS_M = 32 - - int input_size = 2; // bf16 input always - - // Start with conservative defaults - int warps_m = (dtype == RmsNormSiluDtype::NVFP4) ? 32 : 1; - int warps_n = 1; // always 1 for our engine - int bpl = 4; // default bytes per load - int cpr = 1; // no column splitting for fallback - int occ = 1; - int kcfg = 0; - - // Validation: C must be evenly divisible into vectorized loads. - // NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) - // VEC_COLS = C / NUM_ELTS - // VEC_COLS_PER_LDG = CTAS_PER_ROW * WARPS_N * 32 - // Require: C % NUM_ELTS == 0 AND VEC_COLS % VEC_COLS_PER_LDG == 0 - // Also: LDGS = VEC_COLS / VEC_COLS_PER_LDG <= 1024 (avoid register spill) - - auto validate = [&](int test_bpl, int test_wm) -> bool { - int num_elts = test_bpl / input_size; - if (num_elts <= 0 || C % num_elts != 0) return false; - int vec_cols = C / num_elts; - int vec_cols_per_ldg = cpr * warps_n * 32; - if (vec_cols_per_ldg <= 0 || vec_cols % vec_cols_per_ldg != 0) return false; - int ldgs = vec_cols / vec_cols_per_ldg; - if (ldgs > 1024) return false; // reject extreme LDGS to avoid register spilling - // Check WARPS_M constraint: if WARPS_M > 1, rows per CTA must divide evenly - if (test_wm > 1 && num_tokens % test_wm != 0) return false; - return true; - }; - - // Try default BPL=4, then cascade through {4, 8, 16, 2} - static constexpr int bpl_candidates[] = {4, 8, 16, 2}; - bool found = false; - for (int candidate : bpl_candidates) { - if (validate(candidate, warps_m)) { - bpl = candidate; - found = true; - break; - } - } - - // If WARPS_M=1 failed, try bumping to WARPS_M=4 for better row coverage - if (!found && warps_m == 1 && num_tokens % 4 == 0) { - warps_m = 4; - for (int candidate : bpl_candidates) { - if (validate(candidate, warps_m)) { - bpl = candidate; - found = true; - break; - } - } - } - - if (!found) return false; - - out.warps_m = static_cast(warps_m); - out.split_cols = 0; // no column splitting - out.kernel_cfg = static_cast(kcfg); - out.occupancy = static_cast(occ); - out.bytes_per_ldg = static_cast(bpl); - return true; -} - -// Look up knob configuration for a given (C, num_tokens, output_dtype, sm_version). -// Tier 1: exact LUT match for SM100 VAE problem sizes (optimal, sweep-tuned on B200). -// Tier 2: fallback heuristic for other archs or arbitrary sizes (functional, conservative). -// Returns nullptr only if the problem is fundamentally unsupported. -inline const RmsNormSiluKnobs* lookup_rms_norm_silu_knobs(int C, int num_tokens, - RmsNormSiluDtype dtype, - int sm_version = 100) { - // Tier 1: exact LUT match — only valid for SM100 (swept on B200) - if (sm_version >= 100) { - int c_idx = -1, t_idx = -1; - for (int i = 0; i < kNumC; ++i) { - if (kSupportedC[i] == C) { - c_idx = i; - break; - } - } - for (int i = 0; i < kNumTokens; ++i) { - if (kSupportedTokens[i] == num_tokens) { - t_idx = i; - break; - } - } - if (c_idx >= 0 && t_idx >= 0) { - return &knob_lut[c_idx][t_idx][static_cast(dtype)]; - } - } - - // Tier 2: fallback heuristic for non-SM100 archs or non-LUT problem sizes - static thread_local RmsNormSiluKnobs fallback; - if (compute_default_knobs(C, num_tokens, dtype, fallback)) { - return &fallback; - } - - return nullptr; // fundamentally unsupported (C not divisible by any valid config) -} - -} // namespace norm -} // namespace flashinfer From 44829ed004dad6680010226c1f2ec4e1c563912f Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Thu, 2 Apr 2026 16:31:21 -0700 Subject: [PATCH 05/17] address gemini-code-assist comment --- flashinfer/norm/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py index 1d62755a47..699a0046e3 100644 --- a/flashinfer/norm/__init__.py +++ b/flashinfer/norm/__init__.py @@ -536,9 +536,9 @@ def _layernorm_fake( @functools.cache -def _get_rmsnorm_silu_sm_count(): - """Cache the SM count for the current device.""" - props = torch.cuda.get_device_properties(torch.cuda.current_device()) +def _get_rmsnorm_silu_sm_count(device_id: int): + """Cache the SM count per device.""" + props = torch.cuda.get_device_properties(device_id) return props.multi_processor_count @@ -686,7 +686,7 @@ def fused_rmsnorm_silu( warps_m, split_cols, kernel_cfg, occupancy, bytes_per_ldg = knobs ctas_per_row = _estimate_ctas_per_row(C, split_cols, kernel_cfg, bytes_per_ldg) - sm_count = _get_rmsnorm_silu_sm_count() + sm_count = _get_rmsnorm_silu_sm_count(input.device.index) module = _get_rmsnorm_silu_module( C, output_dtype_str, warps_m, ctas_per_row, bytes_per_ldg, kernel_cfg, occupancy From f41322a97d629d19e5b6047f79d943c1d20d25fa Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Thu, 2 Apr 2026 16:55:12 -0700 Subject: [PATCH 06/17] fix and clean up --- csrc/flashinfer_rmsnorm_silu_binding.cu | 2 +- csrc/rmsnorm_silu.cu | 2 +- flashinfer/jit/rmsnorm_silu.py | 2 +- tests/norm/test_fused_rmsnorm_silu.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/flashinfer_rmsnorm_silu_binding.cu b/csrc/flashinfer_rmsnorm_silu_binding.cu index 4f3fea6c76..d2301ad733 100644 --- a/csrc/flashinfer_rmsnorm_silu_binding.cu +++ b/csrc/flashinfer_rmsnorm_silu_binding.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025 by FlashInfer team. + * Copyright (c) 2026 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/csrc/rmsnorm_silu.cu b/csrc/rmsnorm_silu.cu index 5629de80aa..ca34759f2d 100644 --- a/csrc/rmsnorm_silu.cu +++ b/csrc/rmsnorm_silu.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025 by FlashInfer team. + * Copyright (c) 2026 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/flashinfer/jit/rmsnorm_silu.py b/flashinfer/jit/rmsnorm_silu.py index 8225db670e..1c680f4f3c 100644 --- a/flashinfer/jit/rmsnorm_silu.py +++ b/flashinfer/jit/rmsnorm_silu.py @@ -1,5 +1,5 @@ """ -Copyright (c) 2025 by FlashInfer team. +Copyright (c) 2026 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tests/norm/test_fused_rmsnorm_silu.py b/tests/norm/test_fused_rmsnorm_silu.py index 589c7ab20f..a7b03c7f6b 100644 --- a/tests/norm/test_fused_rmsnorm_silu.py +++ b/tests/norm/test_fused_rmsnorm_silu.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 by FlashInfer team. +# Copyright (c) 2026 by FlashInfer team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 6fb1746d81eb48dc4770dec9aa18af1e8444e4d5 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Thu, 2 Apr 2026 17:48:14 -0700 Subject: [PATCH 07/17] Fix include order in rmsnorm_silu.cu to match header dependencies ln_fwd_silu_kernel.cuh requires Ktraits, PersistentLnFwdParams, and other types to be defined before inclusion. The correct order is: 1. ln_silu_headers.cuh (type definitions) 2. rmsnorm_silu_config.inc (Ktraits typedef, constexpr flags) 3. ln_fwd_silu_kernel.cuh (kernel using the above) Protected with clang-format off/on since alphabetical sorting would break this dependency chain. AI-assisted. Made-with: Cursor --- csrc/rmsnorm_silu.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/rmsnorm_silu.cu b/csrc/rmsnorm_silu.cu index ca34759f2d..6204ffb679 100644 --- a/csrc/rmsnorm_silu.cu +++ b/csrc/rmsnorm_silu.cu @@ -14,12 +14,14 @@ * limitations under the License. */ +// clang-format off // Include order matters: headers → config (defines Ktraits) → kernel (uses Ktraits) #include -#include #include - #include "rmsnorm_silu_config.inc" +#include +// clang-format on + #include "tvm_ffi_utils.h" void rmsnorm_silu(TensorView output, TensorView input, TensorView weight, double eps, From 56097582d7ab02575943dfe4caa2dfaa489dee56 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Thu, 2 Apr 2026 17:58:00 -0700 Subject: [PATCH 08/17] add fallback logic (if misses LUT) to aot precompile for better dynamic shape support --- flashinfer/aot.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/flashinfer/aot.py b/flashinfer/aot.py index e83997876f..4d0aa0924e 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -91,6 +91,7 @@ gen_rmsnorm_silu_module, select_knobs, _estimate_ctas_per_row, + _compute_default_knobs, _SUPPORTED_C, _SUPPORTED_TOKENS, ) @@ -578,6 +579,31 @@ def gen_all_modules( jit_specs.append( gen_rmsnorm_silu_module(C, dtype, wm, cpr, bpl, kcfg, occ) ) + # Fallback configs for common hidden sizes not in the LUT. + # Fallback knobs depend only on (C, dtype), not num_tokens, + # so one module per (C, dtype) covers all token counts. + _FALLBACK_C = [ + 768, + 1280, + 1536, + 2048, + 2560, + 3072, + 4096, + 5120, + 6144, + 8192, + ] + for C in _FALLBACK_C: + for dtype in ["bf16", "fp8", "nvfp4"]: + knobs = _compute_default_knobs(C, dtype) + if knobs is None: + continue + wm, sc, kcfg, occ, bpl = knobs + cpr = _estimate_ctas_per_row(C, sc, kcfg, bpl) + jit_specs.append( + gen_rmsnorm_silu_module(C, dtype, wm, cpr, bpl, kcfg, occ) + ) # selective_state_update: one module per dtype combo per GPU arch _ssu_dtype_combos = [ # (state, input, weight, matrixA, stateIndex, state_scale_dtype) From 3611deaa64480a31680dda81b1dc94abed885b9d Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Fri, 3 Apr 2026 13:12:26 -0700 Subject: [PATCH 09/17] nvfp4 return Union of y and block_scale --- flashinfer/norm/__init__.py | 31 ++++++++++++++++++++++++--- tests/norm/test_fused_rmsnorm_silu.py | 11 ++++++++-- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py index 699a0046e3..497fbd818c 100644 --- a/flashinfer/norm/__init__.py +++ b/flashinfer/norm/__init__.py @@ -600,7 +600,7 @@ def fused_rmsnorm_silu( weight: torch.Tensor, eps: float = 1e-6, out: Optional[torch.Tensor] = None, -) -> torch.Tensor: +) -> Union[torch.Tensor, tuple]: r"""Fused RMSNorm + SiLU activation. ``out[i] = SiLU(RMSNorm(input[i], weight, eps))`` @@ -620,11 +620,20 @@ def fused_rmsnorm_silu( Epsilon for numerical stability. out: Optional[torch.Tensor] Output tensor. If None, allocated as same shape/dtype as input. + For NVFP4 output (``torch.float4_e2m1fn_x2``), shape must be + ``(num_tokens, hidden_size // 2)``. Returns ------- - output: torch.Tensor - Normalized + SiLU activated tensor, shape ``(num_tokens, hidden_size)``. + output: torch.Tensor or Tuple[torch.Tensor, torch.Tensor] + For bf16/fp8: normalized + SiLU activated tensor, + shape ``(num_tokens, hidden_size)``. + + For NVFP4: a tuple ``(y_fp4, block_scale)`` following the same + convention as :func:`rmsnorm_fp4quant`. ``y_fp4`` has shape + ``(num_tokens, hidden_size // 2)`` with dtype ``float4_e2m1fn_x2``, + and ``block_scale`` has shape ``(num_tokens, hidden_size // 16)`` + with dtype ``float8_e4m3fn`` (one E4M3 scale per 16-element block). """ if input.device.type != "cuda": raise ValueError("fused_rmsnorm_silu requires CUDA tensors") @@ -705,6 +714,22 @@ def fused_rmsnorm_silu( workspace = torch.empty(ws_size, dtype=torch.uint8, device=input.device) module.rmsnorm_silu(out, input, weight, eps, workspace, sm_count) + + if output_dtype_str == "nvfp4": + # Extract block_scale from workspace (matches C++ layout in rmsnorm_silu.cu). + # Layout: [rs: rows*4, align128] [fp8_scale: 4, align128] [scale_row: ...] + scale_row_offset = num_tokens * 4 # rs + scale_row_offset = ((scale_row_offset + 127) // 128) * 128 + scale_row_offset += 4 # fp8_scale + scale_row_offset = ((scale_row_offset + 127) // 128) * 128 + num_blocks = (C + 15) // 16 + scale_row_bytes = num_tokens * num_blocks + block_scale = workspace[scale_row_offset : scale_row_offset + scale_row_bytes] + block_scale = block_scale.view(torch.float8_e4m3fn).reshape( + num_tokens, num_blocks + ) + return out, block_scale + return out diff --git a/tests/norm/test_fused_rmsnorm_silu.py b/tests/norm/test_fused_rmsnorm_silu.py index a7b03c7f6b..ebb93a3fb9 100644 --- a/tests/norm/test_fused_rmsnorm_silu.py +++ b/tests/norm/test_fused_rmsnorm_silu.py @@ -240,7 +240,11 @@ def test_lut_nvfp4(num_tokens, hidden_size): # FP4 packs 2 values per byte out = torch.empty(num_tokens, C // 2, dtype=torch.float4_e2m1fn_x2, device="cuda") - result = flashinfer.fused_rmsnorm_silu(x, weight, eps=1e-6, out=out) + result, block_scale = flashinfer.fused_rmsnorm_silu(x, weight, eps=1e-6, out=out) + + assert result.data_ptr() == out.data_ptr() + assert block_scale.shape == (num_tokens, C // 16) + assert block_scale.dtype == torch.float8_e4m3fn ref_f32 = rmsnorm_silu_reference(x, weight, eps=1e-6, output_dtype=torch.float32) @@ -386,7 +390,10 @@ def test_fallback_knobs_nvfp4(num_tokens, hidden_size): weight = torch.rand(C, dtype=torch.bfloat16, device="cuda") * 1.5 + 0.5 out = torch.empty(num_tokens, C // 2, dtype=torch.float4_e2m1fn_x2, device="cuda") - result = flashinfer.fused_rmsnorm_silu(x, weight, eps=1e-6, out=out) + result, block_scale = flashinfer.fused_rmsnorm_silu(x, weight, eps=1e-6, out=out) + + assert block_scale.shape == (num_tokens, C // 16) + assert block_scale.dtype == torch.float8_e4m3fn ref_f32 = rmsnorm_silu_reference(x, weight, eps=1e-6, output_dtype=torch.float32) From 837768ca6db0f7eacc7ae2476eed020ca52cfaeb Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Fri, 3 Apr 2026 13:25:27 -0700 Subject: [PATCH 10/17] address https://github.com/flashinfer-ai/flashinfer/pull/2965#discussion_r3031072384 --- include/flashinfer/norm/ln_silu_headers.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flashinfer/norm/ln_silu_headers.cuh b/include/flashinfer/norm/ln_silu_headers.cuh index e71b510e68..a26a322cb9 100644 --- a/include/flashinfer/norm/ln_silu_headers.cuh +++ b/include/flashinfer/norm/ln_silu_headers.cuh @@ -1282,7 +1282,7 @@ struct Stats { // Count leading zeros - start from most significant bit. static int clz(int32_t x) { for (int32_t i = 31; i >= 0; --i) - if (((1 << i) & x) != 0) { + if (((1u << i) & x) != 0) { return 31 - i; } return 32; From 23bc9080e8bc947acd36a6472dbb41189f99a3fb Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Fri, 3 Apr 2026 13:48:48 -0700 Subject: [PATCH 11/17] address https://github.com/flashinfer-ai/flashinfer/pull/2965#discussion_r3031072384 --- flashinfer/norm/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py index 497fbd818c..a351ad6ee0 100644 --- a/flashinfer/norm/__init__.py +++ b/flashinfer/norm/__init__.py @@ -670,8 +670,10 @@ def fused_rmsnorm_silu( ) elif output_dtype_str == "nvfp4": expected_shape = (num_tokens, C // 2) - if C % 2 != 0: - raise ValueError(f"nvfp4 output requires even hidden size, got C={C}") + if C % 16 != 0: + raise ValueError( + f"nvfp4 output requires hidden_size divisible by 16, got C={C}" + ) if tuple(out.shape) != expected_shape: raise ValueError( f"out shape mismatch for nvfp4: expected {expected_shape}, got {tuple(out.shape)}" @@ -722,7 +724,7 @@ def fused_rmsnorm_silu( scale_row_offset = ((scale_row_offset + 127) // 128) * 128 scale_row_offset += 4 # fp8_scale scale_row_offset = ((scale_row_offset + 127) // 128) * 128 - num_blocks = (C + 15) // 16 + num_blocks = C // 16 scale_row_bytes = num_tokens * num_blocks block_scale = workspace[scale_row_offset : scale_row_offset + scale_row_bytes] block_scale = block_scale.view(torch.float8_e4m3fn).reshape( From 92a2edd2d6422ffd0dd71ad515aeca3cde21770e Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Mon, 6 Apr 2026 09:36:52 -0700 Subject: [PATCH 12/17] changes --- csrc/flashinfer_rmsnorm_silu_binding.cu | 2 +- csrc/rmsnorm_silu.cu | 8 ++-- flashinfer/norm/__init__.py | 24 ++++-------- tests/norm/test_fused_rmsnorm_silu.py | 50 +++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 22 deletions(-) diff --git a/csrc/flashinfer_rmsnorm_silu_binding.cu b/csrc/flashinfer_rmsnorm_silu_binding.cu index d2301ad733..1542b2b697 100644 --- a/csrc/flashinfer_rmsnorm_silu_binding.cu +++ b/csrc/flashinfer_rmsnorm_silu_binding.cu @@ -16,6 +16,6 @@ #include "tvm_ffi_utils.h" void rmsnorm_silu(TensorView output, TensorView input, TensorView weight, double eps, - TensorView workspace, int64_t sm_count); + TensorView workspace, TensorView scale_row_out, int64_t sm_count); TVM_FFI_DLL_EXPORT_TYPED_FUNC(rmsnorm_silu, rmsnorm_silu); diff --git a/csrc/rmsnorm_silu.cu b/csrc/rmsnorm_silu.cu index 6204ffb679..f0a79f2e5e 100644 --- a/csrc/rmsnorm_silu.cu +++ b/csrc/rmsnorm_silu.cu @@ -25,7 +25,7 @@ #include "tvm_ffi_utils.h" void rmsnorm_silu(TensorView output, TensorView input, TensorView weight, double eps, - TensorView workspace, int64_t sm_count) { + TensorView workspace, TensorView scale_row_out, int64_t sm_count) { CHECK_LAST_DIM_CONTIGUOUS_INPUT(input); CHECK_LAST_DIM_CONTIGUOUS_INPUT(output); CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight); @@ -92,11 +92,9 @@ void rmsnorm_silu(TensorView output, TensorView input, TensorView weight, double off += sizeof(float); off = ((off + 127) / 128) * 128; - // [aligned] scale_row: rows * ceil(C/16) bytes (NVFP4 only) + // scale_row: passed as separate output tensor (NVFP4 only) if (isFP4Out) { - params.scale_row = ws_ptr + off; - off += static_cast(rows) * ((cols + 15) / 16); - off = ((off + 127) / 128) * 128; + params.scale_row = scale_row_out.data_ptr(); } // [aligned] cooperative workspace + barriers (multi-CTA only) diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py index a351ad6ee0..4fe8a3b1c4 100644 --- a/flashinfer/norm/__init__.py +++ b/flashinfer/norm/__init__.py @@ -561,10 +561,6 @@ def _compute_rmsnorm_silu_workspace_size( # fp8_scale ws += 4 ws = ((ws + 127) // 128) * 128 - # scale_row (NVFP4 only) - if output_dtype == "nvfp4": - ws += rows * ((cols + 15) // 16) - ws = ((ws + 127) // 128) * 128 # cooperative workspace (multi-CTA) if ctas_per_row > 1: ctas_per_col_max = (rows + warps_m - 1) // warps_m @@ -715,21 +711,17 @@ def fused_rmsnorm_silu( ) workspace = torch.empty(ws_size, dtype=torch.uint8, device=input.device) - module.rmsnorm_silu(out, input, weight, eps, workspace, sm_count) - if output_dtype_str == "nvfp4": - # Extract block_scale from workspace (matches C++ layout in rmsnorm_silu.cu). - # Layout: [rs: rows*4, align128] [fp8_scale: 4, align128] [scale_row: ...] - scale_row_offset = num_tokens * 4 # rs - scale_row_offset = ((scale_row_offset + 127) // 128) * 128 - scale_row_offset += 4 # fp8_scale - scale_row_offset = ((scale_row_offset + 127) // 128) * 128 num_blocks = C // 16 - scale_row_bytes = num_tokens * num_blocks - block_scale = workspace[scale_row_offset : scale_row_offset + scale_row_bytes] - block_scale = block_scale.view(torch.float8_e4m3fn).reshape( - num_tokens, num_blocks + block_scale = torch.empty( + num_tokens, num_blocks, dtype=torch.float8_e4m3fn, device=input.device ) + else: + block_scale = torch.empty(0, dtype=torch.uint8, device=input.device) + + module.rmsnorm_silu(out, input, weight, eps, workspace, block_scale, sm_count) + + if output_dtype_str == "nvfp4": return out, block_scale return out diff --git a/tests/norm/test_fused_rmsnorm_silu.py b/tests/norm/test_fused_rmsnorm_silu.py index ebb93a3fb9..b33b1f6b24 100644 --- a/tests/norm/test_fused_rmsnorm_silu.py +++ b/tests/norm/test_fused_rmsnorm_silu.py @@ -497,3 +497,53 @@ def test_uniform_weight(): mismatches = ~torch.isclose(out.float(), ref.float(), atol=2e-2, rtol=2e-2) assert mismatches.sum().item() == 0 + + +# ============================================================ +# NVFP4 round-trip dequantization (verifies block_scale is usable) +# ============================================================ + +ROUNDTRIP_SHAPES = [ + (1560, 256), + (6240, 512), + (24960, 1024), +] + + +@pytest.mark.skipif(not has_fp4_dtype, reason="torch.float4_e2m1fn_x2 not available") +@pytest.mark.parametrize( + "num_tokens,hidden_size", + ROUNDTRIP_SHAPES, + ids=[f"t{t}_C{c}" for t, c in ROUNDTRIP_SHAPES], +) +def test_nvfp4_roundtrip_dequantize(num_tokens, hidden_size): + """Verify that (y_fp4, block_scale) can round-trip back to float via dequantization.""" + import flashinfer + + torch.manual_seed(42) + C = hidden_size + x = torch.randn(num_tokens, C, dtype=torch.bfloat16, device="cuda") * 5.0 + 5.0 + weight = torch.rand(C, dtype=torch.bfloat16, device="cuda") * 1.5 + 0.5 + + out = torch.empty(num_tokens, C // 2, dtype=torch.float4_e2m1fn_x2, device="cuda") + y_fp4, block_scale = flashinfer.fused_rmsnorm_silu(x, weight, eps=1e-6, out=out) + + z_packed = y_fp4.view(torch.uint8).reshape(num_tokens, C // 2) + dequantized = dequantize_nvfp4(z_packed, block_scale, num_tokens, C) + + ref_f32 = rmsnorm_silu_reference(x, weight, eps=1e-6, output_dtype=torch.float32) + + # FP4 has very limited precision (3-bit mantissa equivalent), so the + # dequantized values won't match exactly. We check relative error is + # bounded: each FP4 value is within one block-scale quantum of the reference. + abs_err = (dequantized - ref_f32).abs() + rel_err = abs_err / (ref_f32.abs() + 1e-6) + median_rel_err = rel_err.median().item() + assert median_rel_err < 0.5, ( + f"NVFP4 round-trip median relative error too large: {median_rel_err:.4f}" + ) + # Also check no catastrophic outliers (>2x the reference magnitude) + max_rel_err = rel_err.max().item() + assert max_rel_err < 2.0, ( + f"NVFP4 round-trip max relative error too large: {max_rel_err:.4f}" + ) From c9e785cc3556151671c0adbbc88ea2b64b73af59 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Mon, 6 Apr 2026 16:23:16 -0700 Subject: [PATCH 13/17] add optional user-pre-allocated-input for block scale output tensor --- flashinfer/norm/__init__.py | 41 ++++++++++++---- tests/norm/test_fused_rmsnorm_silu.py | 67 +++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 8 deletions(-) diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py index 4fe8a3b1c4..318e87a2a6 100644 --- a/flashinfer/norm/__init__.py +++ b/flashinfer/norm/__init__.py @@ -596,6 +596,7 @@ def fused_rmsnorm_silu( weight: torch.Tensor, eps: float = 1e-6, out: Optional[torch.Tensor] = None, + block_scale: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple]: r"""Fused RMSNorm + SiLU activation. @@ -615,9 +616,20 @@ def fused_rmsnorm_silu( eps: float Epsilon for numerical stability. out: Optional[torch.Tensor] - Output tensor. If None, allocated as same shape/dtype as input. - For NVFP4 output (``torch.float4_e2m1fn_x2``), shape must be - ``(num_tokens, hidden_size // 2)``. + Output tensor. If ``None``, allocated as ``bfloat16`` matching input. + The dtype of ``out`` selects the output format: + + - ``torch.bfloat16``: shape ``(num_tokens, hidden_size)``. + - ``torch.float8_e4m3fn``: FP8 E4M3 output, shape ``(num_tokens, hidden_size)``. + Requires SM89+ (Ada/Hopper). + - ``torch.float4_e2m1fn_x2``: NVFP4 block-scaled output, shape + ``(num_tokens, hidden_size // 2)``. Requires SM100+ (Blackwell) + and ``hidden_size`` divisible by 16. + block_scale: Optional[torch.Tensor] + Pre-allocated output tensor for per-block scale factors (NVFP4 only). + Shape ``(num_tokens, hidden_size // 16)``, dtype ``torch.float8_e4m3fn``. + If ``None``, allocated automatically when ``out`` is NVFP4. + Ignored for bf16/fp8 output. Returns ------- @@ -713,13 +725,26 @@ def fused_rmsnorm_silu( if output_dtype_str == "nvfp4": num_blocks = C // 16 - block_scale = torch.empty( - num_tokens, num_blocks, dtype=torch.float8_e4m3fn, device=input.device - ) + if block_scale is None: + block_scale = torch.empty( + num_tokens, num_blocks, dtype=torch.float8_e4m3fn, device=input.device + ) + else: + expected_shape = (num_tokens, num_blocks) + if tuple(block_scale.shape) != expected_shape: + raise ValueError( + f"block_scale shape mismatch: expected {expected_shape}, " + f"got {tuple(block_scale.shape)}" + ) + if block_scale.dtype != torch.float8_e4m3fn: + raise ValueError( + f"block_scale must be float8_e4m3fn, got {block_scale.dtype}" + ) + scale_row_out = block_scale else: - block_scale = torch.empty(0, dtype=torch.uint8, device=input.device) + scale_row_out = torch.empty(0, dtype=torch.uint8, device=input.device) - module.rmsnorm_silu(out, input, weight, eps, workspace, block_scale, sm_count) + module.rmsnorm_silu(out, input, weight, eps, workspace, scale_row_out, sm_count) if output_dtype_str == "nvfp4": return out, block_scale diff --git a/tests/norm/test_fused_rmsnorm_silu.py b/tests/norm/test_fused_rmsnorm_silu.py index b33b1f6b24..cd8dffc7a4 100644 --- a/tests/norm/test_fused_rmsnorm_silu.py +++ b/tests/norm/test_fused_rmsnorm_silu.py @@ -456,6 +456,73 @@ def test_preallocated_output_fp8(): assert mismatches.sum().item() == 0 +@pytest.mark.skipif(not has_fp4_dtype, reason="torch.float4_e2m1fn_x2 not available") +def test_preallocated_output_nvfp4(): + """Pre-allocated out AND block_scale for NVFP4.""" + import flashinfer + + num_tokens, hidden_size = 1560, 256 + C = hidden_size + torch.manual_seed(42) + x = torch.randn(num_tokens, C, dtype=torch.bfloat16, device="cuda") * 5.0 + 5.0 + weight = torch.rand(C, dtype=torch.bfloat16, device="cuda") * 1.5 + 0.5 + + out = torch.empty(num_tokens, C // 2, dtype=torch.float4_e2m1fn_x2, device="cuda") + block_scale = torch.empty( + num_tokens, C // 16, dtype=torch.float8_e4m3fn, device="cuda" + ) + + y_fp4, bs = flashinfer.fused_rmsnorm_silu( + x, weight, eps=1e-6, out=out, block_scale=block_scale + ) + + assert y_fp4.data_ptr() == out.data_ptr() + assert bs.data_ptr() == block_scale.data_ptr() + assert bs.shape == (num_tokens, C // 16) + assert bs.dtype == torch.float8_e4m3fn + + ref_f32 = rmsnorm_silu_reference(x, weight, eps=1e-6, output_dtype=torch.float32) + z_packed = y_fp4.view(torch.uint8).reshape(num_tokens, C // 2) + kernel_nibbles = _unpack_fp4_nibbles(z_packed, num_tokens, C) + ref_nibbles = _quantize_to_fp4_reference(ref_f32, C) + nibble_diff = (kernel_nibbles - ref_nibbles).abs() + assert (nibble_diff > 1).sum().item() == 0 + + +@pytest.mark.skipif(not has_fp4_dtype, reason="torch.float4_e2m1fn_x2 not available") +def test_preallocated_block_scale_wrong_shape(): + """block_scale with wrong shape should raise ValueError.""" + import flashinfer + + num_tokens, C = 1560, 256 + x = torch.randn(num_tokens, C, dtype=torch.bfloat16, device="cuda") + weight = torch.rand(C, dtype=torch.bfloat16, device="cuda") + out = torch.empty(num_tokens, C // 2, dtype=torch.float4_e2m1fn_x2, device="cuda") + bad_scale = torch.empty(num_tokens, 1, dtype=torch.float8_e4m3fn, device="cuda") + + with pytest.raises(ValueError, match="block_scale shape mismatch"): + flashinfer.fused_rmsnorm_silu( + x, weight, eps=1e-6, out=out, block_scale=bad_scale + ) + + +@pytest.mark.skipif(not has_fp4_dtype, reason="torch.float4_e2m1fn_x2 not available") +def test_preallocated_block_scale_wrong_dtype(): + """block_scale with wrong dtype should raise ValueError.""" + import flashinfer + + num_tokens, C = 1560, 256 + x = torch.randn(num_tokens, C, dtype=torch.bfloat16, device="cuda") + weight = torch.rand(C, dtype=torch.bfloat16, device="cuda") + out = torch.empty(num_tokens, C // 2, dtype=torch.float4_e2m1fn_x2, device="cuda") + bad_scale = torch.empty(num_tokens, C // 16, dtype=torch.float32, device="cuda") + + with pytest.raises(ValueError, match="block_scale must be float8_e4m3fn"): + flashinfer.fused_rmsnorm_silu( + x, weight, eps=1e-6, out=out, block_scale=bad_scale + ) + + # ============================================================ # Numerical edge cases # ============================================================ From 9c4d9ad6541e8663cd70c624a4524fbbf363a4c7 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Mon, 6 Apr 2026 16:47:49 -0700 Subject: [PATCH 14/17] add notes --- flashinfer/norm/__init__.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py index 318e87a2a6..14cb0ebdb3 100644 --- a/flashinfer/norm/__init__.py +++ b/flashinfer/norm/__init__.py @@ -642,6 +642,15 @@ def fused_rmsnorm_silu( ``(num_tokens, hidden_size // 2)`` with dtype ``float4_e2m1fn_x2``, and ``block_scale`` has shape ``(num_tokens, hidden_size // 16)`` with dtype ``float8_e4m3fn`` (one E4M3 scale per 16-element block). + + Notes + ----- + Kernel tuning knobs are sweep-optimized on B200 (SM100) for WAN VAE + decoder problem sizes: ``hidden_size`` in {64, 128, 160, 256, 320, 512, + 640, 1024} and ``num_tokens`` in {1560, 6240, 24960, 99840, 399360}. + Other problem sizes use conservative fallback heuristics that are + functionally correct but may not achieve peak throughput. Performance + on non-SM100 architectures (SM80/SM89) uses the same fallback path. """ if input.device.type != "cuda": raise ValueError("fused_rmsnorm_silu requires CUDA tensors") From fc18fbbf57388da8014acf6021e3b640a958a770 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Tue, 7 Apr 2026 11:23:13 -0700 Subject: [PATCH 15/17] clarification on sm100 optimized, sm80+ supported --- flashinfer/norm/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flashinfer/norm/__init__.py b/flashinfer/norm/__init__.py index 14cb0ebdb3..0f9911a6ed 100644 --- a/flashinfer/norm/__init__.py +++ b/flashinfer/norm/__init__.py @@ -604,8 +604,8 @@ def fused_rmsnorm_silu( where ``SiLU(x) = x / (1 + exp(-x))`` - Optimized for WAN VAE decoder workloads on SM100 (B200). - Uses sweep-tuned knobs for all standard VAE problem sizes. + Optimized for SM100 (B200) for WAN VAE decoder problem sizes. + Other shapes and architectures (SM80+) use conservative fallback heuristics. Parameters ---------- @@ -650,7 +650,7 @@ def fused_rmsnorm_silu( 640, 1024} and ``num_tokens`` in {1560, 6240, 24960, 99840, 399360}. Other problem sizes use conservative fallback heuristics that are functionally correct but may not achieve peak throughput. Performance - on non-SM100 architectures (SM80/SM89) uses the same fallback path. + on non-SM100 architectures uses the same fallback path. """ if input.device.type != "cuda": raise ValueError("fused_rmsnorm_silu requires CUDA tensors") From 86e38f07ab8413d80b8757ff6d1b0bc8e56c3cb3 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Tue, 7 Apr 2026 16:55:36 -0700 Subject: [PATCH 16/17] update docs --- docs/api/norm.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/api/norm.rst b/docs/api/norm.rst index 058940995d..ea1671acec 100644 --- a/docs/api/norm.rst +++ b/docs/api/norm.rst @@ -17,3 +17,4 @@ Kernels for normalization layers. gemma_rmsnorm gemma_fused_add_rmsnorm layernorm + fused_rmsnorm_silu From 89df4cb903ccf0d5ecb835cbe8309b727fd45507 Mon Sep 17 00:00:00 2001 From: kahyunnam Date: Tue, 7 Apr 2026 16:56:25 -0700 Subject: [PATCH 17/17] add to microbenchmarks --- .../routines/flashinfer_benchmark_utils.py | 1 + benchmarks/routines/norm.py | 121 ++++++++++++++++++ benchmarks/samples/sample_testlist.txt | 8 ++ 3 files changed, 130 insertions(+) diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index f17305fcef..62a7d349d3 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -197,6 +197,7 @@ "fused_add_rmsnorm_quant", "rmsnorm_fp4quant", "add_rmsnorm_fp4quant", + "fused_rmsnorm_silu", ], "quantization": [ "mxfp8_quantize", diff --git a/benchmarks/routines/norm.py b/benchmarks/routines/norm.py index 6daf93c62c..48aa65761e 100644 --- a/benchmarks/routines/norm.py +++ b/benchmarks/routines/norm.py @@ -51,6 +51,8 @@ def run_norm_test(args): return testRmsnormFp4quant(args) elif args.routine == "add_rmsnorm_fp4quant": return testAddRmsnormFp4quant(args) + elif args.routine == "fused_rmsnorm_silu": + return testFusedRmsnormSilu(args) else: raise ValueError(f"Unsupported routine: {args.routine}") @@ -1072,3 +1074,122 @@ def run_backend(backend, input_tensor, residual_tensor, weight): cur_res["case_tag"] = args.case_tag res.append(cur_res) return res + + +def testFusedRmsnormSilu(args): + """ + Test fused_rmsnorm_silu API (RMSNorm + SiLU activation). + + This test: + 1. Generates random input tensors + 2. Runs fused_rmsnorm_silu with bf16 output + 3. Optionally runs reference check + 4. Measures performance metrics (memory bandwidth) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testFusedRmsnormSilu") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + batch_size = args.batch_size + hidden_size = args.hidden_size + eps = args.eps + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype != torch.bfloat16: + raise ValueError( + f"fused_rmsnorm_silu requires bfloat16 input, got {args.input_dtype}" + ) + + input_shape = (batch_size, hidden_size) + input_tensor = torch.randn(input_shape, dtype=torch.bfloat16, device=device) + weight = torch.rand(hidden_size, dtype=torch.bfloat16, device=device) * 1.5 + 0.5 + out = torch.empty(input_shape, dtype=torch.bfloat16, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {input_tensor.shape = }") + print(f"[VVERBOSE] {input_tensor.dtype = }") + print(f"[VVERBOSE] {weight.shape = }") + + def run_fn(input_tensor, weight, out): + return flashinfer.fused_rmsnorm_silu(input_tensor, weight, eps=eps, out=out) + + has_reference_output = False + if run_refcheck: + rms = torch.sqrt( + torch.mean(input_tensor.float() ** 2, dim=-1, keepdim=True) + eps + ) + x_norm = input_tensor.float() / rms * weight.float() + reference_output = torch.nn.functional.silu(x_norm).to(torch.bfloat16) + has_reference_output = True + + if run_refcheck: + test_out = run_fn(input_tensor, weight, out) + if has_reference_output: + ( + num_different_elements, + num_elements, + num_different_elements_percentage, + ) = is_close_stats(reference_output, test_out, rtol=2e-2, atol=2e-2) + if num_different_elements > 0: + print( + f"[ERROR] Output tensor mismatch: " + f"{num_different_elements}/{num_elements} ({num_different_elements_percentage:.2f}%) elements differ" + ) + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Output mismatch with {num_different_elements} elements" + ) + + times = bench_gpu_time( + fn=run_fn, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(input_tensor, weight, out), + ) + + if len(times) > 0: + median_time = np.median(times) + std_time = np.std(times) + + num_elements = np.prod(input_shape) + problem_bytes = ( + num_elements * input_dtype.itemsize # input read + + hidden_size * input_dtype.itemsize # weight read + + num_elements * input_dtype.itemsize # output write + ) + problem_flops = num_elements * 7 # rmsnorm (5) + silu (2: exp + div) + tflops = problem_flops / (10**9 * median_time) + tb_per_sec = problem_bytes / (10**9 * median_time) + + print_perf_metrics("cuda", median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["input_dtype"] = str(input_dtype) + cur_res["eps"] = eps + cur_res["backend"] = "cuda" + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res diff --git a/benchmarks/samples/sample_testlist.txt b/benchmarks/samples/sample_testlist.txt index 0cc7bca6a6..3c4aab7620 100644 --- a/benchmarks/samples/sample_testlist.txt +++ b/benchmarks/samples/sample_testlist.txt @@ -133,6 +133,14 @@ # Both SF layouts with MXFP4 format --routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype mxfp4 --output_both_sf_layouts -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_mxfp4_both_sf" +## Fused RMSNorm + SiLU (SM80+, sweep-tuned on SM100/B200) +# VAE decoder shapes (LUT-optimized on B200) +--routine fused_rmsnorm_silu --batch_size 1560 --hidden_size 1024 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "fused_rmsnorm_silu_vae_small" +--routine fused_rmsnorm_silu --batch_size 24960 --hidden_size 512 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "fused_rmsnorm_silu_vae_mid" +--routine fused_rmsnorm_silu --batch_size 99840 --hidden_size 256 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "fused_rmsnorm_silu_vae_large" +# Non-VAE shapes (fallback heuristics) +--routine fused_rmsnorm_silu --batch_size 2048 --hidden_size 4096 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "fused_rmsnorm_silu_llama" + ## Quantization (Blackwell SM10.0+ only) # MxFP8 Quantization - basic --routine mxfp8_quantize --m 1024 --k 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "mxfp8_quantize_basic"