diff --git a/vllm/v1/sample/ops/topk_topp_sampler.py b/vllm/v1/sample/ops/topk_topp_sampler.py index 33f7090e4e3d..dcae8f974f2f 100644 --- a/vllm/v1/sample/ops/topk_topp_sampler.py +++ b/vllm/v1/sample/ops/topk_topp_sampler.py @@ -248,7 +248,7 @@ def apply_top_k_top_p( if p is None and k is None: return logits - if HAS_TRITON and logits.shape[0] >= 8: + if HAS_TRITON and logits.shape[0] >= 8 and logits.is_cuda: return apply_top_k_top_p_triton(logits, k, p) # Use pytorch sort implementation for small batch sizes.