Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm/v1/sample/ops/logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

import torch

from vllm.platforms import current_platform

@torch.compile(dynamic=True)

@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Evaluating current_platform.simple_compile_backend at module import time makes the backend choice static for the lifetime of the process. Consider using lazy compilation to allow backend selection based on runtime parameters.1

def _batched_count_greater_than_impl(x: torch.Tensor,
                                     values: torch.Tensor) -> torch.Tensor:
    """Implementation of batched_count_greater_than."""
    return (x > values[..., None]).count_nonzero(dim=-1)

_cached_compiled_fn = None

def batched_count_greater_than(x: torch.Tensor,
                               values: torch.Tensor) -> torch.Tensor:
    """
    For each row in `x`, counts the number of elements that are greater than
    the corresponding value in `values`.

    Args:
        x: A 2D tensor of shape (num_rows, num_elements).
        values: A 1D tensor of shape (num_rows,).
    """
    global _cached_compiled_fn
    if _cached_compiled_fn is None:
        from vllm.platforms import current_platform
        _cached_compiled_fn = torch.compile(
            dynamic=True,
            backend=current_platform.simple_compile_backend
        )(_batched_count_greater_than_impl)
    
    return _cached_compiled_fn(x, values)

Style Guide References

Footnotes

  1. Use lazy compilation to allow backend selection based on runtime parameters.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using a static compile backend is enough for a specific platform.

def batched_count_greater_than(x: torch.Tensor,
values: torch.Tensor) -> torch.Tensor:
"""
Expand Down