diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index cf66d2277721..5262e39b2a53 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -179,7 +179,7 @@ def load( example_inputs: list[Any], graph_index: int, compile_range: Range, - ) -> Callable | None: + ) -> Callable[..., Any] | None: if (compile_range, graph_index, self.compiler.name) not in self.cache: return None handle = self.cache[(compile_range, graph_index, self.compiler.name)] @@ -199,7 +199,7 @@ def compile( self, graph: fx.GraphModule, example_inputs: list[Any], - additional_inductor_config, + additional_inductor_config: dict[str, Any], compilation_config: CompilationConfig, compile_range: Range, graph_index: int = 0, @@ -355,7 +355,7 @@ def split_graph( compilation_start_time = 0.0 -class PiecewiseCompileInterpreter(torch.fx.Interpreter): +class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc] """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. It runs the given graph with fake inputs, and compile some submodules specified by `compile_submod_names` with the given @@ -506,9 +506,9 @@ class VllmBackend: # the stiching graph module for all the piecewise graphs split_gm: fx.GraphModule piecewise_graphs: list[SplitItem] - returned_callable: Callable + returned_callable: Callable[..., Any] # Inductor passes to run on the graph pre-defunctionalization - post_grad_passes: Sequence[Callable] + post_grad_passes: Sequence[Callable[..., Any]] sym_tensor_indices: list[int] input_buffers: list[torch.Tensor] compiler_manager: CompilerManager @@ -821,7 +821,7 @@ def __call__( ] # this is the callable we return to Dynamo to run - def copy_and_call(*args): + def copy_and_call(*args: Any) -> Any: list_args = list(args) for i, index in enumerate(self.sym_tensor_indices): runtime_tensor = list_args[index] diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index 8c9ec87bcad5..3d945e2ddd5f 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -4,6 +4,8 @@ import inspect import os import pickle +from collections.abc import Callable, Sequence +from typing import Any, Literal from unittest.mock import patch import torch @@ -25,7 +27,7 @@ logger = init_logger(__name__) -class VllmSerializableFunction(SerializableCallable): +class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] """ A wrapper around a compiled function by vllm. It will forward the tensor inputs to the compiled function and return the result. @@ -38,8 +40,13 @@ class VllmSerializableFunction(SerializableCallable): """ def __init__( - self, graph_module, example_inputs, prefix, optimized_call, is_encoder=False - ): + self, + graph_module: torch.fx.GraphModule, + example_inputs: Sequence[Any], + prefix: str, + optimized_call: Callable[..., Any], + is_encoder: bool = False, + ) -> None: assert isinstance(graph_module, torch.fx.GraphModule) self.graph_module = graph_module self.example_inputs = example_inputs @@ -53,7 +60,7 @@ def __init__( if sym_input is not None: self.shape_env = sym_input.node.shape_env - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.optimized_call(*args, **kwargs) @classmethod @@ -73,7 +80,9 @@ def serialize_compile_artifacts( graph_reducer_override = GraphPickler.reducer_override - def _graph_reducer_override(self, obj): + def _graph_reducer_override( + self: GraphPickler, obj: Any + ) -> tuple[Callable[..., Any], tuple[Any, ...]] | Any: if ( inspect.isclass(obj) and issubclass(obj, sympy.Function) @@ -114,7 +123,7 @@ def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction get_current_vllm_config(), state["prefix"], is_encoder ) - def optimized_call(*example_inputs): + def optimized_call(*example_inputs: Any) -> Any: """ On the first run of the optimized call, we rerun the compiler backend which should result in a cache hit. After the backend @@ -136,7 +145,7 @@ def optimized_call(*example_inputs): return fn @property - def co_name(self): + def co_name(self) -> Literal["VllmSerializableFunction"]: """ Used for depyf debugging. """ diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index fa5dce976f9f..7ffa74d0d7e6 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -42,7 +42,9 @@ class CUDAGraphLogging: "Count", ] - def __init__(self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None): + def __init__( + self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None + ) -> None: self.reset() self.cg_mode = str(cg_mode) self.cg_capture_sizes = str(cg_capture_sizes or []) @@ -54,10 +56,10 @@ def __init__(self, cg_mode: CUDAGraphMode, cg_capture_sizes: list[int] | None): "**CUDAGraph Stats:**\n\n" ) - def reset(self): - self.stats = [] + def reset(self) -> None: + self.stats: list[CUDAGraphStat] = [] - def observe(self, cudagraph_stat: CUDAGraphStat): + def observe(self, cudagraph_stat: CUDAGraphStat) -> None: self.stats.append(cudagraph_stat) def generate_metric_table(self) -> str: @@ -109,7 +111,7 @@ def generate_metric_table(self) -> str: + "\n" ) - def log(self, log_fn=logger.info): + def log(self, log_fn: Callable[..., Any] = logger.info) -> None: if not self.stats: return log_fn(self.generate_metric_table()) @@ -161,11 +163,11 @@ class CUDAGraphWrapper: def __init__( self, - runnable: Callable, + runnable: Callable[..., Any], vllm_config: VllmConfig, runtime_mode: CUDAGraphMode, cudagraph_options: CUDAGraphOptions | None = None, - ): + ) -> None: self.runnable = runnable self.vllm_config = vllm_config self.runtime_mode = runtime_mode @@ -189,7 +191,7 @@ def __init__( # cudagraphs for. self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {} - def __getattr__(self, key: str): + def __getattr__(self, key: str) -> Any: # allow accessing the attributes of the runnable. if hasattr(self.runnable, key): return getattr(self.runnable, key) @@ -198,11 +200,11 @@ def __getattr__(self, key: str): f"cudagraph wrapper: {self.runnable}" ) - def unwrap(self) -> Callable: + def unwrap(self) -> Callable[..., Any]: # in case we need to access the original runnable. return self.runnable - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any | None: forward_context = get_forward_context() batch_descriptor = forward_context.batch_descriptor cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 457fec930397..943220244fac 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -6,8 +6,8 @@ import inspect import os import sys -from collections.abc import Callable -from typing import TypeVar, overload +from collections.abc import Callable, Generator +from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload from unittest.mock import patch import torch @@ -32,6 +32,14 @@ from .monitor import start_monitoring_torch_compile +if TYPE_CHECKING: + # Only added on nightly/2.10 so wrap + try: + from torch._dynamo.package import SourceInfo + except ImportError: + # Fallback for old versions not supporting + SourceInfo = Any + logger = init_logger(__name__) IGNORE_COMPILE_KEY = "_ignore_compile_vllm" @@ -59,7 +67,7 @@ def ignore_torch_compile(cls: _T) -> _T: return cls -def _should_ignore_torch_compile(cls) -> bool: +def _should_ignore_torch_compile(cls: _T) -> bool: """ Check if the class should be ignored for torch.compile. """ @@ -224,7 +232,7 @@ def cls_decorator_helper(cls: _T) -> _T: return cls_decorator_helper -def _model_hash_key(fn) -> str: +def _model_hash_key(fn: Callable[..., Any]) -> str: import vllm sha256_hash = hashlib.sha256() @@ -234,7 +242,9 @@ def _model_hash_key(fn) -> str: return sha256_hash.hexdigest() -def _verify_source_unchanged(source_info, vllm_config) -> None: +def _verify_source_unchanged( + source_info: "SourceInfo", vllm_config: VllmConfig +) -> None: from .caching import _compute_code_hash, _compute_code_hash_with_content file_contents = {} @@ -275,8 +285,12 @@ def _support_torch_compile( setattr(cls, IGNORE_COMPILE_KEY, False) def __init__( - self, *, vllm_config: VllmConfig | None = None, prefix: str = "", **kwargs - ): + self: _T, + *, + vllm_config: VllmConfig | None = None, + prefix: str = "", + **kwargs: Any, + ) -> None: if vllm_config is None: vllm_config = get_current_vllm_config() @@ -309,13 +323,17 @@ def __init__( compilation_counter.num_models_seen += 1 self.compiled = False - TorchCompileWithNoGuardsWrapper.__init__(self) + + # Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class + TorchCompileWithNoGuardsWrapper.__init__(self) # type: ignore[arg-type] cls.__init__ = __init__ - def _mark_dynamic_inputs(mod, type, *args, **kwargs): - def mark_dynamic(arg, dims): - if type == DynamicShapesType.UNBACKED: + def _mark_dynamic_inputs( + mod: _T, ds_type: DynamicShapesType, *args: Any, **kwargs: Any + ) -> None: + def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None: + if ds_type == DynamicShapesType.UNBACKED: if is_torch_equal_or_newer("2.10.0.dev"): for dim in dims: torch._dynamo.decorators.mark_unbacked( @@ -326,7 +344,7 @@ def mark_dynamic(arg, dims): else: torch._dynamo.mark_dynamic(arg, dims) - sig = inspect.signature(mod.__class__.forward) + sig = inspect.signature(mod.__class__.forward) # type: ignore[attr-defined] bound_args = sig.bind(mod, *args, **kwargs) bound_args.apply_defaults() for k, dims in dynamic_arg_dims.items(): @@ -364,7 +382,7 @@ def mark_dynamic(arg, dims): else: torch._dynamo.decorators.mark_unbacked(arg, dims) - def __call__(self, *args, **kwargs): + def __call__(self: _T, *args: Any, **kwargs: Any) -> Any: # torch.compiler.is_compiling() means we are inside the compilation # e.g. TPU has the compilation logic in model runner, so we don't # need to compile the model inside. @@ -444,7 +462,7 @@ def __call__(self, *args, **kwargs): not envs.VLLM_USE_AOT_COMPILE or self.vllm_config.compilation_config.backend == "eager" ) - return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) + return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type] # This is the path for the first compilation. # the first compilation needs to have dynamic shapes marked @@ -477,7 +495,7 @@ def __call__(self, *args, **kwargs): # during Dynamo tracing, and their corresponding files inline_call = InliningInstructionTranslator.inline_call_ - def patched_inline_call(self_): + def patched_inline_call(self_: Any) -> Any: code = self_.f_code self.compilation_config.traced_files.add(code.co_filename) return inline_call(self_) @@ -535,7 +553,7 @@ def patched_inline_call(self_): str(e), ) else: - output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) + output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type] self.compiled = True return output @@ -545,7 +563,9 @@ def patched_inline_call(self_): @contextlib.contextmanager -def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig): +def maybe_use_cudagraph_partition_wrapper( + vllm_config: VllmConfig, +) -> Generator[None, None, None]: """ Context manager to set/unset customized cudagraph partition wrappers. @@ -572,7 +592,9 @@ def maybe_use_cudagraph_partition_wrapper(vllm_config: VllmConfig): current_platform.get_static_graph_wrapper_cls() ) - def customized_cudagraph_wrapper(f, metadata: CUDAGraphWrapperMetadata): + def customized_cudagraph_wrapper( + f: Callable[..., Any], metadata: CUDAGraphWrapperMetadata + ) -> Any: partition_id = metadata.partition_index num_partitions = metadata.num_partitions return static_graph_wrapper_class( @@ -600,7 +622,7 @@ def customized_cudagraph_wrapper(f, metadata: CUDAGraphWrapperMetadata): @contextlib.contextmanager -def _torch27_patch_tensor_subclasses(): +def _torch27_patch_tensor_subclasses() -> Generator[None, None, None]: """ Add support for using tensor subclasses (ie `BasevLLMParameter`, ect) when using torch 2.7.0. This enables using weight_loader_v2 and the use of @@ -614,7 +636,7 @@ def _torch27_patch_tensor_subclasses(): _ColumnvLLMParameter, ) - def return_false(*args, **kwargs): + def return_false(*args: Any, **kwargs: Any) -> Literal[False]: return False if version.parse("2.7") <= version.parse(torch.__version__) < version.parse("2.8"): diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 2625562aadd3..ce37968c9918 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -26,7 +26,7 @@ class FixFunctionalizationPass(VllmInductorPass): """ @VllmInductorPass.time_and_log - def __call__(self, graph: torch.fx.Graph): + def __call__(self, graph: torch.fx.Graph) -> None: # XPU does not support auto-functionalization yet. # Will enable this when switch to vllm-xpu-kernels. if current_platform.is_xpu(): @@ -179,7 +179,7 @@ def __call__(self, graph: torch.fx.Graph): ) self.nodes_to_remove.clear() - def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]): + def _remove(self, node_or_nodes: torch.fx.Node | Iterable[torch.fx.Node]) -> None: """ Stage a node (or nodes) for removal at the end of the pass. """ @@ -194,7 +194,7 @@ def defunctionalize( node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str], args: tuple[torch.fx.Node | str, ...] | None = None, - ): + ) -> None: """ De-functionalize a node by replacing it with a call to the original. It also replaces the getitem users with the mutated arguments. @@ -206,7 +206,7 @@ def defunctionalize( def replace_users_with_mutated_args( self, node: torch.fx.Node, mutated_args: dict[int, torch.fx.Node | str] - ): + ) -> None: """ Replace all getitem users of the auto-functionalized node with the mutated arguments. @@ -237,7 +237,7 @@ def insert_defunctionalized( graph: torch.fx.Graph, node: torch.fx.Node, args: tuple[torch.fx.Node | str, ...] | None = None, - ): + ) -> None: """ Insert a new defunctionalized node into the graph before node. If one of the kwargs is 'out', provide args directly, diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 56b4554c88ef..93b2612f2a59 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -29,6 +29,9 @@ Torch25CustomGraphPass as CustomGraphPass, ) +# Re-export CustomGraphPass for external usage +__all__ = ["CustomGraphPass"] + _pass_context = None P = ParamSpec("P") R = TypeVar("R") diff --git a/vllm/compilation/noop_elimination.py b/vllm/compilation/noop_elimination.py index 06e1771bac96..9af904b457a6 100644 --- a/vllm/compilation/noop_elimination.py +++ b/vllm/compilation/noop_elimination.py @@ -65,7 +65,7 @@ class NoOpEliminationPass(VllmInductorPass): """ @VllmInductorPass.time_and_log - def __call__(self, graph: torch.fx.Graph): + def __call__(self, graph: torch.fx.Graph) -> None: count = 0 # Remove no-op reshapes/views: for node in graph.nodes: @@ -117,7 +117,7 @@ def dims_equivalent(self, dim: int | SymInt, i_dim: int | SymInt) -> bool: 2. The dimensions both correspond to the same SymInt """ # Case 1 - return statically_known_true(dim == i_dim) + return statically_known_true(dim == i_dim) # type: ignore[no-any-return] def all_dims_equivalent( self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt] diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index 4c2dee505a94..a207edd93905 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools +from collections.abc import Callable +from typing import Any, ParamSpec, TypeVar from torch import fx as fx @@ -40,8 +42,11 @@ logger = init_logger(__name__) +P = ParamSpec("P") +R = TypeVar("R") -def with_pattern_match_debug(fn): + +def with_pattern_match_debug(fn: Callable[P, R]) -> Callable[P, R]: """ Function decorator that turns on inductor pattern match debug for the duration of the call. @@ -49,7 +54,7 @@ def with_pattern_match_debug(fn): """ @functools.wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: if (debug_val := envs.VLLM_PATTERN_MATCH_DEBUG) is not None: # optionally check rank here with set_env_var("TORCHINDUCTOR_PATTERN_MATCH_DEBUG", debug_val): @@ -59,7 +64,7 @@ def wrapper(*args, **kwargs): return wrapper -class PostGradPassManager(CustomGraphPass): +class PostGradPassManager(CustomGraphPass): # type: ignore[misc] """ The pass manager for post-grad passes. It handles configuration, adding custom passes, and running passes. @@ -74,11 +79,11 @@ class PostGradPassManager(CustomGraphPass): This way, all passes operate on a functionalized graph. """ - def __init__(self): + def __init__(self) -> None: self.passes: list[InductorPass] = [] @with_pattern_match_debug - def __call__(self, graph: fx.Graph): + def __call__(self, graph: fx.Graph) -> None: VllmInductorPass.dump_prefix = 0 # reset dump index compile_range = get_pass_context().compile_range @@ -98,7 +103,7 @@ def __call__(self, graph: fx.Graph): self.fix_functionalization(graph) VllmInductorPass.dump_prefix = None # Cleanup index - def configure(self, config: VllmConfig): + def configure(self, config: VllmConfig) -> None: self.pass_config = config.compilation_config.pass_config # Set the current vllm config to allow tracing CustomOp instances @@ -135,23 +140,25 @@ def configure(self, config: VllmConfig): self.post_cleanup = PostCleanupPass(config) self.fix_functionalization = FixFunctionalizationPass(config) - def add(self, pass_: InductorPass): + def add(self, pass_: InductorPass) -> None: assert isinstance(pass_, InductorPass) self.passes.append(pass_) - def uuid(self): + def uuid(self) -> str: """ The PostGradPassManager is set as a custom pass in the Inductor and affects compilation caching. Its uuid depends on the UUIDs of all dependent passes and the pass config. See InductorPass for more info. """ - state = {"pass_config": self.pass_config.compute_hash(), "passes": []} + passes = [] + + state: dict[str, Any] = {"pass_config": self.pass_config.compute_hash()} for pass_ in self.passes: - state["passes"].append(pass_.uuid()) - state["passes"].append(self.fix_functionalization.uuid()) + passes.append(pass_.uuid()) + passes.append(self.fix_functionalization.uuid()) # Include the compile range in the uuid to ensure that inductor # recompiles the graph for the new dynamic compile range. state["compile_range"] = str(get_pass_context().compile_range) - + state["passes"] = passes return InductorPass.hash_dict(state) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 12cc49971e08..29d6f89990cd 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -86,27 +86,36 @@ def __init__( self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges) # We only keep compilation management inside this class directly. - for size in self.compile_sizes: - range = Range(start=size, end=size) - if range not in self.compile_ranges: - self.range_entries[range] = RangeEntry( - compile_range=range, - ) - self.to_be_compiled_ranges.add(range) + if self.compile_sizes is not None: + for size in self.compile_sizes: + if isinstance(size, str): + assert size == "cudagraph_capture_sizes" + raise NotImplementedError( + "cudagraph_capture_sizes not supported in compile_sizes." + "This should be handled in `post_init_cudagraph_sizes`." + ) + else: + assert isinstance(size, int) + range = Range(start=size, end=size) + if range not in self.compile_ranges: + self.range_entries[range] = RangeEntry( + compile_range=range, + ) + self.to_be_compiled_ranges.add(range) for range in self.compile_ranges: self.range_entries[range] = RangeEntry( compile_range=range, ) - def check_for_ending_compilation(self): + def check_for_ending_compilation(self) -> None: if self.is_last_graph and not self.to_be_compiled_ranges: # no specific sizes to compile # save the hash of the inductor graph for the next run self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) - def _fakify_args(self, args: list[Any]) -> list[Any]: + def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]: # We need to pass fake example_inputs, otherwise torch.compile # will fakify the example_inputs potentially causing some non dynamic # dimension to be be duck shaped to other existing shapes that have hints @@ -127,7 +136,9 @@ def _fakify_args(self, args: list[Any]) -> list[Any]: assert len(fake_example_inputs) == len(args) return fake_example_inputs - def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: + def _maybe_compile_for_range_entry( + self, range_entry: RangeEntry, args: tuple[Any, ...] + ) -> Any: if not range_entry.compiled: range_entry.compiled = True self.to_be_compiled_ranges.remove(range_entry.compile_range) @@ -136,14 +147,14 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: # fakify for range, real args for concrete size. # For concrete size, we clear the shape env in # compiler_manager.compile() so no need to fakify. - args = ( + args_list = ( self._fakify_args(args) if not range_entry.compile_range.is_single_size() - else args + else list(args) ) range_entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, - args, + args_list, self.vllm_backend.inductor_config, self.compilation_config, compile_range=range_entry.compile_range, @@ -153,10 +164,13 @@ def _maybe_compile_for_range_entry(self, range_entry: RangeEntry, args) -> Any: self.check_for_ending_compilation() - def _find_range_for_shape(self, runtime_shape: int) -> Range | None: + def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None: # First we try to find the range entry for the concrete compile size # If not found, we search for the range entry # that contains the runtime shape. + if self.compile_sizes is None: + return None + if runtime_shape in self.compile_sizes: return self.range_entries[Range(start=runtime_shape, end=runtime_shape)] else: @@ -165,7 +179,7 @@ def _find_range_for_shape(self, runtime_shape: int) -> Range | None: return self.range_entries[range] return None - def __call__(self, *args) -> Any: + def __call__(self, *args: Any) -> Any: runtime_shape = args[self.sym_shape_indices[0]] range_entry = self._find_range_for_shape(runtime_shape) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 02e974b0f9e8..62574d8072d2 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -4,9 +4,10 @@ import os import sys from abc import abstractmethod +from collections.abc import Callable, Generator from contextlib import contextmanager, nullcontext from types import CodeType -from typing import Any +from typing import Any, ParamSpec, TypeVar import torch import torch._C._dynamo.guards @@ -19,19 +20,26 @@ logger = init_logger(__name__) +R = TypeVar("R") +P = ParamSpec("P") -def _noop_add_global_state_guard(self, *args, **kwargs): + +def _noop_add_global_state_guard( + self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any +) -> None: """No-op to skip the GLOBAL_STATE guard entirely""" pass -def _noop_add_torch_function_mode_stack_guard(self, *args, **kwargs): +def _noop_add_torch_function_mode_stack_guard( + self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any +) -> None: """No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely""" pass @contextmanager -def _compilation_context(): +def _compilation_context() -> Generator[None, None, None]: """Context manager for compilation settings and patches. This manager: @@ -88,13 +96,15 @@ class TorchCompileWithNoGuardsWrapper: since we drop all guards. """ - def check_invariants_and_forward(self, *args, **kwargs): + def check_invariants_and_forward(self, *args: Any, **kwargs: Any) -> Any: assert hasattr(self, "_check_shape_invariants") self._check_shape_invariants(*args, **kwargs) return self.forward(*args, **kwargs) - def _call_with_optional_nvtx_range(self, callable_fn, *args, **kwargs): + def _call_with_optional_nvtx_range( + self, callable_fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs + ) -> Any: if self.layerwise_nvtx_tracing_enabled: args_list = list(args) kwargs_dict = dict(kwargs) @@ -108,7 +118,7 @@ def _call_with_optional_nvtx_range(self, callable_fn, *args, **kwargs): return ctx.result return callable_fn(*args, **kwargs) - def __init__(self): + def __init__(self) -> None: self.compiled = False vllm_config = get_current_vllm_config() @@ -192,9 +202,9 @@ def __init__(self): if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE: torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) - self._compiled_bytecode = None + self._compiled_bytecode: CodeType | None = None - def aot_compile(self, *args, **kwargs): + def aot_compile(self, *args: Any, **kwargs: Any) -> Any: if not hasattr(self._compiled_callable, "aot_compile"): raise RuntimeError( "aot_compile is not supported by the current configuration. " @@ -203,7 +213,7 @@ def aot_compile(self, *args, **kwargs): ) return self._compiled_callable.aot_compile((args, kwargs)) - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: if envs.VLLM_USE_BYTECODE_HOOK: if ( self.vllm_config.compilation_config.mode @@ -236,13 +246,13 @@ def __call__(self, *args, **kwargs): ) @abstractmethod - def forward(self, *args, **kwargs): ... + def forward(self, *args: Any, **kwargs: Any) -> Any: ... def original_code_object(self) -> CodeType: """Return the original code object of the forward method.""" return self.__class__.forward.__code__ - def bytecode_hook(self, old_code: CodeType, new_code: CodeType): + def bytecode_hook(self, old_code: CodeType, new_code: CodeType) -> None: """Hook to save the compiled bytecode for direct execution.""" if old_code is not self.original_code_object(): return @@ -299,7 +309,7 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType): raise RuntimeError(msg) @contextmanager - def _dispatch_to_compiled_code(self): + def _dispatch_to_compiled_code(self) -> Generator[None, None, None]: # noqa: E501 """ Context manager to dispatch to internally compiled code for torch<2.8. diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index c8a531c02fd7..ed94418b2316 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -32,6 +32,9 @@ logger = init_logger(__name__) +# Explicitly exports Range +__all__ = ["Range"] + class CompilationMode(enum.IntEnum): """The compilation approach used for torch.compile-based compilation of the diff --git a/vllm/distributed/device_communicators/pynccl_allocator.py b/vllm/distributed/device_communicators/pynccl_allocator.py index 2e5d94de9d01..0ce307bc596c 100644 --- a/vllm/distributed/device_communicators/pynccl_allocator.py +++ b/vllm/distributed/device_communicators/pynccl_allocator.py @@ -60,7 +60,7 @@ def is_symmetric_memory_tensor(tensor: torch.Tensor): return False -def set_graph_pool_id(graph_pool_id): +def set_graph_pool_id(graph_pool_id: Any) -> None: global _graph_pool_id _graph_pool_id = graph_pool_id