diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 4f855fc1d7c2..cf66d2277721 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -9,7 +9,7 @@ import os import pprint import time -from collections.abc import Callable, Sequence +from collections.abc import Callable, Generator, Sequence from contextlib import contextmanager from copy import deepcopy from functools import partial @@ -90,7 +90,7 @@ class CompilerManager: support int as key. """ - def __init__(self, compilation_config: CompilationConfig): + def __init__(self, compilation_config: CompilationConfig) -> None: self.cache: dict[tuple[Range, int, str], Any] = dict() self.is_cache_updated = False self.compilation_config = compilation_config @@ -100,7 +100,7 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) @contextmanager - def compile_context(self, compile_range: Range): + def compile_context(self, compile_range: Range) -> Generator[None, None, None]: """Provide compilation context for the duration of compilation to set any torch global properties we want to scope to a single Inductor compilation (e.g. partition rules, pass context).""" @@ -115,7 +115,7 @@ def compile_context(self, compile_range: Range): def initialize_cache( self, cache_dir: str, disable_cache: bool = False, prefix: str = "" - ): + ) -> None: """ Initialize the cache directory for the compiler. @@ -143,7 +143,7 @@ def initialize_cache( # do not use eval(), it is unsafe. cache = ast.literal_eval(f.read()) - def check_type(value, ty): + def check_type(value: Any, ty: type) -> None: if not isinstance(value, ty): raise TypeError(f"Expected {ty} but got {type(value)} for {value}") @@ -165,7 +165,7 @@ def parse_key(key: Any) -> tuple[Range, int, str]: cache_dir=cache_dir, disable_cache=disable_cache, prefix=prefix ) - def save_to_file(self): + def save_to_file(self) -> None: if self.disable_cache or not self.is_cache_updated: return printer = pprint.PrettyPrinter(indent=4) @@ -198,7 +198,7 @@ def load( def compile( self, graph: fx.GraphModule, - example_inputs, + example_inputs: list[Any], additional_inductor_config, compilation_config: CompilationConfig, compile_range: Range, @@ -373,7 +373,7 @@ def __init__( compile_submod_names: list[str], vllm_config: VllmConfig, vllm_backend: "VllmBackend", - ): + ) -> None: super().__init__(module) from torch._guards import detect_fake_mode @@ -385,7 +385,7 @@ def __init__( # When True, it annoyingly dumps the torch.fx.Graph on errors. self.extra_traceback = False - def run(self, *args): + def run(self, *args: Any) -> Any: # maybe instead just assert inputs are fake? fake_args = [ self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t @@ -467,7 +467,7 @@ def call_module( @contextmanager -def set_model_tag(tag: str, is_encoder: bool = False): +def set_model_tag(tag: str, is_encoder: bool = False) -> Generator[None, None, None]: """Context manager to set the model tag.""" global model_tag global model_is_encoder @@ -521,7 +521,7 @@ def __init__( vllm_config: VllmConfig, prefix: str = "", is_encoder: bool = False, - ): + ) -> None: # if the model is initialized with a non-empty prefix, # then usually it's enough to use that prefix, # e.g. language_model, vision_model, etc. @@ -558,7 +558,7 @@ def __init__( # `torch.compile` is JIT compiled, so we don't need to # do anything here - def configure_post_pass(self): + def configure_post_pass(self) -> None: self.pass_manager.configure(self.vllm_config) # Post-grad custom passes are run using the post_grad_custom_post_pass @@ -580,7 +580,7 @@ def configure_post_pass(self): self.inductor_config[self.pass_key] = self.pass_manager def __call__( - self, graph: fx.GraphModule, example_inputs + self, graph: fx.GraphModule, example_inputs: Sequence[Any] ) -> VllmSerializableFunction: vllm_config = self.vllm_config # Minimal hashing here with existing utilities, reused below. diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 57bd94c7e8ad..a67d63614297 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -50,7 +50,7 @@ class BasePattern: - def __init__(self, dtype: torch.dtype, device: str): + def __init__(self, dtype: torch.dtype, device: str | None) -> None: self.dtype = dtype self.device = device self.tp = get_tp_group() @@ -637,7 +637,7 @@ def __init__( self, epsilon: float, dtype: torch.dtype, - device: str, + device: str | None, allreduce_params: FlashInferFusedAllReduceParams, ): super().__init__(dtype, device) @@ -692,7 +692,7 @@ def __init__( self, epsilon: float, dtype: torch.dtype, - device: str, + device: str | None, allreduce_params: FlashInferFusedAllReduceParams, ): super().__init__(dtype, device) @@ -759,7 +759,7 @@ def __init__( self, epsilon: float, dtype: torch.dtype, - device: str, + device: str | None, allreduce_params: FlashInferFusedAllReduceParams, ): super().__init__(dtype, device) @@ -828,7 +828,7 @@ def __init__( self, epsilon: float, dtype: torch.dtype, - device: str, + device: str | None, allreduce_params: FlashInferFusedAllReduceParams, ): super().__init__(dtype, device) @@ -902,7 +902,7 @@ def __init__( self, epsilon: float, dtype: torch.dtype, - device: str, + device: str | None, allreduce_params: FlashInferFusedAllReduceParams, ): super().__init__(dtype, device) @@ -988,7 +988,7 @@ def __init__( self, epsilon: float, dtype: torch.dtype, - device: str, + device: str | None, allreduce_params: FlashInferFusedAllReduceParams, ): super().__init__(dtype, device) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index ab56d3561c56..b7cf3614e86e 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -31,7 +31,7 @@ class CompilerInterface: def initialize_cache( self, cache_dir: str, disable_cache: bool = False, prefix: str = "" - ): + ) -> None: """ when the vLLM process uses `cache_dir` as the cache directory, the compiler should initialize itself with the cache directory, @@ -66,7 +66,7 @@ def compile( compiler_config: dict[str, Any], compile_range: Range, key: str | None = None, - ) -> tuple[Callable | None, Any | None]: + ) -> tuple[Callable[..., Any] | None, Any | None]: """ Compile the graph with the given example inputs and compiler config, with a range. The `compile_range` specifies the range of the inputs, @@ -100,7 +100,7 @@ def load( example_inputs: list[Any], graph_index: int, compile_range: Range, - ) -> Callable: + ) -> Callable[..., Any]: """ Load the compiled function from the handle. Raises an error if the handle is invalid. @@ -138,13 +138,13 @@ class AlwaysHitShapeEnv: def __init__(self) -> None: self.guards: list[Any] = [] - def evaluate_guards_expression(self, *args, **kwargs): + def evaluate_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[True]: return True - def get_pruned_guards(self, *args, **kwargs): + def get_pruned_guards(self, *args: Any, **kwargs: Any) -> list[Any]: return [] - def produce_guards_expression(self, *args, **kwargs): + def produce_guards_expression(self, *args: Any, **kwargs: Any) -> Literal[""]: return "" @@ -193,7 +193,7 @@ class InductorStandaloneAdaptor(CompilerInterface): name = "inductor_standalone" - def __init__(self, save_format: Literal["binary", "unpacked"]): + def __init__(self, save_format: Literal["binary", "unpacked"]) -> None: self.save_format = save_format def compute_hash(self, vllm_config: VllmConfig) -> str: @@ -205,7 +205,7 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: def initialize_cache( self, cache_dir: str, disable_cache: bool = False, prefix: str = "" - ): + ) -> None: self.cache_dir = cache_dir def compile( @@ -215,7 +215,7 @@ def compile( compiler_config: dict[str, Any], compile_range: Range, key: str | None = None, - ) -> tuple[Callable | None, Any | None]: + ) -> tuple[Callable[..., Any] | None, Any | None]: compilation_counter.num_inductor_compiles += 1 current_config = {} if compiler_config is not None: @@ -252,7 +252,7 @@ def load( example_inputs: list[Any], graph_index: int, compile_range: Range, - ) -> Callable: + ) -> Callable[..., Any]: assert isinstance(handle, tuple) assert isinstance(handle[0], str) assert isinstance(handle[1], str) @@ -264,7 +264,7 @@ def load( returns_tuple = graph_returns_tuple(graph) - def compiled_graph_wrapper(*args): + def compiled_graph_wrapper(*args: Any) -> tuple[Any, ...] | Any: graph_output = inductor_compiled_graph(*args) # unpack the tuple if needed # TODO(rzou): the implication is that we're not @@ -293,7 +293,7 @@ def compute_hash(self, vllm_config: VllmConfig) -> str: def initialize_cache( self, cache_dir: str, disable_cache: bool = False, prefix: str = "" - ): + ) -> None: self.cache_dir = cache_dir self.prefix = prefix self.base_cache_dir = cache_dir[: -len(prefix)] if prefix else cache_dir @@ -317,7 +317,7 @@ def compile( compiler_config: dict[str, Any], compile_range: Range, key: str | None = None, - ) -> tuple[Callable | None, Any | None]: + ) -> tuple[Callable[..., Any] | None, Any | None]: compilation_counter.num_inductor_compiles += 1 from torch._inductor.compile_fx import compile_fx @@ -348,7 +348,7 @@ def compile( original_load = FxGraphCache.load original_load_name = "torch._inductor.codecache.FxGraphCache.load" - def hijack_load(*args, **kwargs): + def hijack_load(*args: Any, **kwargs: Any) -> Any: inductor_compiled_graph = original_load(*args, **kwargs) nonlocal file_path compiled_fn = inductor_compiled_graph.current_callable @@ -375,7 +375,7 @@ def hijack_load(*args, **kwargs): # function renamed in 2.6 original_load_name = None - def hijacked_compile_fx_inner(*args, **kwargs): + def hijacked_compile_fx_inner(*args: Any, **kwargs: Any) -> Any: output = torch._inductor.compile_fx.compile_fx_inner(*args, **kwargs) nonlocal hash_str inductor_compiled_graph = output @@ -401,13 +401,13 @@ def hijacked_compile_fx_inner(*args, **kwargs): hash_str = inductor_compiled_graph._fx_graph_cache_key return output - def hijack_compiled_fx_graph_hash(*args, **kwargs): + def hijack_compiled_fx_graph_hash(*args: Any, **kwargs: Any) -> Any: out = compiled_fx_graph_hash(*args, **kwargs) nonlocal hash_str hash_str = out[0] return out - def _check_can_cache(*args, **kwargs): + def _check_can_cache(*args: Any, **kwargs: Any) -> None: # no error means it can be cached. # Inductor refuses to cache the graph outside of Dynamo # tracing context, and also disables caching for graphs @@ -513,7 +513,7 @@ def load( example_inputs: list[Any], graph_index: int, compile_range: Range, - ) -> Callable: + ) -> Callable[..., Any]: assert isinstance(handle, tuple) assert isinstance(handle[0], str) assert isinstance(handle[1], str) @@ -572,7 +572,7 @@ def load( returns_tuple = graph_returns_tuple(graph) # this is the callable we return to Dynamo to run - def compiled_graph(*args): + def compiled_graph(*args: Any) -> tuple[Any, ...] | Any: # convert args to list list_args = list(args) graph_output = inductor_compiled_graph(list_args) @@ -584,7 +584,7 @@ def compiled_graph(*args): return compiled_graph - def metrics_context(self) -> contextlib.AbstractContextManager: + def metrics_context(self) -> contextlib.AbstractContextManager[Any]: """ This method returns the Dynamo metrics context (if it exists, otherwise a null context). It is used by various compile components. @@ -603,12 +603,12 @@ def metrics_context(self) -> contextlib.AbstractContextManager: if is_torch_equal_or_newer("2.6"): import torch._dynamo.utils - return torch._dynamo.utils.get_metrics_context() + return torch._dynamo.utils.get_metrics_context() # type: ignore[no-any-return] else: return contextlib.nullcontext() -def set_inductor_config(config, compile_range: Range): +def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None: if compile_range.is_single_size(): # for a specific batch size, tuning triton kernel parameters # can be beneficial @@ -618,7 +618,7 @@ def set_inductor_config(config, compile_range: Range): ) -def set_functorch_config(): +def set_functorch_config() -> None: torch._functorch.config.bundled_autograd_cache = False @@ -632,7 +632,7 @@ def compile( compiler_config: dict[str, Any], compile_range: Range, key: str | None = None, - ) -> tuple[Callable | None, Any | None]: + ) -> tuple[Callable[..., Any] | None, Any | None]: compilation_counter.num_eager_compiles += 1 # we don't need to compile the graph, just return the graph itself. # It does not support caching, return None for the handle. diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index 20918099f169..29d3045aac64 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -3,7 +3,9 @@ import copy import dataclasses +from collections.abc import Generator from contextlib import contextmanager +from typing import Any @dataclasses.dataclass @@ -34,7 +36,7 @@ def clone(self) -> "CompilationCounter": return copy.deepcopy(self) @contextmanager - def expect(self, **kwargs): + def expect(self, **kwargs: Any) -> Generator[None, None, None]: old = self.clone() yield for k, v in kwargs.items(): diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 08cae27b1276..fa5dce976f9f 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -219,6 +219,7 @@ def __call__(self, *args, **kwargs): # runtime modes. return self.runnable(*args, **kwargs) + assert batch_descriptor is not None if batch_descriptor not in self.concrete_cudagraph_entries: # create a new entry for this batch descriptor self.concrete_cudagraph_entries[batch_descriptor] = CUDAGraphEntry( diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py index 3650ee6b4174..5c2e7ac93e66 100644 --- a/vllm/compilation/fx_utils.py +++ b/vllm/compilation/fx_utils.py @@ -7,10 +7,11 @@ from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._ops import OpOverload, OpOverloadPacket +from torch.fx.node import Target -def is_func(node: fx.Node, target) -> bool: - return node.op == "call_function" and node.target == target +def is_func(node: fx.Node, target: Target) -> bool: + return bool(node.op == "call_function" and node.target == target) def is_auto_func(node: fx.Node, op: OpOverload) -> bool: diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index dbf154eeb86a..56b4554c88ef 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -8,9 +8,9 @@ import inspect import json import types -from collections.abc import Callable +from collections.abc import Callable, Generator from contextlib import contextmanager -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar import torch from torch import fx @@ -30,6 +30,8 @@ ) _pass_context = None +P = ParamSpec("P") +R = TypeVar("R") class PassContext: @@ -44,7 +46,7 @@ def get_pass_context() -> PassContext: @contextmanager -def pass_context(compile_range: Range): +def pass_context(compile_range: Range) -> Generator[None, None, None]: """A context manager that stores the current pass context, usually it is a list of sizes to specialize. """ @@ -57,7 +59,7 @@ def pass_context(compile_range: Range): _pass_context = prev_context -class InductorPass(CustomGraphPass): +class InductorPass(CustomGraphPass): # type: ignore[misc] """ A custom graph pass that uses a hash of its source as the UUID. This is defined as a convenience and should work in most cases. @@ -73,7 +75,7 @@ def uuid(self) -> Any: return InductorPass.hash_source(self) @staticmethod - def hash_source(*srcs: str | Any): + def hash_source(*srcs: str | Any) -> str: """ Utility method to hash the sources of functions or objects. :param srcs: strings or objects to add to the hash. @@ -93,7 +95,7 @@ def hash_source(*srcs: str | Any): return hasher.hexdigest() @staticmethod - def hash_dict(dict_: dict[Any, Any]): + def hash_dict(dict_: dict[Any, Any]) -> str: """ Utility method to hash a dictionary, can alternatively be used for uuid. :return: A sha256 hash of the json rep of the dictionary. @@ -101,7 +103,7 @@ def hash_dict(dict_: dict[Any, Any]): encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).hexdigest() - def is_applicable_for_range(self, compile_range: Range): + def is_applicable_for_range(self, compile_range: Range) -> bool: return True @@ -111,25 +113,27 @@ class CallableInductorPass(InductorPass): implementation of the UUID. """ - def __init__(self, callable: Callable[[fx.Graph], None], uuid: Any | None = None): + def __init__( + self, callable: Callable[[fx.Graph], None], uuid: Any | None = None + ) -> None: self.callable = callable self._uuid = self.hash_source(callable) if uuid is None else uuid - def __call__(self, graph: torch.fx.Graph): + def __call__(self, graph: torch.fx.Graph) -> None: self.callable(graph) def uuid(self) -> Any: return self._uuid -def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]: +def enable_fake_mode(fn: Callable[P, R]) -> Callable[P, R]: """ Applies a FakeTensorMode context. This is useful when you don't want to create or run things with real tensors. """ @functools.wraps(fn) - def fn_new(*args, **kwargs) -> Any: + def fn_new(*args: P.args, **kwargs: P.kwargs) -> R: with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode(): result = fn(*args, **kwargs) diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index 660fb9887e2c..2bad5f0a16fc 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -12,7 +12,7 @@ torch_compile_start_time: float = 0.0 -def start_monitoring_torch_compile(vllm_config: VllmConfig): +def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None: global torch_compile_start_time torch_compile_start_time = time.time() @@ -28,7 +28,7 @@ def start_monitoring_torch_compile(vllm_config: VllmConfig): context_manager.__enter__() -def end_monitoring_torch_compile(vllm_config: VllmConfig): +def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None: compilation_config: CompilationConfig = vllm_config.compilation_config if compilation_config.mode == CompilationMode.VLLM_COMPILE: logger.info_once( @@ -45,7 +45,7 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig): cudagraph_capturing_enabled: bool = True -def validate_cudagraph_capturing_enabled(): +def validate_cudagraph_capturing_enabled() -> None: # used to monitor whether a cudagraph capturing is legal at runtime. # should be called before any cudagraph capturing. # if an illegal cudagraph capturing happens, raise an error. @@ -57,6 +57,6 @@ def validate_cudagraph_capturing_enabled(): ) -def set_cudagraph_capturing_enabled(enabled: bool): +def set_cudagraph_capturing_enabled(enabled: bool) -> None: global cudagraph_capturing_enabled cudagraph_capturing_enabled = enabled diff --git a/vllm/compilation/partition_rules.py b/vllm/compilation/partition_rules.py index 08bd27e80952..18ebb15d1112 100644 --- a/vllm/compilation/partition_rules.py +++ b/vllm/compilation/partition_rules.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib +from collections.abc import Generator import torch @@ -38,7 +39,9 @@ def should_split(node: torch.fx.Node, splitting_ops: list[str]) -> bool: @contextlib.contextmanager -def inductor_partition_rule_context(splitting_ops: list[str]): +def inductor_partition_rule_context( + splitting_ops: list[str] | None, +) -> Generator[None, None, None]: """Context manager to temporarily register Inductor partition rules. Registers custom partition rules for specified operators, forcing the diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index a4046356bcda..bf81a62f257d 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -41,8 +41,8 @@ def __init__( self, epsilon: float, dtype: torch.dtype, - device: str, - ): + device: str | None, + ) -> None: self.epsilon = epsilon self.dtype = dtype self.device = device @@ -64,7 +64,7 @@ def _all_gather(self, x: torch.Tensor) -> torch.Tensor: class FirstAllReduceRMSNormPattern(_SequenceParallelPatternHelper): - def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None) -> None: super().__init__(epsilon, dtype, device) self.rmsnorm_matcher = MatcherRMSNorm(epsilon) @@ -74,7 +74,7 @@ def get_inputs(self): return [input, arg3_1] - def register(self, pm_pass: PatternMatcherPass): + def register(self, pm_pass: PatternMatcherPass) -> None: def pattern( input: torch.Tensor, arg3_1: torch.Tensor, @@ -100,7 +100,7 @@ def replacement( class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): - def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None): super().__init__(epsilon, dtype, device) self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) @@ -162,7 +162,7 @@ def __init__( self, epsilon: float, dtype: torch.dtype, - device: str, + device: str | None, ): super().__init__(epsilon, dtype, device) self.rmsnorm_matcher = MatcherRMSNorm(epsilon) @@ -203,7 +203,7 @@ def replacement( class MiddleAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): - def __init__(self, epsilon: float, dtype: torch.dtype, device: str): + def __init__(self, epsilon: float, dtype: torch.dtype, device: str | None): super().__init__(epsilon, dtype, device) self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon) self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) diff --git a/vllm/compilation/torch25_custom_graph_pass.py b/vllm/compilation/torch25_custom_graph_pass.py index 1031856cdf00..2da4190c416a 100644 --- a/vllm/compilation/torch25_custom_graph_pass.py +++ b/vllm/compilation/torch25_custom_graph_pass.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from typing import Any +from typing import Any, NoReturn import torch @@ -29,14 +29,14 @@ def uuid(self) -> Any | None: Return None to skip inductor code caching entirely. """ - def __getstate__(self): + def __getstate__(self) -> Any | None: """ Pickling is used instead of uuid() in torch<2.6. Just return uuid() to enable subclasses to only have to implement uuid. """ return self.uuid() - def __setstate__(self, state): + def __setstate__(self, state: Any) -> NoReturn: raise ValueError( "Cannot unpickle CustomGraphPass because pickling" " is used for cache key uuid. Use torch>=2.6 with" diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index 08721e3ae4a2..b64c892881f5 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -3,6 +3,7 @@ import functools import operator import time +from collections.abc import Callable from dataclasses import dataclass from typing import ClassVar @@ -43,13 +44,17 @@ def __init__(self, config: VllmConfig): ) self.pass_config = config.compilation_config.pass_config self.model_dtype = config.model_config.dtype if config.model_config else None - self.device = config.device_config.device if config.device_config else None + self.device: str | None = ( + config.device_config.device if config.device_config else None + ) self.pass_name = self.__class__.__name__ @staticmethod - def time_and_log(call_fn): + def time_and_log( + call_fn: Callable[["VllmInductorPass", torch.fx.Graph], None], + ) -> Callable[["VllmInductorPass", torch.fx.Graph], None]: @functools.wraps(call_fn) - def wrapped(self: VllmInductorPass, graph: torch.fx.Graph): + def wrapped(self: VllmInductorPass, graph: torch.fx.Graph) -> None: self.begin() self.dump_graph(graph, "before") call_fn(self, graph) @@ -58,17 +63,17 @@ def wrapped(self: VllmInductorPass, graph: torch.fx.Graph): return wrapped - def dump_graph(self, graph: torch.fx.Graph, stage: str): + def dump_graph(self, graph: torch.fx.Graph, stage: str) -> None: i = VllmInductorPass.dump_prefix i_str = "" if i is None else f".{i}" lazy_format_graph_code( f"post_grad{i_str}.{self.pass_name}.{stage}", graph.owning_module ) - def begin(self): + def begin(self) -> None: self._start_time = time.perf_counter_ns() - def end_and_log(self): + def end_and_log(self) -> None: self._end_time = time.perf_counter_ns() duration_ms = float(self._end_time - self._start_time) / 1.0e6 logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) @@ -92,12 +97,14 @@ class VllmPatternMatcherPass(VllmInductorPass): def _replace_op_overloads(self, string: str) -> str: """Replace with nicer formulations""" - return self._OP_OVERLOAD_PATTERN.sub( - lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}", - string, + return str( + self._OP_OVERLOAD_PATTERN.sub( + lambda m: f"torch.ops.{m.group(1)}.{m.group(2)}", + string, + ) ) - def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass): + def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass) -> None: """ If debug dumping is enabled, dump the Inductor pattern-matcher patterns into the debug_dump_path folder next to the dumped fx graphs. @@ -165,9 +172,9 @@ def dump_patterns(self, config: VllmConfig, pm_pass: PatternMatcherPass): class PrinterInductorPass(VllmInductorPass): - def __init__(self, name: str, config: VllmConfig): + def __init__(self, name: str, config: VllmConfig) -> None: super().__init__(config) self.name = name - def __call__(self, graph: torch.fx.Graph): + def __call__(self, graph: torch.fx.Graph) -> None: self.dump_graph(graph, self.name)