Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 54 additions & 10 deletions flashinfer/gdn_kernels/gdn_decode_bf16_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,28 @@ def gdn_decode_bf16state_mtp_ilp4_kernel(
i_n = tmp // HV
i_h = i_hv // (HV // H)

cache_idx = h0_indices[i_n]
# Widen pool indices to Int64 BEFORE they multiply ``stride[0]`` of
# ``h0_source``. The reshape ``[pool_size, HV, V, K] -> [pool_size * HV,
# V, K]`` (BF16) gives ``stride[0] = V * K = 16384`` elements, so the
# downstream offset ``(cache_idx * HV + i_hv) * 16384`` reaches 2**31 at
# ``cache_idx >= ceil(2**31 / (HV * V * K)) = 4096`` (HV=32, V=K=128;
# 2048 at HV=64). Past that boundary the Int32 multiplication wraps to
# a negative offset and ``cute.local_tile(h0_source, ...)`` issues a
# load/store to an unmapped global address. Propagating Int64 through
# ``flat_state_idx`` / ``flat_write_state_idx`` (computed below) keeps
# the offset multiplication 64-bit. See
# ``tests/gdn/test_decode_pretranspose_noncontiguous_pool.py
# ::test_decode_pretranspose_pool_int64_offset_bf16`` for the
# regression test.
cache_idx = cutlass.Int64(h0_indices[i_n])
Comment thread
vadiklyutiy marked this conversation as resolved.
if cutlass.const_expr(same_pool):
# Single-pool: alias write to read; nvcc DCEs the write-side LDG /
# IMAD / local_tile entirely in this compile path.
write_cache_idx = cache_idx
else:
write_cache_idx = h0_out_indices[i_n]
write_cache_idx = cutlass.Int64(h0_out_indices[i_n])
if write_cache_idx < 0:
write_cache_idx = cutlass.Int32(0)
write_cache_idx = cutlass.Int64(0)

r_A_log = cutlass.Float32(A_log[i_hv])
r_dt_bias = cutlass.Float32(dt_bias[i_hv])
Expand Down Expand Up @@ -225,7 +238,7 @@ def gdn_decode_bf16state_mtp_ilp4_kernel(
)

if cache_idx < 0:
cache_idx = cutlass.Int32(0)
cache_idx = cutlass.Int64(0)

if cache_idx >= 0:
k_start = lane_in_group * vec_size
Expand Down Expand Up @@ -652,7 +665,16 @@ def gdn_decode_bf16state_mtp_ilp4_kernel(
# initial_state_indices points at slots >= B (i.e. any
# realistic pool_size > B serving config). Fix mirrors
# upstream PR #3145.
flat_idx = i_n * T * HV + i_t * HV + i_hv
# Defense-in-depth: widen to Int64 so the offset
# ``flat_idx * stride[0]`` (= ``flat_idx * V * K``
# = ``flat_idx * 16384`` BF16 elements) into the
# batch-scoped intermediate-states cache cannot
# wrap. This kernel is only reached at
# ``B * HV <= 128`` so the flat_idx itself stays
# well below the wrap threshold, but matching the
# wide_vec kernel below keeps the two paths
# bit-equivalent at large pool sizes.
flat_idx = cutlass.Int64(i_n) * T * HV + i_t * HV + i_hv
ita = cute.local_tile(
intermediate_states,
(1, 1, vec_size),
Expand Down Expand Up @@ -780,7 +802,18 @@ def gdn_wide_vec_kernel(
i_n = tmp // HV
i_h = i_hv // (HV // H)

cache_idx = h0_indices[i_n]
# Widen pool indices to Int64 BEFORE they multiply ``stride[0]`` of
# ``h0_source``. ``h0_source`` is reshaped to ``[pool_size * HV, V,
# K]`` (BF16), so ``stride[0] = V * K = 16384`` elements; the
# downstream offset ``(cache_idx * HV + i_hv) * 16384`` wraps int32
# at ``cache_idx >= ceil(2**31 / (HV * V * K)) = 4096`` (HV=32) /
# 2048 (HV=64). Propagating Int64 through ``flat_state_idx`` /
# ``flat_write_state_idx`` keeps the ``cute.local_tile`` offset
# arithmetic 64-bit at every reachable pool size. See
# ``tests/gdn/test_decode_pretranspose_noncontiguous_pool.py
# ::test_decode_pretranspose_pool_int64_offset_bf16`` for the
# regression test.
cache_idx = cutlass.Int64(h0_indices[i_n])

r_A_log = cutlass.Float32(A_log[i_hv])
r_dt_bias = cutlass.Float32(dt_bias[i_hv])
Expand Down Expand Up @@ -824,7 +857,7 @@ def gdn_wide_vec_kernel(
)

if cache_idx < 0:
cache_idx = cutlass.Int32(0)
cache_idx = cutlass.Int64(0)

# Split-pool write index: distinct slot to write the updated H state.
# When same_pool=True (compile-time, set by the dispatcher whenever the
Expand All @@ -835,9 +868,9 @@ def gdn_wide_vec_kernel(
if cutlass.const_expr(same_pool):
write_cache_idx = cache_idx
else:
write_cache_idx = h0_out_indices[i_n]
write_cache_idx = cutlass.Int64(h0_out_indices[i_n])
if write_cache_idx < 0:
write_cache_idx = cutlass.Int32(0)
write_cache_idx = cutlass.Int64(0)

if cache_idx >= 0:
flat_state_idx = cache_idx * HV + i_hv
Expand Down Expand Up @@ -1169,7 +1202,18 @@ def gdn_wide_vec_kernel(
# initial_state_indices points at slots >= B (i.e. any
# realistic pool_size > B serving config). Fix mirrors
# upstream PR #3145.
flat_idx = i_n * T * HV + i_t * HV + i_hv
# Widen to Int64: ``intermediate_states`` is
# reshaped to ``[B * T * HV, V, K]`` (BF16) with
# ``stride[0] = V * K = 16384`` elements. The
# offset ``flat_idx * 16384`` reaches 2**31 at
# ``flat_idx >= 131072``; with HV=64 + T=8 that's
# already hit at ``i_n >= 256`` (i.e. any
# production-scale MTP decode batch with caching
# enabled). Without the widening the Int32
# multiplication wraps and the
# ``cute.local_tile(intermediate_states, ...)``
# writes corrupt unrelated GMEM.
flat_idx = cutlass.Int64(i_n) * T * HV + i_t * HV + i_hv
it0 = cute.local_tile(
intermediate_states,
(1, 1, vec),
Expand Down
27 changes: 21 additions & 6 deletions flashinfer/gdn_kernels/gdn_decode_pretranspose.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,20 @@ def gdn_decode_kernel_small_batch_pretranspose(

# Compute state index: use pool indexing if enabled.
if cutlass.const_expr(use_pool_indexing):
pool_idx = h0_indices[i_n]
out_pool_idx = h0_out_indices[i_n]
# Widen pool indices to Int64 BEFORE they multiply ``stride[0]`` of
# ``h0_source``. With Int32 indices, the per-slot element offset
# ``pool_idx * stride[0]`` silently wraps once it exceeds INT32_MAX,
# which makes the kernel issue loads/stores to unmapped global
# addresses (illegal memory access). For example, the padded slot
# stride 540672 used by vLLM for Qwen3.5-class GDN models crosses the
# threshold at pool_idx >= ceil(2**31 / 540672) = 3972. See
# ``tests/gdn/test_decode_pretranspose_noncontiguous_pool.py::
# test_decode_pretranspose_pool_int64_offset`` for the regression test.
pool_idx = cutlass.Int64(h0_indices[i_n])
out_pool_idx = cutlass.Int64(h0_out_indices[i_n])
# Redirect negative write indices to null buffer (slot 0)
if out_pool_idx < 0:
out_pool_idx = cutlass.Int32(0)
out_pool_idx = cutlass.Int64(0)
else:
pool_idx = 0
out_pool_idx = 0
Expand Down Expand Up @@ -442,11 +451,17 @@ def gdn_decode_kernel_big_batch_pretranspose(

# Compute state index: use pool indexing if enabled.
if cutlass.const_expr(use_pool_indexing):
pool_idx = h0_indices[i_n]
out_pool_idx = h0_out_indices[i_n]
# Widen pool indices to Int64 BEFORE they multiply ``stride[0]`` of
# ``h0_source``. With Int32 indices, the per-slot element offset
# ``pool_idx * stride[0]`` silently wraps when it exceeds 2**31, which
# makes the kernel issue loads/stores to unmapped global addresses
# (illegal memory access). See the small-batch kernel above for the
# full rationale.
pool_idx = cutlass.Int64(h0_indices[i_n])
out_pool_idx = cutlass.Int64(h0_out_indices[i_n])
# Redirect negative write indices to null buffer (slot 0)
if out_pool_idx < 0:
out_pool_idx = cutlass.Int32(0)
out_pool_idx = cutlass.Int64(0)
else:
pool_idx = 0
out_pool_idx = 0
Expand Down
Loading
Loading