Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@
"fused_add_rmsnorm_quant",
"rmsnorm_fp4quant",
"add_rmsnorm_fp4quant",
"fused_rmsnorm_silu",
],
"quantization": [
"mxfp8_quantize",
Expand Down
121 changes: 121 additions & 0 deletions benchmarks/routines/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def run_norm_test(args):
return testRmsnormFp4quant(args)
elif args.routine == "add_rmsnorm_fp4quant":
return testAddRmsnormFp4quant(args)
elif args.routine == "fused_rmsnorm_silu":
return testFusedRmsnormSilu(args)
else:
raise ValueError(f"Unsupported routine: {args.routine}")

Expand Down Expand Up @@ -1078,3 +1080,122 @@ def run_backend(backend, input_tensor, residual_tensor, weight):
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res


def testFusedRmsnormSilu(args):
"""
Test fused_rmsnorm_silu API (RMSNorm + SiLU activation).

This test:
1. Generates random input tensors
2. Runs fused_rmsnorm_silu with bf16 output
3. Optionally runs reference check
4. Measures performance metrics (memory bandwidth)

Args:
args: Parsed command line arguments containing test configuration

Returns:
dict: List of dictionaries containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testFusedRmsnormSilu")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")

device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)

batch_size = args.batch_size
hidden_size = args.hidden_size
eps = args.eps
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
res = []

input_dtype = dtype_str_to_torch_dtype(args.input_dtype)
if input_dtype != torch.bfloat16:
raise ValueError(
f"fused_rmsnorm_silu requires bfloat16 input, got {args.input_dtype}"
)

input_shape = (batch_size, hidden_size)
input_tensor = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
weight = torch.rand(hidden_size, dtype=torch.bfloat16, device=device) * 1.5 + 0.5
out = torch.empty(input_shape, dtype=torch.bfloat16, device=device)

if args.verbose >= 2:
print(f"[VVERBOSE] {input_tensor.shape = }")
print(f"[VVERBOSE] {input_tensor.dtype = }")
print(f"[VVERBOSE] {weight.shape = }")

def run_fn(input_tensor, weight, out):
return flashinfer.fused_rmsnorm_silu(input_tensor, weight, eps=eps, out=out)

has_reference_output = False
if run_refcheck:
rms = torch.sqrt(
torch.mean(input_tensor.float() ** 2, dim=-1, keepdim=True) + eps
)
x_norm = input_tensor.float() / rms * weight.float()
reference_output = torch.nn.functional.silu(x_norm).to(torch.bfloat16)
has_reference_output = True

if run_refcheck:
test_out = run_fn(input_tensor, weight, out)
if has_reference_output:
(
num_different_elements,
num_elements,
num_different_elements_percentage,
) = is_close_stats(reference_output, test_out, rtol=2e-2, atol=2e-2)
if num_different_elements > 0:
print(
f"[ERROR] Output tensor mismatch: "
f"{num_different_elements}/{num_elements} ({num_different_elements_percentage:.2f}%) elements differ"
)
if not args.allow_output_mismatch:
raise AssertionError(
f"[ERROR] Output mismatch with {num_different_elements} elements"
)

times = bench_gpu_time(
fn=run_fn,
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
enable_cupti=args.use_cupti,
use_cuda_graph=is_cuda_graph_compatible,
input_args=(input_tensor, weight, out),
)

if len(times) > 0:
median_time = np.median(times)
std_time = np.std(times)

num_elements = np.prod(input_shape)
problem_bytes = (
num_elements * input_dtype.itemsize # input read
+ hidden_size * input_dtype.itemsize # weight read
+ num_elements * input_dtype.itemsize # output write
)
problem_flops = num_elements * 7 # rmsnorm (5) + silu (2: exp + div)
tflops = problem_flops / (10**9 * median_time)
tb_per_sec = problem_bytes / (10**9 * median_time)

print_perf_metrics("cuda", median_time, std_time, tflops, tb_per_sec)

if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["input_dtype"] = str(input_dtype)
cur_res["eps"] = eps
cur_res["backend"] = "cuda"
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res
8 changes: 8 additions & 0 deletions benchmarks/samples/sample_testlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@
# Both SF layouts with MXFP4 format
--routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype mxfp4 --output_both_sf_layouts -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_mxfp4_both_sf"

## Fused RMSNorm + SiLU (SM80+, sweep-tuned on SM100/B200)
# VAE decoder shapes (LUT-optimized on B200)
--routine fused_rmsnorm_silu --batch_size 1560 --hidden_size 1024 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "fused_rmsnorm_silu_vae_small"
--routine fused_rmsnorm_silu --batch_size 24960 --hidden_size 512 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "fused_rmsnorm_silu_vae_mid"
--routine fused_rmsnorm_silu --batch_size 99840 --hidden_size 256 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "fused_rmsnorm_silu_vae_large"
# Non-VAE shapes (fallback heuristics)
--routine fused_rmsnorm_silu --batch_size 2048 --hidden_size 4096 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "fused_rmsnorm_silu_llama"

## Quantization (Blackwell SM10.0+ only)
# MxFP8 Quantization - basic
--routine mxfp8_quantize --m 1024 --k 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "mxfp8_quantize_basic"
Expand Down
21 changes: 21 additions & 0 deletions csrc/flashinfer_rmsnorm_silu_binding.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright (c) 2026 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tvm_ffi_utils.h"

void rmsnorm_silu(TensorView output, TensorView input, TensorView weight, double eps,
TensorView workspace, TensorView scale_row_out, int64_t sm_count);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(rmsnorm_silu, rmsnorm_silu);
115 changes: 115 additions & 0 deletions csrc/rmsnorm_silu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright (c) 2026 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// clang-format off
// Include order matters: headers → config (defines Ktraits) → kernel (uses Ktraits)
#include <algorithm>
#include <flashinfer/norm/ln_silu_headers.cuh>
#include "rmsnorm_silu_config.inc"
#include <flashinfer/norm/ln_fwd_silu_kernel.cuh>
// clang-format on

#include "tvm_ffi_utils.h"

void rmsnorm_silu(TensorView output, TensorView input, TensorView weight, double eps,
TensorView workspace, TensorView scale_row_out, int64_t sm_count) {
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(output);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(weight);
CHECK_DEVICE(input, weight);
CHECK_DIM(2, input);
CHECK_DIM(2, output);
CHECK_DIM(1, weight);

int rows = input.size(0);
int cols = input.size(1);
TVM_FFI_ICHECK_EQ(cols, HIDDEN_SIZE) << "Input cols must match compiled HIDDEN_SIZE";
TVM_FFI_ICHECK_EQ(output.size(0), rows);

ffi::CUDADeviceGuard device_guard(input.device().device_id);
const cudaStream_t stream = get_stream(input.device());

// Grid dimensions (same logic as Sm100RmsNormSiluEngine::execute)
int ctas_per_col_max = (rows + WARPS_M - 1) / WARPS_M;
int ctas_per_col;
if (KERNEL_CFG == 2) {
ctas_per_col = ctas_per_col_max;
} else {
ctas_per_col =
std::min(static_cast<int>(sm_count) * DESIRED_OCCUPANCY / CTAS_PER_ROW, ctas_per_col_max);
}
ctas_per_col = std::max(ctas_per_col, 1);
Comment thread
kahyunnam marked this conversation as resolved.

dim3 grid(CTAS_PER_ROW * ctas_per_col);
dim3 block(WARPS_M * WARPS_N * 32);

// Pack kernel params
PersistentLnFwdParams params{};
params.rows = rows;
params.cols = cols;
params.ctas_per_col = ctas_per_col;
params.isRMSNorm = true;
params.noScale = false;
params.noBias = true;
params.isBatchFirst = true;
params.batchSize = 1;
params.seqLen = rows;
params.epsilon = static_cast<float>(eps);
params.x = input.data_ptr();
params.z = output.data_ptr();
params.gamma = weight.data_ptr();

// Workspace layout (128-byte aligned segments)
char* ws_ptr = static_cast<char*>(workspace.data_ptr());

// [0] rs: rows * sizeof(float)
params.rs = ws_ptr;
int64_t off = static_cast<int64_t>(rows) * sizeof(float);
off = ((off + 127) / 128) * 128;

// [aligned] fp8_scale: sizeof(float)
if (isFP8Out) {
params.fp8_out = true;
float* default_scale = reinterpret_cast<float*>(ws_ptr + off);
// Set scale = 1.0f via cudaMemcpyAsync from host
static const float one = 1.0f;
cudaMemcpyAsync(default_scale, &one, sizeof(float), cudaMemcpyHostToDevice, stream);
params.scale = default_scale;
}
off += sizeof(float);
off = ((off + 127) / 128) * 128;

// scale_row: passed as separate output tensor (NVFP4 only)
if (isFP4Out) {
params.scale_row = scale_row_out.data_ptr();
}

// [aligned] cooperative workspace + barriers (multi-CTA only)
if (CTAS_PER_ROW > 1) {
params.workspace = ws_ptr + off;
int64_t coop_ws_size =
static_cast<int64_t>(ctas_per_col) * WARPS_M * CTAS_PER_ROW * sizeof(float) * 2 * 2;
off += coop_ws_size;
off = ((off + 127) / 128) * 128;

params.barrier = reinterpret_cast<int*>(ws_ptr + off);
cudaMemsetAsync(params.barrier, 0, 2 * ctas_per_col * sizeof(int32_t), stream);
}

reduced_divisor divisor(rows);

ln_fwd_kernel<<<grid, block, 0, stream>>>(params, divisor);
}
1 change: 1 addition & 0 deletions docs/api/norm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ Kernels for normalization layers.
gemma_rmsnorm
gemma_fused_add_rmsnorm
layernorm
fused_rmsnorm_silu
1 change: 1 addition & 0 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
from .norm import gemma_rmsnorm as gemma_rmsnorm
from .norm import rmsnorm as rmsnorm
from .norm import rmsnorm_quant as rmsnorm_quant
from .norm import fused_rmsnorm_silu as fused_rmsnorm_silu

try:
from .norm import rmsnorm_fp4quant as rmsnorm_fp4quant
Expand Down
46 changes: 46 additions & 0 deletions flashinfer/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@
)
from .jit.mla import gen_mla_module
from .jit.norm import gen_norm_module
from .jit.rmsnorm_silu import (
gen_rmsnorm_silu_module,
select_knobs,
_estimate_ctas_per_row,
_compute_default_knobs,
_SUPPORTED_C,
_SUPPORTED_TOKENS,
)
from .jit.page import gen_page_module
from .jit.quantization import gen_quantization_module
from .jit.rope import gen_rope_module
Expand Down Expand Up @@ -558,6 +566,44 @@ def gen_all_modules(
gen_sampling_module(),
gen_topk_module(),
]
# Fused RMSNorm+SiLU: pre-compile all LUT configs (SM100+ only)
if has_sm100:
for C in _SUPPORTED_C:
for tokens in _SUPPORTED_TOKENS:
for dtype in ["bf16", "fp8", "nvfp4"]:
knobs = select_knobs(C, tokens, dtype)
if knobs is None:
continue
wm, sc, kcfg, occ, bpl = knobs
cpr = _estimate_ctas_per_row(C, sc, kcfg, bpl)
jit_specs.append(
gen_rmsnorm_silu_module(C, dtype, wm, cpr, bpl, kcfg, occ)
)
# Fallback configs for common hidden sizes not in the LUT.
# Fallback knobs depend only on (C, dtype), not num_tokens,
# so one module per (C, dtype) covers all token counts.
_FALLBACK_C = [
768,
1280,
1536,
2048,
2560,
3072,
4096,
5120,
6144,
8192,
]
for C in _FALLBACK_C:
for dtype in ["bf16", "fp8", "nvfp4"]:
knobs = _compute_default_knobs(C, dtype)
if knobs is None:
continue
wm, sc, kcfg, occ, bpl = knobs
cpr = _estimate_ctas_per_row(C, sc, kcfg, bpl)
jit_specs.append(
gen_rmsnorm_silu_module(C, dtype, wm, cpr, bpl, kcfg, occ)
)
# selective_state_update: one module per dtype combo per GPU arch
_ssu_dtype_combos = [
# (state, input, weight, matrixA, stateIndex, state_scale_dtype)
Expand Down
Loading
Loading