Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 24 additions & 18 deletions include/flashinfer/topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1994,8 +1994,10 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
score_vec.cast_load(&score[base]);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
const auto bin = Traits::ToCoarseKey(score_vec[j]);
atomicAdd(&s_histogram[bin], 1);
if (base + j < length) {
const auto bin = Traits::ToCoarseKey(score_vec[j]);
atomicAdd(&s_histogram[bin], 1);
}
}
}
__syncthreads();
Expand Down Expand Up @@ -2038,10 +2040,12 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
score_vec.cast_load(&score[base]);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
const auto bin = static_cast<int>(Traits::ToCoarseKey(score_vec[j]));
if (bin > threshold_bin) {
const auto pos = atomicAdd(&s_counter, 1);
s_indices[pos] = base + j;
if (base + j < length) {
const auto bin = static_cast<int>(Traits::ToCoarseKey(score_vec[j]));
if (bin > threshold_bin) {
const auto pos = atomicAdd(&s_counter, 1);
s_indices[pos] = base + j;
}
}
}
}
Expand All @@ -2057,18 +2061,20 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS)
score_vec.cast_load(&score[base]);
#pragma unroll
for (int j = 0; j < VEC_SIZE; ++j) {
const auto raw_input = score_vec[j];
const auto bin = static_cast<int>(Traits::ToCoarseKey(raw_input));
if (bin > threshold_bin) {
const auto pos = atomicAdd(&s_counter, 1);
s_indices[pos] = base + j;
} else if (bin == threshold_bin) {
const auto pos = atomicAdd(&s_num_input[0], 1);
if (__builtin_expect(pos < SMEM_INPUT_SIZE, 1)) {
s_input_idx[0][pos] = base + j;
const auto ordered = Traits::ToOrdered(raw_input);
const auto sub_bin = (ordered >> FIRST_SHIFT) & 0xFF;
atomicAdd(&s_histogram[sub_bin], 1);
if (base + j < length) {
const auto raw_input = score_vec[j];
const auto bin = static_cast<int>(Traits::ToCoarseKey(raw_input));
if (bin > threshold_bin) {
const auto pos = atomicAdd(&s_counter, 1);
s_indices[pos] = base + j;
} else if (bin == threshold_bin) {
const auto pos = atomicAdd(&s_num_input[0], 1);
if (__builtin_expect(pos < SMEM_INPUT_SIZE, 1)) {
s_input_idx[0][pos] = base + j;
const auto ordered = Traits::ToOrdered(raw_input);
const auto sub_bin = (ordered >> FIRST_SHIFT) & 0xFF;
atomicAdd(&s_histogram[sub_bin], 1);
}
}
}
}
Expand Down
36 changes: 36 additions & 0 deletions tests/utils/test_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,42 @@ def test_ragged_transform_offset_correctness():
)


@pytest.mark.parametrize("length_offset", [1, 2, 3, 5, 6, 7])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_ragged_transform_unaligned_length_bounds(length_offset, dtype):
"""Test ragged transform with unaligned lengths doesn't return out-of-bounds indices.

Exposes bug where vectorized loads (VEC_SIZE=4 for float32, 8 for float16)
can read beyond valid length, returning indices >= length.
"""
torch.manual_seed(42)
device = "cuda"
num_rows = 4
max_len = 4096
k = 256
length = max_len - length_offset

scores = torch.randn(num_rows, max_len, device=device, dtype=dtype)
offsets = torch.arange(
0, num_rows * max_len, max_len, device=device, dtype=torch.int32
)
lengths = torch.full((num_rows,), length, device=device, dtype=torch.int32)

output = flashinfer.top_k_ragged_transform(scores, offsets, lengths, k)

# Verify indices are within bounds [offset, offset + length)
for i in range(num_rows):
offset = offsets[i].item()
valid_indices = output[i][output[i] >= 0]
relative_indices = valid_indices - offset

assert (relative_indices >= 0).all(), f"Row {i} has negative relative indices"
assert (relative_indices < length).all(), (
f"Row {i} has out-of-bounds index >= {length}. "
f"Max index: {relative_indices.max().item()}"
)


# ===================== SGLang-style Reference Implementation =====================


Expand Down