diff --git a/vllm/v1/sample/ops/logprobs.py b/vllm/v1/sample/ops/logprobs.py index a4d65485140e..82875b7c8452 100644 --- a/vllm/v1/sample/ops/logprobs.py +++ b/vllm/v1/sample/ops/logprobs.py @@ -4,8 +4,10 @@ import torch +from vllm.platforms import current_platform -@torch.compile(dynamic=True) + +@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def batched_count_greater_than(x: torch.Tensor, values: torch.Tensor) -> torch.Tensor: """