-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[V1][TPU] TPU-optimized top-p implementation (avoids scattering). #15736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
15534fe to
00ab67b
Compare
brittrock
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, thank you for adding this @hyeygit, including the notes on expected speedups. I just have minor usability feedback :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @hyeygit , this seems reasonable to me in the interim but I'll let the other folks chime in on the appropriateness. Cc @yaochengji @yarongmu-google
Given this can slightly impact the generated output during ties, this really feels like something we should be warning the user about. Not every time the function is called, of course, but at a minimum, we should be warning users when the argument is set to anything other than the default. I couldn't find a warning log but I'm also on my phone, so apologies if I just missed it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review @brittrock . I added a one-time log message about this algorithm being approx in theory.
In practice, I think the tiny 1e-9 probability perturbation doesn't alter the result in any meaningful way. The only situation where the output differs from the exact algo is if there are multiple tokens whose probabilities are within 1e-9 (one in a billion) of each other. This means they practically have the same probability, so including either one of them in the top-p set should be acceptable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for adding!
I agree in practice, probably ok, but this could break accuracy tests and so good idea to include in any case.
nice job, again!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I signed in from my phone and must have created another github account >.<
ignore my alter ego's request for review @hyeygit 😆
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree in practice, probably ok, but this could break accuracy tests and so good idea to include in any case.
Agreed, makes sense!
ignore my alter ego's request for review @hyeygit 😆
Haha no worries!
NickLucche
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job here @hyeygit ! Left some comments about tests.
Please remember to enable topp like I've done for k here https://github.com/vllm-project/vllm/pull/15489/files.
Otherwise I can enable both in the same PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if the test is under v1/tpu we shouldn't test for CUDA but skip if platform is not tpu. Otherwise we can move the test into the shared directory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep makes sense. Updated to TPU only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice test!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Credit to @njhill's apply_top_k_only test!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add this test to run-tpu-v1-test.sh?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this top-p (to be used as topp-only) implementation also needed on gpu? @njhill
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC scattering isn't a bottleneck on GPU so this impl wouldn't bring much benefit (plus, this impl still involves a full vocab sort same as the forward_native version).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point
|
This pull request has merge conflicts that must be resolved before it can be |
It might be cleaner to enable top-p in your PR (perhaps rename the PR to Enable both Top-k and Top-p) since it's blocked by this and the to-be-sent-out top-k PR. Let me know if that sounds alright! @NickLucche |
|
@hyeygit Works for me! |
|
Thanks @hyeygit this looks good. However I don't really understand the need for the random perturbation / tiebreaking. The behaviour without doing this will already be effectively the same as an arbitrary tie-break. The random pertubation doesn't even make it deterministic. If we want it to be deterministic we can set Also this implicit tie-breaking behaviour is the same in the existing implementation, I don't see how it's something specific to the new one. Re your related comment on the top_k PR:
Again this is no different to how the pre-existing impl works. And I think it's ok for the actual shortlist to comprise more than k tokens in the case that there's a tie for the kth highest probability. I would not characterize that as "incorrect" since this case is not really well-defined. And it's reasonable from an intuition pov since the tied tokens have equal likelihood. So hopefully this PR can be simplified to remove that part? It would be interesting to also test whether this is meaningfully faster on GPUs. I assume most of the overhead is the sort but if this is even slightly faster we might as well change the existing impl to do the count + mask approach. |
Signed-off-by: Hyesoo Yang <[email protected]>
Signed-off-by: Hyesoo Yang <[email protected]>
yaochengji
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
|
Please see the failing v1 test, it looks like the cuda sampler tests are failing https://buildkite.com/vllm/ci/builds/16819/steps?jid=0195f7bd-4a4f-4517-af78-dc4a6772ba71 |
Signed-off-by: Hyesoo Yang <[email protected]>
| # topk.values tensor has shape [batch_size, max_top_k]. | ||
| # Convert top k to 0-based index in range [0, max_top_k). | ||
| k_index = k.sub_(1).unsqueeze(1) | ||
| k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hyeygit is this because of a TPU torch broadcasting limitation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think so. Without the explicit expand this fails on XLA due to shape mismatch.
Oh this is probably caused by my incorrect rebase -- had some duplicate lines in the sampler. After resolving the conflicts the tests seem to pass. Thanks for the approval! |
Previously we found that using torch.topk resulted in significant speed up for TPU. Turns out that's not a viable solution because the return shape of torch.topk depends on k, which means an XLA recompilation is triggered everytime k changes. Additionally, we realized that torch.scatter was the main bottleneck for the original top-k impl on TPU. This PR circumvents both problems by using a threshold-based approach to find the top-k set. The algorithm is nearly identical to that of top-p; see vllm-project#15736 for more details. Signed-off-by: Hyesoo Yang <[email protected]>
…lm-project#15736) Signed-off-by: Hyesoo Yang <[email protected]> Co-authored-by: root <root@t1v-n-822696b7-w-0.us-central2-b.c.tpu-prod-env-large-adhoc.internal> Signed-off-by: xinyuxiao <[email protected]>
…lm-project#15736) Signed-off-by: Hyesoo Yang <[email protected]> Co-authored-by: root <root@t1v-n-822696b7-w-0.us-central2-b.c.tpu-prod-env-large-adhoc.internal> Signed-off-by: Louis Ulmer <[email protected]>
Previously we found that using torch.topk resulted in significant speed up for TPU. Turns out that's not a viable solution because the return shape of torch.topk depends on k, which means an XLA recompilation is triggered everytime k changes. Additionally, we realized that torch.scatter was the main bottleneck for the original top-k impl on TPU. This PR circumvents both problems by using a threshold-based approach to find the top-k set. The algorithm is nearly identical to that of top-p; see vllm-project#15736 for more details. Signed-off-by: Hyesoo Yang <[email protected]>
Previously we found that using torch.topk resulted in significant speed up for TPU. Turns out that's not a viable solution because the return shape of torch.topk depends on k, which means an XLA recompilation is triggered everytime k changes. Additionally, we realized that torch.scatter was the main bottleneck for the original top-k impl on TPU. This PR circumvents both problems by using a threshold-based approach to find the top-k set. The algorithm is nearly identical to that of top-p; see vllm-project#15736 for more details. Signed-off-by: Hyesoo Yang <[email protected]>
…lm-project#15736) Signed-off-by: Hyesoo Yang <[email protected]> Co-authored-by: root <root@t1v-n-822696b7-w-0.us-central2-b.c.tpu-prod-env-large-adhoc.internal>
…lm-project#15736) Signed-off-by: Hyesoo Yang <[email protected]> Co-authored-by: root <root@t1v-n-822696b7-w-0.us-central2-b.c.tpu-prod-env-large-adhoc.internal>
…lm-project#15736) Signed-off-by: Hyesoo Yang <[email protected]> Co-authored-by: root <root@t1v-n-822696b7-w-0.us-central2-b.c.tpu-prod-env-large-adhoc.internal> Signed-off-by: Mu Huai <[email protected]>
Top-k and top-p are slow on TPU because existing algorithms use
torch.scatter. For some reasontorch.scatteris extremely slow on TPU. There's ongoing work to optimize it, but until that's done, we need an alternative algorithm that circumvents scattering.The algorithm in this PR avoids torch.scatter by finding a "cut-off" element in the original logit, and after thresholding the logit using this cut-off, the remaining elements shall constitute the top-p set. This is inspired by the
apply_top_k_onlyalgorithm created by @njhill in #15478.Benchmark
Microbenchmark (on v6e-1) shows significant speed up -- "Running 32 elapsed time" is ~5 ms, down from the original scatter-based algorithm's ~500 ms, a 100x improvement.
Microbenchmark full results on v6e-1
Extra notes
The
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATIONenv (introduced in #15242) can now be removed. Not done in this PR since @NickLucche's pending PR #15489 already handles it (thanks!).