Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def load(
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable | None:
) -> Callable[..., Any] | None:
if (compile_range, graph_index, self.compiler.name) not in self.cache:
return None
handle = self.cache[(compile_range, graph_index, self.compiler.name)]
Expand All @@ -199,7 +199,7 @@ def compile(
self,
graph: fx.GraphModule,
example_inputs: list[Any],
additional_inductor_config,
additional_inductor_config: dict[str, Any],
compilation_config: CompilationConfig,
compile_range: Range,
graph_index: int = 0,
Expand Down Expand Up @@ -355,7 +355,7 @@ def split_graph(
compilation_start_time = 0.0


class PiecewiseCompileInterpreter(torch.fx.Interpreter):
class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given graph with fake inputs, and compile some
submodules specified by `compile_submod_names` with the given
Expand Down Expand Up @@ -506,9 +506,9 @@ class VllmBackend:
# the stiching graph module for all the piecewise graphs
split_gm: fx.GraphModule
piecewise_graphs: list[SplitItem]
returned_callable: Callable
returned_callable: Callable[..., Any]
# Inductor passes to run on the graph pre-defunctionalization
post_grad_passes: Sequence[Callable]
post_grad_passes: Sequence[Callable[..., Any]]
sym_tensor_indices: list[int]
input_buffers: list[torch.Tensor]
compiler_manager: CompilerManager
Expand Down Expand Up @@ -821,7 +821,7 @@ def __call__(
]

# this is the callable we return to Dynamo to run
def copy_and_call(*args):
def copy_and_call(*args: Any) -> Any:
list_args = list(args)
for i, index in enumerate(self.sym_tensor_indices):
runtime_tensor = list_args[index]
Expand Down
23 changes: 16 additions & 7 deletions vllm/compilation/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import inspect
import os
import pickle
from collections.abc import Callable, Sequence
from typing import Any, Literal
from unittest.mock import patch

import torch
Expand All @@ -25,7 +27,7 @@
logger = init_logger(__name__)


class VllmSerializableFunction(SerializableCallable):
class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
"""
A wrapper around a compiled function by vllm. It will forward the tensor
inputs to the compiled function and return the result.
Expand All @@ -38,8 +40,13 @@ class VllmSerializableFunction(SerializableCallable):
"""

def __init__(
self, graph_module, example_inputs, prefix, optimized_call, is_encoder=False
):
self,
graph_module: torch.fx.GraphModule,
example_inputs: Sequence[Any],
prefix: str,
optimized_call: Callable[..., Any],
is_encoder: bool = False,
) -> None:
assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module
self.example_inputs = example_inputs
Expand All @@ -53,7 +60,7 @@ def __init__(
if sym_input is not None:
self.shape_env = sym_input.node.shape_env

def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.optimized_call(*args, **kwargs)

@classmethod
Expand All @@ -73,7 +80,9 @@ def serialize_compile_artifacts(

graph_reducer_override = GraphPickler.reducer_override

def _graph_reducer_override(self, obj):
def _graph_reducer_override(
self: GraphPickler, obj: Any
) -> tuple[Callable[..., Any], tuple[Any, ...]] | Any:
if (
inspect.isclass(obj)
and issubclass(obj, sympy.Function)
Expand Down Expand Up @@ -114,7 +123,7 @@ def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction
get_current_vllm_config(), state["prefix"], is_encoder
)

def optimized_call(*example_inputs):
def optimized_call(*example_inputs: Any) -> Any:
"""
On the first run of the optimized call, we rerun the compiler
backend which should result in a cache hit. After the backend
Expand All @@ -136,7 +145,7 @@ def optimized_call(*example_inputs):
return fn

@property
def co_name(self):
def co_name(self) -> Literal["VllmSerializableFunction"]:
"""
Used for depyf debugging.
"""
Expand Down
22 changes: 12 additions & 10 deletions vllm/compilation/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ class CUDAGraphLogging:
"Count",
]

def __init__(self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None):
def __init__(
self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None
) -> None:
self.reset()
self.cg_mode = str(cg_mode)
self.cg_capture_sizes = str(cg_capture_sizes or [])
Expand All @@ -54,10 +56,10 @@ def __init__(self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None):
"**CUDAGraph Stats:**\n\n"
)

def reset(self):
self.stats = []
def reset(self) -> None:
self.stats: list[CUDAGraphStat] = []

def observe(self, cudagraph_stat: CUDAGraphStat):
def observe(self, cudagraph_stat: CUDAGraphStat) -> None:
self.stats.append(cudagraph_stat)

def generate_metric_table(self) -> str:
Expand Down Expand Up @@ -109,7 +111,7 @@ def generate_metric_table(self) -> str:
+ "\n"
)

def log(self, log_fn=logger.info):
def log(self, log_fn: Callable[..., Any] = logger.info) -> None:
if not self.stats:
return
log_fn(self.generate_metric_table())
Expand Down Expand Up @@ -161,11 +163,11 @@ class CUDAGraphWrapper:

def __init__(
self,
runnable: Callable,
runnable: Callable[..., Any],
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
cudagraph_options: CUDAGraphOptions | None = None,
):
) -> None:
self.runnable = runnable
self.vllm_config = vllm_config
self.runtime_mode = runtime_mode
Expand All @@ -189,7 +191,7 @@ def __init__(
# cudagraphs for.
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {}

def __getattr__(self, key: str):
def __getattr__(self, key: str) -> Any:
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
return getattr(self.runnable, key)
Expand All @@ -198,11 +200,11 @@ def __getattr__(self, key: str):
f"cudagraph wrapper: {self.runnable}"
)

def unwrap(self) -> Callable:
def unwrap(self) -> Callable[..., Any]:
# in case we need to access the original runnable.
return self.runnable

def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
Expand Down
62 changes: 42 additions & 20 deletions vllm/compilation/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import inspect
import os
import sys
from collections.abc import Callable
from typing import TypeVar, overload
from collections.abc import Callable, Generator
from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload
from unittest.mock import patch

import torch
Expand All @@ -32,6 +32,14 @@

from .monitor import start_monitoring_torch_compile

if TYPE_CHECKING:
# Only added on nightly/2.10 so wrap
try:
from torch._dynamo.package import SourceInfo
except ImportError:
# Fallback for old versions not supporting
SourceInfo = Any

logger = init_logger(__name__)

IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
Expand Down Expand Up @@ -59,7 +67,7 @@ def ignore_torch_compile(cls: _T) -> _T:
return cls


def _should_ignore_torch_compile(cls) -> bool:
def _should_ignore_torch_compile(cls: _T) -> bool:
"""
Check if the class should be ignored for torch.compile.
"""
Expand Down Expand Up @@ -224,7 +232,7 @@ def cls_decorator_helper(cls: _T) -> _T:
return cls_decorator_helper


def _model_hash_key(fn) -> str:
def _model_hash_key(fn: Callable[..., Any]) -> str:
import vllm

sha256_hash = hashlib.sha256()
Expand All @@ -234,7 +242,9 @@ def _model_hash_key(fn) -> str:
return sha256_hash.hexdigest()


def _verify_source_unchanged(source_info, vllm_config) -> None:
def _verify_source_unchanged(
source_info: "SourceInfo", vllm_config: VllmConfig
) -> None:
from .caching import _compute_code_hash, _compute_code_hash_with_content

file_contents = {}
Expand Down Expand Up @@ -275,8 +285,12 @@ def _support_torch_compile(
setattr(cls, IGNORE_COMPILE_KEY, False)

def __init__(
self, *, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs
):
self: _T,
*,
vllm_config: VllmConfig | None = None,
prefix: str = "",
**kwargs: Any,
) -> None:
if vllm_config is None:
vllm_config = get_current_vllm_config()

Expand Down Expand Up @@ -309,13 +323,17 @@ def __init__(

compilation_counter.num_models_seen += 1
self.compiled = False
TorchCompileWithNoGuardsWrapper.__init__(self)

# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
TorchCompileWithNoGuardsWrapper.__init__(self) # type: ignore[arg-type]

cls.__init__ = __init__

def _mark_dynamic_inputs(mod, type, *args, **kwargs):
def mark_dynamic(arg, dims):
if type == DynamicShapesType.UNBACKED:
def _mark_dynamic_inputs(
mod: _T, ds_type: DynamicShapesType, *args: Any, **kwargs: Any
) -> None:
def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None:
if ds_type == DynamicShapesType.UNBACKED:
if is_torch_equal_or_newer("2.10.0.dev"):
for dim in dims:
torch._dynamo.decorators.mark_unbacked(
Expand All @@ -326,7 +344,7 @@ def mark_dynamic(arg, dims):
else:
torch._dynamo.mark_dynamic(arg, dims)

sig = inspect.signature(mod.__class__.forward)
sig = inspect.signature(mod.__class__.forward) # type: ignore[attr-defined]
bound_args = sig.bind(mod, *args, **kwargs)
bound_args.apply_defaults()
for k, dims in dynamic_arg_dims.items():
Expand Down Expand Up @@ -364,7 +382,7 @@ def mark_dynamic(arg, dims):
else:
torch._dynamo.decorators.mark_unbacked(arg, dims)

def __call__(self, *args, **kwargs):
def __call__(self: _T, *args: Any, **kwargs: Any) -> Any:
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
Expand Down Expand Up @@ -444,7 +462,7 @@ def __call__(self, *args, **kwargs):
not envs.VLLM_USE_AOT_COMPILE
or self.vllm_config.compilation_config.backend == "eager"
)
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]

# This is the path for the first compilation.
# the first compilation needs to have dynamic shapes marked
Expand Down Expand Up @@ -477,7 +495,7 @@ def __call__(self, *args, **kwargs):
# during Dynamo tracing, and their corresponding files
inline_call = InliningInstructionTranslator.inline_call_

def patched_inline_call(self_):
def patched_inline_call(self_: Any) -> Any:
code = self_.f_code
self.compilation_config.traced_files.add(code.co_filename)
return inline_call(self_)
Expand Down Expand Up @@ -535,7 +553,7 @@ def patched_inline_call(self_):
str(e),
)
else:
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]

self.compiled = True
return output
Expand All @@ -545,7 +563,9 @@ def patched_inline_call(self_):


@contextlib.contextmanager
def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
def maybe_use_cudagraph_partition_wrapper(
vllm_config: VllmConfig,
) -> Generator[None, None, None]:
"""
Context manager to set/unset customized cudagraph partition wrappers.

Expand All @@ -572,7 +592,9 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig):
current_platform.get_static_graph_wrapper_cls()
)

def customized_cudagraph_wrapper(f, metadata: CUDAGraphWrapperMetadata):
def customized_cudagraph_wrapper(
f: Callable[..., Any], metadata: CUDAGraphWrapperMetadata
) -> Any:
partition_id = metadata.partition_index
num_partitions = metadata.num_partitions
return static_graph_wrapper_class(
Expand Down Expand Up @@ -600,7 +622,7 @@ def customized_cudagraph_wrapper(f, metadata: CUDAGraphWrapperMetadata):


@contextlib.contextmanager
def _torch27_patch_tensor_subclasses():
def _torch27_patch_tensor_subclasses() -> Generator[None, None, None]:
"""
Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when
using torch 2.7.0. This enables using weight_loader_v2 and the use of
Expand All @@ -614,7 +636,7 @@ def _torch27_patch_tensor_subclasses():
_ColumnvLLMParameter,
)

def return_false(*args, **kwargs):
def return_false(*args: Any, **kwargs: Any) -> Literal[False]:
return False

if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"):
Expand Down
Loading