Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .buildkite/run-tpu-v1-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ docker run --privileged --net host --shm-size=16G -it \
&& echo TEST_6 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
&& echo TEST_7 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py \
&& echo TEST_8 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py" \


# TODO: This test fails because it uses RANDOM_SEED sampling
Expand Down
132 changes: 132 additions & 0 deletions tests/v1/tpu/test_topk_topp_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add this test to run-tpu-v1-test.sh?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

import math

import pytest
import torch

from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu

if not current_platform.is_tpu():
pytest.skip("This test needs a TPU.", allow_module_level=True)
import torch_xla.core.xla_model as xm

BATCH_SIZE = 1024
VOCAB_SIZE = 128 * 1024
TOLERANCE = 1e-6


def test_topp_result_sums_past_p():
with torch.device(xm.xla_device()):
xm.set_rng_state(seed=33)

logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
probs = logits.softmax(dim=-1)

# Random top-p values between 0 and 1.
p = torch.rand((BATCH_SIZE, ))

# Set p=1 for ~50% of requests in the batch (top-p disabled).
p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), 1)

no_op_k = torch.tensor([VOCAB_SIZE])
logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(),
k=no_op_k,
p=p)

# Verify that the masked logit's probability sums to at least p.
probs.masked_fill_(logits_masked.isinf(), 0)
masked_prob_sum = probs.sum(dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually it's a good idea to put a xm.mark_step before the assert to clear the graph

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


xm.mark_step()

# Perform assertion on CPU.
assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu()))


def test_topp_basic():
with torch.device(xm.xla_device()):
logits = torch.tensor([[math.log(0.2),
math.log(0.3),
math.log(0.5)],
[math.log(0.5),
math.log(0.1),
math.log(0.4)]])

result = apply_top_k_top_p_tpu(logits=logits.clone(),
k=torch.tensor([3, 3]),
p=torch.tensor([0.79, 0.79]))

xm.mark_step()

# Expect the smallest elements to be dropped.
expected_result = logits.clone().cpu()
expected_result[0, 0] = float("-inf")
expected_result[1, 1] = float("-inf")
assert torch.allclose(expected_result, result.cpu())


def test_topp_select_all():
with torch.device(xm.xla_device()):
logits = torch.tensor([[math.log(0.2),
math.log(0.3),
math.log(0.5)],
[math.log(0.5),
math.log(0.1),
math.log(0.4)]])

result = apply_top_k_top_p_tpu(logits=logits.clone(),
k=torch.tensor([3, 3]),
p=torch.tensor([1.0, 1.0]))

xm.mark_step()

assert torch.allclose(logits.cpu(), result.cpu())


def test_topp_with_ties():
with torch.device(xm.xla_device()):
# Input has multiple math.log(0.3).
logits = torch.tensor(
[[math.log(0.3),
math.log(0.3),
math.log(0.3),
math.log(0.1)]])

result = apply_top_k_top_p_tpu(logits=logits.clone(),
k=torch.tensor([4]),
p=torch.tensor([0.2]))

xm.mark_step()

# All tie values are included in the top-p set. Tie breaking is left
# to be done during final sampling (all tie tokens have equal
# probability of being chosen).
expected_result = logits.clone().cpu()
expected_result[0, 3] = float("-inf")
assert torch.allclose(expected_result, result.cpu())


def test_both_topk_topp():
with torch.device(xm.xla_device()):
logits = torch.tensor([[math.log(0.2),
math.log(0.3),
math.log(0.5)],
[math.log(0.5),
math.log(0.1),
math.log(0.4)]])

# Set k=1 for the first batch.
result = apply_top_k_top_p_tpu(logits=logits.clone(),
k=torch.tensor([1, 3]),
p=torch.tensor([0.79, 0.79]))

xm.mark_step()

# Since for the first batch k=1, expect only the largest element gets
# selected.
expected_result = logits.clone().cpu()
expected_result[0, 0] = float("-inf")
expected_result[0, 1] = float("-inf")
expected_result[1, 1] = float("-inf")
assert torch.allclose(expected_result, result.cpu())
53 changes: 40 additions & 13 deletions vllm/v1/sample/ops/topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,23 +122,48 @@ def forward_tpu(
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
# If only top-k is specified, use pytorch's builtin topk op. This leads
# to significant speed up on TPU compared to using apply_top_k_top_p.
if k is not None and p is None:
topk_values, topk_indices = torch.topk(logits, k, dim=-1)

mask = torch.ones_like(logits, dtype=torch.bool)
mask.scatter_(-1, topk_indices, False)
logits.masked_fill_(mask, float('-inf'))
else:
# TODO Placeholder for TPU optimized topp kernel
# logits = apply_top_k_top_p(logits, k, p)
pass

logits = apply_top_k_top_p_tpu(logits, k, p)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @hyeygit , this seems reasonable to me in the interim but I'll let the other folks chime in on the appropriateness. Cc @yaochengji @yarongmu-google

Given this can slightly impact the generated output during ties, this really feels like something we should be warning the user about. Not every time the function is called, of course, but at a minimum, we should be warning users when the argument is set to anything other than the default. I couldn't find a warning log but I'm also on my phone, so apologies if I just missed it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @brittrock . I added a one-time log message about this algorithm being approx in theory.

In practice, I think the tiny 1e-9 probability perturbation doesn't alter the result in any meaningful way. The only situation where the output differs from the exact algo is if there are multiple tokens whose probabilities are within 1e-9 (one in a billion) of each other. This means they practically have the same probability, so including either one of them in the top-p set should be acceptable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for adding!

I agree in practice, probably ok, but this could break accuracy tests and so good idea to include in any case.

nice job, again!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I signed in from my phone and must have created another github account >.<

ignore my alter ego's request for review @hyeygit 😆

Copy link
Contributor Author

@hyeygit hyeygit Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree in practice, probably ok, but this could break accuracy tests and so good idea to include in any case.

Agreed, makes sense!

ignore my alter ego's request for review @hyeygit 😆

Haha no worries!

probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)


def apply_top_k_top_p_tpu(
logits: torch.Tensor,
k: torch.Tensor,
p: torch.Tensor,
) -> torch.Tensor:
"""
Apply top-k and top-p optimized for TPU.

This algorithm avoids using torch.scatter which is extremely slow on TPU.
This is achieved by finding a "cut-off" element in the original logit, and
after thresholding the logit using this cut-off, the remaining elements
shall constitute the top-p set.

Note: in the case of tie (i.e. multipple cut-off elements present in the
logit), all tie elements are included in the top-p set. In other words,
this function does not break ties. Instead, these tie tokens have equal
chance of being chosen during final sampling, so we can consider the tie
being broken then.
"""
if k is not None:
logits = apply_top_k_only(logits, k)

if p is not None:
probs = logits.softmax(dim=-1)
probs_sort, _ = probs.sort(dim=-1, descending=False)
cumprob = torch.cumsum(probs_sort, dim=-1)
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
top_p_mask[:, -1] = False # at least one

top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
top_p_cutoff = probs_sort.gather(-1, top_p_count)
elements_to_discard = probs < top_p_cutoff
logits.masked_fill_(elements_to_discard, -float("inf"))

return logits


def apply_top_k_top_p(
logits: torch.Tensor,
k: Optional[torch.Tensor],
Expand Down Expand Up @@ -201,6 +226,8 @@ def apply_top_k_only(
# Convert top k to 0-based index in range [0, max_top_k).
k_index = k.sub_(1).unsqueeze(1)
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
k_index = k.sub_(1).unsqueeze(1).expand(logits.shape[0], 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hyeygit is this because of a TPU torch broadcasting limitation?

Copy link
Contributor Author

@hyeygit hyeygit Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think so. Without the explicit expand this fails on XLA due to shape mismatch.

top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
# Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
logits.masked_fill_(logits < top_k_mask, -float("inf"))
Expand Down