Skip to content

fix: use sym_int64 for strides in rmsnorm CuTe DSL kernels to prevent int32 overflow#3007

Merged
bkryu merged 3 commits intoflashinfer-ai:mainfrom
bkryu:norm_stride_overflow_fix
Apr 8, 2026
Merged

fix: use sym_int64 for strides in rmsnorm CuTe DSL kernels to prevent int32 overflow#3007
bkryu merged 3 commits intoflashinfer-ai:mainfrom
bkryu:norm_stride_overflow_fix

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Apr 7, 2026

📌 Description

Summary:

  • Fix integer overflow in rmsnorm CuTe DSL kernels when tensor strides exceed INT32_MAX (~2.1B)
  • Change cute.sym_int()cute.sym_int64() for all stride symbols across 5 compiled kernel functions in rmsnorm.py and fused_add_rmsnorm.py
  • Add 5 regression tests using torch.as_strided with stride = 2^31

Details
Fixes #3005

Tensor strides are products of higher dimensions and can exceed int32 range for large sequence lengths. The WAN 2.2 model (hidden_dim=5120, QKV fused projection) hits this at seq_len > 139,810 where batch_stride = 15360 × seq_len overflows.

The overflow is caught by TVM-FFI before kernel launch (see previously failing test output before)

Added unit test to detect this type of case
Without the fix, UT fails:

________________________________________________________________________________________________________________________________________________ test_qknorm_int64_stride _________________________________________________________________________________________________________________________________________________

    def test_qknorm_int64_stride():
        """3D qk_rmsnorm with batch stride > INT32_MAX (issue #3005)."""
        num_heads, head_dim = 4, 128
        dtype = torch.bfloat16
        buf = torch.randn(1, num_heads, head_dim, dtype=dtype, device="cuda")
        w = torch.randn(head_dim, dtype=dtype, device="cuda")
    
        x = torch.as_strided(
            buf, (1, num_heads, head_dim), (_INT64_STRIDE, head_dim, 1)
        )
>       y = flashinfer.norm.rmsnorm(x, w)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

../tests/utils/test_norm.py:385: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../flashinfer/norm/__init__.py:125: in rmsnorm
    _rmsnorm_impl(out, input, weight, eps, enable_pdl)
../flashinfer/norm/__init__.py:143: in _rmsnorm_impl
    qk_rmsnorm_cute(
../flashinfer/norm/kernels/rmsnorm.py:1337: in qk_rmsnorm_cute
    kernel(input, weight, output, batch_size, num_heads, eps)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

>   ???
E   ValueError: Out of bound mX.strides[0] on argument #0 when calling: `__call__(mX: Tensor([n0, n1, 128], bfloat16), mW: Tensor([128], bfloat16), mY: Tensor([n0, n1, 128], bfloat16), B: int32, N: int32, eps: float32)`, expected to be in int32 range [-2147483648, 2147483647]

python/tvm_ffi/cython/function.pxi:929: ValueError
================================================================================================================================================= short test summary info =================================================================================================================================================
FAILED ../tests/utils/test_norm.py::test_qknorm_int64_stride - ValueError: Out of bound mX.strides[0] on argument #0 when calling: `__call__(mX: Tensor([n0, n1, 128], bfloat16), mW: Tensor([128], bfloat16), mY: Tensor([n0, n1, 128], bfloat16), B: int32, N: int32, eps: float32)`, expected to be in int32 range [-2147483648, 2147483647]

with the current fix, it passed

2710 passed in 86.08s (0:01:26)

🔍 Related Issues

#3005

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Improved normalization kernel robustness for tensors with very large stride values to prevent failures on strided views.
  • Tests

    • Added regression tests for RMSNorm, QK-RMSNorm, quantized RMSNorm, and fused normalization variants covering large-stride edge cases; tests skip automatically when GPU memory is insufficient.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 7, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c09f69e8-8163-48c2-b8dc-912f60f6f262

📥 Commits

Reviewing files that changed from the base of the PR and between d7d687856bd2c3a517ee27720e9bd7ea3e8c34ad and 626e2fe.

📒 Files selected for processing (1)
  • tests/utils/test_norm.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/utils/test_norm.py

📝 Walkthrough

Walkthrough

Updated CuTe DSL RMSNorm and fused-add+RMSNorm kernel compilation to use 64-bit symbolic stride integers for non-contiguous (strided) tensor paths; added regression tests that exercise kernels with strides exceeding int32 limits.

Changes

Cohort / File(s) Summary
RMSNorm kernels
flashinfer/norm/kernels/rmsnorm.py
Switched symbolic stride/int symbols from cute.sym_int(...) to cute.sym_int64(...) in non-contiguous ("strided") compilation branches for _get_compiled_rmsnorm_kernel, _get_compiled_qk_rmsnorm_kernel, and _get_compiled_rmsnorm_quant_kernel.
Fused add + RMSNorm kernels
flashinfer/norm/kernels/fused_add_rmsnorm.py
Switched symbolic stride/int symbols from cute.sym_int(...) to cute.sym_int64(...) in the non-contiguous compilation paths for fused add+rmsnorm and fused add+rmsnorm+quant kernels; contiguous branches unchanged.
Regression tests (int64 strides)
tests/utils/test_norm.py
Added tests and helpers that construct torch.as_strided views with row/batch strides >= 2**31 and conditionally skip if insufficient GPU VRAM; validates rmsnorm, qk-rmsnorm, quantized rmsnorm, fused_add_rmsnorm, and fused_add_rmsnorm_quant against reference implementations.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

  • PR #2243: Related to fused_add_rmsnorm(_quant) kernel work; this change extends those compilation paths to use 64-bit symbolic strides.

Suggested reviewers

  • aleozlx
  • sricketts
  • yzh119
  • cyx-6
  • jimmyzho
  • kahyunnam
  • nv-yunzheq
  • samuellees

Poem

🐰
Stride grew long across the field,
Thirty-two could not hold the yield.
I hopped and changed the symbolic sign,
Now sixty-four keeps tensors fine.
Hooray — big sequences dance in line! 🥕

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and specifically describes the main change: using sym_int64 for strides in rmsnorm CuTe DSL kernels to fix int32 overflow.
Description check ✅ Passed The PR description comprehensively addresses the template sections with clear problem statement, root cause analysis, solution details, linked issue reference, and regression test evidence demonstrating the fix works.
Linked Issues check ✅ Passed All coding objectives from issue #3005 are met: stride symbols changed from cute.sym_int to cute.sym_int64 across five compiled kernels, and comprehensive regression tests added to validate int64-range strides work correctly.
Out of Scope Changes check ✅ Passed All changes are directly scoped to fixing the int32 overflow issue: modifications to two kernel files and test additions for regression validation, with no unrelated changes present.
Docstring Coverage ✅ Passed Docstring coverage is 91.67% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@bkryu bkryu self-assigned this Apr 7, 2026
@bkryu bkryu added v0.6.8 release blocker label for 0.6.8 op: norm labels Apr 7, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the symbolic integer types for tensor strides from 32-bit to 64-bit in the RMSNorm and fused add-RMSNorm kernels to prevent potential integer overflow issues. Additionally, several regression tests have been added to verify that these kernels correctly handle strides exceeding the INT32_MAX limit. I have no feedback to provide.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/utils/test_norm.py`:
- Around line 405-447: Both tests use torch.as_strided with shape (1, H) which
can appear contiguous and route to the contiguous kernel; instead allocate
minimal two-row buffers and create a 2-row as_strided view so the non-contiguous
code path is taken. Concretely: in test_fused_add_rmsnorm_int64_stride and
test_fused_add_rmsnorm_quant_int64_stride, create buf_x and buf_r with shape (2,
H) (dtype/device unchanged), compute reference outputs from the first-row clones
(use fused_add_rms_norm / fused_add_rms_norm_quant on buf_x[:1].clone(),
buf_r[:1].clone()), then make x = torch.as_strided(buf_x, (2, H),
(_INT64_STRIDE, 1)) and r = torch.as_strided(buf_r, (2, H), (_INT64_STRIDE, 1)),
call flashinfer.fused_add_rmsnorm or flashinfer.norm.fused_add_rmsnorm_quant on
those views, and compare the first-row of results (x[0], r[0] or y[0]) to the
reference outputs; this ensures the non-contiguous kernel is exercised without
changing kernel-selection logic.
- Around line 361-372: The test test_rmsnorm_int64_stride currently uses a shape
(1, H) so PyTorch considers it contiguous and doesn't exercise the
non-contiguous/sym_int64 path; change the buffer and shaped tensor creation to
use two rows (e.g., buf sized (2, H) and x with shape (2, H) using
torch.as_strided and (_INT64_STRIDE, 1)) so is_contiguous() is False, forcing
the non-contiguous kernel path and exercising the symbolic int64 stride handling
when calling flashinfer.norm.rmsnorm and comparing with llama_rms_norm.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c4f99d51-34b9-401f-ac57-995a138170b7

📥 Commits

Reviewing files that changed from the base of the PR and between edcef4b and d7d687856bd2c3a517ee27720e9bd7ea3e8c34ad.

📒 Files selected for processing (3)
  • flashinfer/norm/kernels/fused_add_rmsnorm.py
  • flashinfer/norm/kernels/rmsnorm.py
  • tests/utils/test_norm.py

Comment thread tests/utils/test_norm.py
Comment thread tests/utils/test_norm.py Outdated
@bkryu bkryu force-pushed the norm_stride_overflow_fix branch from d7d6878 to 01fcf9d Compare April 7, 2026 21:43
@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Apr 7, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !521 has been created, and the CI pipeline #47961615 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47961615: 9/20 passed

@bkryu bkryu merged commit f0edd77 into flashinfer-ai:main Apr 8, 2026
39 checks passed
@bkryu bkryu deleted the norm_stride_overflow_fix branch April 8, 2026 03:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

op: norm v0.6.8 release blocker label for 0.6.8

Projects

None yet

Development

Successfully merging this pull request may close these issues.

integer overflow in rms_norm cutedsl kernel

3 participants