-
-
Notifications
You must be signed in to change notification settings - Fork 11.2k
[V1][TPU] TPU-optimized top-p implementation (avoids scattering). #15736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
d739e5e
1cfef37
f023979
85952b6
71d8cad
78836ca
e055c96
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| import math | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from vllm.platforms import current_platform | ||
| from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu | ||
|
|
||
| if not current_platform.is_tpu(): | ||
| pytest.skip("This test needs a TPU.", allow_module_level=True) | ||
| import torch_xla.core.xla_model as xm | ||
|
|
||
| BATCH_SIZE = 1024 | ||
| VOCAB_SIZE = 128 * 1024 | ||
| TOLERANCE = 1e-6 | ||
|
|
||
|
|
||
| def test_topp_result_sums_past_p(): | ||
| with torch.device(xm.xla_device()): | ||
| xm.set_rng_state(seed=33) | ||
|
|
||
| logits = torch.rand((BATCH_SIZE, VOCAB_SIZE)) | ||
| probs = logits.softmax(dim=-1) | ||
|
|
||
| # Random top-p values between 0 and 1. | ||
| p = torch.rand((BATCH_SIZE, )) | ||
|
|
||
| # Set p=1 for ~50% of requests in the batch (top-p disabled). | ||
| p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), 1) | ||
|
|
||
| no_op_k = torch.tensor([VOCAB_SIZE]) | ||
| logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(), | ||
| k=no_op_k, | ||
| p=p) | ||
|
|
||
| # Verify that the masked logit's probability sums to at least p. | ||
| probs.masked_fill_(logits_masked.isinf(), 0) | ||
| masked_prob_sum = probs.sum(dim=-1) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Usually it's a good idea to put a
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
|
||
| xm.mark_step() | ||
|
|
||
| # Perform assertion on CPU. | ||
| assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu())) | ||
|
|
||
|
|
||
| def test_topp_basic(): | ||
| with torch.device(xm.xla_device()): | ||
| logits = torch.tensor([[math.log(0.2), | ||
| math.log(0.3), | ||
| math.log(0.5)], | ||
| [math.log(0.5), | ||
| math.log(0.1), | ||
| math.log(0.4)]]) | ||
|
|
||
| result = apply_top_k_top_p_tpu(logits=logits.clone(), | ||
| k=torch.tensor([3, 3]), | ||
| p=torch.tensor([0.79, 0.79])) | ||
|
|
||
| xm.mark_step() | ||
|
|
||
| # Expect the smallest elements to be dropped. | ||
| expected_result = logits.clone().cpu() | ||
| expected_result[0, 0] = float("-inf") | ||
| expected_result[1, 1] = float("-inf") | ||
| assert torch.allclose(expected_result, result.cpu()) | ||
|
|
||
|
|
||
| def test_topp_select_all(): | ||
| with torch.device(xm.xla_device()): | ||
| logits = torch.tensor([[math.log(0.2), | ||
| math.log(0.3), | ||
| math.log(0.5)], | ||
| [math.log(0.5), | ||
| math.log(0.1), | ||
| math.log(0.4)]]) | ||
|
|
||
| result = apply_top_k_top_p_tpu(logits=logits.clone(), | ||
| k=torch.tensor([3, 3]), | ||
| p=torch.tensor([1.0, 1.0])) | ||
|
|
||
| xm.mark_step() | ||
|
|
||
| assert torch.allclose(logits.cpu(), result.cpu()) | ||
|
|
||
|
|
||
| def test_topp_with_ties(): | ||
| with torch.device(xm.xla_device()): | ||
| # Input has multiple math.log(0.3). | ||
| logits = torch.tensor( | ||
| [[math.log(0.3), | ||
| math.log(0.3), | ||
| math.log(0.3), | ||
| math.log(0.1)]]) | ||
|
|
||
| result = apply_top_k_top_p_tpu(logits=logits.clone(), | ||
| k=torch.tensor([4]), | ||
| p=torch.tensor([0.2])) | ||
|
|
||
| xm.mark_step() | ||
|
|
||
| # All tie values are included in the top-p set. Tie breaking is left | ||
| # to be done during final sampling (all tie tokens have equal | ||
| # probability of being chosen). | ||
| expected_result = logits.clone().cpu() | ||
| expected_result[0, 3] = float("-inf") | ||
| assert torch.allclose(expected_result, result.cpu()) | ||
|
|
||
|
|
||
| def test_both_topk_topp(): | ||
| with torch.device(xm.xla_device()): | ||
| logits = torch.tensor([[math.log(0.2), | ||
| math.log(0.3), | ||
| math.log(0.5)], | ||
| [math.log(0.5), | ||
| math.log(0.1), | ||
| math.log(0.4)]]) | ||
|
|
||
| # Set k=1 for the first batch. | ||
| result = apply_top_k_top_p_tpu(logits=logits.clone(), | ||
| k=torch.tensor([1, 3]), | ||
| p=torch.tensor([0.79, 0.79])) | ||
|
|
||
| xm.mark_step() | ||
|
|
||
| # Since for the first batch k=1, expect only the largest element gets | ||
| # selected. | ||
| expected_result = logits.clone().cpu() | ||
| expected_result[0, 0] = float("-inf") | ||
| expected_result[0, 1] = float("-inf") | ||
| expected_result[1, 1] = float("-inf") | ||
| assert torch.allclose(expected_result, result.cpu()) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -122,23 +122,48 @@ def forward_tpu( | |
| k: Optional[torch.Tensor], | ||
| p: Optional[torch.Tensor], | ||
| ) -> torch.Tensor: | ||
| # If only top-k is specified, use pytorch's builtin topk op. This leads | ||
| # to significant speed up on TPU compared to using apply_top_k_top_p. | ||
| if k is not None and p is None: | ||
| topk_values, topk_indices = torch.topk(logits, k, dim=-1) | ||
|
|
||
| mask = torch.ones_like(logits, dtype=torch.bool) | ||
| mask.scatter_(-1, topk_indices, False) | ||
| logits.masked_fill_(mask, float('-inf')) | ||
| else: | ||
| # TODO Placeholder for TPU optimized topp kernel | ||
| # logits = apply_top_k_top_p(logits, k, p) | ||
| pass | ||
|
|
||
| logits = apply_top_k_top_p_tpu(logits, k, p) | ||
|
||
| probs = logits.softmax(dim=-1, dtype=torch.float32) | ||
| return random_sample(probs, generators) | ||
|
|
||
|
|
||
| def apply_top_k_top_p_tpu( | ||
| logits: torch.Tensor, | ||
| k: torch.Tensor, | ||
| p: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Apply top-k and top-p optimized for TPU. | ||
|
|
||
| This algorithm avoids using torch.scatter which is extremely slow on TPU. | ||
| This is achieved by finding a "cut-off" element in the original logit, and | ||
| after thresholding the logit using this cut-off, the remaining elements | ||
| shall constitute the top-p set. | ||
|
|
||
| Note: in the case of tie (i.e. multipple cut-off elements present in the | ||
| logit), all tie elements are included in the top-p set. In other words, | ||
| this function does not break ties. Instead, these tie tokens have equal | ||
| chance of being chosen during final sampling, so we can consider the tie | ||
| being broken then. | ||
| """ | ||
| if k is not None: | ||
| logits = apply_top_k_only(logits, k) | ||
|
|
||
| if p is not None: | ||
| probs = logits.softmax(dim=-1) | ||
| probs_sort, _ = probs.sort(dim=-1, descending=False) | ||
| cumprob = torch.cumsum(probs_sort, dim=-1) | ||
| top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) | ||
| top_p_mask[:, -1] = False # at least one | ||
|
|
||
| top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) | ||
| top_p_cutoff = probs_sort.gather(-1, top_p_count) | ||
| elements_to_discard = probs < top_p_cutoff | ||
| logits.masked_fill_(elements_to_discard, -float("inf")) | ||
|
|
||
| return logits | ||
|
|
||
|
|
||
| def apply_top_k_top_p( | ||
| logits: torch.Tensor, | ||
| k: Optional[torch.Tensor], | ||
|
|
@@ -201,6 +226,8 @@ def apply_top_k_only( | |
| # Convert top k to 0-based index in range [0, max_top_k). | ||
| k_index = k.sub_(1).unsqueeze(1) | ||
| top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long()) | ||
| k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @hyeygit is this because of a TPU torch broadcasting limitation?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I think so. Without the explicit expand this fails on XLA due to shape mismatch. |
||
| top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long()) | ||
| # Handle non-topk rows. | ||
| top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf")) | ||
| logits.masked_fill_(logits < top_k_mask, -float("inf")) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add this test to
run-tpu-v1-test.sh?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.