Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions include/flashinfer/attention/mla.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,17 @@ __device__ __forceinline__ void load_kv(
uint32_t packed_block_iter = packed_block_iter_base + lane_idx / 8 + warp_idx_in_wg * 4;
block_size.divmod(packed_block_iter, q, r);

DTypeKV* ckv_ptr = ckv +
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * ckv_stride_page +
r * ckv_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
DTypeKV* kpe_ptr = kpe +
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * kpe_stride_page +
r * kpe_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
// Cast page index to int64_t before multiplying to avoid overflow.
DTypeKV* ckv_ptr =
ckv +
static_cast<int64_t>(packed_block_iter < packed_kv_bound ? indices[q] : 0) *
ckv_stride_page +
Comment thread
Tracin marked this conversation as resolved.
r * ckv_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
DTypeKV* kpe_ptr =
kpe +
static_cast<int64_t>(packed_block_iter < packed_kv_bound ? indices[q] : 0) *
kpe_stride_page +
r * kpe_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();

#pragma unroll
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_CKV / 4; ++mma_d) {
Expand Down Expand Up @@ -234,12 +239,17 @@ __device__ __forceinline__ void load_kv(
(warpgroup_idx + mma_kv * 2) * 16 + warp_idx_in_wg * 4;
block_size.divmod(packed_block_iter, q, r);

DTypeKV* ckv_ptr = ckv +
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * ckv_stride_page +
r * ckv_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
DTypeKV* kpe_ptr = kpe +
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * kpe_stride_page +
r * kpe_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
// See comment above: widen to int64_t to avoid 32-bit overflow when indices[q] is large.
DTypeKV* ckv_ptr =
ckv +
static_cast<int64_t>(packed_block_iter < packed_kv_bound ? indices[q] : 0) *
ckv_stride_page +
r * ckv_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
DTypeKV* kpe_ptr =
kpe +
static_cast<int64_t>(packed_block_iter < packed_kv_bound ? indices[q] : 0) *
kpe_stride_page +
r * kpe_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();

#pragma unroll
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_CKV / 4; ++mma_d) {
Expand Down
7 changes: 5 additions & 2 deletions include/flashinfer/attention/mla_hopper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,14 @@ __device__ __forceinline__ void prefetch_offset(
uint32_t packed_block_iter =
packed_block_iter_base + lane_idx / 8 + (j + mma_kv * 2) * 16 + warp_idx_in_wg * 4;
block_size.divmod(packed_block_iter, q, r);
// Widen page index to int64_t before multiplying to avoid overflow.
ckv_offset[mma_kv][j] =
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * ckv_stride_page +
static_cast<int64_t>(packed_block_iter < packed_kv_bound ? indices[q] : 0) *
ckv_stride_page +
r * ckv_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
kpe_offset[mma_kv][j] =
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * kpe_stride_page +
static_cast<int64_t>(packed_block_iter < packed_kv_bound ? indices[q] : 0) *
kpe_stride_page +
r * kpe_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
}
}
Expand Down
82 changes: 82 additions & 0 deletions tests/attention/test_mla_decode_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch import nn

import flashinfer
from flashinfer.utils import is_sm90a_supported
from tests.test_helpers.rope_reference import apply_rotary_emb, precompute_freqs_cis
from tests.test_helpers.test_helpers import skip_on_gpu_arch_error
from tvm_ffi import use_torch_stream
Expand Down Expand Up @@ -508,6 +509,87 @@ def test_mla_decode_kernel(bsz, kv_len, page_size):
print(f"mse_use_flashinfer = {mse_use_flashinfer}")


@pytest.mark.parametrize("backend", ["fa2", "fa3"])
def test_mla_page_index_uint32_overflow_regression(backend):
# Regression for the int64 widening in mla.cuh / mla_hopper.cuh
# (`indices[q] * ckv_stride_page`). For a contiguous
# [num_pages, page_size, head_dim_ckv] cache with page_size=32 and
# head_dim_ckv=512, ckv_stride_page = 16384 elements. Any page index
# >= 2^32 / 16384 = 262144 makes the multiplication overflow uint32 and
# β€” pre-fix β€” silently wraps to the wrong page (no crash, wrong output).
device = torch.device("cuda:0")
if backend == "fa3" and not is_sm90a_supported(device):
pytest.skip("fa3 backend requires SM90a")
Comment on lines +512 to +522
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Add the existing architecture-error skip wrapper to this backend-parametrized test.

fa3 is gated, but fa2 or backend dispatch can still raise an unsupported-architecture error on some runners. Match the existing MLA decode test by wrapping this regression with @skip_on_gpu_arch_error.

Proposed fix
+@skip_on_gpu_arch_error
 `@pytest.mark.parametrize`("backend", ["fa2", "fa3"])
 def test_mla_page_index_uint32_overflow_regression(backend):

As per coding guidelines, tests/**/*.py: β€œSkip test execution on unsupported GPU architectures using flashinfer.utils check functions (is_sm90a_supported(), is_sm100a_supported(), etc.) or API methods like api_name.is_compute_capability_supported(cc)”.

πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/attention/test_mla_decode_kernel.py` around lines 512 - 522, Add the
existing architecture-error skip wrapper to the parametrized test by decorating
test_mla_page_index_uint32_overflow_regression with `@skip_on_gpu_arch_error`
(imported from flashinfer.utils or the test utilities), so that
GPU-architecture-related exceptions raised for either backend are handled; keep
the existing `@pytest.mark.parametrize`("backend", ["fa2", "fa3"]) and the in-body
SM90a guard (is_sm90a_supported) for fa3, but place the `@skip_on_gpu_arch_error`
decorator directly above the test function definition to match other MLA decode
tests.


page_size, head_dim_ckv, head_dim_kpe, num_heads = 32, 512, 64, 128
# 262144 * (32 * 512) = 2^32 exactly β€” the smallest index that overflows.
OVERFLOW_START = 262144
NUM_PAGES = 26 # matches the 26-page decode scenario from the original repro
total_num_pages = OVERFLOW_START + NUM_PAGES # 262170
kv_len = NUM_PAGES * page_size

# Big cache alone is ~9.66 GiB (bf16/fp16). Skip on small-memory runners.
if torch.cuda.mem_get_info(device)[0] < 12 * (1 << 30):
pytest.skip("needs β‰₯12 GiB free VRAM to force the 32-bit overflow")

torch.manual_seed(0)
torch.set_grad_enabled(False)
dtype = torch.float16
sm_scale = 1.0 / ((128 + 64) ** 0.5)

real_ckv = torch.randn(
NUM_PAGES, page_size, head_dim_ckv, device=device, dtype=dtype
)
real_kpe = torch.randn(
NUM_PAGES, page_size, head_dim_kpe, device=device, dtype=dtype
)
q_nope = torch.randn(1, num_heads, head_dim_ckv, device=device, dtype=dtype)
q_pe = torch.randn(1, num_heads, head_dim_kpe, device=device, dtype=dtype)
workspace = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)

def _run(ckv_cache, kpe_cache, page_indices):
w = flashinfer.mla.BatchMLAPagedAttentionWrapper(workspace, backend=backend)
w.plan(
torch.tensor([0, 1], dtype=torch.int32, device=device), # qo_indptr
torch.tensor([0, len(page_indices)], dtype=torch.int32, device=device),
page_indices,
torch.tensor([kv_len], dtype=torch.int32, device=device),
num_heads,
head_dim_ckv,
head_dim_kpe,
page_size,
False,
sm_scale,
dtype,
dtype,
)
return w.run(q_nope, q_pe, ckv_cache, kpe_cache)

# Overflow path: big contiguous cache; real data lives at [OVERFLOW_START, end).
# stride(0) = page_size * head_dim_ckv = 16384 matches the reference below,
# so only the page-index arithmetic differs between the two runs.
ckv_big = torch.zeros(
total_num_pages, page_size, head_dim_ckv, device=device, dtype=dtype
)
kpe_big = torch.zeros(
total_num_pages, page_size, head_dim_kpe, device=device, dtype=dtype
)
ckv_big[OVERFLOW_START:] = real_ckv
kpe_big[OVERFLOW_START:] = real_kpe
big_indices = torch.arange(
OVERFLOW_START, total_num_pages, dtype=torch.int32, device=device
)
out = _run(ckv_big, kpe_big, big_indices)
del ckv_big, kpe_big
torch.cuda.empty_cache()

# Reference: same data, same stride(0), but page indices < overflow threshold.
ref_indices = torch.arange(NUM_PAGES, dtype=torch.int32, device=device)
ref = _run(real_ckv, real_kpe, ref_indices)

torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-3)


if __name__ == "__main__":
bsz = 6
kv_len = 640
Expand Down
Loading