-
Notifications
You must be signed in to change notification settings - Fork 840
bugfix: fix multi-cta top-k implementation when k value is different for different row #2325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -897,6 +897,16 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( | |
| input[row_idx * stride + chunk_start + i]; | ||
| } | ||
| } | ||
| // Clear histogram for next iteration (in case it's k < length) | ||
| if constexpr (!SINGLE_CTA) { | ||
| constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8; | ||
| uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3; | ||
| if (cta_in_group == 0) { | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| state->histogram[next_first_hist_idx][i] = 0; | ||
| } | ||
| } | ||
| } | ||
|
Comment on lines
+900
to
+909
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add an intra-CTA sync after βnext-iteration histogram clearβ to avoid partial clears being βreleasedβ early. In these blocks, multiple threads in CTA0 write Minimal pattern to apply in each location if (cta_in_group == 0) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
state->histogram[next_first_hist_idx][i] = 0;
}
+ __syncthreads(); // ensure CTA0 finished clearing before the next iteration can signal
}
- // No sync needed - next iteration's barrier will ensure visibility
+ // Next iteration's inter-CTA barrier will handle cross-CTA visibility/order.
}Also: the comment βEnsure all threads in leading CTA finish clearingβ¦β (Line 364) is slightly misleadingβthis Also applies to: 937-947, 955-965, 1136-1147, 1425-1437 π€ Prompt for AI Agents |
||
| continue; | ||
| } | ||
| } else if constexpr (MODE == RadixTopKMode::PageTableTransform) { | ||
|
|
@@ -906,6 +916,16 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( | |
| for (uint32_t i = tx; i < top_k_val; i += BLOCK_THREADS) { | ||
| row_output[i] = (i < length) ? src_page_entry[i] : static_cast<IdType>(-1); | ||
| } | ||
| // Clear histogram for next iteration | ||
| if constexpr (!SINGLE_CTA) { | ||
| constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8; | ||
| uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3; | ||
| if (cta_in_group == 0) { | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| state->histogram[next_first_hist_idx][i] = 0; | ||
| } | ||
| } | ||
| } | ||
| continue; | ||
| } | ||
| } else { // RaggedTransform | ||
|
|
@@ -914,6 +934,16 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified( | |
| for (uint32_t i = tx; i < top_k_val; i += BLOCK_THREADS) { | ||
| row_output[i] = (i < length) ? static_cast<IdType>(i) + offset : static_cast<IdType>(-1); | ||
| } | ||
| // Clear histogram for next iteration | ||
| if constexpr (!SINGLE_CTA) { | ||
| constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8; | ||
| uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3; | ||
| if (cta_in_group == 0) { | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| state->histogram[next_first_hist_idx][i] = 0; | ||
| } | ||
| } | ||
| } | ||
| continue; | ||
| } | ||
| } | ||
|
|
@@ -1084,6 +1114,19 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKMaskLogitsKernel_Multi | |
| masked_logits[row_idx * vocab_size + chunk_start + i] = | ||
| logits[row_idx * vocab_size + chunk_start + i]; | ||
| } | ||
|
|
||
| // Clear histogram for next iteration (in case it's k < vocab_size) | ||
| // Only needed for multi-CTA mode; single-CTA uses shared memory cleared each iteration | ||
| if constexpr (!SINGLE_CTA) { | ||
| constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8; // ORDERED_BITS / RADIX_BITS | ||
| uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3; | ||
| if (cta_in_group == 0) { | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| state->histogram[next_first_hist_idx][i] = 0; | ||
| } | ||
| } | ||
| // No sync needed - next iteration's barrier will ensure visibility | ||
| } | ||
| continue; | ||
| } | ||
|
|
||
|
|
@@ -1360,6 +1403,20 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKRenormProbKernel_Multi | |
| renormed_prob[row_idx * vocab_size + chunk_start + i] = | ||
| DType(float(probs[row_idx * vocab_size + chunk_start + i]) * normalizer); | ||
| } | ||
|
|
||
| // Clear histogram for next iteration (in case it's k < vocab_size) | ||
| // Only needed for multi-CTA mode; single-CTA uses shared memory cleared each iteration | ||
| // Next iteration (iter+1) will use histogram[((iter+1)*NUM_ROUNDS) % 3] for its first round | ||
| if constexpr (!SINGLE_CTA) { | ||
| constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8; // ORDERED_BITS / RADIX_BITS | ||
| uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3; | ||
| if (cta_in_group == 0) { | ||
| for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) { | ||
| state->histogram[next_first_hist_idx][i] = 0; | ||
| } | ||
| } | ||
| // No sync needed - next iteration's barrier will ensure visibility | ||
| } | ||
| continue; | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -513,6 +513,83 @@ def test_top_k_renorm_probs(batch_size, vocab_size, k, distribution, dtype): | |
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) | ||
| def test_top_k_renorm_probs_mixed_k_persistent_loop(dtype): | ||
| """Test top_k_renorm_probs with mixed k values in persistent loop (multi-CTA mode). | ||
|
|
||
| This test catches a specific bug where: | ||
| - Large batch size triggers the persistent loop (multiple iterations per CTA group) | ||
| - Large vocab_size triggers multi-CTA mode (multiple CTAs per row) | ||
| - Mixed k values: some rows have k >= vocab_size (skip radix select), | ||
| others have k < vocab_size (use radix select) | ||
|
|
||
| The bug was that k >= vocab_size iterations would skip radix select | ||
| without clearing the histogram buffers, leaving stale data that corrupted | ||
| subsequent k < vocab_size iterations. | ||
| """ | ||
| batch_size = 1024 # Large batch to trigger persistent loop | ||
| vocab_size = 128 * 1024 # Large vocab to trigger multi-CTA mode | ||
|
|
||
| torch.manual_seed(42) | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
|
|
||
| # Generate random logits | ||
| logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator) | ||
|
|
||
| # Generate k values: mix of small k and k == vocab_size | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
| k_values = torch.randint( | ||
| 1, 1000, (batch_size,), device="cuda:0", generator=generator | ||
| ) | ||
|
|
||
| # Randomly set some rows to k == vocab_size (about 50%) | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
| mask = torch.randint( | ||
| 0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0" | ||
| ) | ||
| k_values.masked_fill_(mask, vocab_size) | ||
|
Comment on lines
+533
to
+550
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The torch.manual_seed(42)
generator = torch.Generator(device="cuda:0").manual_seed(42)
# Generate random logits
logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)
# Generate k values: mix of small k and k == vocab_size
k_values = torch.randint(
1, 1000, (batch_size,), device="cuda:0", generator=generator
)
# Randomly set some rows to k == vocab_size (about 50%)
mask = torch.randint(
0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
)
k_values.masked_fill_(mask, vocab_size) |
||
|
|
||
| # Convert to probs | ||
| probs = torch.softmax(logits, dim=-1).to(dtype) | ||
|
|
||
| # Run FlashInfer top_k_renorm_probs | ||
| renorm_probs = flashinfer.sampling.top_k_renorm_probs(probs, k_values) | ||
|
|
||
| # Verify output dtype | ||
| assert renorm_probs.dtype == dtype | ||
|
|
||
| # Verify sum to 1 | ||
| sums = renorm_probs.float().sum(dim=-1) | ||
| torch.testing.assert_close(sums, torch.ones_like(sums), rtol=1e-2, atol=1e-2) | ||
|
|
||
| # Verify non-zero count matches k for each row | ||
| nonzero_counts = (renorm_probs > 0).sum(dim=-1) | ||
|
|
||
| # For rows with k >= vocab_size, all elements should be non-zero | ||
| # For rows with k < vocab_size, non-zero count should be >= k (may be more due to ties) | ||
| for i in range(batch_size): | ||
| k = k_values[i].item() | ||
| count = nonzero_counts[i].item() | ||
|
|
||
| if k >= vocab_size: | ||
| # All elements should be non-zero | ||
| assert count == vocab_size, ( | ||
| f"Row {i}: k >= vocab_size but count={count} != {vocab_size}" | ||
| ) | ||
| else: | ||
| # Count should be at least k (may be more due to ties at the threshold) | ||
| row_probs = probs[i].float() | ||
| topk_vals, _ = torch.topk(row_probs, k, sorted=True) | ||
| threshold = topk_vals[-1] | ||
| expected_ge_threshold = (row_probs >= threshold).sum().item() | ||
|
|
||
| # Allow small tolerance for floating point | ||
| assert count >= k - 1, f"Row {i}: k={k} but only {count} non-zero elements" | ||
| assert count <= expected_ge_threshold + 1, ( | ||
| f"Row {i}: k={k}, expected at most {expected_ge_threshold} but got {count}" | ||
| ) | ||
|
|
||
|
Comment on lines
+516
to
+591
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These new tests are likely to OOM / timeout in CI; also avoid
Proposed tightening (reduce OOM risk + reduce syncs) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_top_k_renorm_probs_mixed_k_persistent_loop(dtype):
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA required")
+ # Per tests guideline: consider using flashinfer.utils helpers to skip unsupported archs.
+
- batch_size = 1024 # Large batch to trigger persistent loop
- vocab_size = 128 * 1024 # Large vocab to trigger multi-CTA mode
+ # Still triggers multi-CTA (vocab_size > 57344) but avoids common OOM.
+ vocab_size = 65536
+ batch_size = 256 if dtype == torch.float32 else 512
torch.manual_seed(42)
generator = torch.Generator(device="cuda:0").manual_seed(42)
# Generate random logits
logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)
@@
- mask = torch.randint(
- 0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
- )
+ mask = torch.randint(0, 2, (batch_size,), generator=generator, device="cuda:0").bool()
k_values.masked_fill_(mask, vocab_size)
@@
- for i in range(batch_size):
+ # Only spot-check a small subset of k < vocab rows (avoid 1000x topk calls + .item() sync).
+ small_rows = torch.nonzero(k_values < vocab_size, as_tuple=False).flatten()[:16].tolist()
+ for i in small_rows:
k = k_values[i].item()
count = nonzero_counts[i].item()
@@
- if k >= vocab_size:
- # All elements should be non-zero
- assert count == vocab_size, (
- f"Row {i}: k >= vocab_size but count={count} != {vocab_size}"
- )
- else:
+ if k < vocab_size:
# Count should be at least k (may be more due to ties at the threshold)
row_probs = probs[i].float()
topk_vals, _ = torch.topk(row_probs, k, sorted=True)
threshold = topk_vals[-1]
expected_ge_threshold = (row_probs >= threshold).sum().item()π€ Prompt for AI Agents
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have enough GPU memory on CI.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yzh119, understood! Since you've confirmed the CI environment has sufficient GPU memory for these tests, that addresses the main concern. Feel free to keep the test as-is. π§ Learnings used |
||
|
|
||
| @pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) | ||
| @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) | ||
| @pytest.mark.parametrize("k", [10, 100, 500]) | ||
|
|
@@ -597,6 +674,68 @@ def test_top_k_mask_logits( | |
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) | ||
| def test_top_k_mask_logits_mixed_k_persistent_loop(dtype): | ||
| """Test top_k_mask_logits with mixed k values in persistent loop (multi-CTA mode). | ||
|
|
||
| This test catches the same bug as test_top_k_renorm_probs_mixed_k_persistent_loop | ||
| but for the mask_logits variant. | ||
| """ | ||
| batch_size = 1024 # Large batch to trigger persistent loop | ||
| vocab_size = 128 * 1024 # Large vocab to trigger multi-CTA mode | ||
|
|
||
| torch.manual_seed(42) | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
|
|
||
| # Generate random logits | ||
| logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator) | ||
| logits = logits.to(dtype) | ||
|
|
||
| # Generate k values: mix of small k and k == vocab_size | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
| k_values = torch.randint( | ||
| 1, 1000, (batch_size,), device="cuda:0", generator=generator | ||
| ) | ||
|
|
||
| # Randomly set some rows to k == vocab_size (about 50%) | ||
| generator = torch.Generator(device="cuda:0").manual_seed(42) | ||
| mask = torch.randint( | ||
| 0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0" | ||
| ) | ||
| k_values.masked_fill_(mask, vocab_size) | ||
|
Comment on lines
+687
to
+705
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The torch.manual_seed(42)
generator = torch.Generator(device="cuda:0").manual_seed(42)
# Generate random logits
logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)
logits = logits.to(dtype)
# Generate k values: mix of small k and k == vocab_size
k_values = torch.randint(
1, 1000, (batch_size,), device="cuda:0", generator=generator
)
# Randomly set some rows to k == vocab_size (about 50%)
mask = torch.randint(
0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
)
k_values.masked_fill_(mask, vocab_size) |
||
|
|
||
| # Run FlashInfer top_k_mask_logits | ||
| masked_logits = flashinfer.sampling.top_k_mask_logits(logits, k_values) | ||
|
|
||
| # Verify output dtype | ||
| assert masked_logits.dtype == dtype | ||
|
|
||
| # Verify finite count matches k for each row | ||
| finite_counts = torch.isfinite(masked_logits).sum(dim=-1) | ||
|
|
||
| for i in range(batch_size): | ||
| k = k_values[i].item() | ||
| count = finite_counts[i].item() | ||
|
|
||
| if k >= vocab_size: | ||
| # All elements should be finite | ||
| assert count == vocab_size, ( | ||
| f"Row {i}: k >= vocab_size but finite count={count} != {vocab_size}" | ||
| ) | ||
| else: | ||
| # Count should be at least k (may be more due to ties at the threshold) | ||
| row_logits = logits[i].float() | ||
| topk_vals, _ = torch.topk(row_logits, k, sorted=True) | ||
| threshold = topk_vals[-1] | ||
| expected_ge_threshold = (row_logits >= threshold).sum().item() | ||
|
|
||
| # Allow small tolerance for floating point | ||
| assert count >= k - 1, f"Row {i}: k={k} but only {count} finite elements" | ||
| assert count <= expected_ge_threshold + 1, ( | ||
| f"Row {i}: k={k}, expected at most {expected_ge_threshold} but got {count}" | ||
| ) | ||
|
|
||
|
Comment on lines
+677
to
+737
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same OOM/timeout concerns apply; also consider validating the For Minimal adjustments mirroring the other test def test_top_k_mask_logits_mixed_k_persistent_loop(dtype):
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA required")
batch_size = 1024 # Large batch to trigger persistent loop
vocab_size = 128 * 1024 # Large vocab to trigger multi-CTA mode
@@
- mask = torch.randint(
- 0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
- )
+ mask = torch.randint(0, 2, (batch_size,), generator=generator, device="cuda:0").bool()
k_values.masked_fill_(mask, vocab_size)
+
+ # Stronger + cheaper: rows with k>=vocab are copies
+ full_rows = (k_values >= vocab_size).nonzero(as_tuple=False).flatten()
+ if full_rows.numel() > 0:
+ torch.testing.assert_close(masked_logits[full_rows], logits[full_rows])
π€ Prompt for AI Agents
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @yzh119, understood! I've noted your response. |
||
|
|
||
| @pytest.mark.parametrize("batch_size", [1, 99, 989]) | ||
| @pytest.mark.parametrize("vocab_size", [111, 32000, 128256]) | ||
| @pytest.mark.parametrize("num_speculate_tokens", [1, 3, 5, 7]) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic to clear the histogram for the next iteration is duplicated in 5 places across
RadixTopKKernel_Unified,RadixTopKMaskLogitsKernel_MultiCTA, andRadixTopKRenormProbKernel_MultiCTA. To improve maintainability and reduce code duplication, consider refactoring this into a__device__ __forceinline__helper function.Also, the calculation for
NUM_ROUNDScan be simplified.sizeof(OrderedType) * 8 / 8is equivalent to justsizeof(OrderedType).Here's an example of how you could define and use the helper function:
This would make the code cleaner and easier to maintain.