diff --git a/include/flashinfer/topk.cuh b/include/flashinfer/topk.cuh index 942d90eb3a..6867fdb061 100644 --- a/include/flashinfer/topk.cuh +++ b/include/flashinfer/topk.cuh @@ -1994,8 +1994,10 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) score_vec.cast_load(&score[base]); #pragma unroll for (int j = 0; j < VEC_SIZE; ++j) { - const auto bin = Traits::ToCoarseKey(score_vec[j]); - atomicAdd(&s_histogram[bin], 1); + if (base + j < length) { + const auto bin = Traits::ToCoarseKey(score_vec[j]); + atomicAdd(&s_histogram[bin], 1); + } } } __syncthreads(); @@ -2038,10 +2040,12 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) score_vec.cast_load(&score[base]); #pragma unroll for (int j = 0; j < VEC_SIZE; ++j) { - const auto bin = static_cast(Traits::ToCoarseKey(score_vec[j])); - if (bin > threshold_bin) { - const auto pos = atomicAdd(&s_counter, 1); - s_indices[pos] = base + j; + if (base + j < length) { + const auto bin = static_cast(Traits::ToCoarseKey(score_vec[j])); + if (bin > threshold_bin) { + const auto pos = atomicAdd(&s_counter, 1); + s_indices[pos] = base + j; + } } } } @@ -2057,18 +2061,20 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) score_vec.cast_load(&score[base]); #pragma unroll for (int j = 0; j < VEC_SIZE; ++j) { - const auto raw_input = score_vec[j]; - const auto bin = static_cast(Traits::ToCoarseKey(raw_input)); - if (bin > threshold_bin) { - const auto pos = atomicAdd(&s_counter, 1); - s_indices[pos] = base + j; - } else if (bin == threshold_bin) { - const auto pos = atomicAdd(&s_num_input[0], 1); - if (__builtin_expect(pos < SMEM_INPUT_SIZE, 1)) { - s_input_idx[0][pos] = base + j; - const auto ordered = Traits::ToOrdered(raw_input); - const auto sub_bin = (ordered >> FIRST_SHIFT) & 0xFF; - atomicAdd(&s_histogram[sub_bin], 1); + if (base + j < length) { + const auto raw_input = score_vec[j]; + const auto bin = static_cast(Traits::ToCoarseKey(raw_input)); + if (bin > threshold_bin) { + const auto pos = atomicAdd(&s_counter, 1); + s_indices[pos] = base + j; + } else if (bin == threshold_bin) { + const auto pos = atomicAdd(&s_num_input[0], 1); + if (__builtin_expect(pos < SMEM_INPUT_SIZE, 1)) { + s_input_idx[0][pos] = base + j; + const auto ordered = Traits::ToOrdered(raw_input); + const auto sub_bin = (ordered >> FIRST_SHIFT) & 0xFF; + atomicAdd(&s_histogram[sub_bin], 1); + } } } } diff --git a/tests/utils/test_topk.py b/tests/utils/test_topk.py index 91829250ca..3e89b20d93 100644 --- a/tests/utils/test_topk.py +++ b/tests/utils/test_topk.py @@ -783,6 +783,42 @@ def test_ragged_transform_offset_correctness(): ) +@pytest.mark.parametrize("length_offset", [1, 2, 3, 5, 6, 7]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_ragged_transform_unaligned_length_bounds(length_offset, dtype): + """Test ragged transform with unaligned lengths doesn't return out-of-bounds indices. + + Exposes bug where vectorized loads (VEC_SIZE=4 for float32, 8 for float16) + can read beyond valid length, returning indices >= length. + """ + torch.manual_seed(42) + device = "cuda" + num_rows = 4 + max_len = 4096 + k = 256 + length = max_len - length_offset + + scores = torch.randn(num_rows, max_len, device=device, dtype=dtype) + offsets = torch.arange( + 0, num_rows * max_len, max_len, device=device, dtype=torch.int32 + ) + lengths = torch.full((num_rows,), length, device=device, dtype=torch.int32) + + output = flashinfer.top_k_ragged_transform(scores, offsets, lengths, k) + + # Verify indices are within bounds [offset, offset + length) + for i in range(num_rows): + offset = offsets[i].item() + valid_indices = output[i][output[i] >= 0] + relative_indices = valid_indices - offset + + assert (relative_indices >= 0).all(), f"Row {i} has negative relative indices" + assert (relative_indices < length).all(), ( + f"Row {i} has out-of-bounds index >= {length}. " + f"Max index: {relative_indices.max().item()}" + ) + + # ===================== SGLang-style Reference Implementation =====================