fix(gdn_decode): widen pool indices to Int64 to prevent int32 element-offset overflow#3230
fix(gdn_decode): widen pool indices to Int64 to prevent int32 element-offset overflow#3230vadiklyutiy wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR widens pool- and state-index arithmetic from 32-bit to 64-bit (cutlass.Int64) across BF16 GDN decode kernels (MTP ILP4 and wide-vec paths) and both small/big-batch pretranspose kernels, and adds two GPU regression tests exercising 32-bit overflow scenarios. ChangesInteger Overflow Prevention in Pool/State Indexing
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request fixes a bug where int32 overflow in memory offset calculations led to illegal memory access in GDN kernels. By widening pool indices to Int64 in gdn_decode_bf16state_mtp_kernel and gdn_decode_kernel_*_pretranspose, the code now safely handles large pool indices and strides. Regression tests for both fp32 and bf16 paths have been added to ensure stability for large-scale models like Qwen3.5. I have no feedback to provide.
|
@kahyunnam @kaixih @yzh119 @bkryu @yongwww |
kahyunnam
left a comment
There was a problem hiding this comment.
LGTM, seems straightforwards. Please resolve merge conflicts
|
/bot run |
|
[FAILED] Pipeline #50458583: 1/20 passed |
…-offset overflow
Same root cause as the original report — the CuTe-DSL pool+indices kernels
compute the per-slot element offset (cache_idx * HV + i_hv) * stride[0]
(or pool_idx * stride[0] for fp32) using Int32, so once the product
exceeds INT32_MAX the multiplication wraps to a negative offset and the
load/store hits an unmapped global address (CUDA error: an illegal
memory access was encountered). Discovered while integrating
gated_delta_rule_decode_pretranspose into vLLM's GDN decode path for
Qwen3.5-class models.
Backends fixed (HV=32, V=K=128 unless noted):
fp32 pretranspose:
gdn_decode_kernel_{small,big}_batch_pretranspose
pool_idx * stride[0] overflows at pool_idx >= 3972
(vLLM padded slot stride 540672)
bf16 small-batch fallback (B*HV <= 128):
gdn_decode_bf16state_mtp_ilp4_kernel
(cache_idx * HV + i_hv) * V * K overflows at cache_idx >= 4096
bf16 production fast path
(B*HV >= 128 at T>=2, B*HV >= 512 at T=1):
gdn_wide_vec_kernel
(cache_idx * HV + i_hv) * V * K overflows at cache_idx >= 4096
(i_n * T * HV + i_t * HV + i_hv) overflows at B*T*HV >= 131072
* V * K (intermediate-states e.g. B>=256 at T=8 HV=64
cache writeback) with cache_intermediate_states=True
Fix: widen pool_idx / out_pool_idx (pretranspose) and cache_idx /
write_cache_idx (bf16) to Int64 immediately after they are read from
the indices tensors; widen the per-call batch index used for the
intermediate-states cache write (i_n) to Int64 in both bf16 kernels so
that flat_idx = i_n * T * HV + i_t * HV + i_hv inherits Int64 even when
B * T * HV crosses 131072. The downstream flat_state_idx,
flat_write_state_idx and flat_idx all stay Int64, so the offset
multiplications inside cute.local_tile / h0_source[(...)] cannot wrap.
Note on PR scope vs. PR flashinfer-ai#3147 (Ameyn/wide vec t1, merged on main after
this branch was cut): flashinfer-ai#3147 deleted gdn_decode_bf16state_mtp_kernel and
replaced the bf16 surface with gdn_decode_bf16state_mtp_ilp4_kernel
(small-batch fallback) + gdn_wide_vec_kernel (production hot path).
Both surviving kernels reproduce the same Int32 cache_idx offset bug as
the deleted kernel, and gdn_wide_vec_kernel additionally has an Int32
flat_idx offset bug on the intermediate-states writeback that's
reachable at production batch sizes (B >= 256 with MTP T=8, HV=64).
This commit fixes all three sites at once.
Regression test (tests/gdn/test_decode_pretranspose_noncontiguous_pool.py)
covers:
* test_decode_pretranspose_pool_int64_offset[3973, 8191] — fp32
vLLM-padded pool (~8.6 / 17.7 GB) at the pool_idx threshold.
* test_decode_pretranspose_pool_int64_offset_bf16[4096, 4196] — bf16
contiguous pool (~4.3 GB) at the cache_idx threshold; at B=1 HV=32
the dispatcher selects gdn_decode_bf16state_mtp_ilp4_kernel, which
has the identical overflow site and identical Int64 fix as
gdn_wide_vec_kernel. The wide_vec / intermediate-states fixes are
structurally identical to the ilp4 fix exercised here.
Both bf16 / fp32 tests compare the pool path against a gather + direct-state
reference (numerical correctness, not just non-crashing) and assert the
in-place state update matches. VRAM-based skip when free memory is
insufficient.
AI-assisted (Cursor / Claude).
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
62cac54 to
825a55b
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Line 179: The assignment to cache_idx in gdn_decode_mtp.py reads
h0_indices[i_n] into a 32-bit int and can overflow when used in flat_state_idx =
cache_idx * HV + i_hv; change the two occurrences where cache_idx is set (around
the h0_indices[i_n] reads at the reported locations) to cast the value to 64-bit
by using cutlass.Int64(h0_indices[i_n]) so all subsequent arithmetic and
indexing (flat_state_idx, and any uses across memory accesses) use 64-bit
integers.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f0921492-9f18-4867-aba0-45e8e5945ee4
📒 Files selected for processing (3)
flashinfer/gdn_kernels/gdn_decode_bf16_state.pyflashinfer/gdn_kernels/gdn_decode_pretranspose.pytests/gdn/test_decode_pretranspose_noncontiguous_pool.py
|
/bot run |
|
@kahyunnam both CI jobs fails are infra problem and don't relate to the PR |
📌 Description
Fix
CUDA error: an illegal memory access was encounteredinflashinfer.gdn_decode.gated_delta_rule_decode_pretransposewhen the pool+indices API is used with sufficiently large pool indices.Root cause. The CuTe-DSL kernels compute the per-slot element offset (
pool_idx * stride[0], or(cache_idx * HV + i_hv) * stride[0]for bf16) using Int32 arithmetic. Once the product exceedsINT32_MAX, it wraps to a negative offset and the load/store hits an unmapped global address.Affects both backends the API can dispatch to (HV=32, V=K=128):
gdn_decode_kernel_{small,big}_batch_pretransposepool_idx >= 3972(vLLM padded slot stride 540 672)gdn_decode_bf16state_mtp_kernelcache_idx >= 4096(contiguous,stride[0] = HV*V*K = 524 288)Discovered while integrating the kernel into vLLM's GDN decode path for Qwen3.5-class models.
Fix. Widen the pool indices to Int64 immediately after they are read; downstream offsets in
cute.local_tile(...)/h0_source[(...)]then promote to Int64 and cannot wrap:🔍 Related Issues
Same class of bug as #3005 / #3007 (rmsnorm stride overflow), in a different family of CuTe-DSL kernels.
🚀 Pull Request Checklist
✅ Pre-commit Checks
pre-commitinstalled and hooks installed.pre-commit run --files <changed files>— all hooks pass.🧪 Tests
Added
tests/gdn/test_decode_pretranspose_noncontiguous_pool.py:test_decode_pretranspose_pool_int64_offset[3972, 8191]— fp32 vLLM-padded pool (~8.6 / 17.7 GB).test_decode_pretranspose_pool_int64_offset_bf16[4096, 4196]— bf16 contiguous pool (~4.3 GB).Both compare the pool path against a gather + direct-state reference (numerical correctness, not just non-crashing) and assert the in-place state update matches. VRAM-based skip when free memory is insufficient.
Verified on NVIDIA B200 (SM100): all 4 new tests crash without the fix and pass with it; existing
pretransposeandbf16_statetests intests/gdn/test_decode_delta_rule.pycontinue to pass.Summary by CodeRabbit
Bug Fixes
Tests