Skip to content

[Core] Optimize top-k + top-p sampling by avoiding full vocabulary sort#32234

Open
nadavrot wants to merge 1 commit intovllm-project:mainfrom
nadavrot:export-D90454233
Open

[Core] Optimize top-k + top-p sampling by avoiding full vocabulary sort#32234
nadavrot wants to merge 1 commit intovllm-project:mainfrom
nadavrot:export-D90454233

Conversation

@nadavrot
Copy link
Copy Markdown

@nadavrot nadavrot commented Jan 13, 2026

Summary:
This diff accelerates sampling on targets that have a slow sort implementations. When both top-k and top-p filtering are applied, the current implementation sorts the entire vocabulary (204K tokens for LLama4). This diff adds a new optimized fast-path that uses topk() to extract only the top-k elements, then sorts only those, which is faster.

Batch Vocab k Baseline (ms) Optimized (ms) Speedup
1024 32,768 100 7.53 1.33 5.67x
1024 32,768 1000 7.53 1.55 4.87x
2048 131,072 100 57.59 7.70 7.48x
2048 131,072 1000 57.61 8.22 7.00x
2048 262,144 100 115.14 14.97 7.69x
2048 262,144 1000 115.09 15.55 7.40x

Signed-off-by: Nadav Rotem nrotem@meta.com

Test Plan:
Added two tests and ran the sample tests:

  • test_v1_sampler
  • test_v1_topk_topp_sampler
  • test_kernels_top_k_per_row

Reviewed By: diviramon, wushidonguc

Differential Revision: D90454233


Note

Performance: Optimizes apply_top_k_top_p when both k and p are set by introducing apply_top_k_and_top_p, which uses topk to extract only the top‑k logits and sorts that subset. Falls back to the general full‑sort path when any row has k == vocab_size.

  • Adds apply_top_k_and_top_p to mask via partial sort and scatter back in place, preserving behavior
  • Wires optimized path behind a max_k < vocab_size check in apply_top_k_top_p
  • Tests: new correctness and equivalence tests for the optimized path; extends existing sampler tests and device fixture in tests/v1/sample/test_topk_topp_sampler.py

Written by Cursor Bugbot for commit 7c66555248199fe60354f0cd839b19afcf41dc55. This will update automatically on new commits. Configure here.


Note

Improves combined top‑k + top‑p masking performance and adds tests to validate correctness and parity with the existing path.

  • Adds apply_top_k_and_top_p to mask by extracting topk logits, sorting only those, and scattering back in place; assumes max(k) < vocab_size
  • Wires optimized branch in apply_top_k_top_p guarded by max_k < vocab_size, otherwise falls back to full‑sort path
  • Tests: new correctness and equivalence tests for the optimized path; extends sampler test fixture and imports in tests/v1/sample/test_topk_topp_sampler.py

Written by Cursor Bugbot for commit bf7fa630d90ab4b5670c450409f9634e726c1d89. This will update automatically on new commits. Configure here.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimization for top-k + top-p sampling by avoiding a full vocabulary sort. The new implementation uses torch.topk to select the top-k candidates and then applies sorting and top-p filtering only on this smaller set, which significantly speeds up sampling for large vocabularies. The changes are well-implemented and include new tests for correctness and equivalence with the existing path. My review focuses on strengthening one of the new correctness tests to ensure it fully validates the top-p filtering logic.

Comment on lines +140 to +143
# Top-3 values are 5.0, 4.0, 3.0 (indices 4, 3, 2).
# After top-k mask, indices 0, 1 should be -inf.
assert result[0, 0] == float("-inf")
assert result[0, 1] == float("-inf")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This test is not fully verifying the top-p filtering logic. With the given inputs (k=3, p=0.9), the top-p filtering should mask an additional logit (3.0), but the current assertions only check the masking from top-k.

To ensure the correctness of the new optimized path, it's important to have a robust test that validates both top-k and top-p filtering. I'm suggesting more explicit assertions to verify the complete behavior.

Suggested change
# Top-3 values are 5.0, 4.0, 3.0 (indices 4, 3, 2).
# After top-k mask, indices 0, 1 should be -inf.
assert result[0, 0] == float("-inf")
assert result[0, 1] == float("-inf")
# After top-k, logits for 1.0 and 2.0 are -inf.
assert result[0, 0] == float("-inf")
assert result[0, 1] == float("-inf")
# After top-p, logit for 3.0 should also be -inf.
# Softmax of [3.0, 4.0, 5.0] is approx. [0.09, 0.24, 0.67].
# Cumsum is [0.09, 0.33, 1.0]. With p=0.9, 1-p=0.1.
# The first element (logit 3.0) is masked.
assert result[0, 2] == float("-inf")
# The other top-k values should not be masked.
assert result[0, 3] != float("-inf")
assert result[0, 4] != float("-inf")

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Jan 13, 2026

Hi @nadavrot, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Jan 13, 2026

Hi @nadavrot, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

sorted_values.masked_fill_(top_k_mask, -float("inf"))

# Apply top-p on the sorted values.
probs_sort = sorted_values.softmax(dim=-1)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to do softmax on the full logits and then gather based on sort_indices, or the probabilities will be incorrect.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry I guess maybe I'm wrong since in the existing impl it's applied to the already top-k masked logits. I wonder if that's actually correct/expected behaviour.

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Jan 13, 2026

Hi @nadavrot, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Jan 13, 2026

Hi @nadavrot, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

…m-project#32234)

Summary:

This diff accelerates sampling on targets that have a slow sort implementations. When both top-k and top-p filtering are applied, the current implementation sorts the entire vocabulary (204K tokens for LLama4). This diff adds a new optimized fast-path that uses topk() to extract only the top-k elements, then sorts only those, which is faster.

  | Batch |   Vocab |    k | Baseline (ms) | Optimized (ms) | Speedup |
  |-------|---------|------|---------------|----------------|---------|
  |  1024 |  32,768 |  100 |          7.53 |           1.33 |   5.67x |
  |  1024 |  32,768 | 1000 |          7.53 |           1.55 |   4.87x |
  |  2048 | 131,072 |  100 |         57.59 |           7.70 |   7.48x |
  |  2048 | 131,072 | 1000 |         57.61 |           8.22 |   7.00x |
  |  2048 | 262,144 |  100 |        115.14 |          14.97 |   7.69x |
  |  2048 | 262,144 | 1000 |        115.09 |          15.55 |   7.40x |


Signed-off-by: Nadav Rotem <nrotem@meta.com>

Test Plan:
Added two tests and ran the sample tests:
- test_v1_sampler
- test_v1_topk_topp_sampler
- test_kernels_top_k_per_row

Reviewed By: diviramon, wushidonguc

Differential Revision: D90454233
@nadavrot nadavrot changed the title Optimize top-k + top-p sampling by avoiding full vocabulary sort [Core] Optimize top-k + top-p sampling by avoiding full vocabulary sort Jan 13, 2026
Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! A few thoughts

# Use optimized path only when max_k is small enough to benefit.
# When k = vocab_size for any row, fall back to the general path.
vocab_size = logits.shape[1]
max_k = int(k.max().item())
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

item() might cause a sync that affects performance

fall back to the general path when any row has k = vocab_size.
"""
# Get the maximum k value to determine how many elements to extract.
max_k = int(k.max().item())
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

return apply_top_k_only(logits, k)

if p is not None and k is not None:
# Use optimized path only when max_k is small enough to benefit.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Use optimized path only when max_k is small enough to benefit.
# Use optimized path only when max_k < vocab_size

@njhill
Copy link
Copy Markdown
Member

njhill commented Jan 16, 2026

Thanks @nadavrot @yewentao256, yes I think we want to avoid item() now with async scheduling, so I don't think torch.topk is a good option anymore, even for the top_k-only case. I'm actually testing an alternative approach, will share later today.

@njhill
Copy link
Copy Markdown
Member

njhill commented Jan 20, 2026

Please see #32558!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants