Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions include/flashinfer/topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,16 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified(
input[row_idx * stride + chunk_start + i];
}
}
// Clear histogram for next iteration (in case it's k < length)
if constexpr (!SINGLE_CTA) {
constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8;
uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3;
if (cta_in_group == 0) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
state->histogram[next_first_hist_idx][i] = 0;
}
}
}
Comment on lines +900 to +909
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic to clear the histogram for the next iteration is duplicated in 5 places across RadixTopKKernel_Unified, RadixTopKMaskLogitsKernel_MultiCTA, and RadixTopKRenormProbKernel_MultiCTA. To improve maintainability and reduce code duplication, consider refactoring this into a __device__ __forceinline__ helper function.

Also, the calculation for NUM_ROUNDS can be simplified. sizeof(OrderedType) * 8 / 8 is equivalent to just sizeof(OrderedType).

Here's an example of how you could define and use the helper function:

// Placed somewhere before its first use, e.g., after RadixRowState struct
template <typename OrderedType, uint32_t BLOCK_THREADS>
__device__ __forceinline__ void ClearNextIterationHistogram(
    RadixRowState* state, uint32_t iter, uint32_t cta_in_group, uint32_t tx) {
  constexpr uint32_t RADIX = 256;
  constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType);
  uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3;
  if (cta_in_group == 0) {
    for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
      state->histogram[next_first_hist_idx][i] = 0;
    }
  }
}

// ... inside the kernel ...
// Clear histogram for next iteration (in case it's k < length)
if constexpr (!SINGLE_CTA) {
  ClearNextIterationHistogram<OrderedType, BLOCK_THREADS>(state, iter, cta_in_group, tx);
}

This would make the code cleaner and easier to maintain.

        // Clear histogram for next iteration (in case it's k < length)
        if constexpr (!SINGLE_CTA) {
          constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType);
          uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3;
          if (cta_in_group == 0) {
            for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
              state->histogram[next_first_hist_idx][i] = 0;
            }
          }
        }

Comment on lines +900 to +909
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | πŸ”΄ Critical

Add an intra-CTA sync after β€œnext-iteration histogram clear” to avoid partial clears being β€œreleased” early.

In these blocks, multiple threads in CTA0 write state->histogram[...], but there’s no __syncthreads() to prevent CTA0’s tx==0 from running ahead into the next iteration and signaling the next inter-CTA barrier before all clears finish (similar to the issue you already fixed in RadixSelectOneRound).

Minimal pattern to apply in each location
         if (cta_in_group == 0) {
           for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
             state->histogram[next_first_hist_idx][i] = 0;
           }
+          __syncthreads();  // ensure CTA0 finished clearing before the next iteration can signal
         }
-        // No sync needed - next iteration's barrier will ensure visibility
+        // Next iteration's inter-CTA barrier will handle cross-CTA visibility/order.
       }

Also: the comment β€œEnsure all threads in leading CTA finish clearing…” (Line 364) is slightly misleadingβ€”this __syncthreads() is also ensuring all threads finish the atomicAdds before tx==0 signals the barrier.

Also applies to: 937-947, 955-965, 1136-1147, 1425-1437

πŸ€– Prompt for AI Agents
In @include/flashinfer/topk.cuh around lines 918 - 927, Multiple threads in the
leading CTA write state->histogram for the next iteration without an intra-CTA
barrier, allowing tx==0 to proceed and signal the inter-CTA barrier before all
clears complete; in each affected block (the branches guarded by if constexpr
(!SINGLE_CTA) that compute NUM_ROUNDS, next_first_hist_idx and clear
state->histogram using tx and BLOCK_THREADS) insert a __syncthreads() after the
clearing loop to ensure all threads in the CTA finish clearing (and any
preceding atomicAdds) before tx==0 or other threads can advance to the inter-CTA
synchronization; apply the same insertion to the other listed locations (blocks
around lines 937-947, 955-965, 1136-1147, 1425-1437) referencing the same
symbols (state->histogram, tx, cta_in_group, BLOCK_THREADS, iter, NUM_ROUNDS).

continue;
}
} else if constexpr (MODE == RadixTopKMode::PageTableTransform) {
Expand All @@ -906,6 +916,16 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified(
for (uint32_t i = tx; i < top_k_val; i += BLOCK_THREADS) {
row_output[i] = (i < length) ? src_page_entry[i] : static_cast<IdType>(-1);
}
// Clear histogram for next iteration
if constexpr (!SINGLE_CTA) {
constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8;
uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3;
if (cta_in_group == 0) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
state->histogram[next_first_hist_idx][i] = 0;
}
}
}
continue;
}
} else { // RaggedTransform
Expand All @@ -914,6 +934,16 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKKernel_Unified(
for (uint32_t i = tx; i < top_k_val; i += BLOCK_THREADS) {
row_output[i] = (i < length) ? static_cast<IdType>(i) + offset : static_cast<IdType>(-1);
}
// Clear histogram for next iteration
if constexpr (!SINGLE_CTA) {
constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8;
uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3;
if (cta_in_group == 0) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
state->histogram[next_first_hist_idx][i] = 0;
}
}
}
continue;
}
}
Expand Down Expand Up @@ -1084,6 +1114,19 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKMaskLogitsKernel_Multi
masked_logits[row_idx * vocab_size + chunk_start + i] =
logits[row_idx * vocab_size + chunk_start + i];
}

// Clear histogram for next iteration (in case it's k < vocab_size)
// Only needed for multi-CTA mode; single-CTA uses shared memory cleared each iteration
if constexpr (!SINGLE_CTA) {
constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8; // ORDERED_BITS / RADIX_BITS
uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3;
if (cta_in_group == 0) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
state->histogram[next_first_hist_idx][i] = 0;
}
}
// No sync needed - next iteration's barrier will ensure visibility
}
continue;
}

Expand Down Expand Up @@ -1360,6 +1403,20 @@ __global__ void __launch_bounds__(BLOCK_THREADS) RadixTopKRenormProbKernel_Multi
renormed_prob[row_idx * vocab_size + chunk_start + i] =
DType(float(probs[row_idx * vocab_size + chunk_start + i]) * normalizer);
}

// Clear histogram for next iteration (in case it's k < vocab_size)
// Only needed for multi-CTA mode; single-CTA uses shared memory cleared each iteration
// Next iteration (iter+1) will use histogram[((iter+1)*NUM_ROUNDS) % 3] for its first round
if constexpr (!SINGLE_CTA) {
constexpr uint32_t NUM_ROUNDS = sizeof(OrderedType) * 8 / 8; // ORDERED_BITS / RADIX_BITS
uint32_t next_first_hist_idx = ((iter + 1) * NUM_ROUNDS) % 3;
if (cta_in_group == 0) {
for (uint32_t i = tx; i < RADIX; i += BLOCK_THREADS) {
state->histogram[next_first_hist_idx][i] = 0;
}
}
// No sync needed - next iteration's barrier will ensure visibility
}
continue;
}

Expand Down
139 changes: 139 additions & 0 deletions tests/utils/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,83 @@ def test_top_k_renorm_probs(batch_size, vocab_size, k, distribution, dtype):
)


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_top_k_renorm_probs_mixed_k_persistent_loop(dtype):
"""Test top_k_renorm_probs with mixed k values in persistent loop (multi-CTA mode).

This test catches a specific bug where:
- Large batch size triggers the persistent loop (multiple iterations per CTA group)
- Large vocab_size triggers multi-CTA mode (multiple CTAs per row)
- Mixed k values: some rows have k >= vocab_size (skip radix select),
others have k < vocab_size (use radix select)

The bug was that k >= vocab_size iterations would skip radix select
without clearing the histogram buffers, leaving stale data that corrupted
subsequent k < vocab_size iterations.
"""
batch_size = 1024 # Large batch to trigger persistent loop
vocab_size = 128 * 1024 # Large vocab to trigger multi-CTA mode

torch.manual_seed(42)
generator = torch.Generator(device="cuda:0").manual_seed(42)

# Generate random logits
logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)

# Generate k values: mix of small k and k == vocab_size
generator = torch.Generator(device="cuda:0").manual_seed(42)
k_values = torch.randint(
1, 1000, (batch_size,), device="cuda:0", generator=generator
)

# Randomly set some rows to k == vocab_size (about 50%)
generator = torch.Generator(device="cuda:0").manual_seed(42)
mask = torch.randint(
0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
)
k_values.masked_fill_(mask, vocab_size)
Comment on lines +533 to +550
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The torch.Generator is created and seeded multiple times within this test function. You can create it just once at the beginning of the test and reuse it for all random tensor creations to make the code cleaner and avoid redundancy.

    torch.manual_seed(42)
    generator = torch.Generator(device="cuda:0").manual_seed(42)

    # Generate random logits
    logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)

    # Generate k values: mix of small k and k == vocab_size
    k_values = torch.randint(
        1, 1000, (batch_size,), device="cuda:0", generator=generator
    )

    # Randomly set some rows to k == vocab_size (about 50%)
    mask = torch.randint(
        0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
    )
    k_values.masked_fill_(mask, vocab_size)


# Convert to probs
probs = torch.softmax(logits, dim=-1).to(dtype)

# Run FlashInfer top_k_renorm_probs
renorm_probs = flashinfer.sampling.top_k_renorm_probs(probs, k_values)

# Verify output dtype
assert renorm_probs.dtype == dtype

# Verify sum to 1
sums = renorm_probs.float().sum(dim=-1)
torch.testing.assert_close(sums, torch.ones_like(sums), rtol=1e-2, atol=1e-2)

# Verify non-zero count matches k for each row
nonzero_counts = (renorm_probs > 0).sum(dim=-1)

# For rows with k >= vocab_size, all elements should be non-zero
# For rows with k < vocab_size, non-zero count should be >= k (may be more due to ties)
for i in range(batch_size):
k = k_values[i].item()
count = nonzero_counts[i].item()

if k >= vocab_size:
# All elements should be non-zero
assert count == vocab_size, (
f"Row {i}: k >= vocab_size but count={count} != {vocab_size}"
)
else:
# Count should be at least k (may be more due to ties at the threshold)
row_probs = probs[i].float()
topk_vals, _ = torch.topk(row_probs, k, sorted=True)
threshold = topk_vals[-1]
expected_ge_threshold = (row_probs >= threshold).sum().item()

# Allow small tolerance for floating point
assert count >= k - 1, f"Row {i}: k={k} but only {count} non-zero elements"
assert count <= expected_ge_threshold + 1, (
f"Row {i}: k={k}, expected at most {expected_ge_threshold} but got {count}"
)

Comment on lines +516 to +591
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Jan 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

These new tests are likely to OOM / timeout in CI; also avoid torch.randint(..., dtype=torch.bool) here.

  • Memory: 1024 x 131072 for logits + probs + output is very likely to exceed GPU memory (especially for float32).
  • Runtime: the for i in range(batch_size) loop does repeated GPU syncs via .item() and runs torch.topk up to 1024 times.
  • torch.randint(..., dtype=torch.bool, ...) is risky; prefer generating int/binary then .bool() (or torch.rand()<p).
  • Per tests/**/*.py guidelines, please add flashinfer.utils-based skipping for unsupported GPUs and size defensively.
Proposed tightening (reduce OOM risk + reduce syncs)
 @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
 def test_top_k_renorm_probs_mixed_k_persistent_loop(dtype):
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA required")
+    # Per tests guideline: consider using flashinfer.utils helpers to skip unsupported archs.
+
-    batch_size = 1024  # Large batch to trigger persistent loop
-    vocab_size = 128 * 1024  # Large vocab to trigger multi-CTA mode
+    # Still triggers multi-CTA (vocab_size > 57344) but avoids common OOM.
+    vocab_size = 65536
+    batch_size = 256 if dtype == torch.float32 else 512

     torch.manual_seed(42)
     generator = torch.Generator(device="cuda:0").manual_seed(42)

     # Generate random logits
     logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)

@@
-    mask = torch.randint(
-        0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
-    )
+    mask = torch.randint(0, 2, (batch_size,), generator=generator, device="cuda:0").bool()
     k_values.masked_fill_(mask, vocab_size)

@@
-    for i in range(batch_size):
+    # Only spot-check a small subset of k < vocab rows (avoid 1000x topk calls + .item() sync).
+    small_rows = torch.nonzero(k_values < vocab_size, as_tuple=False).flatten()[:16].tolist()
+    for i in small_rows:
         k = k_values[i].item()
         count = nonzero_counts[i].item()
@@
-        if k >= vocab_size:
-            # All elements should be non-zero
-            assert count == vocab_size, (
-                f"Row {i}: k >= vocab_size but count={count} != {vocab_size}"
-            )
-        else:
+        if k < vocab_size:
             # Count should be at least k (may be more due to ties at the threshold)
             row_probs = probs[i].float()
             topk_vals, _ = torch.topk(row_probs, k, sorted=True)
             threshold = topk_vals[-1]
             expected_ge_threshold = (row_probs >= threshold).sum().item()
πŸ€– Prompt for AI Agents
In @tests/utils/test_sampling.py around lines 516 - 591, The test
test_top_k_renorm_probs_mixed_k_persistent_loop allocates huge tensors and does
many GPU-syncing per-row ops which will OOM/timeout and uses torch.randint(...,
dtype=torch.bool) which is unsafe; reduce memory and syncs by shrinking
batch_size and vocab_size to CI-safe values, generate mask with
torch.randint(0,2,(batch_size,), device=..., generator=...).bool() or
torch.rand(device=...) < 0.5 instead of dtype=torch.bool, replace the Python
for-loop that calls .item() and per-row torch.topk with vectorized tensor
operations to compute nonzero_counts and thresholds (operate on entire tensors
and use torch.topk once/batched or torch.kthvalue), and add a conditional skip
using flashinfer.utils (e.g., skip_if_gpu_too_small or similar) at the top of
the test to avoid running on unsupported GPUs.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have enough GPU memory on CI.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh119, understood! Since you've confirmed the CI environment has sufficient GPU memory for these tests, that addresses the main concern. Feel free to keep the test as-is.


🧠 Learnings used
Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Avoid OOM (out-of-memory) errors in tests by using appropriate problem sizes - `tests/conftest.py` provides auto-skipping for OOM tests as a safety net but should not be relied upon

Learnt from: CR
Repo: flashinfer-ai/flashinfer PR: 0
File: CLAUDE.md:0-0
Timestamp: 2025-12-30T09:34:39.900Z
Learning: Applies to tests/**/*.py : Test implementations should use `flashinfer.utils` functions (`get_compute_capability`, `is_sm90a_supported`, `is_sm100a_supported`, etc.) to skip tests on unsupported GPU architectures


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("k", [10, 100, 500])
Expand Down Expand Up @@ -597,6 +674,68 @@ def test_top_k_mask_logits(
)


@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_top_k_mask_logits_mixed_k_persistent_loop(dtype):
"""Test top_k_mask_logits with mixed k values in persistent loop (multi-CTA mode).

This test catches the same bug as test_top_k_renorm_probs_mixed_k_persistent_loop
but for the mask_logits variant.
"""
batch_size = 1024 # Large batch to trigger persistent loop
vocab_size = 128 * 1024 # Large vocab to trigger multi-CTA mode

torch.manual_seed(42)
generator = torch.Generator(device="cuda:0").manual_seed(42)

# Generate random logits
logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)
logits = logits.to(dtype)

# Generate k values: mix of small k and k == vocab_size
generator = torch.Generator(device="cuda:0").manual_seed(42)
k_values = torch.randint(
1, 1000, (batch_size,), device="cuda:0", generator=generator
)

# Randomly set some rows to k == vocab_size (about 50%)
generator = torch.Generator(device="cuda:0").manual_seed(42)
mask = torch.randint(
0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
)
k_values.masked_fill_(mask, vocab_size)
Comment on lines +687 to +705
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The torch.Generator is created and seeded multiple times within this test function. You can create it just once at the beginning of the test and reuse it for all random tensor creations to make the code cleaner and avoid redundancy.

    torch.manual_seed(42)
    generator = torch.Generator(device="cuda:0").manual_seed(42)

    # Generate random logits
    logits = torch.rand((batch_size, vocab_size), device="cuda:0", generator=generator)
    logits = logits.to(dtype)

    # Generate k values: mix of small k and k == vocab_size
    k_values = torch.randint(
        1, 1000, (batch_size,), device="cuda:0", generator=generator
    )

    # Randomly set some rows to k == vocab_size (about 50%)
    mask = torch.randint(
        0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
    )
    k_values.masked_fill_(mask, vocab_size)


# Run FlashInfer top_k_mask_logits
masked_logits = flashinfer.sampling.top_k_mask_logits(logits, k_values)

# Verify output dtype
assert masked_logits.dtype == dtype

# Verify finite count matches k for each row
finite_counts = torch.isfinite(masked_logits).sum(dim=-1)

for i in range(batch_size):
k = k_values[i].item()
count = finite_counts[i].item()

if k >= vocab_size:
# All elements should be finite
assert count == vocab_size, (
f"Row {i}: k >= vocab_size but finite count={count} != {vocab_size}"
)
else:
# Count should be at least k (may be more due to ties at the threshold)
row_logits = logits[i].float()
topk_vals, _ = torch.topk(row_logits, k, sorted=True)
threshold = topk_vals[-1]
expected_ge_threshold = (row_logits >= threshold).sum().item()

# Allow small tolerance for floating point
assert count >= k - 1, f"Row {i}: k={k} but only {count} finite elements"
assert count <= expected_ge_threshold + 1, (
f"Row {i}: k={k}, expected at most {expected_ge_threshold} but got {count}"
)

Comment on lines +677 to +737
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Jan 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Same OOM/timeout concerns apply; also consider validating the k>=vocab branch via equality rather than counts.

For top_k_mask_logits, when k >= vocab_size the kernel should be a pure copy; asserting masked_logits == logits for those rows is typically a stronger (and cheaper) check than counting finites, and avoids edge cases if inputs change (e.g., allow -inf later).

Minimal adjustments mirroring the other test
 def test_top_k_mask_logits_mixed_k_persistent_loop(dtype):
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA required")
     batch_size = 1024  # Large batch to trigger persistent loop
     vocab_size = 128 * 1024  # Large vocab to trigger multi-CTA mode
@@
-    mask = torch.randint(
-        0, 2, (batch_size,), generator=generator, dtype=torch.bool, device="cuda:0"
-    )
+    mask = torch.randint(0, 2, (batch_size,), generator=generator, device="cuda:0").bool()
     k_values.masked_fill_(mask, vocab_size)
+
+    # Stronger + cheaper: rows with k>=vocab are copies
+    full_rows = (k_values >= vocab_size).nonzero(as_tuple=False).flatten()
+    if full_rows.numel() > 0:
+        torch.testing.assert_close(masked_logits[full_rows], logits[full_rows])

Committable suggestion skipped: line range outside the PR's diff.

πŸ€– Prompt for AI Agents
In @tests/utils/test_sampling.py around lines 677 - 737, The test
test_top_k_mask_logits_mixed_k_persistent_loop currently checks k>=vocab_size by
counting finite values which is weaker and can miss cases; update the
k>=vocab_size branch to assert a direct equality copy for those rows by
comparing masked_logits[i] == logits[i] (or .equal() / torch.allclose after
casting to the same dtype) for each i where k_values[i] >= vocab_size,
referencing flashinfer.sampling.top_k_mask_logits, masked_logits, logits, and
k_values to locate the code.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh119, understood! I've noted your response.


@pytest.mark.parametrize("batch_size", [1, 99, 989])
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
@pytest.mark.parametrize("num_speculate_tokens", [1, 3, 5, 7])
Expand Down