diff --git a/pyproject.toml b/pyproject.toml index a5d41c673866..629f58c132bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,13 @@ ignore_missing_imports = true check_untyped_defs = true follow_imports = "silent" +[[tool.mypy.overrides]] +module = "vllm.compilation.*" +disallow_untyped_defs = true +disallow_incomplete_defs = true +warn_return_any = true +follow_imports = "silent" + [tool.pytest.ini_options] markers = [ "slow_test", diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py index 6ed77b0085f5..df8e5b69fc51 100644 --- a/tests/compile/test_pass_manager.py +++ b/tests/compile/test_pass_manager.py @@ -28,7 +28,7 @@ def test_bad_callable(): pass_manager.configure(config) with pytest.raises(AssertionError): - pass_manager.add(simple_callable) + pass_manager.add(simple_callable) # type: ignore[arg-type] # Pass that inherits from InductorPass diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index 48803930d7b5..5bbe01f958c9 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -76,6 +76,11 @@ "vllm/v1/attention/ops", ] +# Directories that should be checked with --strict +STRICT_DIRS = [ + "vllm/compilation", +] + def group_files(changed_files: list[str]) -> dict[str, list[str]]: """ @@ -107,11 +112,17 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]: return file_groups +def is_strict_file(filepath: str) -> bool: + """Check if a file should be checked with strict mode.""" + return any(filepath.startswith(strict_dir) for strict_dir in STRICT_DIRS) + + def mypy( targets: list[str], python_version: str | None, follow_imports: str | None, file_group: str, + strict: bool = False, ) -> int: """ Run mypy on the given targets. @@ -123,6 +134,7 @@ def mypy( follow_imports: Value for the --follow-imports option or None to use the default mypy behavior. file_group: The file group name for logging purposes. + strict: If True, run mypy with --strict flag. Returns: The return code from mypy. @@ -132,6 +144,8 @@ def mypy( args += ["--python-version", python_version] if follow_imports is not None: args += ["--follow-imports", follow_imports] + if strict: + args += ["--strict"] print(f"$ {' '.join(args)} {file_group}") return subprocess.run(args + targets, check=False).returncode @@ -148,9 +162,29 @@ def main(): for file_group, changed_files in file_groups.items(): follow_imports = None if ci and file_group == "" else "skip" if changed_files: - returncode |= mypy( - changed_files, python_version, follow_imports, file_group - ) + # Separate files into strict and non-strict groups + strict_files = [f for f in changed_files if is_strict_file(f)] + non_strict_files = [f for f in changed_files if not is_strict_file(f)] + + # Run mypy on non-strict files + if non_strict_files: + returncode |= mypy( + non_strict_files, + python_version, + follow_imports, + file_group, + strict=False, + ) + + # Run mypy on strict files with --strict flag + if strict_files: + returncode |= mypy( + strict_files, + python_version, + follow_imports, + f"{file_group} (strict)", + strict=True, + ) return returncode diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 4e797e42dfd2..454d81317ebd 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -68,7 +68,7 @@ def make_copy_and_call( A wrapper function that copies inputs and calls the compiled function """ - def copy_and_call(*args): + def copy_and_call(*args: Any) -> Any: list_args = list(args) for i, index in enumerate(sym_tensor_indices): runtime_tensor = list_args[index] diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index 6df23c6a4581..07f9db4190b9 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -43,15 +43,15 @@ class StandaloneCompiledArtifacts: split on attn) """ - def __init__(self): + def __init__(self) -> None: # dict from submodule name to byte hash - self.submodule_bytes = {} + self.submodule_bytes: dict[str, str] = {} # dict from byte hash to bytes - self.submodule_bytes_store = {} + self.submodule_bytes_store: dict[str, bytes] = {} # dict from byte hash to loaded module - self.loaded_submodule_store = {} + self.loaded_submodule_store: dict[str, Any] = {} - def insert(self, submod_name: str, shape: str, entry: bytes): + def insert(self, submod_name: str, shape: str, entry: bytes) -> None: hasher = hashlib.sha256() hasher.update(entry) hex_digest = hasher.hexdigest() @@ -86,7 +86,7 @@ def get(self, submod_name: str, shape: str) -> bytes: self.submodule_bytes[f"{submod_name}_{shape}"] ] - def get_loaded(self, submod_name: str, shape: str): + def get_loaded(self, submod_name: str, shape: str) -> Any: logger.debug( "getting artifact for submod %s with shape %s", submod_name, @@ -119,7 +119,7 @@ def load_all(self) -> None: from torch._inductor.standalone_compile import AOTCompiledArtifact - def _load_entry(entry_bytes) -> AOTCompiledArtifact: + def _load_entry(entry_bytes: bytes) -> AOTCompiledArtifact: entry = pickle.loads(entry_bytes) return AOTCompiledArtifact.deserialize(entry) @@ -132,13 +132,13 @@ def _load_entry(entry_bytes) -> AOTCompiledArtifact: logger.debug("loaded all %s submodules", self.num_artifacts()) - def __getstate__(self): + def __getstate__(self) -> dict[str, dict[str, str] | dict[str, bytes]]: return { "submodule_bytes": self.submodule_bytes, "submodule_bytes_store": self.submodule_bytes_store, } - def __setstate__(self, state): + def __setstate__(self, state: dict[str, dict[str, Any]]) -> None: self.submodule_bytes = state["submodule_bytes"] self.submodule_bytes_store = state["submodule_bytes_store"] self.loaded_submodule_store = {} @@ -387,7 +387,7 @@ def reconstruct_serializable_fn_from_mega_artifact( standalone_compile_artifacts.load_all() submod_names = standalone_compile_artifacts.submodule_names() - compiled_callables: dict[str, dict[str, Callable]] = {} + compiled_callables: dict[str, dict[str, Callable[..., Any]]] = {} for cache_key in standalone_compile_artifacts.submodule_bytes: submod_name, shape_str = cache_key.rsplit("_", 1) @@ -495,9 +495,10 @@ def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str: # e.g. exec(). We can't actually check these. continue hash_content.append(content) - return safe_hash( + result: str = safe_hash( "\n".join(hash_content).encode(), usedforsecurity=False ).hexdigest() + return result def _compute_code_hash(files: set[str]) -> str: diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 4200071310ac..b5d162209d2f 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -30,19 +30,15 @@ FP8_DTYPE = current_platform.fp8_dtype() +flashinfer_comm: ModuleType | None = None if find_spec("flashinfer"): try: - import flashinfer.comm as flashinfer_comm + import flashinfer.comm as _flashinfer_comm - flashinfer_comm: ModuleType | None = ( # type: ignore[no-redef] - flashinfer_comm - if hasattr(flashinfer_comm, "trtllm_allreduce_fusion") - else None - ) + if hasattr(_flashinfer_comm, "trtllm_allreduce_fusion"): + flashinfer_comm = _flashinfer_comm except ImportError: - flashinfer_comm = None # type: ignore[assignment] -else: - flashinfer_comm = None # type: ignore[assignment] + pass logger = init_logger(__name__) @@ -441,7 +437,7 @@ def is_applicable_for_range(self, compile_range: Range) -> bool: ): return True tp_size = get_tensor_model_parallel_world_size() - return compile_range.is_single_size() and compile_range.end % tp_size == 0 + return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph) -> None: @@ -516,7 +512,7 @@ def call_trtllm_fused_allreduce_norm( # Get one shot input size limit for the current world size # for the current device capability max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get( - device_capability, # type: ignore[arg-type] + device_capability, # type: ignore[arg-type, unused-ignore] {}, ).get(world_size, None) # Use one shot if no max size is specified @@ -666,6 +662,7 @@ def replacement( ) -> tuple[torch.Tensor, torch.Tensor]: residual = torch.zeros_like(input) rms_result = torch.empty_like(input) + assert flashinfer_comm is not None, "FlashInfer must be enabled" allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -722,6 +719,7 @@ def pattern( def replacement( residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: + assert flashinfer_comm is not None, "FlashInfer must be enabled" allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -800,6 +798,7 @@ def replacement( residual = torch.zeros_like(input) result_rms = torch.empty_like(input) result_quant = torch.empty_like(input, dtype=self.quant_dtype) + assert flashinfer_comm is not None, "FlashInfer must be enabled" allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -875,6 +874,7 @@ def replacement( scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: result_quant = torch.empty_like(input, dtype=self.quant_dtype) + assert flashinfer_comm is not None, "FlashInfer must be enabled" allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -960,6 +960,7 @@ def replacement( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: residual = torch.zeros_like(input) result_rms = torch.empty_like(input) + assert flashinfer_comm is not None, "FlashInfer must be enabled" allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -1055,6 +1056,7 @@ def replacement( weight: torch.Tensor, input_global_scale: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert flashinfer_comm is not None, "FlashInfer must be enabled" allreduce = auto_functionalized( flashinfer_trtllm_fused_allreduce_norm, allreduce_in=input, @@ -1131,7 +1133,7 @@ def __init__(self, config: VllmConfig) -> None: ) self.ipc_handles, workspace_tensor = ( - flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( # type: ignore[misc] + flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( tp_rank=rank, tp_size=self.tp_size, max_token_num=self.max_token_num, @@ -1204,7 +1206,7 @@ def is_applicable_for_range(self, compile_range: Range) -> bool: if self.disabled: logger.warning_once("AllReduce fusion pass is disabled.") return False - return compile_range.end <= self.max_token_num + return bool(compile_range.end <= self.max_token_num) @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph) -> None: diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 3c7e44da8beb..8f7c1cfcd072 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -201,9 +201,9 @@ def __init__(self, save_format: Literal["binary", "unpacked"]) -> None: def compute_hash(self, vllm_config: VllmConfig) -> str: factors = get_inductor_factors() - hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[ - :10 - ] + hash_str: str = safe_hash( + str(factors).encode(), usedforsecurity=False + ).hexdigest()[:10] return hash_str def initialize_cache( @@ -319,9 +319,9 @@ class InductorAdaptor(CompilerInterface): def compute_hash(self, vllm_config: VllmConfig) -> str: factors = get_inductor_factors() - hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[ - :10 - ] + hash_str: str = safe_hash( + str(factors).encode(), usedforsecurity=False + ).hexdigest()[:10] return hash_str def initialize_cache( diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 0699bd67c016..3df5ea6d636d 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -45,10 +45,10 @@ IGNORE_COMPILE_KEY = "_ignore_compile_vllm" -_T = TypeVar("_T", bound=type[nn.Module]) +_T = TypeVar("_T", bound=nn.Module) -def ignore_torch_compile(cls: _T) -> _T: +def ignore_torch_compile(cls: type[_T]) -> type[_T]: """ A decorator to ignore support_torch_compile decorator on the class. This is useful when a parent class has @@ -68,7 +68,7 @@ def ignore_torch_compile(cls: _T) -> _T: return cls -def _should_ignore_torch_compile(cls: _T) -> bool: +def _should_ignore_torch_compile(cls: type[_T]) -> bool: """ Check if the class should be ignored for torch.compile. """ @@ -79,21 +79,21 @@ def _should_ignore_torch_compile(cls: _T) -> bool: def support_torch_compile( *, enable_if: Callable[[VllmConfig], bool] | None = None, -) -> Callable[[_T], _T]: ... +) -> Callable[[type[_T]], type[_T]]: ... @overload def support_torch_compile( *, dynamic_arg_dims: dict[str, int | list[int]] | None, -) -> Callable[[_T], _T]: ... +) -> Callable[[type[_T]], type[_T]]: ... @overload def support_torch_compile( *, mark_unbacked_dims: dict[str, int | list[int]] | None, -) -> Callable[[_T], _T]: ... +) -> Callable[[type[_T]], type[_T]]: ... @overload @@ -101,21 +101,21 @@ def support_torch_compile( *, dynamic_arg_dims: dict[str, int | list[int]] | None, mark_unbacked_dims: dict[str, int | list[int]] | None, -) -> Callable[[_T], _T]: ... +) -> Callable[[type[_T]], type[_T]]: ... @overload -def support_torch_compile(cls: _T) -> _T: ... +def support_torch_compile(cls: type[_T]) -> type[_T]: ... def support_torch_compile( - cls: _T | None = None, + cls: type[_T] | None = None, *, dynamic_arg_dims: dict[str, int | list[int]] | None = None, mark_unbacked_dims: dict[str, int | list[int]] | None = None, enable_if: Callable[[VllmConfig], bool] | None = None, shape_invariants: Callable[..., None] = lambda *args, **kwargs: None, -) -> Callable[[_T], _T] | _T: +) -> Callable[[type[_T]], type[_T]] | type[_T]: """ A decorator to add support for compiling the forward method of a class. @@ -182,7 +182,7 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ... errors. """ - def cls_decorator_helper(cls: _T) -> _T: + def cls_decorator_helper(cls: type[_T]) -> type[_T]: # helper to pass `dynamic_arg_dims` to `_support_torch_compile` # to avoid too much indentation for `_support_torch_compile` if not hasattr(cls, "forward"): @@ -263,12 +263,12 @@ def _verify_source_unchanged( def _support_torch_compile( - cls: _T, + cls: type[_T], dynamic_arg_dims: dict[str, int | list[int]], mark_unbacked_dims: dict[str, int | list[int]] | None = None, enable_if: Callable[[VllmConfig], bool] | None = None, shape_invariants: Callable[..., None] = lambda *args, **kwargs: None, -) -> _T: +) -> type[_T]: """ A decorator to add support for compiling the forward method of a class. """ @@ -325,12 +325,12 @@ def __init__( self.compiled = False # Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class - TorchCompileWithNoGuardsWrapper.__init__(self) # type: ignore[arg-type] + TorchCompileWithNoGuardsWrapper.__init__(self) cls.__init__ = __init__ def _mark_dynamic_inputs( - mod: _T, ds_type: DynamicShapesType, *args: Any, **kwargs: Any + mod: type[_T], ds_type: DynamicShapesType, *args: Any, **kwargs: Any ) -> None: def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None: if ds_type == DynamicShapesType.UNBACKED: @@ -382,7 +382,7 @@ def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None: else: torch._dynamo.decorators.mark_unbacked(arg, dims) - def __call__(self: _T, *args: Any, **kwargs: Any) -> Any: + def __call__(self: type[_T], *args: Any, **kwargs: Any) -> Any: # torch.compiler.is_compiling() means we are inside the compilation # e.g. TPU has the compilation logic in model runner, so we don't # need to compile the model inside. @@ -564,7 +564,7 @@ def patched_inline_call(self_: Any) -> Any: return output # triggers VllmSerializableFunction.serialize() - def save_aot_compiled_function(self): + def save_aot_compiled_function(self: type[_T]) -> None: if self.was_aot_compile_fn_loaded_from_disk: logger.debug("AOT compiled function was loaded from cache, skipping save") return diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index 7bb98db5e9f6..8fd7a46173e9 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -141,15 +141,18 @@ def forward_native( key: torch.Tensor | None, cos_sin_cache: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor | None]: - return RotaryEmbedding.forward_static( - positions, - query, - key, - self.head_size, - self.rotary_dim, - cos_sin_cache, - self.is_neox, + result: tuple[torch.Tensor, torch.Tensor | None] = ( + RotaryEmbedding.forward_static( + positions, + query, + key, + self.head_size, + self.rotary_dim, + cos_sin_cache, + self.is_neox, + ) ) + return result class MatcherRMSNorm(MatcherCustomOp): @@ -275,9 +278,10 @@ def forward_native( weight: torch.Tensor, residual: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - return RMSNorm.forward_static( + result: tuple[torch.Tensor, torch.Tensor] = RMSNorm.forward_static( input, self.epsilon, input.size(-1), self.model_dtype, weight, residual ) + return result class MatcherQuantFP8(MatcherCustomOp): diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 83d1a73bdb92..e7595e3063db 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -25,7 +25,7 @@ class RangeEntry: compile_range: Range compiled: bool = False - runnable: Callable = None # type: ignore + runnable: Callable[..., Any] = None # type: ignore class PiecewiseBackend: @@ -38,7 +38,7 @@ def __init__( sym_shape_indices: list[int], vllm_backend: VllmBackend, returns_tuple: bool, - compiled_runnables: dict[str, Callable] | None = None, + compiled_runnables: dict[str, Callable[..., Any]] | None = None, ): """ The backend for piecewise compilation. @@ -138,8 +138,10 @@ def __init__( self.on_compilation_complete = _on_compilation_complete_callback.get() - def get_compiled_graph_wrapper(self, compiled_graph): - def compiled_graph_wrapper(*args): + def get_compiled_graph_wrapper( + self, compiled_graph: Callable[..., Any] + ) -> Callable[..., Any]: + def compiled_graph_wrapper(*args: Any) -> Any: graph_output = compiled_graph(*args) # unpack the tuple if needed # TODO(rzou): the implication is that we're not @@ -163,7 +165,7 @@ def check_for_ending_compilation(self) -> None: def to_bytes(self) -> dict[str, bytes]: class StandaloneCompiledArtifactsPickler(Pickler): - def reducer_override(self, obj): + def reducer_override(self, obj: object) -> Any: if isinstance(obj, CachingAutotuner): obj.prepare_for_pickle() return pickle.loads, ( @@ -173,7 +175,7 @@ def reducer_override(self, obj): ) return NotImplemented - def serialize(fn) -> bytes: + def serialize(fn: Callable[..., Any]) -> bytes: assert hasattr(fn, "serialize"), "fn must have serialize method" with torch._functorch.config.patch("bundled_autograd_cache", True): entry = fn.serialize() diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index b35c192dfd23..dda653c5f788 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -358,7 +358,10 @@ def is_applicable_for_range(self, compile_range: Range) -> bool: ): return True tp_size = get_tensor_model_parallel_world_size() - return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0) + result: bool = (compile_range.is_single_size()) and ( + compile_range.end % tp_size == 0 + ) + return result @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph) -> None: