diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py index 76b8ddb92b78..4e5a57bee327 100644 --- a/tests/v1/tpu/test_sampler.py +++ b/tests/v1/tpu/test_sampler.py @@ -39,7 +39,7 @@ def test_sampler_compilation(model_name: str, monkeypatch): sampling_params = SamplingParams( temperature=0.7, # top_p=0.6, # TODO too slow! - # top_k=10, + top_k=10, min_p=0.2, max_tokens=16) s = time() @@ -49,6 +49,7 @@ def test_sampler_compilation(model_name: str, monkeypatch): # Second request with different params, but for which we # compiled for in previous eager iteration. sampling_params = SamplingParams(temperature=0.1, + top_k=12, min_p=0.8, max_tokens=24) s = time() diff --git a/vllm/envs.py b/vllm/envs.py index 56bf86267476..d88ab3b5e7d0 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -95,6 +95,7 @@ VLLM_DP_MASTER_PORT: int = 0 VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False + VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False def get_default_cache_root(): @@ -623,6 +624,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # an environment with potentially malicious users. "VLLM_V0_USE_OUTLINES_CACHE": lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1", + + # If set, disables TPU-specific optimization for top-k & top-p sampling + "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION": + lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"])) + if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None, } # end-env-vars-definition diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index e1a3e92de493..1dea711874bf 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -66,7 +66,14 @@ def __init__(self): "best performance, please install FlashInfer.") self.forward = self.forward_native elif current_platform.is_tpu(): - self.forward = self.forward_tpu + if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: + logger.warning( + "TPU-specific optimization for top-k & top-p sampling are " + "disabled, falling back to PyTorch-native implementation " + "which could be very slow.") + self.forward = self.forward_native + else: + self.forward = self.forward_tpu else: self.forward = self.forward_native @@ -105,8 +112,19 @@ def forward_tpu( k: Optional[torch.Tensor], p: Optional[torch.Tensor], ) -> torch.Tensor: - # TODO Placeholder for TPU optimized topk/p kernel - # logits = apply_top_k_top_p(logits, k, p) + # If only top-k is specified, use pytorch's builtin topk op. This leads + # to significant speed up on TPU compared to using apply_top_k_top_p. + if k is not None and p is None: + topk_values, topk_indices = torch.topk(logits, k, dim=-1) + + mask = torch.ones_like(logits, dtype=torch.bool) + mask.scatter_(-1, topk_indices, False) + logits.masked_fill_(mask, float('-inf')) + else: + # TODO Placeholder for TPU optimized topp kernel + # logits = apply_top_k_top_p(logits, k, p) + pass + probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators)