Skip to content

fix(mla): widen page index to int64_t to avoid 32-bit overflow#3136

Open
Tracin wants to merge 1 commit intoflashinfer-ai:mainfrom
Tracin:fix_mla
Open

fix(mla): widen page index to int64_t to avoid 32-bit overflow#3136
Tracin wants to merge 1 commit intoflashinfer-ai:mainfrom
Tracin:fix_mla

Conversation

@Tracin
Copy link
Copy Markdown

@Tracin Tracin commented Apr 21, 2026

📌 Description

In the MLA decode/prefill KV load path, indices[q] * ckv_stride_page was computed in 32-bit because IdType is int32_t and *_stride_page is uint32_t; the product wraps modulo 2^32 before any widening to int64_t (Hopper) or pointer arithmetic (FA2). For large page pools (e.g. page_idx ~1M with page_size=32, kv_lora_rank=512, stride=16384) the true product exceeds 2^32 and the kernel reads the wrong page, producing all-zero outputs. Cast the selected page index to int64_t at all three sites (mla.cuh NUM_MMA_KV==1 and !=1 branches, and mla_hopper.cuh prefetch_offset) so the multiply executes in 64-bit.

🔍 Related Issues

#3130

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes
    • Fixed potential integer overflow in memory address calculations for attention operations with large sequences.

…address computation

In the MLA decode/prefill KV load path, `indices[q] * ckv_stride_page`
was computed in 32-bit because `IdType` is `int32_t` and `*_stride_page`
is `uint32_t`; the product wraps modulo 2^32 before any widening to
`int64_t` (Hopper) or pointer arithmetic (FA2). For large page pools
(e.g. page_idx ~1M with page_size=32, kv_lora_rank=512, stride=16384)
the true product exceeds 2^32 and the kernel reads the wrong page,
producing all-zero outputs. Cast the selected page index to `int64_t`
at all three sites (mla.cuh NUM_MMA_KV==1 and !=1 branches, and
mla_hopper.cuh prefetch_offset) so the multiply executes in 64-bit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 21, 2026

📝 Walkthrough

Walkthrough

Two CUDA attention kernel files are modified to prevent integer overflow in pointer arithmetic. The load_kv function in mla.cuh and the prefetch_offset function in mla_hopper.cuh now cast computed KV page indices to int64_t before multiplying by stride values, ensuring correct calculations for large indices.

Changes

Cohort / File(s) Summary
MLA KV Loading
include/flashinfer/attention/mla.cuh
Widened KV page index arithmetic to int64_t before stride multiplication in load_kv function across both code paths (single and multiple MMA KV cases).
MLA Hopper Prefetching
include/flashinfer/attention/mla_hopper.cuh
Widened computed page-offset term to int64_t in prefetch_offset function before multiplying by stride values.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

Possibly related PRs

Suggested labels

run-ci, op: attention

Suggested reviewers

  • yzh119
  • saltyminty
  • bkryu
  • nv-yunzheq

Poem

🐰 Hop along with careful stride,
Where index numbers multiply wide,
From thirty-two to sixty-four,
No overflow knocks down the door,
Attention kernels, safe and spry!

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: widening page index to int64_t to prevent 32-bit overflow in MLA operations.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The PR description provides clear context about a 32-bit overflow bug, explains the issue with concrete examples, and mentions the related GitHub issue.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses potential 32-bit integer overflows in KV cache offset calculations within mla.cuh and mla_hopper.cuh by casting page indices to int64_t. Feedback suggests that the entire offset calculation should be promoted to 64-bit to prevent overflows in subsequent additions and to improve future-proofing.

Comment thread include/flashinfer/attention/mla.cuh
@qsang-nv
Copy link
Copy Markdown
Collaborator

Nice fix — cast is in the right place and all three call sites are covered.

One suggestion: add a minimal regression test that forces page_idx * stride_page > 2^32. This is a silent-output-corruption bug (wrong output, no crash), so without a guard test, a future refactor of IdType or *_stride_page types could easily reintroduce it.

You don't need a huge KV cache — a sparse kv_indices with a few large index values pointing at a small real allocation should hit the overflow path in ~30 lines in tests/attention/test_mla_decode_kernel.py.

@Tracin
Copy link
Copy Markdown
Author

Tracin commented Apr 22, 2026

@qsang-nv Thanks for the review! However I do not get how large index values pointing at a small real allocation. I suppose we need a real address for large page_idx * stride_page.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants