Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
import pprint
import time
from collections.abc import Callable, Sequence
from collections.abc import Callable, Generator, Sequence
from contextlib import contextmanager
from copy import deepcopy
from functools import partial
Expand Down Expand Up @@ -90,7 +90,7 @@ class CompilerManager:
support int as key.
"""

def __init__(self, compilation_config: CompilationConfig):
def __init__(self, compilation_config: CompilationConfig) -> None:
self.cache: dict[tuple[Range, int, str], Any] = dict()
self.is_cache_updated = False
self.compilation_config = compilation_config
Expand All @@ -100,7 +100,7 @@ def compute_hash(self, vllm_config: VllmConfig) -> str:
return self.compiler.compute_hash(vllm_config)

@contextmanager
def compile_context(self, compile_range: Range):
def compile_context(self, compile_range: Range) -> Generator[None, None, None]:
"""Provide compilation context for the duration of compilation to set
any torch global properties we want to scope to a single Inductor
compilation (e.g. partition rules, pass context)."""
Expand All @@ -115,7 +115,7 @@ def compile_context(self, compile_range: Range):

def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
) -> None:
"""
Initialize the cache directory for the compiler.

Expand Down Expand Up @@ -143,7 +143,7 @@ def initialize_cache(
# do not use eval(), it is unsafe.
cache = ast.literal_eval(f.read())

def check_type(value, ty):
def check_type(value: Any, ty: type) -> None:
if not isinstance(value, ty):
raise TypeError(f"Expected {ty} but got {type(value)} for {value}")

Expand All @@ -165,7 +165,7 @@ def parse_key(key: Any) -> tuple[Range, int, str]:
cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix
)

def save_to_file(self):
def save_to_file(self) -> None:
if self.disable_cache or not self.is_cache_updated:
return
printer = pprint.PrettyPrinter(indent=4)
Expand Down Expand Up @@ -198,7 +198,7 @@ def load(
def compile(
self,
graph: fx.GraphModule,
example_inputs,
example_inputs: list[Any],
additional_inductor_config,
compilation_config: CompilationConfig,
compile_range: Range,
Expand Down Expand Up @@ -373,7 +373,7 @@ def __init__(
compile_submod_names: list[str],
vllm_config: VllmConfig,
vllm_backend: "VllmBackend",
):
) -> None:
super().__init__(module)
from torch._guards import detect_fake_mode

Expand All @@ -385,7 +385,7 @@ def __init__(
# When True, it annoyingly dumps the torch.fx.Graph on errors.
self.extra_traceback = False

def run(self, *args):
def run(self, *args: Any) -> Any:
# maybe instead just assert inputs are fake?
fake_args = [
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
Expand Down Expand Up @@ -467,7 +467,7 @@ def call_module(


@contextmanager
def set_model_tag(tag: str, is_encoder: bool = False):
def set_model_tag(tag: str, is_encoder: bool = False) -> Generator[None, None, None]:
"""Context manager to set the model tag."""
global model_tag
global model_is_encoder
Expand Down Expand Up @@ -521,7 +521,7 @@ def __init__(
vllm_config: VllmConfig,
prefix: str = "",
is_encoder: bool = False,
):
) -> None:
# if the model is initialized with a non-empty prefix,
# then usually it's enough to use that prefix,
# e.g. language_model, vision_model, etc.
Expand Down Expand Up @@ -558,7 +558,7 @@ def __init__(
# `torch.compile` is JIT compiled, so we don't need to
# do anything here

def configure_post_pass(self):
def configure_post_pass(self) -> None:
self.pass_manager.configure(self.vllm_config)

# Post-grad custom passes are run using the post_grad_custom_post_pass
Expand All @@ -580,7 +580,7 @@ def configure_post_pass(self):
self.inductor_config[self.pass_key] = self.pass_manager

def __call__(
self, graph: fx.GraphModule, example_inputs
self, graph: fx.GraphModule, example_inputs: Sequence[Any]
) -> VllmSerializableFunction:
vllm_config = self.vllm_config
# Minimal hashing here with existing utilities, reused below.
Expand Down
14 changes: 7 additions & 7 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@


class BasePattern:
def __init__(self, dtype: torch.dtype, device: str):
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
self.dtype = dtype
self.device = device
self.tp = get_tp_group()
Expand Down Expand Up @@ -637,7 +637,7 @@ def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
super().__init__(dtype, device)
Expand Down Expand Up @@ -692,7 +692,7 @@ def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
super().__init__(dtype, device)
Expand Down Expand Up @@ -759,7 +759,7 @@ def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
super().__init__(dtype, device)
Expand Down Expand Up @@ -828,7 +828,7 @@ def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
super().__init__(dtype, device)
Expand Down Expand Up @@ -902,7 +902,7 @@ def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
super().__init__(dtype, device)
Expand Down Expand Up @@ -988,7 +988,7 @@ def __init__(
self,
epsilon: float,
dtype: torch.dtype,
device: str,
device: str | None,
allreduce_params: FlashInferFusedAllReduceParams,
):
super().__init__(dtype, device)
Expand Down
48 changes: 24 additions & 24 deletions vllm/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class CompilerInterface:

def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
) -> None:
"""
when the vLLM process uses `cache_dir` as the cache directory,
the compiler should initialize itself with the cache directory,
Expand Down Expand Up @@ -66,7 +66,7 @@ def compile(
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
) -> tuple[Callable[..., Any] | None, Any | None]:
"""
Compile the graph with the given example inputs and compiler config,
with a range. The `compile_range` specifies the range of the inputs,
Expand Down Expand Up @@ -100,7 +100,7 @@ def load(
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable:
) -> Callable[..., Any]:
"""
Load the compiled function from the handle.
Raises an error if the handle is invalid.
Expand Down Expand Up @@ -138,13 +138,13 @@ class AlwaysHitShapeEnv:
def __init__(self) -> None:
self.guards: list[Any] = []

def evaluate_guards_expression(self, *args, **kwargs):
def evaluate_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[True]:
return True

def get_pruned_guards(self, *args, **kwargs):
def get_pruned_guards(self, *args: Any, **kwargs: Any) -> list[Any]:
return []

def produce_guards_expression(self, *args, **kwargs):
def produce_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[""]:
return ""


Expand Down Expand Up @@ -193,7 +193,7 @@ class InductorStandaloneAdaptor(CompilerInterface):

name = "inductor_standalone"

def __init__(self, save_format: Literal["binary", "unpacked"]):
def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
self.save_format = save_format

def compute_hash(self, vllm_config: VllmConfig) -> str:
Expand All @@ -205,7 +205,7 @@ def compute_hash(self, vllm_config: VllmConfig) -> str:

def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
) -> None:
self.cache_dir = cache_dir

def compile(
Expand All @@ -215,7 +215,7 @@ def compile(
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_inductor_compiles += 1
current_config = {}
if compiler_config is not None:
Expand Down Expand Up @@ -252,7 +252,7 @@ def load(
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable:
) -> Callable[..., Any]:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
Expand All @@ -264,7 +264,7 @@ def load(

returns_tuple = graph_returns_tuple(graph)

def compiled_graph_wrapper(*args):
def compiled_graph_wrapper(*args: Any) -> tuple[Any, ...] | Any:
graph_output = inductor_compiled_graph(*args)
# unpack the tuple if needed
# TODO(rzou): the implication is that we're not
Expand Down Expand Up @@ -293,7 +293,7 @@ def compute_hash(self, vllm_config: VllmConfig) -> str:

def initialize_cache(
self, cache_dir: str, disable_cache: bool = False, prefix: str = ""
):
) -> None:
self.cache_dir = cache_dir
self.prefix = prefix
self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir
Expand All @@ -317,7 +317,7 @@ def compile(
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_inductor_compiles += 1
from torch._inductor.compile_fx import compile_fx

Expand Down Expand Up @@ -348,7 +348,7 @@ def compile(
original_load = FxGraphCache.load
original_load_name = "torch._inductor.codecache.FxGraphCache.load"

def hijack_load(*args, **kwargs):
def hijack_load(*args: Any, **kwargs: Any) -> Any:
inductor_compiled_graph = original_load(*args, **kwargs)
nonlocal file_path
compiled_fn = inductor_compiled_graph.current_callable
Expand All @@ -375,7 +375,7 @@ def hijack_load(*args, **kwargs):
# function renamed in 2.6
original_load_name = None

def hijacked_compile_fx_inner(*args, **kwargs):
def hijacked_compile_fx_inner(*args: Any, **kwargs: Any) -> Any:
output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs)
nonlocal hash_str
inductor_compiled_graph = output
Expand All @@ -401,13 +401,13 @@ def hijacked_compile_fx_inner(*args, **kwargs):
hash_str = inductor_compiled_graph._fx_graph_cache_key
return output

def hijack_compiled_fx_graph_hash(*args, **kwargs):
def hijack_compiled_fx_graph_hash(*args: Any, **kwargs: Any) -> Any:
out = compiled_fx_graph_hash(*args, **kwargs)
nonlocal hash_str
hash_str = out[0]
return out

def _check_can_cache(*args, **kwargs):
def _check_can_cache(*args: Any, **kwargs: Any) -> None:
# no error means it can be cached.
# Inductor refuses to cache the graph outside of Dynamo
# tracing context, and also disables caching for graphs
Expand Down Expand Up @@ -513,7 +513,7 @@ def load(
example_inputs: list[Any],
graph_index: int,
compile_range: Range,
) -> Callable:
) -> Callable[..., Any]:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
Expand Down Expand Up @@ -572,7 +572,7 @@ def load(
returns_tuple = graph_returns_tuple(graph)

# this is the callable we return to Dynamo to run
def compiled_graph(*args):
def compiled_graph(*args: Any) -> tuple[Any, ...] | Any:
# convert args to list
list_args = list(args)
graph_output = inductor_compiled_graph(list_args)
Expand All @@ -584,7 +584,7 @@ def compiled_graph(*args):

return compiled_graph

def metrics_context(self) -> contextlib.AbstractContextManager:
def metrics_context(self) -> contextlib.AbstractContextManager[Any]:
"""
This method returns the Dynamo metrics context (if it exists,
otherwise a null context). It is used by various compile components.
Expand All @@ -603,12 +603,12 @@ def metrics_context(self) -> contextlib.AbstractContextManager:
if is_torch_equal_or_newer("2.6"):
import torch._dynamo.utils

return torch._dynamo.utils.get_metrics_context()
return torch._dynamo.utils.get_metrics_context() # type: ignore[no-any-return]
else:
return contextlib.nullcontext()


def set_inductor_config(config, compile_range: Range):
def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None:
if compile_range.is_single_size():
# for a specific batch size, tuning triton kernel parameters
# can be beneficial
Expand All @@ -618,7 +618,7 @@ def set_inductor_config(config, compile_range: Range):
)


def set_functorch_config():
def set_functorch_config() -> None:
torch._functorch.config.bundled_autograd_cache = False


Expand All @@ -632,7 +632,7 @@ def compile(
compiler_config: dict[str, Any],
compile_range: Range,
key: str | None = None,
) -> tuple[Callable | None, Any | None]:
) -> tuple[Callable[..., Any] | None, Any | None]:
compilation_counter.num_eager_compiles += 1
# we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle.
Expand Down
4 changes: 3 additions & 1 deletion vllm/compilation/counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import copy
import dataclasses
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any


@dataclasses.dataclass
Expand Down Expand Up @@ -34,7 +36,7 @@ def clone(self) -> "CompilationCounter":
return copy.deepcopy(self)

@contextmanager
def expect(self, **kwargs):
def expect(self, **kwargs: Any) -> Generator[None, None, None]:
old = self.clone()
yield
for k, v in kwargs.items():
Expand Down
Loading