Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/v1/tpu/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_sampler_compilation(model_name: str, monkeypatch):
sampling_params = SamplingParams(
temperature=0.7,
# top_p=0.6, # TODO too slow!
# top_k=10,
top_k=10,
min_p=0.2,
max_tokens=16)
s = time()
Expand All @@ -49,6 +49,7 @@ def test_sampler_compilation(model_name: str, monkeypatch):
# Second request with different params, but for which we
# compiled for in previous eager iteration.
sampling_params = SamplingParams(temperature=0.1,
top_k=12,
min_p=0.8,
max_tokens=24)
s = time()
Expand Down
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
VLLM_DP_MASTER_PORT: int = 0
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this. Can you add a follow up PR to remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will send a follow up PR to remove this once things look stable after merging.



def get_default_cache_root():
Expand Down Expand Up @@ -623,6 +624,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
# an environment with potentially malicious users.
"VLLM_V0_USE_OUTLINES_CACHE":
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",

# If set, disables TPU-specific optimization for top-k & top-p sampling
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
}

# end-env-vars-definition
Expand Down
24 changes: 21 additions & 3 deletions vllm/v1/sample/ops/topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ def __init__(self):
"best performance, please install FlashInfer.")
self.forward = self.forward_native
elif current_platform.is_tpu():
self.forward = self.forward_tpu
if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION:
logger.warning(
"TPU-specific optimization for top-k & top-p sampling are "
"disabled, falling back to PyTorch-native implementation "
"which could be very slow.")
self.forward = self.forward_native
else:
self.forward = self.forward_tpu
else:
self.forward = self.forward_native

Expand Down Expand Up @@ -105,8 +112,19 @@ def forward_tpu(
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
# TODO Placeholder for TPU optimized topk/p kernel
# logits = apply_top_k_top_p(logits, k, p)
# If only top-k is specified, use pytorch's builtin topk op. This leads
# to significant speed up on TPU compared to using apply_top_k_top_p.
if k is not None and p is None:
topk_values, topk_indices = torch.topk(logits, k, dim=-1)

mask = torch.ones_like(logits, dtype=torch.bool)
mask.scatter_(-1, topk_indices, False)
logits.masked_fill_(mask, float('-inf'))
else:
# TODO Placeholder for TPU optimized topp kernel
# logits = apply_top_k_top_p(logits, k, p)
pass

probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)

Expand Down