Add flashinfer.fused_rmsnorm_silu() with native kernel backend#2965
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds an SM100-targeted fused RMSNorm+SiLU implementation: new NVRTC-compatible headers and kernel, CUDA host entry and TVM FFI binding, JIT generator and AOT integration, Python API surface and re-export, workspace/knob logic, and SM100‑gated tests for bf16/FP8/NVFP4 outputs. Changes
Sequence DiagramsequenceDiagram
participant User as User Code
participant API as fused_rmsnorm_silu API
participant JIT as JIT Module Generator
participant Compiler as NVRTC/AOT Compiler
participant Module as Compiled Module
participant Kernel as ln_fwd_kernel
participant GPU as GPU Hardware
User->>API: fused_rmsnorm_silu(input, weight, eps, out, block_scale?)
activate API
API->>API: validate inputs, select_knobs, compute workspace
API->>JIT: gen_rmsnorm_silu_module(config)
activate JIT
JIT->>JIT: generate config, copy CSRC, write inc
JIT->>Compiler: request build (NVRTC/AOT)
deactivate JIT
Compiler-->>Module: compiled module
API->>Module: load module
API->>Module: module.rmsnorm_silu(..., workspace, scale_row_out, sm_count)
Module->>Kernel: launch ln_fwd_kernel<<<grid,block>>>(params)
activate Kernel
Kernel->>GPU: init shared memory/barriers, compute stats
GPU->>GPU: apply RMSNorm, SiLU, optional quantize/block-scale
Kernel-->>Module: kernel complete
deactivate Kernel
Module-->>API: output (and optional block_scale)
deactivate API
API-->>User: return output (and optional block_scale)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a fused RMSNorm and SiLU kernel, ported from the cuDNN frontend, to optimize performance for specific workloads on SM100+ architectures. The implementation includes JIT compilation support, a configuration lookup table for optimal kernel parameters, and comprehensive unit tests. I have provided feedback to ensure that the SM count is retrieved based on the input tensor's device rather than the current CUDA device, which is critical for multi-GPU support.
The C++ header sm100_rms_norm_silu_knobs.h was never included by any source file — all knob selection happens in Python at JIT compile time via flashinfer/jit/rmsnorm_silu.py. Keeping a duplicate 120-entry LUT in C++ was a maintenance burden with no benefit. AI-assisted. Made-with: Cursor
e35a4db to
f41322a
Compare
ln_fwd_silu_kernel.cuh requires Ktraits, PersistentLnFwdParams, and other types to be defined before inclusion. The correct order is: 1. ln_silu_headers.cuh (type definitions) 2. rmsnorm_silu_config.inc (Ktraits typedef, constexpr flags) 3. ln_fwd_silu_kernel.cuh (kernel using the above) Protected with clang-format off/on since alphabetical sorting would break this dependency chain. AI-assisted. Made-with: Cursor
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (3)
flashinfer/jit/rmsnorm_silu.py (1)
339-347: Declare the supported SM majors on this JIT spec.
gen_rmsnorm_silu_module()leavessupported_major_versionsunset, so this backend has no explicit arch filter at the spec level. Please pass the validated major list here instead of relying on the caller's arch list to constrain compilation.As per coding guidelines "Specify supported NVIDIA SM major versions in JIT modules using
supported_major_versionsparameter to limit compilation to specific GPU architectures"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/jit/rmsnorm_silu.py` around lines 339 - 347, gen_rmsnorm_silu_module currently calls gen_jit_spec without setting supported_major_versions, so add the validated SM major list to the gen_jit_spec call by passing supported_major_versions=<validated_list> (use the same validated list computed in this module, e.g., supported_majors or validated_majors) instead of relying on the caller's arch list; update the gen_jit_spec invocation where uri, sources, extra_cuda_cflags, and extra_include_paths are passed to include supported_major_versions to constrain compilation to the intended NVIDIA SM major versions.tests/norm/test_fused_rmsnorm_silu.py (1)
24-26: Use the shared GPU-capability helpers for skips.This file reimplements the arch gate with
torch.cuda.get_device_capability()instead of the repo helpers. Please switch the fixture toflashinfer.utils.get_compute_capability()/is_sm100a_supported()so the skip semantics stay aligned with the rest of the suite.As per coding guidelines "Use flashinfer.utils functions (
get_compute_capability(),is_sm90a_supported(),is_sm100a_supported()) to skip tests on unsupported GPU architectures"Also applies to: 130-135
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/norm/test_fused_rmsnorm_silu.py` around lines 24 - 26, The test reimplements GPU arch checks with a local get_cc() that calls torch.cuda.get_device_capability(); replace that with the repo helpers by importing and using flashinfer.utils.get_compute_capability() (and/or the predicate helpers is_sm100a_supported() or is_sm90a_supported() as appropriate) for skip logic so the test's skip semantics match the rest of the suite—specifically, remove or replace the local get_cc() function and any direct calls to torch.cuda.get_device_capability() with calls to get_compute_capability() or the boolean helpers (is_sm100a_supported()/is_sm90a_supported()) used where the test decides to skip; also update the other occurrence of the same pattern later in the file to use the same helpers.include/flashinfer/norm/ln_silu_headers.cuh (1)
767-769: Make the unsupported cluster branch fail explicitly.Lines 768 and 1095 use
static_assert(true, ...), which is a no-op. IfUSE_CLUSTERis ever instantiated on an unsupported toolkit/arch, these branches stop returning a value and the failure becomes much harder to understand.♻️ Suggested guard
- static_assert(true, "Cluster enabled on host side but not available on device"); + static_assert(!USE_CLUSTER, + "Cluster enabled on host side but not available on device");Based on learnings,
static_assert-based constraints are intentionally kept in the CUDA header close to the implementation for easier auditability.Also applies to: 1094-1096
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/norm/__init__.py`:
- Around line 662-669: The NVFP4 branch in flashinfer.norm.__init__.py (the
output_dtype_str == "nvfp4" path that validates expected_shape for out and uses
variables num_tokens, C, out and workspace) returns packed FP4 nibbles without
the per-block scale (scale_row) written into workspace by the kernel; update the
API to either surface and return the scale tensor alongside out (read the scale
data from workspace and return a tuple like (out, scale_row) or similar) or
explicitly raise a ValueError rejecting "nvfp4" outputs until scale metadata can
be returned; apply the same change to the other NVFP4 validation block
referenced around lines 705-708 so callers receive scale information or the
dtype is disallowed.
In `@include/flashinfer/norm/ln_silu_headers.cuh`:
- Around line 258-270: The pre-SM80 fallback in struct Converter<float2,
nv_bfloat162>::convert uses a union whose nv_bfloat16 members overlap, so
assigning tmp.x then tmp.y clobbers the first lane; fix by replacing the union
layout so the two nv_bfloat16 lanes occupy distinct storage (e.g., use a struct
or an array like nv_bfloat16 lanes[2] alongside the nv_bfloat162 raw
representation) and assign lanes[0] = __float2bfloat16_rn(x.x); lanes[1] =
__float2bfloat16_rn(x.y); then return the raw nv_bfloat162; update the
auto-generator template the same way so generated headers get the corrected
non-overlapping lane assignments.
- Around line 1283-1297: The clz function uses a signed left-shift (1 << i)
which is undefined for i==31; change clz to operate with unsigned masks by
converting the input to uint32_t (or changing the parameter to uint32_t) and use
1u (or uint32_t(1)) for the shift and comparisons so (1u << i) & ux is used
instead of ((1 << i) & x); keep the return semantics the same so
find_log_2(int32_t, bool) can continue calling clz unchanged.
---
Nitpick comments:
In `@flashinfer/jit/rmsnorm_silu.py`:
- Around line 339-347: gen_rmsnorm_silu_module currently calls gen_jit_spec
without setting supported_major_versions, so add the validated SM major list to
the gen_jit_spec call by passing supported_major_versions=<validated_list> (use
the same validated list computed in this module, e.g., supported_majors or
validated_majors) instead of relying on the caller's arch list; update the
gen_jit_spec invocation where uri, sources, extra_cuda_cflags, and
extra_include_paths are passed to include supported_major_versions to constrain
compilation to the intended NVIDIA SM major versions.
In `@tests/norm/test_fused_rmsnorm_silu.py`:
- Around line 24-26: The test reimplements GPU arch checks with a local get_cc()
that calls torch.cuda.get_device_capability(); replace that with the repo
helpers by importing and using flashinfer.utils.get_compute_capability() (and/or
the predicate helpers is_sm100a_supported() or is_sm90a_supported() as
appropriate) for skip logic so the test's skip semantics match the rest of the
suite—specifically, remove or replace the local get_cc() function and any direct
calls to torch.cuda.get_device_capability() with calls to
get_compute_capability() or the boolean helpers
(is_sm100a_supported()/is_sm90a_supported()) used where the test decides to
skip; also update the other occurrence of the same pattern later in the file to
use the same helpers.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c95b7559-ce27-4b87-b226-19ebc642a70a
📒 Files selected for processing (9)
csrc/flashinfer_rmsnorm_silu_binding.cucsrc/rmsnorm_silu.cuflashinfer/__init__.pyflashinfer/aot.pyflashinfer/jit/rmsnorm_silu.pyflashinfer/norm/__init__.pyinclude/flashinfer/norm/ln_fwd_silu_kernel.cuhinclude/flashinfer/norm/ln_silu_headers.cuhtests/norm/test_fused_rmsnorm_silu.py
|
[FAILED] Pipeline #47591827: 10/20 passed |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
flashinfer/norm/__init__.py (1)
597-603: Preferbackend_requirementfor capability-gated public API.
fused_rmsnorm_siluhas explicit SM gating logic but is not wired through the repository’s API capability-decorator pattern. Aligning this API withbackend_requirementkeeps discoverability and capability checks consistent across public entrypoints.Based on learnings: "Applies to flashinfer/*.py : Use
backend_requirementdecorator on APIs that have compute capability requirements and provideis_compute_capability_supported(cc)andis_backend_supported()methods".Also applies to: 680-687
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/norm/__init__.py` around lines 597 - 603, The public API fused_rmsnorm_silu currently performs its own SM capability gating; instead annotate the function with the repository's capability decorator (use `@backend_requirement`) and ensure it supplies/uses the required check methods (is_compute_capability_supported(cc) and is_backend_supported()) so capability checks are centralized; update fused_rmsnorm_silu (and the similar API around lines ~680-687) to remove or delegate internal SM gating to the decorator, import and apply backend_requirement to the function, and wire the two helper methods referenced above so the decorator can perform the gating consistently for this public entrypoint.tests/norm/test_fused_rmsnorm_silu.py (1)
91-108: Avoid per-block CPU transfers in FP4 reference quantization.
block_vals = ... .cpu().float()inside the block loop causes repeated host-device sync/copies and dominates test runtime at large shapes. Keep this path on GPU to reduce runtime and flakiness.Suggested change
- block_vals = values_f32[:, col_start:col_end].cpu().float() + block_vals = values_f32[:, col_start:col_end].float() ... - nibbles[:, col_start:col_end] = block_nibbles.to(values_f32.device) + nibbles[:, col_start:col_end] = block_nibbles🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/norm/test_fused_rmsnorm_silu.py` around lines 91 - 108, The reference quantization loop currently moves each block to CPU via block_vals = values_f32[:, col_start:col_end].cpu().float(), causing repeated host-device transfers; keep computation on GPU by removing .cpu() and ensuring block_vals is cast to float on the same device as values_f32 (e.g., use .to(dtype=torch.float32, device=values_f32.device) or .float() while not calling .cpu()), then perform amax, scale, scaled, magnitudes, signs, diffs, argmin (mag_nibbles), and nibbles assignment entirely on the device so no per-block CPU sync occurs; update references to block_vals and any intermediate tensors (amax, scale, scaled, diffs, mag_nibbles, block_nibbles) to operate on GPU and only move data to CPU once if/when needed outside the loop.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/norm/__init__.py`:
- Around line 632-636: The docstring for NVFP4 currently claims block_scale has
shape (num_tokens, hidden_size // 16) but the implementation computes num_blocks
= (C + 15) // 16 and returns block_scale with shape (num_tokens,
ceil(hidden_size / 16)); update the docstrings in __init__.py for the NVFP4
sections (the docstring around the NVFP4 description and the similar text near
lines showing the rmsnorm_fp4quant convention) to state block_scale shape as
(num_tokens, ceil(hidden_size / 16)) (or explicitly note num_blocks =
(hidden_size + 15) // 16) to match the implementation, or alternatively enforce
C % 16 == 0 in the code—pick the documentation change to keep behavior stable
and reference rmsnorm_fp4quant and the NVFP4 description when making the edit.
In `@tests/norm/test_fused_rmsnorm_silu.py`:
- Around line 138-141: The test matrix ALL_LUT_SHAPES (built from SUPPORTED_C
and SUPPORTED_TOKENS) is too large for CI; limit default CI to a small smoke
subset and move exhaustive combinations behind a slow marker. Create a
SMALL_SMOKE_LUT_SHAPES (e.g., pick 2 C values and 2 small token values) and
replace uses of ALL_LUT_SHAPES in the default parametrized tests with this smoke
list, and add a new EXHAUSTIVE_LUT_SHAPES = ALL_LUT_SHAPES that is used only in
tests decorated with pytest.mark.slow (or a custom marker) to run bf16/fp8/nvfp4
coverage in extended runs; update references to SUPPORTED_C, SUPPORTED_TOKENS,
and ALL_LUT_SHAPES accordingly in the functions/tests that currently iterate
these lists.
- Around line 24-26: Replace the ad-hoc torch.cuda checks in the test helper
get_cc and other GPU-arch skip logic (e.g., the code around get_cc and the
checks at lines ~130-136) with the flashinfer.utils functions: call
flashinfer.utils.get_compute_capability() instead of
torch.cuda.get_device_capability(), and use
flashinfer.utils.is_sm100a_supported() (or is_sm90a_supported() as appropriate)
to decide skips; update imports to pull these utilities and ensure skip
conditions use those boolean helpers rather than manual major/minor arithmetic.
---
Nitpick comments:
In `@flashinfer/norm/__init__.py`:
- Around line 597-603: The public API fused_rmsnorm_silu currently performs its
own SM capability gating; instead annotate the function with the repository's
capability decorator (use `@backend_requirement`) and ensure it supplies/uses the
required check methods (is_compute_capability_supported(cc) and
is_backend_supported()) so capability checks are centralized; update
fused_rmsnorm_silu (and the similar API around lines ~680-687) to remove or
delegate internal SM gating to the decorator, import and apply
backend_requirement to the function, and wire the two helper methods referenced
above so the decorator can perform the gating consistently for this public
entrypoint.
In `@tests/norm/test_fused_rmsnorm_silu.py`:
- Around line 91-108: The reference quantization loop currently moves each block
to CPU via block_vals = values_f32[:, col_start:col_end].cpu().float(), causing
repeated host-device transfers; keep computation on GPU by removing .cpu() and
ensuring block_vals is cast to float on the same device as values_f32 (e.g., use
.to(dtype=torch.float32, device=values_f32.device) or .float() while not calling
.cpu()), then perform amax, scale, scaled, magnitudes, signs, diffs, argmin
(mag_nibbles), and nibbles assignment entirely on the device so no per-block CPU
sync occurs; update references to block_vals and any intermediate tensors (amax,
scale, scaled, diffs, mag_nibbles, block_nibbles) to operate on GPU and only
move data to CPU once if/when needed outside the loop.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 619f6680-c27d-4cf7-8ca4-50cc28d039d3
📒 Files selected for processing (2)
flashinfer/norm/__init__.pytests/norm/test_fused_rmsnorm_silu.py
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/norm/__init__.py`:
- Around line 720-733: Return value retains a view into the temporary workspace
(block_scale) which keeps the whole scratch buffer alive; instead make a
standalone copy before returning. After slicing workspace into block_scale and
converting with .view(torch.float8_e4m3fn), replace the direct view-return with
creating an owned tensor (e.g., clone()/detach() and ensure contiguous memory)
preserving dtype and shape, then reshape to (num_tokens, num_blocks) and return
that copy alongside out so the scratch workspace can be released.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a6e96218-2924-40d8-8fc1-f1c10a1e79d1
📒 Files selected for processing (2)
flashinfer/norm/__init__.pyinclude/flashinfer/norm/ln_silu_headers.cuh
✅ Files skipped from review due to trivial changes (1)
- include/flashinfer/norm/ln_silu_headers.cuh
|
[FAILED] Pipeline #47663798: 10/20 passed |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tests/norm/test_fused_rmsnorm_silu.py (1)
24-26:⚠️ Potential issue | 🟡 MinorSwitch this fixture to the repo's SM100 skip helpers.
The ad-hoc
get_device_capability()/< 100check can still misclassify unsupported Blackwell variants, and thetorch.cuda.is_available()branch hides misconfigured CUDA test jobs in this repo.Based on learnings: "Tests in the repository assume CUDA is available and do not require torch.cuda.is_available() guards in pytest fixtures." As per coding guidelines: "Use flashinfer.utils functions (`get_compute_capability()`, `is_sm90a_supported()`, `is_sm100a_supported()`) to skip tests on unsupported GPU architectures"💡 Suggested fixture update
import pytest import torch import torch.nn.functional as F +from flashinfer.utils import get_compute_capability, is_sm100a_supported - -def get_cc(): - major, minor = torch.cuda.get_device_capability() - return major * 10 + minor - - `@pytest.fixture`(autouse=True) def skip_if_not_sm100(): - if not torch.cuda.is_available(): - pytest.skip("CUDA not available") - if get_cc() < 100: - pytest.skip("Fused RMSNorm+SiLU requires SM100+") + if not is_sm100a_supported(get_compute_capability(torch.device("cuda"))): + pytest.skip("Fused RMSNorm+SiLU requires SM100a")Also applies to: 130-135
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/norm/test_fused_rmsnorm_silu.py` around lines 24 - 26, Replace the ad-hoc get_cc() and any torch.cuda.is_available() guards with the repository skip helpers: call flashinfer.utils.get_compute_capability() to obtain capability and use flashinfer.utils.is_sm90a_supported()/is_sm100a_supported() to decide skipping; update the fixture that defines get_cc() (and the similar logic around the other block referenced) to import and use those helpers so tests skip unsupported Blackwell/SM100 variants correctly and avoid masking misconfigured CUDA jobs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/rmsnorm_silu.cu`:
- Around line 37-54: Detect the empty-input case (rows == 0) at the top of the
launcher and return early before computing launch geometry or constructing
reduced_divisor(rows); specifically, add a guard right after computing rows/cols
(and after any input/output size checks) that does a no-op return if rows == 0
to avoid the subsequent ctas_per_col math and reduced_divisor(rows) creation
(the same change should be applied to the analogous block around lines 112-114).
Ensure the early-return occurs before using device_guard/get_stream or building
grid dimensions (ctas_per_col) so the launcher is a defined no-op for empty
inputs.
---
Duplicate comments:
In `@tests/norm/test_fused_rmsnorm_silu.py`:
- Around line 24-26: Replace the ad-hoc get_cc() and any
torch.cuda.is_available() guards with the repository skip helpers: call
flashinfer.utils.get_compute_capability() to obtain capability and use
flashinfer.utils.is_sm90a_supported()/is_sm100a_supported() to decide skipping;
update the fixture that defines get_cc() (and the similar logic around the other
block referenced) to import and use those helpers so tests skip unsupported
Blackwell/SM100 variants correctly and avoid masking misconfigured CUDA jobs.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: d3195c17-ba72-4f93-89b1-ab3911ac2ac1
📒 Files selected for processing (4)
csrc/flashinfer_rmsnorm_silu_binding.cucsrc/rmsnorm_silu.cuflashinfer/norm/__init__.pytests/norm/test_fused_rmsnorm_silu.py
🚧 Files skipped from review as they are similar to previous changes (1)
- csrc/flashinfer_rmsnorm_silu_binding.cu
|
/bot run |
bkryu
left a comment
There was a problem hiding this comment.
Thanks @kahyunnam , can you add benchmark scripts or support in the microbenchmark harness? This will help us understand what level of performance we will be achieving with the kernels.
|
[FAILED] Pipeline #47873905: 10/20 passed |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
flashinfer/norm/__init__.py (1)
708-713: Defensive None check may be unreachable.Based on
select_knobsinflashinfer/jit/rmsnorm_silu.py, the function always returns either from the LUT or via_compute_default_knobs()fallback—it never returnsNone. This check is defensive but unreachable. Consider whether it's worth keeping for future-proofing or removing to avoid confusing readers.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/norm/__init__.py` around lines 708 - 713, The None check after calling select_knobs is unreachable because select_knobs always returns knobs (from the LUT or via _compute_default_knobs); remove the defensive branch raising ValueError to avoid confusion, or if you want to keep future-proofing, replace the raise with an explicit assertion or a comment documenting that select_knobs never returns None (referencing select_knobs and _compute_default_knobs and the local variable knobs) so readers understand the intent.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/norm/__init__.py`:
- Line 717: When calling _get_rmsnorm_silu_sm_count, guard against
input.device.index being None by resolving the actual CUDA device index first
(e.g., device_index = input.device.index if input.device.index is not None else
torch.cuda.current_device()) and pass that device_index into
_get_rmsnorm_silu_sm_count; this prevents torch.cuda.get_device_properties(None)
from being called and uses the current CUDA device when tensors were created
with device="cuda".
---
Nitpick comments:
In `@flashinfer/norm/__init__.py`:
- Around line 708-713: The None check after calling select_knobs is unreachable
because select_knobs always returns knobs (from the LUT or via
_compute_default_knobs); remove the defensive branch raising ValueError to avoid
confusion, or if you want to keep future-proofing, replace the raise with an
explicit assertion or a comment documenting that select_knobs never returns None
(referencing select_knobs and _compute_default_knobs and the local variable
knobs) so readers understand the intent.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 146e8b28-fa53-4bb2-8e6e-5ff0df3dabaa
📒 Files selected for processing (2)
flashinfer/norm/__init__.pytests/norm/test_fused_rmsnorm_silu.py
bkryu
left a comment
There was a problem hiding this comment.
Hi @kahyunnam , forgot to mention the first time.
Can you add a link to the new fused_rmsnorm_silu in the documentation norm.rst?
📌 Description
Originally, this was kernel open sourced into CuDNN OSS and integrated here: #2691.
However, CuDNN OSS does not have native support for on-disk cache or precompiled PyPI wheels. This limits end-to-end perf since this would not support dynamic shapes. After scoping out the internal process for releasing a new PyPI wheel, it was decided that this would take too much time.
In this PR, I move this kernel directly into FlashInfer, so that we can re-use the existing jit cache and wheel packaging architecture.
🔍 Related Issues
Issue 2571
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit