fix: use sym_int64 for strides in rmsnorm CuTe DSL kernels to prevent int32 overflow#3007
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: 📥 CommitsReviewing files that changed from the base of the PR and between d7d687856bd2c3a517ee27720e9bd7ea3e8c34ad and 626e2fe. 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughUpdated CuTe DSL RMSNorm and fused-add+RMSNorm kernel compilation to use 64-bit symbolic stride integers for non-contiguous (strided) tensor paths; added regression tests that exercise kernels with strides exceeding int32 limits. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request updates the symbolic integer types for tensor strides from 32-bit to 64-bit in the RMSNorm and fused add-RMSNorm kernels to prevent potential integer overflow issues. Additionally, several regression tests have been added to verify that these kernels correctly handle strides exceeding the INT32_MAX limit. I have no feedback to provide.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/utils/test_norm.py`:
- Around line 405-447: Both tests use torch.as_strided with shape (1, H) which
can appear contiguous and route to the contiguous kernel; instead allocate
minimal two-row buffers and create a 2-row as_strided view so the non-contiguous
code path is taken. Concretely: in test_fused_add_rmsnorm_int64_stride and
test_fused_add_rmsnorm_quant_int64_stride, create buf_x and buf_r with shape (2,
H) (dtype/device unchanged), compute reference outputs from the first-row clones
(use fused_add_rms_norm / fused_add_rms_norm_quant on buf_x[:1].clone(),
buf_r[:1].clone()), then make x = torch.as_strided(buf_x, (2, H),
(_INT64_STRIDE, 1)) and r = torch.as_strided(buf_r, (2, H), (_INT64_STRIDE, 1)),
call flashinfer.fused_add_rmsnorm or flashinfer.norm.fused_add_rmsnorm_quant on
those views, and compare the first-row of results (x[0], r[0] or y[0]) to the
reference outputs; this ensures the non-contiguous kernel is exercised without
changing kernel-selection logic.
- Around line 361-372: The test test_rmsnorm_int64_stride currently uses a shape
(1, H) so PyTorch considers it contiguous and doesn't exercise the
non-contiguous/sym_int64 path; change the buffer and shaped tensor creation to
use two rows (e.g., buf sized (2, H) and x with shape (2, H) using
torch.as_strided and (_INT64_STRIDE, 1)) so is_contiguous() is False, forcing
the non-contiguous kernel path and exercising the symbolic int64 stride handling
when calling flashinfer.norm.rmsnorm and comparing with llama_rms_norm.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c4f99d51-34b9-401f-ac57-995a138170b7
📥 Commits
Reviewing files that changed from the base of the PR and between edcef4b and d7d687856bd2c3a517ee27720e9bd7ea3e8c34ad.
📒 Files selected for processing (3)
flashinfer/norm/kernels/fused_add_rmsnorm.pyflashinfer/norm/kernels/rmsnorm.pytests/utils/test_norm.py
d7d6878 to
01fcf9d
Compare
|
/bot run |
|
[FAILED] Pipeline #47961615: 9/20 passed |
📌 Description
Summary:
cute.sym_int()→cute.sym_int64()for all stride symbols across 5 compiled kernel functions inrmsnorm.pyandfused_add_rmsnorm.pytorch.as_stridedwithstride = 2^31Details
Fixes #3005
Tensor strides are products of higher dimensions and can exceed int32 range for large sequence lengths. The WAN 2.2 model (hidden_dim=5120, QKV fused projection) hits this at seq_len > 139,810 where batch_stride = 15360 × seq_len overflows.
The overflow is caught by TVM-FFI before kernel launch (see previously failing test output before)
Added unit test to detect this type of case
Without the fix, UT fails:
with the current fix, it passed
🔍 Related Issues
#3005
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Tests