Skip to content

Add flashinfer.fused_rmsnorm_silu() with native kernel backend#2965

Merged
kahyunnam merged 18 commits intoflashinfer-ai:mainfrom
kahyunnam:knam/fused-rmsnorm-silu-option3_direct_kernel
Apr 8, 2026
Merged

Add flashinfer.fused_rmsnorm_silu() with native kernel backend#2965
kahyunnam merged 18 commits intoflashinfer-ai:mainfrom
kahyunnam:knam/fused-rmsnorm-silu-option3_direct_kernel

Conversation

@kahyunnam
Copy link
Copy Markdown
Collaborator

@kahyunnam kahyunnam commented Apr 2, 2026

📌 Description

Originally, this was kernel open sourced into CuDNN OSS and integrated here: #2691.

However, CuDNN OSS does not have native support for on-disk cache or precompiled PyPI wheels. This limits end-to-end perf since this would not support dynamic shapes. After scoping out the internal process for releasing a new PyPI wheel, it was decided that this would take too much time.

In this PR, I move this kernel directly into FlashInfer, so that we can re-use the existing jit cache and wheel packaging architecture.

🔍 Related Issues

Issue 2571

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • Added a fused RMSNorm+SiLU primitive with bf16, FP8 (E4M3) and NVFP4 (block-scale) output modes.
  • Performance
    • SM100-optimized kernels for high-throughput GPU execution.
  • JIT / AOT
    • JIT/AOT generation for SM100+ with automatic kernel/config selection and fallback knob heuristics.
  • API
    • Exported flashinfer.fused_rmsnorm_silu supporting out= preallocation; NVFP4 returns per-block scale alongside output.
  • Tests
    • Comprehensive CUDA SM100+ tests covering dtypes, fallback shapes, preallocated outputs, eps sensitivity, and NVFP4 round‑trip checks.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 2, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds an SM100-targeted fused RMSNorm+SiLU implementation: new NVRTC-compatible headers and kernel, CUDA host entry and TVM FFI binding, JIT generator and AOT integration, Python API surface and re-export, workspace/knob logic, and SM100‑gated tests for bf16/FP8/NVFP4 outputs.

Changes

Cohort / File(s) Summary
CUDA bindings & host entry
csrc/flashinfer_rmsnorm_silu_binding.cu, csrc/rmsnorm_silu.cu
New TVM FFI export rmsnorm_silu and CUDA host entry: tensor validation, device/stream selection, workspace layout & init (including FP8/NVFP4 scale handling), params packing (PersistentLnFwdParams), multi‑CTA barrier setup, and kernel launch.
Kernel headers & device kernel
include/flashinfer/norm/ln_fwd_silu_kernel.cuh, include/flashinfer/norm/ln_silu_headers.cuh
New NVRTC/JIT-friendly headers and SM100-optimized ln_fwd kernel: kernel traits, inter-CTA sync, reducers/stats, FP8/NVFP4 block-scale helpers, shared-memory layout, and ln_fwd_kernel entry.
JIT generator & AOT integration
flashinfer/jit/rmsnorm_silu.py, flashinfer/aot.py
Added JIT module generator with LUT and fallback knob selection, CTAs-per-row estimation, config generation & per-config directory, CSRC copying, extra CUDA flags/includes, and integration into AOT gen_all_modules() gated by SM100.
Python API surface
flashinfer/__init__.py, flashinfer/norm/__init__.py
Re-exported fused_rmsnorm_silu and added implementation: input/device/dtype gating, knob selection, workspace sizing/allocation, module build/load invocation, NVFP4 block_scale handling and return semantics.
Tests
tests/norm/test_fused_rmsnorm_silu.py
New SM100-gated end-to-end test suite for bf16/FP8/NVFP4: float32 reference, FP4 helpers, LUT and fallback shapes, preallocated-output behavior, NVFP4 shape/dtype error checks, epsilon/weight edge cases, and NVFP4 round‑trip checks.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant API as fused_rmsnorm_silu API
    participant JIT as JIT Module Generator
    participant Compiler as NVRTC/AOT Compiler
    participant Module as Compiled Module
    participant Kernel as ln_fwd_kernel
    participant GPU as GPU Hardware

    User->>API: fused_rmsnorm_silu(input, weight, eps, out, block_scale?)
    activate API
    API->>API: validate inputs, select_knobs, compute workspace
    API->>JIT: gen_rmsnorm_silu_module(config)
    activate JIT
    JIT->>JIT: generate config, copy CSRC, write inc
    JIT->>Compiler: request build (NVRTC/AOT)
    deactivate JIT
    Compiler-->>Module: compiled module
    API->>Module: load module
    API->>Module: module.rmsnorm_silu(..., workspace, scale_row_out, sm_count)
    Module->>Kernel: launch ln_fwd_kernel<<<grid,block>>>(params)
    activate Kernel
    Kernel->>GPU: init shared memory/barriers, compute stats
    GPU->>GPU: apply RMSNorm, SiLU, optional quantize/block-scale
    Kernel-->>Module: kernel complete
    deactivate Kernel
    Module-->>API: output (and optional block_scale)
    deactivate API
    API-->>User: return output (and optional block_scale)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Suggested labels

run-ci, op: norm

Suggested reviewers

  • aleozlx
  • yzh119
  • cyx-6
  • jimmyzho
  • nv-yunzheq
  • sricketts
  • samuellees

Poem

🐰 In a burrow of kernels I quietly tune,
RMSNorm and SiLU dance under SM100 moon,
LUTs lay the path, scales snug in a row,
Barriers hum softly where warps ebb and flow,
A rabbit applauds: outputs ready to go.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 63.64% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: adding a fused RMSNorm+SiLU function with a native kernel backend.
Description check ✅ Passed The PR description comprehensively covers the motivation (moving from CuDNN OSS to native kernel for JIT cache/wheel support), references related issue #2571, and confirms pre-commit and test requirements are met.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@kahyunnam kahyunnam changed the title gnative kernel [wip] rmsnorm + silu native kernel Apr 2, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused RMSNorm and SiLU kernel, ported from the cuDNN frontend, to optimize performance for specific workloads on SM100+ architectures. The implementation includes JIT compilation support, a configuration lookup table for optimal kernel parameters, and comprehensive unit tests. I have provided feedback to ensure that the SM count is retrieved based on the input tensor's device rather than the current CUDA device, which is critical for multi-GPU support.

Comment thread flashinfer/norm/__init__.py
Comment thread flashinfer/norm/__init__.py Outdated
The C++ header sm100_rms_norm_silu_knobs.h was never included by any
source file — all knob selection happens in Python at JIT compile time
via flashinfer/jit/rmsnorm_silu.py. Keeping a duplicate 120-entry LUT
in C++ was a maintenance burden with no benefit. AI-assisted.

Made-with: Cursor
@kahyunnam kahyunnam force-pushed the knam/fused-rmsnorm-silu-option3_direct_kernel branch from e35a4db to f41322a Compare April 3, 2026 00:33
ln_fwd_silu_kernel.cuh requires Ktraits, PersistentLnFwdParams, and
other types to be defined before inclusion. The correct order is:
  1. ln_silu_headers.cuh (type definitions)
  2. rmsnorm_silu_config.inc (Ktraits typedef, constexpr flags)
  3. ln_fwd_silu_kernel.cuh (kernel using the above)

Protected with clang-format off/on since alphabetical sorting
would break this dependency chain. AI-assisted.

Made-with: Cursor
@kahyunnam kahyunnam changed the title [wip] rmsnorm + silu native kernel Add flashinfer.fused_rmsnorm_silu() with native kernel backend Apr 3, 2026
@kahyunnam
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !501 has been created, and the CI pipeline #47591827 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
flashinfer/jit/rmsnorm_silu.py (1)

339-347: Declare the supported SM majors on this JIT spec.

gen_rmsnorm_silu_module() leaves supported_major_versions unset, so this backend has no explicit arch filter at the spec level. Please pass the validated major list here instead of relying on the caller's arch list to constrain compilation.

As per coding guidelines "Specify supported NVIDIA SM major versions in JIT modules using supported_major_versions parameter to limit compilation to specific GPU architectures"

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/jit/rmsnorm_silu.py` around lines 339 - 347,
gen_rmsnorm_silu_module currently calls gen_jit_spec without setting
supported_major_versions, so add the validated SM major list to the gen_jit_spec
call by passing supported_major_versions=<validated_list> (use the same
validated list computed in this module, e.g., supported_majors or
validated_majors) instead of relying on the caller's arch list; update the
gen_jit_spec invocation where uri, sources, extra_cuda_cflags, and
extra_include_paths are passed to include supported_major_versions to constrain
compilation to the intended NVIDIA SM major versions.
tests/norm/test_fused_rmsnorm_silu.py (1)

24-26: Use the shared GPU-capability helpers for skips.

This file reimplements the arch gate with torch.cuda.get_device_capability() instead of the repo helpers. Please switch the fixture to flashinfer.utils.get_compute_capability() / is_sm100a_supported() so the skip semantics stay aligned with the rest of the suite.

As per coding guidelines "Use flashinfer.utils functions (get_compute_capability(), is_sm90a_supported(), is_sm100a_supported()) to skip tests on unsupported GPU architectures"

Also applies to: 130-135

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/norm/test_fused_rmsnorm_silu.py` around lines 24 - 26, The test
reimplements GPU arch checks with a local get_cc() that calls
torch.cuda.get_device_capability(); replace that with the repo helpers by
importing and using flashinfer.utils.get_compute_capability() (and/or the
predicate helpers is_sm100a_supported() or is_sm90a_supported() as appropriate)
for skip logic so the test's skip semantics match the rest of the
suite—specifically, remove or replace the local get_cc() function and any direct
calls to torch.cuda.get_device_capability() with calls to
get_compute_capability() or the boolean helpers
(is_sm100a_supported()/is_sm90a_supported()) used where the test decides to
skip; also update the other occurrence of the same pattern later in the file to
use the same helpers.
include/flashinfer/norm/ln_silu_headers.cuh (1)

767-769: Make the unsupported cluster branch fail explicitly.

Lines 768 and 1095 use static_assert(true, ...), which is a no-op. If USE_CLUSTER is ever instantiated on an unsupported toolkit/arch, these branches stop returning a value and the failure becomes much harder to understand.

♻️ Suggested guard
-      static_assert(true, "Cluster enabled on host side but not available on device");
+      static_assert(!USE_CLUSTER,
+                    "Cluster enabled on host side but not available on device");

Based on learnings, static_assert-based constraints are intentionally kept in the CUDA header close to the implementation for easier auditability.

Also applies to: 1094-1096

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/norm/__init__.py`:
- Around line 662-669: The NVFP4 branch in flashinfer.norm.__init__.py (the
output_dtype_str == "nvfp4" path that validates expected_shape for out and uses
variables num_tokens, C, out and workspace) returns packed FP4 nibbles without
the per-block scale (scale_row) written into workspace by the kernel; update the
API to either surface and return the scale tensor alongside out (read the scale
data from workspace and return a tuple like (out, scale_row) or similar) or
explicitly raise a ValueError rejecting "nvfp4" outputs until scale metadata can
be returned; apply the same change to the other NVFP4 validation block
referenced around lines 705-708 so callers receive scale information or the
dtype is disallowed.

In `@include/flashinfer/norm/ln_silu_headers.cuh`:
- Around line 258-270: The pre-SM80 fallback in struct Converter<float2,
nv_bfloat162>::convert uses a union whose nv_bfloat16 members overlap, so
assigning tmp.x then tmp.y clobbers the first lane; fix by replacing the union
layout so the two nv_bfloat16 lanes occupy distinct storage (e.g., use a struct
or an array like nv_bfloat16 lanes[2] alongside the nv_bfloat162 raw
representation) and assign lanes[0] = __float2bfloat16_rn(x.x); lanes[1] =
__float2bfloat16_rn(x.y); then return the raw nv_bfloat162; update the
auto-generator template the same way so generated headers get the corrected
non-overlapping lane assignments.
- Around line 1283-1297: The clz function uses a signed left-shift (1 << i)
which is undefined for i==31; change clz to operate with unsigned masks by
converting the input to uint32_t (or changing the parameter to uint32_t) and use
1u (or uint32_t(1)) for the shift and comparisons so (1u << i) & ux is used
instead of ((1 << i) & x); keep the return semantics the same so
find_log_2(int32_t, bool) can continue calling clz unchanged.

---

Nitpick comments:
In `@flashinfer/jit/rmsnorm_silu.py`:
- Around line 339-347: gen_rmsnorm_silu_module currently calls gen_jit_spec
without setting supported_major_versions, so add the validated SM major list to
the gen_jit_spec call by passing supported_major_versions=<validated_list> (use
the same validated list computed in this module, e.g., supported_majors or
validated_majors) instead of relying on the caller's arch list; update the
gen_jit_spec invocation where uri, sources, extra_cuda_cflags, and
extra_include_paths are passed to include supported_major_versions to constrain
compilation to the intended NVIDIA SM major versions.

In `@tests/norm/test_fused_rmsnorm_silu.py`:
- Around line 24-26: The test reimplements GPU arch checks with a local get_cc()
that calls torch.cuda.get_device_capability(); replace that with the repo
helpers by importing and using flashinfer.utils.get_compute_capability() (and/or
the predicate helpers is_sm100a_supported() or is_sm90a_supported() as
appropriate) for skip logic so the test's skip semantics match the rest of the
suite—specifically, remove or replace the local get_cc() function and any direct
calls to torch.cuda.get_device_capability() with calls to
get_compute_capability() or the boolean helpers
(is_sm100a_supported()/is_sm90a_supported()) used where the test decides to
skip; also update the other occurrence of the same pattern later in the file to
use the same helpers.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c95b7559-ce27-4b87-b226-19ebc642a70a

📥 Commits

Reviewing files that changed from the base of the PR and between fc08cd1 and 5609758.

📒 Files selected for processing (9)
  • csrc/flashinfer_rmsnorm_silu_binding.cu
  • csrc/rmsnorm_silu.cu
  • flashinfer/__init__.py
  • flashinfer/aot.py
  • flashinfer/jit/rmsnorm_silu.py
  • flashinfer/norm/__init__.py
  • include/flashinfer/norm/ln_fwd_silu_kernel.cuh
  • include/flashinfer/norm/ln_silu_headers.cuh
  • tests/norm/test_fused_rmsnorm_silu.py

Comment thread flashinfer/norm/__init__.py
Comment thread include/flashinfer/norm/ln_silu_headers.cuh
Comment thread include/flashinfer/norm/ln_silu_headers.cuh
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47591827: 10/20 passed

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
flashinfer/norm/__init__.py (1)

597-603: Prefer backend_requirement for capability-gated public API.

fused_rmsnorm_silu has explicit SM gating logic but is not wired through the repository’s API capability-decorator pattern. Aligning this API with backend_requirement keeps discoverability and capability checks consistent across public entrypoints.

Based on learnings: "Applies to flashinfer/*.py : Use backend_requirement decorator on APIs that have compute capability requirements and provide is_compute_capability_supported(cc) and is_backend_supported() methods".

Also applies to: 680-687

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/norm/__init__.py` around lines 597 - 603, The public API
fused_rmsnorm_silu currently performs its own SM capability gating; instead
annotate the function with the repository's capability decorator (use
`@backend_requirement`) and ensure it supplies/uses the required check methods
(is_compute_capability_supported(cc) and is_backend_supported()) so capability
checks are centralized; update fused_rmsnorm_silu (and the similar API around
lines ~680-687) to remove or delegate internal SM gating to the decorator,
import and apply backend_requirement to the function, and wire the two helper
methods referenced above so the decorator can perform the gating consistently
for this public entrypoint.
tests/norm/test_fused_rmsnorm_silu.py (1)

91-108: Avoid per-block CPU transfers in FP4 reference quantization.

block_vals = ... .cpu().float() inside the block loop causes repeated host-device sync/copies and dominates test runtime at large shapes. Keep this path on GPU to reduce runtime and flakiness.

Suggested change
-        block_vals = values_f32[:, col_start:col_end].cpu().float()
+        block_vals = values_f32[:, col_start:col_end].float()
...
-        nibbles[:, col_start:col_end] = block_nibbles.to(values_f32.device)
+        nibbles[:, col_start:col_end] = block_nibbles
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/norm/test_fused_rmsnorm_silu.py` around lines 91 - 108, The reference
quantization loop currently moves each block to CPU via block_vals =
values_f32[:, col_start:col_end].cpu().float(), causing repeated host-device
transfers; keep computation on GPU by removing .cpu() and ensuring block_vals is
cast to float on the same device as values_f32 (e.g., use
.to(dtype=torch.float32, device=values_f32.device) or .float() while not calling
.cpu()), then perform amax, scale, scaled, magnitudes, signs, diffs, argmin
(mag_nibbles), and nibbles assignment entirely on the device so no per-block CPU
sync occurs; update references to block_vals and any intermediate tensors (amax,
scale, scaled, diffs, mag_nibbles, block_nibbles) to operate on GPU and only
move data to CPU once if/when needed outside the loop.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/norm/__init__.py`:
- Around line 632-636: The docstring for NVFP4 currently claims block_scale has
shape (num_tokens, hidden_size // 16) but the implementation computes num_blocks
= (C + 15) // 16 and returns block_scale with shape (num_tokens,
ceil(hidden_size / 16)); update the docstrings in __init__.py for the NVFP4
sections (the docstring around the NVFP4 description and the similar text near
lines showing the rmsnorm_fp4quant convention) to state block_scale shape as
(num_tokens, ceil(hidden_size / 16)) (or explicitly note num_blocks =
(hidden_size + 15) // 16) to match the implementation, or alternatively enforce
C % 16 == 0 in the code—pick the documentation change to keep behavior stable
and reference rmsnorm_fp4quant and the NVFP4 description when making the edit.

In `@tests/norm/test_fused_rmsnorm_silu.py`:
- Around line 138-141: The test matrix ALL_LUT_SHAPES (built from SUPPORTED_C
and SUPPORTED_TOKENS) is too large for CI; limit default CI to a small smoke
subset and move exhaustive combinations behind a slow marker. Create a
SMALL_SMOKE_LUT_SHAPES (e.g., pick 2 C values and 2 small token values) and
replace uses of ALL_LUT_SHAPES in the default parametrized tests with this smoke
list, and add a new EXHAUSTIVE_LUT_SHAPES = ALL_LUT_SHAPES that is used only in
tests decorated with pytest.mark.slow (or a custom marker) to run bf16/fp8/nvfp4
coverage in extended runs; update references to SUPPORTED_C, SUPPORTED_TOKENS,
and ALL_LUT_SHAPES accordingly in the functions/tests that currently iterate
these lists.
- Around line 24-26: Replace the ad-hoc torch.cuda checks in the test helper
get_cc and other GPU-arch skip logic (e.g., the code around get_cc and the
checks at lines ~130-136) with the flashinfer.utils functions: call
flashinfer.utils.get_compute_capability() instead of
torch.cuda.get_device_capability(), and use
flashinfer.utils.is_sm100a_supported() (or is_sm90a_supported() as appropriate)
to decide skips; update imports to pull these utilities and ensure skip
conditions use those boolean helpers rather than manual major/minor arithmetic.

---

Nitpick comments:
In `@flashinfer/norm/__init__.py`:
- Around line 597-603: The public API fused_rmsnorm_silu currently performs its
own SM capability gating; instead annotate the function with the repository's
capability decorator (use `@backend_requirement`) and ensure it supplies/uses the
required check methods (is_compute_capability_supported(cc) and
is_backend_supported()) so capability checks are centralized; update
fused_rmsnorm_silu (and the similar API around lines ~680-687) to remove or
delegate internal SM gating to the decorator, import and apply
backend_requirement to the function, and wire the two helper methods referenced
above so the decorator can perform the gating consistently for this public
entrypoint.

In `@tests/norm/test_fused_rmsnorm_silu.py`:
- Around line 91-108: The reference quantization loop currently moves each block
to CPU via block_vals = values_f32[:, col_start:col_end].cpu().float(), causing
repeated host-device transfers; keep computation on GPU by removing .cpu() and
ensuring block_vals is cast to float on the same device as values_f32 (e.g., use
.to(dtype=torch.float32, device=values_f32.device) or .float() while not calling
.cpu()), then perform amax, scale, scaled, magnitudes, signs, diffs, argmin
(mag_nibbles), and nibbles assignment entirely on the device so no per-block CPU
sync occurs; update references to block_vals and any intermediate tensors (amax,
scale, scaled, diffs, mag_nibbles, block_nibbles) to operate on GPU and only
move data to CPU once if/when needed outside the loop.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 619f6680-c27d-4cf7-8ca4-50cc28d039d3

📥 Commits

Reviewing files that changed from the base of the PR and between 5609758 and 3611dea.

📒 Files selected for processing (2)
  • flashinfer/norm/__init__.py
  • tests/norm/test_fused_rmsnorm_silu.py

Comment thread flashinfer/norm/__init__.py
Comment thread tests/norm/test_fused_rmsnorm_silu.py
Comment thread tests/norm/test_fused_rmsnorm_silu.py
@kahyunnam
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !501 has been updated with latest changes, and the CI pipeline #47663798 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/norm/__init__.py`:
- Around line 720-733: Return value retains a view into the temporary workspace
(block_scale) which keeps the whole scratch buffer alive; instead make a
standalone copy before returning. After slicing workspace into block_scale and
converting with .view(torch.float8_e4m3fn), replace the direct view-return with
creating an owned tensor (e.g., clone()/detach() and ensure contiguous memory)
preserving dtype and shape, then reshape to (num_tokens, num_blocks) and return
that copy alongside out so the scratch workspace can be released.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: a6e96218-2924-40d8-8fc1-f1c10a1e79d1

📥 Commits

Reviewing files that changed from the base of the PR and between 3611dea and 23bc908.

📒 Files selected for processing (2)
  • flashinfer/norm/__init__.py
  • include/flashinfer/norm/ln_silu_headers.cuh
✅ Files skipped from review due to trivial changes (1)
  • include/flashinfer/norm/ln_silu_headers.cuh

Comment thread flashinfer/norm/__init__.py
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47663798: 10/20 passed

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
tests/norm/test_fused_rmsnorm_silu.py (1)

24-26: ⚠️ Potential issue | 🟡 Minor

Switch this fixture to the repo's SM100 skip helpers.

The ad-hoc get_device_capability()/< 100 check can still misclassify unsupported Blackwell variants, and the torch.cuda.is_available() branch hides misconfigured CUDA test jobs in this repo.

💡 Suggested fixture update
 import pytest
 import torch
 import torch.nn.functional as F
+from flashinfer.utils import get_compute_capability, is_sm100a_supported
 
-
-def get_cc():
-    major, minor = torch.cuda.get_device_capability()
-    return major * 10 + minor
-
-
 `@pytest.fixture`(autouse=True)
 def skip_if_not_sm100():
-    if not torch.cuda.is_available():
-        pytest.skip("CUDA not available")
-    if get_cc() < 100:
-        pytest.skip("Fused RMSNorm+SiLU requires SM100+")
+    if not is_sm100a_supported(get_compute_capability(torch.device("cuda"))):
+        pytest.skip("Fused RMSNorm+SiLU requires SM100a")
Based on learnings: "Tests in the repository assume CUDA is available and do not require torch.cuda.is_available() guards in pytest fixtures." As per coding guidelines: "Use flashinfer.utils functions (`get_compute_capability()`, `is_sm90a_supported()`, `is_sm100a_supported()`) to skip tests on unsupported GPU architectures"

Also applies to: 130-135

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/norm/test_fused_rmsnorm_silu.py` around lines 24 - 26, Replace the
ad-hoc get_cc() and any torch.cuda.is_available() guards with the repository
skip helpers: call flashinfer.utils.get_compute_capability() to obtain
capability and use flashinfer.utils.is_sm90a_supported()/is_sm100a_supported()
to decide skipping; update the fixture that defines get_cc() (and the similar
logic around the other block referenced) to import and use those helpers so
tests skip unsupported Blackwell/SM100 variants correctly and avoid masking
misconfigured CUDA jobs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/rmsnorm_silu.cu`:
- Around line 37-54: Detect the empty-input case (rows == 0) at the top of the
launcher and return early before computing launch geometry or constructing
reduced_divisor(rows); specifically, add a guard right after computing rows/cols
(and after any input/output size checks) that does a no-op return if rows == 0
to avoid the subsequent ctas_per_col math and reduced_divisor(rows) creation
(the same change should be applied to the analogous block around lines 112-114).
Ensure the early-return occurs before using device_guard/get_stream or building
grid dimensions (ctas_per_col) so the launcher is a defined no-op for empty
inputs.

---

Duplicate comments:
In `@tests/norm/test_fused_rmsnorm_silu.py`:
- Around line 24-26: Replace the ad-hoc get_cc() and any
torch.cuda.is_available() guards with the repository skip helpers: call
flashinfer.utils.get_compute_capability() to obtain capability and use
flashinfer.utils.is_sm90a_supported()/is_sm100a_supported() to decide skipping;
update the fixture that defines get_cc() (and the similar logic around the other
block referenced) to import and use those helpers so tests skip unsupported
Blackwell/SM100 variants correctly and avoid masking misconfigured CUDA jobs.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: d3195c17-ba72-4f93-89b1-ab3911ac2ac1

📥 Commits

Reviewing files that changed from the base of the PR and between 23bc908 and 92a2edd.

📒 Files selected for processing (4)
  • csrc/flashinfer_rmsnorm_silu_binding.cu
  • csrc/rmsnorm_silu.cu
  • flashinfer/norm/__init__.py
  • tests/norm/test_fused_rmsnorm_silu.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • csrc/flashinfer_rmsnorm_silu_binding.cu

Comment thread csrc/rmsnorm_silu.cu
@kahyunnam
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !501 has been updated with latest changes, and the CI pipeline #47873905 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @kahyunnam , can you add benchmark scripts or support in the microbenchmark harness? This will help us understand what level of performance we will be achieving with the kernels.

Comment thread flashinfer/norm/__init__.py
Comment thread flashinfer/norm/__init__.py Outdated
Comment thread flashinfer/norm/__init__.py
@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47873905: 10/20 passed

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
flashinfer/norm/__init__.py (1)

708-713: Defensive None check may be unreachable.

Based on select_knobs in flashinfer/jit/rmsnorm_silu.py, the function always returns either from the LUT or via _compute_default_knobs() fallback—it never returns None. This check is defensive but unreachable. Consider whether it's worth keeping for future-proofing or removing to avoid confusing readers.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/norm/__init__.py` around lines 708 - 713, The None check after
calling select_knobs is unreachable because select_knobs always returns knobs
(from the LUT or via _compute_default_knobs); remove the defensive branch
raising ValueError to avoid confusion, or if you want to keep future-proofing,
replace the raise with an explicit assertion or a comment documenting that
select_knobs never returns None (referencing select_knobs and
_compute_default_knobs and the local variable knobs) so readers understand the
intent.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/norm/__init__.py`:
- Line 717: When calling _get_rmsnorm_silu_sm_count, guard against
input.device.index being None by resolving the actual CUDA device index first
(e.g., device_index = input.device.index if input.device.index is not None else
torch.cuda.current_device()) and pass that device_index into
_get_rmsnorm_silu_sm_count; this prevents torch.cuda.get_device_properties(None)
from being called and uses the current CUDA device when tensors were created
with device="cuda".

---

Nitpick comments:
In `@flashinfer/norm/__init__.py`:
- Around line 708-713: The None check after calling select_knobs is unreachable
because select_knobs always returns knobs (from the LUT or via
_compute_default_knobs); remove the defensive branch raising ValueError to avoid
confusion, or if you want to keep future-proofing, replace the raise with an
explicit assertion or a comment documenting that select_knobs never returns None
(referencing select_knobs and _compute_default_knobs and the local variable
knobs) so readers understand the intent.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 146e8b28-fa53-4bb2-8e6e-5ff0df3dabaa

📥 Commits

Reviewing files that changed from the base of the PR and between 92a2edd and fc18fbb.

📒 Files selected for processing (2)
  • flashinfer/norm/__init__.py
  • tests/norm/test_fused_rmsnorm_silu.py

Comment thread flashinfer/norm/__init__.py
Copy link
Copy Markdown
Collaborator

@bkryu bkryu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kahyunnam , forgot to mention the first time.

Can you add a link to the new fused_rmsnorm_silu in the documentation norm.rst?

@kahyunnam kahyunnam merged commit c2b4db2 into flashinfer-ai:main Apr 8, 2026
31 of 38 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants