diff --git a/.buildkite/run-tpu-v1-test.sh b/.buildkite/run-tpu-v1-test.sh index 4aac57cca94c..5b7ce9a7677e 100755 --- a/.buildkite/run-tpu-v1-test.sh +++ b/.buildkite/run-tpu-v1-test.sh @@ -36,7 +36,9 @@ docker run --privileged --net host --shm-size=16G -it \ && echo TEST_6 \ && pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \ && echo TEST_7 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \ + && pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py \ + && echo TEST_8 \ + && pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py" \ # TODO: This test fails because it uses RANDOM_SEED sampling diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py new file mode 100644 index 000000000000..dce0303e68d5 --- /dev/null +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 +import math + +import pytest +import torch + +from vllm.platforms import current_platform +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu + +if not current_platform.is_tpu(): + pytest.skip("This test needs a TPU.", allow_module_level=True) +import torch_xla.core.xla_model as xm + +BATCH_SIZE = 1024 +VOCAB_SIZE = 128 * 1024 +TOLERANCE = 1e-6 + + +def test_topp_result_sums_past_p(): + with torch.device(xm.xla_device()): + xm.set_rng_state(seed=33) + + logits = torch.rand((BATCH_SIZE, VOCAB_SIZE)) + probs = logits.softmax(dim=-1) + + # Random top-p values between 0 and 1. + p = torch.rand((BATCH_SIZE, )) + + # Set p=1 for ~50% of requests in the batch (top-p disabled). + p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), 1) + + no_op_k = torch.tensor([VOCAB_SIZE]) + logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(), + k=no_op_k, + p=p) + + # Verify that the masked logit's probability sums to at least p. + probs.masked_fill_(logits_masked.isinf(), 0) + masked_prob_sum = probs.sum(dim=-1) + + xm.mark_step() + + # Perform assertion on CPU. + assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu())) + + +def test_topp_basic(): + with torch.device(xm.xla_device()): + logits = torch.tensor([[math.log(0.2), + math.log(0.3), + math.log(0.5)], + [math.log(0.5), + math.log(0.1), + math.log(0.4)]]) + + result = apply_top_k_top_p_tpu(logits=logits.clone(), + k=torch.tensor([3, 3]), + p=torch.tensor([0.79, 0.79])) + + xm.mark_step() + + # Expect the smallest elements to be dropped. + expected_result = logits.clone().cpu() + expected_result[0, 0] = float("-inf") + expected_result[1, 1] = float("-inf") + assert torch.allclose(expected_result, result.cpu()) + + +def test_topp_select_all(): + with torch.device(xm.xla_device()): + logits = torch.tensor([[math.log(0.2), + math.log(0.3), + math.log(0.5)], + [math.log(0.5), + math.log(0.1), + math.log(0.4)]]) + + result = apply_top_k_top_p_tpu(logits=logits.clone(), + k=torch.tensor([3, 3]), + p=torch.tensor([1.0, 1.0])) + + xm.mark_step() + + assert torch.allclose(logits.cpu(), result.cpu()) + + +def test_topp_with_ties(): + with torch.device(xm.xla_device()): + # Input has multiple math.log(0.3). + logits = torch.tensor( + [[math.log(0.3), + math.log(0.3), + math.log(0.3), + math.log(0.1)]]) + + result = apply_top_k_top_p_tpu(logits=logits.clone(), + k=torch.tensor([4]), + p=torch.tensor([0.2])) + + xm.mark_step() + + # All tie values are included in the top-p set. Tie breaking is left + # to be done during final sampling (all tie tokens have equal + # probability of being chosen). + expected_result = logits.clone().cpu() + expected_result[0, 3] = float("-inf") + assert torch.allclose(expected_result, result.cpu()) + + +def test_both_topk_topp(): + with torch.device(xm.xla_device()): + logits = torch.tensor([[math.log(0.2), + math.log(0.3), + math.log(0.5)], + [math.log(0.5), + math.log(0.1), + math.log(0.4)]]) + + # Set k=1 for the first batch. + result = apply_top_k_top_p_tpu(logits=logits.clone(), + k=torch.tensor([1, 3]), + p=torch.tensor([0.79, 0.79])) + + xm.mark_step() + + # Since for the first batch k=1, expect only the largest element gets + # selected. + expected_result = logits.clone().cpu() + expected_result[0, 0] = float("-inf") + expected_result[0, 1] = float("-inf") + expected_result[1, 1] = float("-inf") + assert torch.allclose(expected_result, result.cpu()) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index d4bc23364c57..f69623edd632 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -122,23 +122,48 @@ def forward_tpu( k: Optional[torch.Tensor], p: Optional[torch.Tensor], ) -> torch.Tensor: - # If only top-k is specified, use pytorch's builtin topk op. This leads - # to significant speed up on TPU compared to using apply_top_k_top_p. - if k is not None and p is None: - topk_values, topk_indices = torch.topk(logits, k, dim=-1) - - mask = torch.ones_like(logits, dtype=torch.bool) - mask.scatter_(-1, topk_indices, False) - logits.masked_fill_(mask, float('-inf')) - else: - # TODO Placeholder for TPU optimized topp kernel - # logits = apply_top_k_top_p(logits, k, p) - pass - + logits = apply_top_k_top_p_tpu(logits, k, p) probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators) +def apply_top_k_top_p_tpu( + logits: torch.Tensor, + k: torch.Tensor, + p: torch.Tensor, +) -> torch.Tensor: + """ + Apply top-k and top-p optimized for TPU. + + This algorithm avoids using torch.scatter which is extremely slow on TPU. + This is achieved by finding a "cut-off" element in the original logit, and + after thresholding the logit using this cut-off, the remaining elements + shall constitute the top-p set. + + Note: in the case of tie (i.e. multipple cut-off elements present in the + logit), all tie elements are included in the top-p set. In other words, + this function does not break ties. Instead, these tie tokens have equal + chance of being chosen during final sampling, so we can consider the tie + being broken then. + """ + if k is not None: + logits = apply_top_k_only(logits, k) + + if p is not None: + probs = logits.softmax(dim=-1) + probs_sort, _ = probs.sort(dim=-1, descending=False) + cumprob = torch.cumsum(probs_sort, dim=-1) + top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = False # at least one + + top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) + top_p_cutoff = probs_sort.gather(-1, top_p_count) + elements_to_discard = probs < top_p_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + return logits + + def apply_top_k_top_p( logits: torch.Tensor, k: Optional[torch.Tensor], @@ -199,7 +224,7 @@ def apply_top_k_only( max_top_k = k.max() # topk.values tensor has shape [batch_size, max_top_k]. # Convert top k to 0-based index in range [0, max_top_k). - k_index = k.sub_(1).unsqueeze(1) + k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1) top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long()) # Handle non-topk rows. top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))