diff --git a/include/flashinfer/topk.cuh b/include/flashinfer/topk.cuh index 975b778a21..71371c2e9e 100644 --- a/include/flashinfer/topk.cuh +++ b/include/flashinfer/topk.cuh @@ -897,6 +897,16 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( input[row_idx * stride + chunk_start + i]; } } + // Clear histogram for next iteration (in case it's k < length) + if constexpr (!SINGLE_CTA) { + constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8; + uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3; + if (cta_in_group == 0) { + for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { + state->histogram[next_first_hist_idx][i] = 0; + } + } + } continue; } } else if constexpr (MODE == RadixTopKMode::PageTableTransform) { @@ -906,6 +916,16 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( for (uint32_t i = tx; i < top_k_val; i += BLOCK_THREADS) { row_output[i] = (i < length) ? src_page_entry[i] : static_cast(-1); } + // Clear histogram for next iteration + if constexpr (!SINGLE_CTA) { + constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8; + uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3; + if (cta_in_group == 0) { + for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { + state->histogram[next_first_hist_idx][i] = 0; + } + } + } continue; } } else { // RaggedTransform @@ -914,6 +934,16 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( for (uint32_t i = tx; i < top_k_val; i += BLOCK_THREADS) { row_output[i] = (i < length) ? static_cast(i) + offset : static_cast(-1); } + // Clear histogram for next iteration + if constexpr (!SINGLE_CTA) { + constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8; + uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3; + if (cta_in_group == 0) { + for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { + state->histogram[next_first_hist_idx][i] = 0; + } + } + } continue; } } @@ -1084,6 +1114,19 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKMaskLogitsKernel_Multi masked_logits[row_idx * vocab_size + chunk_start + i] = logits[row_idx * vocab_size + chunk_start + i]; } + + // Clear histogram for next iteration (in case it's k < vocab_size) + // Only needed for multi-CTA mode; single-CTA uses shared memory cleared each iteration + if constexpr (!SINGLE_CTA) { + constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8; // ORDERED_BITS / RADIX_BITS + uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3; + if (cta_in_group == 0) { + for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { + state->histogram[next_first_hist_idx][i] = 0; + } + } + // No sync needed - next iteration's barrier will ensure visibility + } continue; } @@ -1360,6 +1403,20 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKRenormProbKernel_Multi renormed_prob[row_idx * vocab_size + chunk_start + i] = DType(float(probs[row_idx * vocab_size + chunk_start + i]) * normalizer); } + + // Clear histogram for next iteration (in case it's k < vocab_size) + // Only needed for multi-CTA mode; single-CTA uses shared memory cleared each iteration + // Next iteration (iter+1) will use histogram[((iter+1)*NUM_ROUNDS) % 3] for its first round + if constexpr (!SINGLE_CTA) { + constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8; // ORDERED_BITS / RADIX_BITS + uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3; + if (cta_in_group == 0) { + for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { + state->histogram[next_first_hist_idx][i] = 0; + } + } + // No sync needed - next iteration's barrier will ensure visibility + } continue; } diff --git a/tests/utils/test_sampling.py b/tests/utils/test_sampling.py index 99ff6a3e2b..2ae453ecf0 100644 --- a/tests/utils/test_sampling.py +++ b/tests/utils/test_sampling.py @@ -513,6 +513,83 @@ def test_top_k_renorm_probs(batch_size, vocab_size, k, distribution, dtype): ) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_top_k_renorm_probs_mixed_k_persistent_loop(dtype): + """Test top_k_renorm_probs with mixed k values in persistent loop (multi-CTA mode). + + This test catches a specific bug where: + - Large batch size triggers the persistent loop (multiple iterations per CTA group) + - Large vocab_size triggers multi-CTA mode (multiple CTAs per row) + - Mixed k values: some rows have k >= vocab_size (skip radix select), + others have k < vocab_size (use radix select) + + The bug was that k >= vocab_size iterations would skip radix select + without clearing the histogram buffers, leaving stale data that corrupted + subsequent k < vocab_size iterations. + """ + batch_size = 1024 # Large batch to trigger persistent loop + vocab_size = 128 * 1024 # Large vocab to trigger multi-CTA mode + + torch.manual_seed(42) + generator = torch.Generator(device="cuda:0").manual_seed(42) + + # Generate random logits + logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator) + + # Generate k values: mix of small k and k == vocab_size + generator = torch.Generator(device="cuda:0").manual_seed(42) + k_values = torch.randint( + 1, 1000, (batch_size,), device="cuda:0", generator=generator + ) + + # Randomly set some rows to k == vocab_size (about 50%) + generator = torch.Generator(device="cuda:0").manual_seed(42) + mask = torch.randint( + 0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0" + ) + k_values.masked_fill_(mask, vocab_size) + + # Convert to probs + probs = torch.softmax(logits, dim=-1).to(dtype) + + # Run FlashInfer top_k_renorm_probs + renorm_probs = flashinfer.sampling.top_k_renorm_probs(probs, k_values) + + # Verify output dtype + assert renorm_probs.dtype == dtype + + # Verify sum to 1 + sums = renorm_probs.float().sum(dim=-1) + torch.testing.assert_close(sums, torch.ones_like(sums), rtol=1e-2, atol=1e-2) + + # Verify non-zero count matches k for each row + nonzero_counts = (renorm_probs > 0).sum(dim=-1) + + # For rows with k >= vocab_size, all elements should be non-zero + # For rows with k < vocab_size, non-zero count should be >= k (may be more due to ties) + for i in range(batch_size): + k = k_values[i].item() + count = nonzero_counts[i].item() + + if k >= vocab_size: + # All elements should be non-zero + assert count == vocab_size, ( + f"Row {i}: k >= vocab_size but count={count} != {vocab_size}" + ) + else: + # Count should be at least k (may be more due to ties at the threshold) + row_probs = probs[i].float() + topk_vals, _ = torch.topk(row_probs, k, sorted=True) + threshold = topk_vals[-1] + expected_ge_threshold = (row_probs >= threshold).sum().item() + + # Allow small tolerance for floating point + assert count >= k - 1, f"Row {i}: k={k} but only {count} non-zero elements" + assert count <= expected_ge_threshold + 1, ( + f"Row {i}: k={k}, expected at most {expected_ge_threshold} but got {count}" + ) + + @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("k", [10, 100, 500]) @@ -597,6 +674,68 @@ def test_top_k_mask_logits( ) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_top_k_mask_logits_mixed_k_persistent_loop(dtype): + """Test top_k_mask_logits with mixed k values in persistent loop (multi-CTA mode). + + This test catches the same bug as test_top_k_renorm_probs_mixed_k_persistent_loop + but for the mask_logits variant. + """ + batch_size = 1024 # Large batch to trigger persistent loop + vocab_size = 128 * 1024 # Large vocab to trigger multi-CTA mode + + torch.manual_seed(42) + generator = torch.Generator(device="cuda:0").manual_seed(42) + + # Generate random logits + logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator) + logits = logits.to(dtype) + + # Generate k values: mix of small k and k == vocab_size + generator = torch.Generator(device="cuda:0").manual_seed(42) + k_values = torch.randint( + 1, 1000, (batch_size,), device="cuda:0", generator=generator + ) + + # Randomly set some rows to k == vocab_size (about 50%) + generator = torch.Generator(device="cuda:0").manual_seed(42) + mask = torch.randint( + 0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0" + ) + k_values.masked_fill_(mask, vocab_size) + + # Run FlashInfer top_k_mask_logits + masked_logits = flashinfer.sampling.top_k_mask_logits(logits, k_values) + + # Verify output dtype + assert masked_logits.dtype == dtype + + # Verify finite count matches k for each row + finite_counts = torch.isfinite(masked_logits).sum(dim=-1) + + for i in range(batch_size): + k = k_values[i].item() + count = finite_counts[i].item() + + if k >= vocab_size: + # All elements should be finite + assert count == vocab_size, ( + f"Row {i}: k >= vocab_size but finite count={count} != {vocab_size}" + ) + else: + # Count should be at least k (may be more due to ties at the threshold) + row_logits = logits[i].float() + topk_vals, _ = torch.topk(row_logits, k, sorted=True) + threshold = topk_vals[-1] + expected_ge_threshold = (row_logits >= threshold).sum().item() + + # Allow small tolerance for floating point + assert count >= k - 1, f"Row {i}: k={k} but only {count} finite elements" + assert count <= expected_ge_threshold + 1, ( + f"Row {i}: k={k}, expected at most {expected_ge_threshold} but got {count}" + ) + + @pytest.mark.parametrize("batch_size", [1, 99, 989]) @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) @pytest.mark.parametrize("num_speculate_tokens", [1, 3, 5, 7])