[Core] Optimize top-k + top-p sampling by avoiding full vocabulary sort#32234
[Core] Optimize top-k + top-p sampling by avoiding full vocabulary sort#32234nadavrot wants to merge 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces an optimization for top-k + top-p sampling by avoiding a full vocabulary sort. The new implementation uses torch.topk to select the top-k candidates and then applies sorting and top-p filtering only on this smaller set, which significantly speeds up sampling for large vocabularies. The changes are well-implemented and include new tests for correctness and equivalence with the existing path. My review focuses on strengthening one of the new correctness tests to ensure it fully validates the top-p filtering logic.
| # Top-3 values are 5.0, 4.0, 3.0 (indices 4, 3, 2). | ||
| # After top-k mask, indices 0, 1 should be -inf. | ||
| assert result[0, 0] == float("-inf") | ||
| assert result[0, 1] == float("-inf") |
There was a problem hiding this comment.
This test is not fully verifying the top-p filtering logic. With the given inputs (k=3, p=0.9), the top-p filtering should mask an additional logit (3.0), but the current assertions only check the masking from top-k.
To ensure the correctness of the new optimized path, it's important to have a robust test that validates both top-k and top-p filtering. I'm suggesting more explicit assertions to verify the complete behavior.
| # Top-3 values are 5.0, 4.0, 3.0 (indices 4, 3, 2). | |
| # After top-k mask, indices 0, 1 should be -inf. | |
| assert result[0, 0] == float("-inf") | |
| assert result[0, 1] == float("-inf") | |
| # After top-k, logits for 1.0 and 2.0 are -inf. | |
| assert result[0, 0] == float("-inf") | |
| assert result[0, 1] == float("-inf") | |
| # After top-p, logit for 3.0 should also be -inf. | |
| # Softmax of [3.0, 4.0, 5.0] is approx. [0.09, 0.24, 0.67]. | |
| # Cumsum is [0.09, 0.33, 1.0]. With p=0.9, 1-p=0.1. | |
| # The first element (logit 3.0) is masked. | |
| assert result[0, 2] == float("-inf") | |
| # The other top-k values should not be masked. | |
| assert result[0, 3] != float("-inf") | |
| assert result[0, 4] != float("-inf") |
|
Hi @nadavrot, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
7c66555 to
bf7fa63
Compare
|
Hi @nadavrot, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
| sorted_values.masked_fill_(top_k_mask, -float("inf")) | ||
|
|
||
| # Apply top-p on the sorted values. | ||
| probs_sort = sorted_values.softmax(dim=-1) |
There was a problem hiding this comment.
You need to do softmax on the full logits and then gather based on sort_indices, or the probabilities will be incorrect.
There was a problem hiding this comment.
Oh sorry I guess maybe I'm wrong since in the existing impl it's applied to the already top-k masked logits. I wonder if that's actually correct/expected behaviour.
bf7fa63 to
49ed98a
Compare
|
Hi @nadavrot, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
49ed98a to
5719beb
Compare
5719beb to
cdd036e
Compare
|
Hi @nadavrot, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
cdd036e to
4901049
Compare
…m-project#32234) Summary: This diff accelerates sampling on targets that have a slow sort implementations. When both top-k and top-p filtering are applied, the current implementation sorts the entire vocabulary (204K tokens for LLama4). This diff adds a new optimized fast-path that uses topk() to extract only the top-k elements, then sorts only those, which is faster. | Batch | Vocab | k | Baseline (ms) | Optimized (ms) | Speedup | |-------|---------|------|---------------|----------------|---------| | 1024 | 32,768 | 100 | 7.53 | 1.33 | 5.67x | | 1024 | 32,768 | 1000 | 7.53 | 1.55 | 4.87x | | 2048 | 131,072 | 100 | 57.59 | 7.70 | 7.48x | | 2048 | 131,072 | 1000 | 57.61 | 8.22 | 7.00x | | 2048 | 262,144 | 100 | 115.14 | 14.97 | 7.69x | | 2048 | 262,144 | 1000 | 115.09 | 15.55 | 7.40x | Signed-off-by: Nadav Rotem <nrotem@meta.com> Test Plan: Added two tests and ran the sample tests: - test_v1_sampler - test_v1_topk_topp_sampler - test_kernels_top_k_per_row Reviewed By: diviramon, wushidonguc Differential Revision: D90454233
4901049 to
f787967
Compare
yewentao256
left a comment
There was a problem hiding this comment.
Thanks for the work! A few thoughts
| # Use optimized path only when max_k is small enough to benefit. | ||
| # When k = vocab_size for any row, fall back to the general path. | ||
| vocab_size = logits.shape[1] | ||
| max_k = int(k.max().item()) |
There was a problem hiding this comment.
item() might cause a sync that affects performance
| fall back to the general path when any row has k = vocab_size. | ||
| """ | ||
| # Get the maximum k value to determine how many elements to extract. | ||
| max_k = int(k.max().item()) |
| return apply_top_k_only(logits, k) | ||
|
|
||
| if p is not None and k is not None: | ||
| # Use optimized path only when max_k is small enough to benefit. |
There was a problem hiding this comment.
| # Use optimized path only when max_k is small enough to benefit. | |
| # Use optimized path only when max_k < vocab_size |
|
Thanks @nadavrot @yewentao256, yes I think we want to avoid |
|
Please see #32558! |
Summary:
This diff accelerates sampling on targets that have a slow sort implementations. When both top-k and top-p filtering are applied, the current implementation sorts the entire vocabulary (204K tokens for LLama4). This diff adds a new optimized fast-path that uses topk() to extract only the top-k elements, then sorts only those, which is faster.
Signed-off-by: Nadav Rotem nrotem@meta.com
Test Plan:
Added two tests and ran the sample tests:
Reviewed By: diviramon, wushidonguc
Differential Revision: D90454233
Note
Performance: Optimizes
apply_top_k_top_pwhen bothkandpare set by introducingapply_top_k_and_top_p, which usestopkto extract only the top‑k logits and sorts that subset. Falls back to the general full‑sort path when any row hask == vocab_size.apply_top_k_and_top_pto mask via partial sort and scatter back in place, preserving behaviormax_k < vocab_sizecheck inapply_top_k_top_ptests/v1/sample/test_topk_topp_sampler.pyWritten by Cursor Bugbot for commit 7c66555248199fe60354f0cd839b19afcf41dc55. This will update automatically on new commits. Configure here.
Note
Improves combined top‑k + top‑p masking performance and adds tests to validate correctness and parity with the existing path.
apply_top_k_and_top_pto mask by extractingtopklogits, sorting only those, and scattering back in place; assumesmax(k) < vocab_sizeapply_top_k_top_pguarded bymax_k < vocab_size, otherwise falls back to full‑sort pathtests/v1/sample/test_topk_topp_sampler.pyWritten by Cursor Bugbot for commit bf7fa630d90ab4b5670c450409f9634e726c1d89. This will update automatically on new commits. Configure here.