From d739e5e0d0a7ec81391363395ac8626e447abdea Mon Sep 17 00:00:00 2001 From: root Date: Fri, 28 Mar 2025 18:44:22 +0000 Subject: [PATCH 1/7] [V1][TPU] Add a TPU-optimized top-p algorithm. Top-k and top-p are slow on TPU because existing algorithms use torch.scatter. For some reason torch.scatter is extremely slow on TPU. There's ongoing work to optimize it, but until that's done, we need an alternative algorithm that circumvents scattering. The algorithm in this PR avoids using torch.scatter by finding a "cut-off" element in the original logit, and after thresholding the logit using this cut-off, the remaining elements shall constitute the top-p set. A caveat of the above approach is that ties are not correctly handled -- if there are duplicate cutoff elements present in the logit, then the resulting top-p set will be incorrect. To address this problem, we introduce a tiny perturbation to the probabilities (after softmax) to break any potential ties. The added perturbation is tiny so it should not alter the end results significantly, but it still makes this algorithm approximate rather than an exact one. Signed-off-by: Hyesoo Yang --- tests/v1/tpu/test_topk_topp_sampler.py | 120 ++++++++++++++++++++++++ vllm/v1/sample/ops/topk_topp_sampler.py | 74 ++++++++++++--- 2 files changed, 181 insertions(+), 13 deletions(-) create mode 100644 tests/v1/tpu/test_topk_topp_sampler.py diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py new file mode 100644 index 000000000000..3447ca01b367 --- /dev/null +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +import math + +import torch + +from vllm.platforms import current_platform +from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu + +if current_platform.is_tpu(): + import torch_xla.core.xla_model as xm + +DEVICE = xm.xla_device() if current_platform.is_tpu() else torch.device("cuda") + +BATCH_SIZE = 1024 +VOCAB_SIZE = 128 * 1024 + + +def test_topk_and_no_op_topp(): + with torch.device(DEVICE): + if current_platform.is_tpu(): + xm.set_rng_state(seed=33) + else: + torch.manual_seed(33) + + logits = torch.rand((BATCH_SIZE, VOCAB_SIZE)) + + # Random top-k values between 1 and 9. + k = torch.randint(1, 10, (BATCH_SIZE, )) + + # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). + k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), + VOCAB_SIZE) + + # Top-k only implementation + result1 = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None) + + # Top-p + top-k + no_op_top_p = torch.tensor([1.0]) + result2 = apply_top_k_top_p_tpu(logits=logits.clone(), + k=k, + p=no_op_top_p) + + assert torch.allclose(result1, result2) + + +def test_topp_basic(): + with torch.device(DEVICE): + logits = torch.tensor([[math.log(0.2), + math.log(0.3), + math.log(0.5)], + [math.log(0.5), + math.log(0.1), + math.log(0.4)]]) + + result = apply_top_k_top_p_tpu(logits=logits.clone(), + k=torch.tensor([3, 3]), + p=torch.tensor([0.79, 0.79])) + + # Expect the smallest elements to be dropped. + expected_result = logits.clone() + expected_result[0, 0] = float("-inf") + expected_result[1, 1] = float("-inf") + assert torch.allclose(expected_result, result) + + +def test_topp_select_all(): + with torch.device(DEVICE): + logits = torch.tensor([[math.log(0.2), + math.log(0.3), + math.log(0.5)], + [math.log(0.5), + math.log(0.1), + math.log(0.4)]]) + + result = apply_top_k_top_p_tpu(logits=logits.clone(), + k=torch.tensor([3, 3]), + p=torch.tensor([1.0, 1.0])) + + assert torch.allclose(logits, result) + + +def test_topp_with_ties(): + with torch.device(DEVICE): + # Input has multiple math.log(0.3). + logits = torch.tensor( + [[math.log(0.3), + math.log(0.3), + math.log(0.3), + math.log(0.1)]]) + + result = apply_top_k_top_p_tpu(logits=logits.clone(), + k=torch.tensor([4]), + p=torch.tensor([0.2])) + + # Expect math.log(0.3) to be the only selected element. + expected_result = torch.tensor([math.log(0.3)]) + assert torch.allclose(expected_result, result[result.isfinite()]) + + +def test_both_topk_topp(): + with torch.device(DEVICE): + logits = torch.tensor([[math.log(0.2), + math.log(0.3), + math.log(0.5)], + [math.log(0.5), + math.log(0.1), + math.log(0.4)]]) + + # Set k=1 for the first batch. + result = apply_top_k_top_p_tpu(logits=logits.clone(), + k=torch.tensor([1, 3]), + p=torch.tensor([0.79, 0.79])) + + # Since for the first batch k=1, expect only the largest element gets + # selected. + expected_result = logits.clone() + expected_result[0, 0] = float("-inf") + expected_result[0, 1] = float("-inf") + expected_result[1, 1] = float("-inf") + assert torch.allclose(expected_result, result) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index d4bc23364c57..bd1aa3eda9c8 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -79,6 +79,10 @@ def __init__(self): "which could be very slow.") self.forward = self.forward_native else: + logger.info( + "Using approximate top-p optimized for TPU. Result may in " + "theory differ from the exact algorithm if there are " + "tokens with near-identical probabilities (< 1e-9 diff).") self.forward = self.forward_tpu else: self.forward = self.forward_native @@ -122,23 +126,65 @@ def forward_tpu( k: Optional[torch.Tensor], p: Optional[torch.Tensor], ) -> torch.Tensor: - # If only top-k is specified, use pytorch's builtin topk op. This leads - # to significant speed up on TPU compared to using apply_top_k_top_p. - if k is not None and p is None: - topk_values, topk_indices = torch.topk(logits, k, dim=-1) - - mask = torch.ones_like(logits, dtype=torch.bool) - mask.scatter_(-1, topk_indices, False) - logits.masked_fill_(mask, float('-inf')) - else: - # TODO Placeholder for TPU optimized topp kernel - # logits = apply_top_k_top_p(logits, k, p) - pass - + logits = apply_top_k_top_p_tpu(logits, k, p) probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators) +def apply_top_k_top_p_tpu( + logits: torch.Tensor, + k: torch.Tensor, + p: torch.Tensor, +) -> torch.Tensor: + if k is not None: + logits = apply_top_k_only(logits, k) + + if p is not None: + logits = apply_approx_top_p(logits, p) + + return logits + + +def apply_approx_top_p( + logits: torch.Tensor, + p: torch.Tensor, +) -> torch.Tensor: + """ + Apply approximate top-p that is optimized for TPU. + + This algorithm avoids using torch.scatter which is extremely slow on TPU. + This is achieved by finding a "cut-off" element in the original logit, and + after thresholding the logit using this cut-off, the remaining elements + shall constitute the top-p set. + + A caveat of the above approach is that ties are not correctly handled -- + if there are duplicate cutoff elements present in the logit, then the + resulting top-p set will be incorrect. To address this problem, we + introduce a tiny perturbation to the probabilities (after softmax) to + break any potential ties. The added perturbation is tiny so it should + not alter the end results significantly, but it still makes this algorithm + approximate rather than an exact one. + """ + probs = logits.softmax(dim=-1) + + # Add a small, random perturbation to the probabilities, and re-normalize. + epsilon = torch.empty(probs.shape).uniform_(-1e-9, 1e-9) + probs += epsilon + probs /= probs.sum(dim=-1, keepdim=True) + + probs_sort, sorted_idx = probs.sort(dim=-1, descending=False) + cumprob = torch.cumsum(probs_sort, dim=-1) + top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = False # at least one + + top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) + top_p_cutoff = probs_sort.gather(-1, top_p_count) + elements_to_discard = probs < top_p_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) + + return logits + + def apply_top_k_top_p( logits: torch.Tensor, k: Optional[torch.Tensor], @@ -201,6 +247,8 @@ def apply_top_k_only( # Convert top k to 0-based index in range [0, max_top_k). k_index = k.sub_(1).unsqueeze(1) top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long()) + k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1) + top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long()) # Handle non-topk rows. top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf")) logits.masked_fill_(logits < top_k_mask, -float("inf")) From 1cfef373fc292476c56c1b5c175ac1589abb1ccb Mon Sep 17 00:00:00 2001 From: Hyesoo Yang Date: Mon, 31 Mar 2025 14:44:35 +0000 Subject: [PATCH 2/7] Bugfix: Make sure uniform sampling happens on XLA not CPU. Previously the uniform sampling happens on CPU and resulted in a slowdown. (Running 32 elasped time is 39 ms). Moving the sampling to XLA sped things up significantly, the benchmark time is down to 5 ms. Signed-off-by: Hyesoo Yang --- vllm/v1/sample/ops/topk_topp_sampler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index bd1aa3eda9c8..c413e87634fb 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -168,7 +168,8 @@ def apply_approx_top_p( probs = logits.softmax(dim=-1) # Add a small, random perturbation to the probabilities, and re-normalize. - epsilon = torch.empty(probs.shape).uniform_(-1e-9, 1e-9) + epsilon = torch.empty(probs.shape, + device=logits.device).uniform_(-1e-9, 1e-9) probs += epsilon probs /= probs.sum(dim=-1, keepdim=True) From f02397906664397ab33b803b9cbacf30b809272b Mon Sep 17 00:00:00 2001 From: Hyesoo Yang Date: Mon, 31 Mar 2025 15:52:34 +0000 Subject: [PATCH 3/7] Update tests. * Added a more comprehensive correctness test for top-p. * Included tests/v1/tpu/test_topk_topp_sampler.py in run-tpu-v1-test.sh. Signed-off-by: Hyesoo Yang --- .buildkite/run-tpu-v1-test.sh | 4 +- tests/v1/tpu/test_topk_topp_sampler.py | 53 ++++++++++++-------------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/.buildkite/run-tpu-v1-test.sh b/.buildkite/run-tpu-v1-test.sh index 4aac57cca94c..5b7ce9a7677e 100755 --- a/.buildkite/run-tpu-v1-test.sh +++ b/.buildkite/run-tpu-v1-test.sh @@ -36,7 +36,9 @@ docker run --privileged --net host --shm-size=16G -it \ && echo TEST_6 \ && pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \ && echo TEST_7 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \ + && pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py \ + && echo TEST_8 \ + && pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py" \ # TODO: This test fails because it uses RANDOM_SEED sampling diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py index 3447ca01b367..fdee53d844bf 100644 --- a/tests/v1/tpu/test_topk_topp_sampler.py +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -1,50 +1,47 @@ # SPDX-License-Identifier: Apache-2.0 import math +import pytest import torch from vllm.platforms import current_platform from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu -if current_platform.is_tpu(): - import torch_xla.core.xla_model as xm - -DEVICE = xm.xla_device() if current_platform.is_tpu() else torch.device("cuda") +if not current_platform.is_tpu(): + pytest.skip("This test needs a TPU.", allow_module_level=True) +import torch_xla.core.xla_model as xm BATCH_SIZE = 1024 VOCAB_SIZE = 128 * 1024 +TOLERANCE = 1e-4 -def test_topk_and_no_op_topp(): - with torch.device(DEVICE): - if current_platform.is_tpu(): - xm.set_rng_state(seed=33) - else: - torch.manual_seed(33) +def test_topp_result_sums_past_p(): + with torch.device(xm.xla_device()): + xm.set_rng_state(seed=33) logits = torch.rand((BATCH_SIZE, VOCAB_SIZE)) + probs = logits.softmax(dim=-1) - # Random top-k values between 1 and 9. - k = torch.randint(1, 10, (BATCH_SIZE, )) - - # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). - k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), - VOCAB_SIZE) + # Random top-p values between 0 and 1. + p = torch.rand((BATCH_SIZE, )) - # Top-k only implementation - result1 = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None) + # Set p=1 for ~50% of requests in the batch (top-p disabled). + p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), 1) - # Top-p + top-k - no_op_top_p = torch.tensor([1.0]) - result2 = apply_top_k_top_p_tpu(logits=logits.clone(), - k=k, - p=no_op_top_p) + no_op_k = torch.tensor([VOCAB_SIZE]) + logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(), + k=no_op_k, + p=p) - assert torch.allclose(result1, result2) + # Verify that the masked logit's probability sums to at least p. + probs.masked_fill_(logits_masked.isinf(), 0) + masked_prob_sum = probs.sum(dim=-1) + assert torch.all(torch.ge(masked_prob_sum + TOLERANCE, p)) def test_topp_basic(): - with torch.device(DEVICE): + with torch.device(xm.xla_device()): logits = torch.tensor([[math.log(0.2), math.log(0.3), math.log(0.5)], @@ -64,7 +61,7 @@ def test_topp_basic(): def test_topp_select_all(): - with torch.device(DEVICE): + with torch.device(xm.xla_device()): logits = torch.tensor([[math.log(0.2), math.log(0.3), math.log(0.5)], @@ -80,7 +77,7 @@ def test_topp_select_all(): def test_topp_with_ties(): - with torch.device(DEVICE): + with torch.device(xm.xla_device()): # Input has multiple math.log(0.3). logits = torch.tensor( [[math.log(0.3), @@ -98,7 +95,7 @@ def test_topp_with_ties(): def test_both_topk_topp(): - with torch.device(DEVICE): + with torch.device(xm.xla_device()): logits = torch.tensor([[math.log(0.2), math.log(0.3), math.log(0.5)], From 85952b6f26fdfdbc0fa3fe069af1e23677cb23b5 Mon Sep 17 00:00:00 2001 From: Hyesoo Yang Date: Mon, 31 Mar 2025 23:35:46 +0000 Subject: [PATCH 4/7] Change tie-breaking behavior * Do not break ties. Instead, include all tied tokens in the return set and leave tie breaking to the final sampling stage (since all tie tokens will have equal probability of being chosen). * Removed random perturbation. * Removed warning regarding the algorithm being approx. * Edited tests. Signed-off-by: Hyesoo Yang --- tests/v1/tpu/test_topk_topp_sampler.py | 11 +++-- vllm/v1/sample/ops/topk_topp_sampler.py | 62 ++++++++----------------- 2 files changed, 27 insertions(+), 46 deletions(-) diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py index fdee53d844bf..e3f29fc513dc 100644 --- a/tests/v1/tpu/test_topk_topp_sampler.py +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -13,7 +13,7 @@ BATCH_SIZE = 1024 VOCAB_SIZE = 128 * 1024 -TOLERANCE = 1e-4 +TOLERANCE = 1e-6 def test_topp_result_sums_past_p(): @@ -89,9 +89,12 @@ def test_topp_with_ties(): k=torch.tensor([4]), p=torch.tensor([0.2])) - # Expect math.log(0.3) to be the only selected element. - expected_result = torch.tensor([math.log(0.3)]) - assert torch.allclose(expected_result, result[result.isfinite()]) + # All tie values are included in the top-p set. Tie breaking is left + # to be done during final sampling (all tie tokens have equal + # probability of being chosen). + expected_result = logits.clone() + expected_result[0, 3] = float("-inf") + assert torch.allclose(expected_result, result) def test_both_topk_topp(): diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index c413e87634fb..8041cccd2fcc 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -79,10 +79,6 @@ def __init__(self): "which could be very slow.") self.forward = self.forward_native else: - logger.info( - "Using approximate top-p optimized for TPU. Result may in " - "theory differ from the exact algorithm if there are " - "tokens with near-identical probabilities (< 1e-9 diff).") self.forward = self.forward_tpu else: self.forward = self.forward_native @@ -135,53 +131,35 @@ def apply_top_k_top_p_tpu( logits: torch.Tensor, k: torch.Tensor, p: torch.Tensor, -) -> torch.Tensor: - if k is not None: - logits = apply_top_k_only(logits, k) - - if p is not None: - logits = apply_approx_top_p(logits, p) - - return logits - - -def apply_approx_top_p( - logits: torch.Tensor, - p: torch.Tensor, ) -> torch.Tensor: """ - Apply approximate top-p that is optimized for TPU. + Apply top-k and top-p optimized for TPU. This algorithm avoids using torch.scatter which is extremely slow on TPU. This is achieved by finding a "cut-off" element in the original logit, and after thresholding the logit using this cut-off, the remaining elements shall constitute the top-p set. - A caveat of the above approach is that ties are not correctly handled -- - if there are duplicate cutoff elements present in the logit, then the - resulting top-p set will be incorrect. To address this problem, we - introduce a tiny perturbation to the probabilities (after softmax) to - break any potential ties. The added perturbation is tiny so it should - not alter the end results significantly, but it still makes this algorithm - approximate rather than an exact one. + Note: in the case of tie (i.e. multipple cut-off elements present in the + logit), all tie elements are included in the top-p set. In other words, + this function does not break ties. Instead, these tie tokens have equal + chance of being chosen during final sampling, so we can consider the tie + being broken then. """ - probs = logits.softmax(dim=-1) - - # Add a small, random perturbation to the probabilities, and re-normalize. - epsilon = torch.empty(probs.shape, - device=logits.device).uniform_(-1e-9, 1e-9) - probs += epsilon - probs /= probs.sum(dim=-1, keepdim=True) - - probs_sort, sorted_idx = probs.sort(dim=-1, descending=False) - cumprob = torch.cumsum(probs_sort, dim=-1) - top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) - top_p_mask[:, -1] = False # at least one - - top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) - top_p_cutoff = probs_sort.gather(-1, top_p_count) - elements_to_discard = probs < top_p_cutoff - logits.masked_fill_(elements_to_discard, -float("inf")) + if k is not None: + logits = apply_top_k_only(logits, k) + + if p is not None: + probs = logits.softmax(dim=-1) + probs_sort, _ = probs.sort(dim=-1, descending=False) + cumprob = torch.cumsum(probs_sort, dim=-1) + top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) + top_p_mask[:, -1] = False # at least one + + top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) + top_p_cutoff = probs_sort.gather(-1, top_p_count) + elements_to_discard = probs < top_p_cutoff + logits.masked_fill_(elements_to_discard, -float("inf")) return logits From 71d8cad22644aa061c94f194abe0e6067be54abd Mon Sep 17 00:00:00 2001 From: Hyesoo Yang Date: Wed, 2 Apr 2025 00:30:47 +0000 Subject: [PATCH 5/7] Update tests to perform assertion on CPU. Signed-off-by: Hyesoo Yang --- tests/v1/tpu/test_topk_topp_sampler.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py index e3f29fc513dc..eb236bc56e34 100644 --- a/tests/v1/tpu/test_topk_topp_sampler.py +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -34,6 +34,8 @@ def test_topp_result_sums_past_p(): k=no_op_k, p=p) + xm.mark_step() + with torch.device("cpu"): # Verify that the masked logit's probability sums to at least p. probs.masked_fill_(logits_masked.isinf(), 0) masked_prob_sum = probs.sum(dim=-1) @@ -53,6 +55,8 @@ def test_topp_basic(): k=torch.tensor([3, 3]), p=torch.tensor([0.79, 0.79])) + xm.mark_step() + with torch.device("cpu"): # Expect the smallest elements to be dropped. expected_result = logits.clone() expected_result[0, 0] = float("-inf") @@ -73,6 +77,8 @@ def test_topp_select_all(): k=torch.tensor([3, 3]), p=torch.tensor([1.0, 1.0])) + xm.mark_step() + with torch.device("cpu"): assert torch.allclose(logits, result) @@ -89,6 +95,8 @@ def test_topp_with_ties(): k=torch.tensor([4]), p=torch.tensor([0.2])) + xm.mark_step() + with torch.device("cpu"): # All tie values are included in the top-p set. Tie breaking is left # to be done during final sampling (all tie tokens have equal # probability of being chosen). @@ -111,6 +119,8 @@ def test_both_topk_topp(): k=torch.tensor([1, 3]), p=torch.tensor([0.79, 0.79])) + xm.mark_step() + with torch.device("cpu"): # Since for the first batch k=1, expect only the largest element gets # selected. expected_result = logits.clone() From 78836ca7d10aaa59eb1c114299a197c2b0f27279 Mon Sep 17 00:00:00 2001 From: Hyesoo Yang Date: Wed, 2 Apr 2025 18:19:25 +0000 Subject: [PATCH 6/7] Update tests to use explicit .cpu() calls. Signed-off-by: Hyesoo Yang --- tests/v1/tpu/test_topk_topp_sampler.py | 54 +++++++++++++------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py index eb236bc56e34..dce0303e68d5 100644 --- a/tests/v1/tpu/test_topk_topp_sampler.py +++ b/tests/v1/tpu/test_topk_topp_sampler.py @@ -34,12 +34,14 @@ def test_topp_result_sums_past_p(): k=no_op_k, p=p) - xm.mark_step() - with torch.device("cpu"): # Verify that the masked logit's probability sums to at least p. probs.masked_fill_(logits_masked.isinf(), 0) masked_prob_sum = probs.sum(dim=-1) - assert torch.all(torch.ge(masked_prob_sum + TOLERANCE, p)) + + xm.mark_step() + + # Perform assertion on CPU. + assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu())) def test_topp_basic(): @@ -56,12 +58,12 @@ def test_topp_basic(): p=torch.tensor([0.79, 0.79])) xm.mark_step() - with torch.device("cpu"): - # Expect the smallest elements to be dropped. - expected_result = logits.clone() - expected_result[0, 0] = float("-inf") - expected_result[1, 1] = float("-inf") - assert torch.allclose(expected_result, result) + + # Expect the smallest elements to be dropped. + expected_result = logits.clone().cpu() + expected_result[0, 0] = float("-inf") + expected_result[1, 1] = float("-inf") + assert torch.allclose(expected_result, result.cpu()) def test_topp_select_all(): @@ -78,8 +80,8 @@ def test_topp_select_all(): p=torch.tensor([1.0, 1.0])) xm.mark_step() - with torch.device("cpu"): - assert torch.allclose(logits, result) + + assert torch.allclose(logits.cpu(), result.cpu()) def test_topp_with_ties(): @@ -96,13 +98,13 @@ def test_topp_with_ties(): p=torch.tensor([0.2])) xm.mark_step() - with torch.device("cpu"): - # All tie values are included in the top-p set. Tie breaking is left - # to be done during final sampling (all tie tokens have equal - # probability of being chosen). - expected_result = logits.clone() - expected_result[0, 3] = float("-inf") - assert torch.allclose(expected_result, result) + + # All tie values are included in the top-p set. Tie breaking is left + # to be done during final sampling (all tie tokens have equal + # probability of being chosen). + expected_result = logits.clone().cpu() + expected_result[0, 3] = float("-inf") + assert torch.allclose(expected_result, result.cpu()) def test_both_topk_topp(): @@ -120,11 +122,11 @@ def test_both_topk_topp(): p=torch.tensor([0.79, 0.79])) xm.mark_step() - with torch.device("cpu"): - # Since for the first batch k=1, expect only the largest element gets - # selected. - expected_result = logits.clone() - expected_result[0, 0] = float("-inf") - expected_result[0, 1] = float("-inf") - expected_result[1, 1] = float("-inf") - assert torch.allclose(expected_result, result) + + # Since for the first batch k=1, expect only the largest element gets + # selected. + expected_result = logits.clone().cpu() + expected_result[0, 0] = float("-inf") + expected_result[0, 1] = float("-inf") + expected_result[1, 1] = float("-inf") + assert torch.allclose(expected_result, result.cpu()) From e055c96da156996a495549c87418ef6ccebcd5b1 Mon Sep 17 00:00:00 2001 From: Hyesoo Yang Date: Wed, 2 Apr 2025 19:55:55 +0000 Subject: [PATCH 7/7] Fix merge mistake Signed-off-by: Hyesoo Yang --- vllm/v1/sample/ops/topk_topp_sampler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 8041cccd2fcc..f69623edd632 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -224,8 +224,6 @@ def apply_top_k_only( max_top_k = k.max() # topk.values tensor has shape [batch_size, max_top_k]. # Convert top k to 0-based index in range [0, max_top_k). - k_index = k.sub_(1).unsqueeze(1) - top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long()) k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1) top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long()) # Handle non-topk rows.