Skip to content

Conversation

@hyeygit
Copy link
Contributor

@hyeygit hyeygit commented Mar 29, 2025

Top-k and top-p are slow on TPU because existing algorithms use torch.scatter. For some reason torch.scatter is extremely slow on TPU. There's ongoing work to optimize it, but until that's done, we need an alternative algorithm that circumvents scattering.

The algorithm in this PR avoids torch.scatter by finding a "cut-off" element in the original logit, and after thresholding the logit using this cut-off, the remaining elements shall constitute the top-p set. This is inspired by the apply_top_k_only algorithm created by @njhill in #15478.

Benchmark

Microbenchmark (on v6e-1) shows significant speed up -- "Running 32 elapsed time" is ~5 ms, down from the original scatter-based algorithm's ~500 ms, a 100x improvement.

Microbenchmark full results on v6e-1
$ VLLM_USE_V1=1 python sampler_microbenchmark.py 
INFO 03-31 14:40:10 [__init__.py:239] Automatically detected platform tpu.
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
INFO 03-31 14:40:15 [topk_topp_sampler.py:82] Using approximate top-p optimized for TPU. Result may in theory differ from the exact algorithm if there are tokens with near-identical probabilities (< 1e-9 diff).
Compiling/Warmup 1 elapsed time: 9.270433902740479
Compiling/Warmup 4 elapsed time: 9.104885816574097
Compiling/Warmup 16 elapsed time: 8.811976194381714
Compiling/Warmup 32 elapsed time: 8.926635026931763
Running 1 elapsed time: 0.004515171051025391
Running 1 elapsed time: 0.003937482833862305
Running 1 elapsed time: 0.003930091857910156
Running 1 elapsed time: 0.0038993358612060547
Average time:  0.0040705204010009766
Running 4 elapsed time: 0.0042819976806640625
Running 4 elapsed time: 0.004051685333251953
Running 4 elapsed time: 0.00403141975402832
Running 4 elapsed time: 0.004080057144165039
Average time:  0.004111289978027344
Running 16 elapsed time: 0.00475311279296875
Running 16 elapsed time: 0.0045032501220703125
Running 16 elapsed time: 0.0044858455657958984
Running 16 elapsed time: 0.19616365432739258
Average time:  0.052476465702056885
Running 32 elapsed time: 0.00586247444152832
Running 32 elapsed time: 0.005380868911743164
Running 32 elapsed time: 0.005262851715087891
Running 32 elapsed time: 0.0052433013916015625
Average time:  0.005437374114990234

Extra notes

The VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION env (introduced in #15242) can now be removed. Not done in this PR since @NickLucche's pending PR #15489 already handles it (thanks!).

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 29, 2025
@hyeygit hyeygit force-pushed the tpu_topp branch 4 times, most recently from 15534fe to 00ab67b Compare March 30, 2025 18:53
@mergify mergify bot added the tpu Related to Google TPUs label Mar 30, 2025
@hyeygit hyeygit changed the title [V1][TPU] Speed up top-p for TPU by avoiding scattering. [V1][TPU] TPU-optimized top-p implementation (avoids scattering). Mar 30, 2025
@hyeygit hyeygit marked this pull request as ready for review March 30, 2025 19:06
Copy link

@brittrock brittrock left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, thank you for adding this @hyeygit, including the notes on expected speedups. I just have minor usability feedback :)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @hyeygit , this seems reasonable to me in the interim but I'll let the other folks chime in on the appropriateness. Cc @yaochengji @yarongmu-google

Given this can slightly impact the generated output during ties, this really feels like something we should be warning the user about. Not every time the function is called, of course, but at a minimum, we should be warning users when the argument is set to anything other than the default. I couldn't find a warning log but I'm also on my phone, so apologies if I just missed it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @brittrock . I added a one-time log message about this algorithm being approx in theory.

In practice, I think the tiny 1e-9 probability perturbation doesn't alter the result in any meaningful way. The only situation where the output differs from the exact algo is if there are multiple tokens whose probabilities are within 1e-9 (one in a billion) of each other. This means they practically have the same probability, so including either one of them in the top-p set should be acceptable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for adding!

I agree in practice, probably ok, but this could break accuracy tests and so good idea to include in any case.

nice job, again!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I signed in from my phone and must have created another github account >.<

ignore my alter ego's request for review @hyeygit 😆

Copy link
Contributor Author

@hyeygit hyeygit Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree in practice, probably ok, but this could break accuracy tests and so good idea to include in any case.

Agreed, makes sense!

ignore my alter ego's request for review @hyeygit 😆

Haha no worries!

Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job here @hyeygit ! Left some comments about tests.

Please remember to enable topp like I've done for k here https://github.com/vllm-project/vllm/pull/15489/files.
Otherwise I can enable both in the same PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if the test is under v1/tpu we shouldn't test for CUDA but skip if platform is not tpu. Otherwise we can move the test into the shared directory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep makes sense. Updated to TPU only.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice test!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Credit to @njhill's apply_top_k_only test!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add this test to run-tpu-v1-test.sh?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 142 to 143
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this top-p (to be used as topp-only) implementation also needed on gpu? @njhill

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC scattering isn't a bottleneck on GPU so this impl wouldn't bring much benefit (plus, this impl still involves a full vocab sort same as the forward_native version).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point

@hyeygit hyeygit mentioned this pull request Mar 31, 2025
@mergify mergify bot added the ci/build label Mar 31, 2025
@mergify
Copy link

mergify bot commented Mar 31, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @hyeygit.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 31, 2025
@hyeygit
Copy link
Contributor Author

hyeygit commented Mar 31, 2025

Please remember to enable topp like I've done for k here https://github.com/vllm-project/vllm/pull/15489/files. Otherwise I can enable both in the same PR.

It might be cleaner to enable top-p in your PR (perhaps rename the PR to Enable both Top-k and Top-p) since it's blocked by this and the to-be-sent-out top-k PR. Let me know if that sounds alright! @NickLucche

@NickLucche
Copy link
Collaborator

@hyeygit Works for me!
I may still do the enablement one at a time to track benchmarks.

@njhill
Copy link
Member

njhill commented Mar 31, 2025

Thanks @hyeygit this looks good. However I don't really understand the need for the random perturbation / tiebreaking.

The behaviour without doing this will already be effectively the same as an arbitrary tie-break. The random pertubation doesn't even make it deterministic. If we want it to be deterministic we can set stable=True in the sort operation (but not suggesting that's needed).

Also this implicit tie-breaking behaviour is the same in the existing implementation, I don't see how it's something specific to the new one.

Re your related comment on the top_k PR:

However I think one corner case where this would break is if there are duplicate elements in the logit that equal the cut off value (i.e. top_k_mask). For example, given an input of [1, 2, 2, 2, 3] and k=3, the current apply_top_k_only would return [-inf, 2, 2, 2, 3] while the correct result should be [-inf, -inf, 2, 2, 3].

Again this is no different to how the pre-existing impl works. And I think it's ok for the actual shortlist to comprise more than k tokens in the case that there's a tie for the kth highest probability. I would not characterize that as "incorrect" since this case is not really well-defined. And it's reasonable from an intuition pov since the tied tokens have equal likelihood.

So hopefully this PR can be simplified to remove that part?

It would be interesting to also test whether this is meaningfully faster on GPUs. I assume most of the overhead is the sort but if this is even slightly faster we might as well change the existing impl to do the count + mask approach.

Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@mgoin
Copy link
Member

mgoin commented Apr 2, 2025

Please see the failing v1 test, it looks like the cuda sampler tests are failing https://buildkite.com/vllm/ci/builds/16819/steps?jid=0195f7bd-4a4f-4517-af78-dc4a6772ba71

Signed-off-by: Hyesoo Yang <[email protected]>
# topk.values tensor has shape [batch_size, max_top_k].
# Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1)
k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hyeygit is this because of a TPU torch broadcasting limitation?

Copy link
Contributor Author

@hyeygit hyeygit Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think so. Without the explicit expand this fails on XLA due to shape mismatch.

@hyeygit
Copy link
Contributor Author

hyeygit commented Apr 2, 2025

Please see the failing v1 test, it looks like the cuda sampler tests are failing https://buildkite.com/vllm/ci/builds/16819/steps?jid=0195f7bd-4a4f-4517-af78-dc4a6772ba71

Oh this is probably caused by my incorrect rebase -- had some duplicate lines in the sampler. After resolving the conflicts the tests seem to pass.

Thanks for the approval!

@robertgshaw2-redhat robertgshaw2-redhat merged commit 1b84eff into vllm-project:main Apr 3, 2025
31 checks passed
hyeygit added a commit to hyeygit/vllm that referenced this pull request Apr 3, 2025
Previously we found that using torch.topk resulted in significant
speed up for TPU. Turns out that's not a viable solution because the
return shape of torch.topk depends on k, which means an XLA recompilation
is triggered everytime k changes.

Additionally, we realized that torch.scatter was the main bottleneck for
the original top-k impl on TPU. This PR circumvents both problems by using
a threshold-based approach to find the top-k set. The algorithm is nearly
identical to that of top-p; see vllm-project#15736 for more details.

Signed-off-by: Hyesoo Yang <[email protected]>
Alex4210987 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Apr 5, 2025
…lm-project#15736)

Signed-off-by: Hyesoo Yang <[email protected]>
Co-authored-by: root <root@t1v-n-822696b7-w-0.us-central2-b.c.tpu-prod-env-large-adhoc.internal>
Signed-off-by: xinyuxiao <[email protected]>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
…lm-project#15736)

Signed-off-by: Hyesoo Yang <[email protected]>
Co-authored-by: root <root@t1v-n-822696b7-w-0.us-central2-b.c.tpu-prod-env-large-adhoc.internal>
Signed-off-by: Louis Ulmer <[email protected]>
hyeygit added a commit to hyeygit/vllm that referenced this pull request Apr 9, 2025
Previously we found that using torch.topk resulted in significant
speed up for TPU. Turns out that's not a viable solution because the
return shape of torch.topk depends on k, which means an XLA recompilation
is triggered everytime k changes.

Additionally, we realized that torch.scatter was the main bottleneck for
the original top-k impl on TPU. This PR circumvents both problems by using
a threshold-based approach to find the top-k set. The algorithm is nearly
identical to that of top-p; see vllm-project#15736 for more details.

Signed-off-by: Hyesoo Yang <[email protected]>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Apr 11, 2025
Previously we found that using torch.topk resulted in significant
speed up for TPU. Turns out that's not a viable solution because the
return shape of torch.topk depends on k, which means an XLA recompilation
is triggered everytime k changes.

Additionally, we realized that torch.scatter was the main bottleneck for
the original top-k impl on TPU. This PR circumvents both problems by using
a threshold-based approach to find the top-k set. The algorithm is nearly
identical to that of top-p; see vllm-project#15736 for more details.

Signed-off-by: Hyesoo Yang <[email protected]>
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
…lm-project#15736)

Signed-off-by: Hyesoo Yang <[email protected]>
Co-authored-by: root <root@t1v-n-822696b7-w-0.us-central2-b.c.tpu-prod-env-large-adhoc.internal>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
…lm-project#15736)

Signed-off-by: Hyesoo Yang <[email protected]>
Co-authored-by: root <root@t1v-n-822696b7-w-0.us-central2-b.c.tpu-prod-env-large-adhoc.internal>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…lm-project#15736)

Signed-off-by: Hyesoo Yang <[email protected]>
Co-authored-by: root <root@t1v-n-822696b7-w-0.us-central2-b.c.tpu-prod-env-large-adhoc.internal>
Signed-off-by: Mu Huai <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants