diff --git a/include/flashinfer/topk.cuh b/include/flashinfer/topk.cuh index 942d90eb3a..975b778a21 100644 --- a/include/flashinfer/topk.cuh +++ b/include/flashinfer/topk.cuh @@ -1989,8 +1989,9 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) vec_t score_vec; + const int aligned_length = (length / VEC_SIZE) * VEC_SIZE; #pragma unroll 2 - for (int base = tx * VEC_SIZE; base < length; base += BLOCK_SIZE * VEC_SIZE) { + for (int base = tx * VEC_SIZE; base < aligned_length; base += BLOCK_SIZE * VEC_SIZE) { score_vec.cast_load(&score[base]); #pragma unroll for (int j = 0; j < VEC_SIZE; ++j) { @@ -1998,6 +1999,11 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) atomicAdd(&s_histogram[bin], 1); } } + // Handle tail + for (int i = aligned_length + tx; i < length; i += BLOCK_SIZE) { + const auto bin = Traits::ToCoarseKey(score[i]); + atomicAdd(&s_histogram[bin], 1); + } __syncthreads(); // Suffix sum @@ -2034,7 +2040,7 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) if (topk == 0) { // Collect indices where bin > threshold #pragma unroll 2 - for (int base = tx * VEC_SIZE; base < length; base += BLOCK_SIZE * VEC_SIZE) { + for (int base = tx * VEC_SIZE; base < aligned_length; base += BLOCK_SIZE * VEC_SIZE) { score_vec.cast_load(&score[base]); #pragma unroll for (int j = 0; j < VEC_SIZE; ++j) { @@ -2045,6 +2051,14 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) } } } + // Handle tail + for (int i = aligned_length + tx; i < length; i += BLOCK_SIZE) { + const auto bin = static_cast(Traits::ToCoarseKey(score[i])); + if (bin > threshold_bin) { + const auto pos = atomicAdd(&s_counter, 1); + s_indices[pos] = i; + } + } __syncthreads(); } else { __syncthreads(); @@ -2052,27 +2066,33 @@ __global__ void __launch_bounds__(FILTERED_TOPK_BLOCK_THREADS) __syncthreads(); // Filter + histogram for refinement + auto filter_and_add_to_histogram = [&](auto raw_input, int index) { + const auto bin = static_cast(Traits::ToCoarseKey(raw_input)); + if (bin > threshold_bin) { + const auto pos = atomicAdd(&s_counter, 1); + s_indices[pos] = index; + } else if (bin == threshold_bin) { + const auto pos = atomicAdd(&s_num_input[0], 1); + if (__builtin_expect(pos < SMEM_INPUT_SIZE, 1)) { + s_input_idx[0][pos] = index; + const auto ordered = Traits::ToOrdered(raw_input); + const auto sub_bin = (ordered >> FIRST_SHIFT) & 0xFF; + atomicAdd(&s_histogram[sub_bin], 1); + } + } + }; #pragma unroll 2 - for (int base = tx * VEC_SIZE; base < length; base += BLOCK_SIZE * VEC_SIZE) { + for (int base = tx * VEC_SIZE; base < aligned_length; base += BLOCK_SIZE * VEC_SIZE) { score_vec.cast_load(&score[base]); #pragma unroll for (int j = 0; j < VEC_SIZE; ++j) { - const auto raw_input = score_vec[j]; - const auto bin = static_cast(Traits::ToCoarseKey(raw_input)); - if (bin > threshold_bin) { - const auto pos = atomicAdd(&s_counter, 1); - s_indices[pos] = base + j; - } else if (bin == threshold_bin) { - const auto pos = atomicAdd(&s_num_input[0], 1); - if (__builtin_expect(pos < SMEM_INPUT_SIZE, 1)) { - s_input_idx[0][pos] = base + j; - const auto ordered = Traits::ToOrdered(raw_input); - const auto sub_bin = (ordered >> FIRST_SHIFT) & 0xFF; - atomicAdd(&s_histogram[sub_bin], 1); - } - } + filter_and_add_to_histogram(score_vec[j], base + j); } } + // Handle tail + for (int i = aligned_length + tx; i < length; i += BLOCK_SIZE) { + filter_and_add_to_histogram(score[i], i); + } __syncthreads(); // Stage 2: refine with 8bit radix passes diff --git a/tests/utils/test_topk.py b/tests/utils/test_topk.py index 91829250ca..921fc7c5ec 100644 --- a/tests/utils/test_topk.py +++ b/tests/utils/test_topk.py @@ -419,6 +419,54 @@ def test_top_k_ragged_transform(num_rows, max_len, k, dtype): assert accuracy >= min_accuracy, f"Accuracy {accuracy:.4f} < {min_accuracy}" +@pytest.mark.parametrize("num_rows", [1, 8, 32]) +@pytest.mark.parametrize("max_len", [1024, 4096, 8192]) +@pytest.mark.parametrize("k", [64, 256, 512]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_top_k_ragged_transform_out_of_length(num_rows, max_len, k, dtype): + """Test top_k_ragged_transform returns correct indices with offsets.""" + if k > max_len: + pytest.skip("k should be less than max_len") + + torch.manual_seed(42) + device = "cuda" + + # Generate random scores + scores = torch.randn(num_rows, max_len, device=device, dtype=dtype) + + # Generate naive offsets (cumulative sum style) + offsets = torch.zeros(num_rows, device=device, dtype=torch.int32) + + # Random in [1, max_len] + lengths = torch.randint( + 1, max_len + 1, (num_rows,), device=device, dtype=torch.int32 + ) + + # Test flashinfer implementation + output = flashinfer.top_k_ragged_transform(scores, offsets, lengths, k) + + # Reference implementation + ref_output = reference_ragged_transform(scores, offsets, lengths, k) + + # Check output shape + assert output.shape == (num_rows, k), ( + f"Expected shape {(num_rows, k)}, got {output.shape}" + ) + assert output.dtype == torch.int32 + + # Check accuracy + accuracy = compute_transform_accuracy(output, ref_output, num_rows, k) + min_accuracy = 0.95 + assert accuracy >= min_accuracy, f"Accuracy {accuracy:.4f} < {min_accuracy}" + # Check out of length + valid_min = offsets + valid_max = offsets + lengths + output = output.clamp_min(0) + assert torch.all((output >= valid_min[:, None]) & (output < valid_max[:, None])), ( + f"Out of length Error. {valid_min=}, {valid_max=}, {output.max(dim=1).values=}, {output.min(dim=1).values=}" + ) + + @pytest.mark.parametrize("num_rows", [4, 16]) @pytest.mark.parametrize("max_len", [2048]) @pytest.mark.parametrize("k", [256, 512])