Skip to content

fix(gdn_decode): widen pool indices to Int64 to prevent int32 element-offset overflow#3230

Open
vadiklyutiy wants to merge 1 commit intoflashinfer-ai:mainfrom
vadiklyutiy:gdn-decode-int32-overflow-fix
Open

fix(gdn_decode): widen pool indices to Int64 to prevent int32 element-offset overflow#3230
vadiklyutiy wants to merge 1 commit intoflashinfer-ai:mainfrom
vadiklyutiy:gdn-decode-int32-overflow-fix

Conversation

@vadiklyutiy
Copy link
Copy Markdown
Contributor

@vadiklyutiy vadiklyutiy commented May 5, 2026

📌 Description

Fix CUDA error: an illegal memory access was encountered in flashinfer.gdn_decode.gated_delta_rule_decode_pretranspose when the pool+indices API is used with sufficiently large pool indices.

Root cause. The CuTe-DSL kernels compute the per-slot element offset (pool_idx * stride[0], or (cache_idx * HV + i_hv) * stride[0] for bf16) using Int32 arithmetic. Once the product exceeds INT32_MAX, it wraps to a negative offset and the load/store hits an unmapped global address.

Affects both backends the API can dispatch to (HV=32, V=K=128):

backend kernel overflow threshold
fp32 pretranspose gdn_decode_kernel_{small,big}_batch_pretranspose pool_idx >= 3972 (vLLM padded slot stride 540 672)
bf16 fast path gdn_decode_bf16state_mtp_kernel cache_idx >= 4096 (contiguous, stride[0] = HV*V*K = 524 288)

Discovered while integrating the kernel into vLLM's GDN decode path for Qwen3.5-class models.

Fix. Widen the pool indices to Int64 immediately after they are read; downstream offsets in cute.local_tile(...) / h0_source[(...)] then promote to Int64 and cannot wrap:

# fp32 pretranspose (small + big batch)
pool_idx = cutlass.Int64(h0_indices[i_n])
out_pool_idx = cutlass.Int64(h0_out_indices[i_n])

# bf16 MTP — propagates Int64 through flat_state_idx,
# flat_write_idx, and the intermediate-states cache's flat_idx.
cache_idx = cutlass.Int64(h0_indices[i_n])
write_cache_idx = cutlass.Int64(h0_out_indices[i_n])

🔍 Related Issues

Same class of bug as #3005 / #3007 (rmsnorm stride overflow), in a different family of CuTe-DSL kernels.

🚀 Pull Request Checklist

✅ Pre-commit Checks

  • pre-commit installed and hooks installed.
  • pre-commit run --files <changed files> — all hooks pass.

🧪 Tests

  • Tests added.
  • All tests pass.

Added tests/gdn/test_decode_pretranspose_noncontiguous_pool.py:

  • test_decode_pretranspose_pool_int64_offset[3972, 8191] — fp32 vLLM-padded pool (~8.6 / 17.7 GB).
  • test_decode_pretranspose_pool_int64_offset_bf16[4096, 4196] — bf16 contiguous pool (~4.3 GB).

Both compare the pool path against a gather + direct-state reference (numerical correctness, not just non-crashing) and assert the in-place state update matches. VRAM-based skip when free memory is insufficient.

Verified on NVIDIA B200 (SM100): all 4 new tests crash without the fix and pass with it; existing pretranspose and bf16_state tests in tests/gdn/test_decode_delta_rule.py continue to pass.

Summary by CodeRabbit

  • Bug Fixes

    • Resolved integer overflow in GPU decode kernels by switching pool- and state-index arithmetic to 64-bit, preventing wraparound and out-of-bounds addressing for large pools and batches.
    • Ensured consistent 64-bit handling across all decode paths and negative-index clamping.
  • Tests

    • Added GPU regression tests covering large-pool overflow scenarios for FP32 and BF16 fast paths, with device-capacity guards to avoid OOM on low-VRAM systems.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 5, 2026

📝 Walkthrough

Walkthrough

This PR widens pool- and state-index arithmetic from 32-bit to 64-bit (cutlass.Int64) across BF16 GDN decode kernels (MTP ILP4 and wide-vec paths) and both small/big-batch pretranspose kernels, and adds two GPU regression tests exercising 32-bit overflow scenarios.

Changes

Integer Overflow Prevention in Pool/State Indexing

Layer / File(s) Summary
Index type changes / Data shape
flashinfer/gdn_kernels/gdn_decode_bf16_state.py, flashinfer/gdn_kernels/gdn_decode_pretranspose.py
Pool/index variables (cache_idx, h0_indices[i_n], h0_out_indices[i_n], intermediate flat_idx, flat_state_idx, flat_write_state_idx) are converted/cast to cutlass.Int64 before arithmetic. Negative-index redirection now uses cutlass.Int64(0).
Core kernel logic
flashinfer/gdn_kernels/gdn_decode_bf16_state.py
BF16 MTP ILP4 and wide-vec paths updated to use 64-bit addressing for read/write slot calculations and intermediate-state indexing to avoid 32-bit wraparound.
Pretranspose wiring
flashinfer/gdn_kernels/gdn_decode_pretranspose.py
Small-batch and big-batch pretranspose kernels cast pool indices to Int64 prior to global-memory address offset computation; inline comments document overflow scenarios.
Tests / Regression coverage
tests/gdn/test_decode_pretranspose_noncontiguous_pool.py
Adds two GPU tests: an FP32 padded-stride overflow case and a BF16 MTP fast-path overflow case. Both validate pool-based outputs/states against gather/direct-state references and include VRAM-capacity guards.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • yzh119
  • bkryu
  • yongwww
  • kahyunnam
  • kaixih

🐰 Pool indices grow so tall and wide,
Int32 couldn't hold them with pride!
Int64 comes to save the day,
No overflow shall come our way!
With tests that soar and kernels bright,
Large memories handled just right!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 62.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main fix: widening pool indices from Int32 to Int64 to prevent overflow in element-offset calculations across GDN decode kernels.
Description check ✅ Passed The description is comprehensive and well-structured, covering root cause, affected backends with thresholds, the fix rationale, related issues, test coverage, and verification results on hardware.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes a bug where int32 overflow in memory offset calculations led to illegal memory access in GDN kernels. By widening pool indices to Int64 in gdn_decode_bf16state_mtp_kernel and gdn_decode_kernel_*_pretranspose, the code now safely handles large pool indices and strides. Regression tests for both fp32 and bf16 paths have been added to ensure stability for large-scale models like Qwen3.5. I have no feedback to provide.

@vadiklyutiy
Copy link
Copy Markdown
Contributor Author

@kahyunnam @kaixih @yzh119 @bkryu @yongwww
Could you pls take a look on this changes

Copy link
Copy Markdown
Member

@kahyunnam kahyunnam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, seems straightforwards. Please resolve merge conflicts

@kahyunnam
Copy link
Copy Markdown
Member

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !637 has been created, and the CI pipeline #50458583 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #50458583: 1/20 passed

…-offset overflow

Same root cause as the original report — the CuTe-DSL pool+indices kernels
compute the per-slot element offset (cache_idx * HV + i_hv) * stride[0]
(or pool_idx * stride[0] for fp32) using Int32, so once the product
exceeds INT32_MAX the multiplication wraps to a negative offset and the
load/store hits an unmapped global address (CUDA error: an illegal
memory access was encountered). Discovered while integrating
gated_delta_rule_decode_pretranspose into vLLM's GDN decode path for
Qwen3.5-class models.

Backends fixed (HV=32, V=K=128 unless noted):

  fp32 pretranspose:
    gdn_decode_kernel_{small,big}_batch_pretranspose
      pool_idx * stride[0]                  overflows at pool_idx >= 3972
                                            (vLLM padded slot stride 540672)

  bf16 small-batch fallback (B*HV <= 128):
    gdn_decode_bf16state_mtp_ilp4_kernel
      (cache_idx * HV + i_hv) * V * K       overflows at cache_idx >= 4096

  bf16 production fast path
  (B*HV >= 128 at T>=2, B*HV >= 512 at T=1):
    gdn_wide_vec_kernel
      (cache_idx * HV + i_hv) * V * K       overflows at cache_idx >= 4096
      (i_n * T * HV + i_t * HV + i_hv)      overflows at B*T*HV >= 131072
        * V * K   (intermediate-states      e.g. B>=256 at T=8 HV=64
         cache writeback)                   with cache_intermediate_states=True

Fix: widen pool_idx / out_pool_idx (pretranspose) and cache_idx /
write_cache_idx (bf16) to Int64 immediately after they are read from
the indices tensors; widen the per-call batch index used for the
intermediate-states cache write (i_n) to Int64 in both bf16 kernels so
that flat_idx = i_n * T * HV + i_t * HV + i_hv inherits Int64 even when
B * T * HV crosses 131072. The downstream flat_state_idx,
flat_write_state_idx and flat_idx all stay Int64, so the offset
multiplications inside cute.local_tile / h0_source[(...)] cannot wrap.

Note on PR scope vs. PR flashinfer-ai#3147 (Ameyn/wide vec t1, merged on main after
this branch was cut): flashinfer-ai#3147 deleted gdn_decode_bf16state_mtp_kernel and
replaced the bf16 surface with gdn_decode_bf16state_mtp_ilp4_kernel
(small-batch fallback) + gdn_wide_vec_kernel (production hot path).
Both surviving kernels reproduce the same Int32 cache_idx offset bug as
the deleted kernel, and gdn_wide_vec_kernel additionally has an Int32
flat_idx offset bug on the intermediate-states writeback that's
reachable at production batch sizes (B >= 256 with MTP T=8, HV=64).
This commit fixes all three sites at once.

Regression test (tests/gdn/test_decode_pretranspose_noncontiguous_pool.py)
covers:
  * test_decode_pretranspose_pool_int64_offset[3973, 8191] — fp32
    vLLM-padded pool (~8.6 / 17.7 GB) at the pool_idx threshold.
  * test_decode_pretranspose_pool_int64_offset_bf16[4096, 4196] — bf16
    contiguous pool (~4.3 GB) at the cache_idx threshold; at B=1 HV=32
    the dispatcher selects gdn_decode_bf16state_mtp_ilp4_kernel, which
    has the identical overflow site and identical Int64 fix as
    gdn_wide_vec_kernel. The wide_vec / intermediate-states fixes are
    structurally identical to the ilp4 fix exercised here.

Both bf16 / fp32 tests compare the pool path against a gather + direct-state
reference (numerical correctness, not just non-crashing) and assert the
in-place state update matches. VRAM-based skip when free memory is
insufficient.

AI-assisted (Cursor / Claude).

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
@vadiklyutiy vadiklyutiy force-pushed the gdn-decode-int32-overflow-fix branch from 62cac54 to 825a55b Compare May 6, 2026 19:43
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@flashinfer/gdn_kernels/gdn_decode_bf16_state.py`:
- Line 179: The assignment to cache_idx in gdn_decode_mtp.py reads
h0_indices[i_n] into a 32-bit int and can overflow when used in flat_state_idx =
cache_idx * HV + i_hv; change the two occurrences where cache_idx is set (around
the h0_indices[i_n] reads at the reported locations) to cast the value to 64-bit
by using cutlass.Int64(h0_indices[i_n]) so all subsequent arithmetic and
indexing (flat_state_idx, and any uses across memory accesses) use 64-bit
integers.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f0921492-9f18-4867-aba0-45e8e5945ee4

📥 Commits

Reviewing files that changed from the base of the PR and between 62cac54 and 825a55b.

📒 Files selected for processing (3)
  • flashinfer/gdn_kernels/gdn_decode_bf16_state.py
  • flashinfer/gdn_kernels/gdn_decode_pretranspose.py
  • tests/gdn/test_decode_pretranspose_noncontiguous_pool.py

Comment thread flashinfer/gdn_kernels/gdn_decode_bf16_state.py
@kahyunnam
Copy link
Copy Markdown
Member

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !637 has been updated with latest changes, and the CI pipeline #50593896 is currently running. I'll report back once the pipeline job completes.

@vadiklyutiy
Copy link
Copy Markdown
Contributor Author

@kahyunnam both CI jobs fails are infra problem and don't relate to the PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants