fix(mla): widen page index to int64_t to avoid 32-bit overflow#3136
fix(mla): widen page index to int64_t to avoid 32-bit overflow#3136Tracin wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
…address computation In the MLA decode/prefill KV load path, `indices[q] * ckv_stride_page` was computed in 32-bit because `IdType` is `int32_t` and `*_stride_page` is `uint32_t`; the product wraps modulo 2^32 before any widening to `int64_t` (Hopper) or pointer arithmetic (FA2). For large page pools (e.g. page_idx ~1M with page_size=32, kv_lora_rank=512, stride=16384) the true product exceeds 2^32 and the kernel reads the wrong page, producing all-zero outputs. Cast the selected page index to `int64_t` at all three sites (mla.cuh NUM_MMA_KV==1 and !=1 branches, and mla_hopper.cuh prefetch_offset) so the multiply executes in 64-bit. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
📝 WalkthroughWalkthroughTwo CUDA attention kernel files are modified to prevent integer overflow in pointer arithmetic. The Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request addresses potential 32-bit integer overflows in KV cache offset calculations within mla.cuh and mla_hopper.cuh by casting page indices to int64_t. Feedback suggests that the entire offset calculation should be promoted to 64-bit to prevent overflows in subsequent additions and to improve future-proofing.
|
Nice fix — cast is in the right place and all three call sites are covered. One suggestion: add a minimal regression test that forces You don't need a huge KV cache — a sparse |
|
@qsang-nv Thanks for the review! However I do not get how |
📌 Description
In the MLA decode/prefill KV load path,
indices[q] * ckv_stride_pagewas computed in 32-bit becauseIdTypeisint32_tand*_stride_pageisuint32_t; the product wraps modulo 2^32 before any widening toint64_t(Hopper) or pointer arithmetic (FA2). For large page pools (e.g. page_idx ~1M with page_size=32, kv_lora_rank=512, stride=16384) the true product exceeds 2^32 and the kernel reads the wrong page, producing all-zero outputs. Cast the selected page index toint64_tat all three sites (mla.cuh NUM_MMA_KV==1 and !=1 branches, and mla_hopper.cuh prefetch_offset) so the multiply executes in 64-bit.🔍 Related Issues
#3130
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit