Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ ignore_missing_imports = true
check_untyped_defs = true
follow_imports = "silent"

[[tool.mypy.overrides]]
module = "vllm.compilation.*"
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note for reviewer: This is not strict = True because this would cause cascading strict imports to other modules from follow_imports when checked with pre-commit only

Instead we default to a few sensible options here

disallow_untyped_defs = true
disallow_incomplete_defs = true
warn_return_any = true
follow_imports = "silent"

[tool.pytest.ini_options]
markers = [
"slow_test",
Expand Down
2 changes: 1 addition & 1 deletion tests/compile/test_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_bad_callable():
pass_manager.configure(config)

with pytest.raises(AssertionError):
pass_manager.add(simple_callable)
pass_manager.add(simple_callable) # type: ignore[arg-type]


# Pass that inherits from InductorPass
Expand Down
40 changes: 37 additions & 3 deletions tools/pre_commit/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@
"vllm/v1/attention/ops",
]

# Directories that should be checked with --strict
STRICT_DIRS = [
"vllm/compilation",
]


def group_files(changed_files: list[str]) -> dict[str, list[str]]:
"""
Expand Down Expand Up @@ -107,11 +112,17 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]:
return file_groups


def is_strict_file(filepath: str) -> bool:
"""Check if a file should be checked with strict mode."""
return any(filepath.startswith(strict_dir) for strict_dir in STRICT_DIRS)


def mypy(
targets: list[str],
python_version: str | None,
follow_imports: str | None,
file_group: str,
strict: bool = False,
) -> int:
"""
Run mypy on the given targets.
Expand All @@ -123,6 +134,7 @@ def mypy(
follow_imports: Value for the --follow-imports option or None to use
the default mypy behavior.
file_group: The file group name for logging purposes.
strict: If True, run mypy with --strict flag.

Returns:
The return code from mypy.
Expand All @@ -132,6 +144,8 @@ def mypy(
args += ["--python-version", python_version]
if follow_imports is not None:
args += ["--follow-imports", follow_imports]
if strict:
args += ["--strict"]
print(f"$ {' '.join(args)} {file_group}")
return subprocess.run(args + targets, check=False).returncode

Expand All @@ -148,9 +162,29 @@ def main():
for file_group, changed_files in file_groups.items():
follow_imports = None if ci and file_group == "" else "skip"
if changed_files:
returncode |= mypy(
changed_files, python_version, follow_imports, file_group
)
# Separate files into strict and non-strict groups
strict_files = [f for f in changed_files if is_strict_file(f)]
non_strict_files = [f for f in changed_files if not is_strict_file(f)]

# Run mypy on non-strict files
if non_strict_files:
returncode |= mypy(
non_strict_files,
python_version,
follow_imports,
file_group,
strict=False,
)

# Run mypy on strict files with --strict flag
if strict_files:
returncode |= mypy(
strict_files,
python_version,
follow_imports,
f"{file_group} (strict)",
strict=True,
)
return returncode


Expand Down
2 changes: 1 addition & 1 deletion vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def make_copy_and_call(
A wrapper function that copies inputs and calls the compiled function
"""

def copy_and_call(*args):
def copy_and_call(*args: Any) -> Any:
list_args = list(args)
for i, index in enumerate(sym_tensor_indices):
runtime_tensor = list_args[index]
Expand Down
23 changes: 12 additions & 11 deletions vllm/compilation/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ class StandaloneCompiledArtifacts:
split on attn)
"""

def __init__(self):
def __init__(self) -> None:
# dict from submodule name to byte hash
self.submodule_bytes = {}
self.submodule_bytes: dict[str, str] = {}
# dict from byte hash to bytes
self.submodule_bytes_store = {}
self.submodule_bytes_store: dict[str, bytes] = {}
# dict from byte hash to loaded module
self.loaded_submodule_store = {}
self.loaded_submodule_store: dict[str, Any] = {}

def insert(self, submod_name: str, shape: str, entry: bytes):
def insert(self, submod_name: str, shape: str, entry: bytes) -> None:
hasher = hashlib.sha256()
hasher.update(entry)
hex_digest = hasher.hexdigest()
Expand Down Expand Up @@ -86,7 +86,7 @@ def get(self, submod_name: str, shape: str) -> bytes:
self.submodule_bytes[f"{submod_name}_{shape}"]
]

def get_loaded(self, submod_name: str, shape: str):
def get_loaded(self, submod_name: str, shape: str) -> Any:
logger.debug(
"getting artifact for submod %s with shape %s",
submod_name,
Expand Down Expand Up @@ -119,7 +119,7 @@ def load_all(self) -> None:

from torch._inductor.standalone_compile import AOTCompiledArtifact

def _load_entry(entry_bytes) -> AOTCompiledArtifact:
def _load_entry(entry_bytes: bytes) -> AOTCompiledArtifact:
entry = pickle.loads(entry_bytes)
return AOTCompiledArtifact.deserialize(entry)

Expand All @@ -132,13 +132,13 @@ def _load_entry(entry_bytes) -> AOTCompiledArtifact:

logger.debug("loaded all %s submodules", self.num_artifacts())

def __getstate__(self):
def __getstate__(self) -> dict[str, dict[str, str] | dict[str, bytes]]:
return {
"submodule_bytes": self.submodule_bytes,
"submodule_bytes_store": self.submodule_bytes_store,
}

def __setstate__(self, state):
def __setstate__(self, state: dict[str, dict[str, Any]]) -> None:
self.submodule_bytes = state["submodule_bytes"]
self.submodule_bytes_store = state["submodule_bytes_store"]
self.loaded_submodule_store = {}
Expand Down Expand Up @@ -387,7 +387,7 @@ def reconstruct_serializable_fn_from_mega_artifact(
standalone_compile_artifacts.load_all()

submod_names = standalone_compile_artifacts.submodule_names()
compiled_callables: dict[str, dict[str, Callable]] = {}
compiled_callables: dict[str, dict[str, Callable[..., Any]]] = {}

for cache_key in standalone_compile_artifacts.submodule_bytes:
submod_name, shape_str = cache_key.rsplit("_", 1)
Expand Down Expand Up @@ -495,9 +495,10 @@ def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
# e.g. exec(). We can't actually check these.
continue
hash_content.append(content)
return safe_hash(
result: str = safe_hash(
"\n".join(hash_content).encode(), usedforsecurity=False
).hexdigest()
return result


def _compute_code_hash(files: set[str]) -> str:
Expand Down
28 changes: 15 additions & 13 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,15 @@

FP8_DTYPE = current_platform.fp8_dtype()

flashinfer_comm: ModuleType | None = None
if find_spec("flashinfer"):
try:
import flashinfer.comm as flashinfer_comm
import flashinfer.comm as _flashinfer_comm

flashinfer_comm: ModuleType | None = ( # type: ignore[no-redef]
flashinfer_comm
if hasattr(flashinfer_comm, "trtllm_allreduce_fusion")
else None
)
if hasattr(_flashinfer_comm, "trtllm_allreduce_fusion"):
flashinfer_comm = _flashinfer_comm
except ImportError:
flashinfer_comm = None # type: ignore[assignment]
else:
flashinfer_comm = None # type: ignore[assignment]
pass

logger = init_logger(__name__)

Expand Down Expand Up @@ -441,7 +437,7 @@ def is_applicable_for_range(self, compile_range: Range) -> bool:
):
return True
tp_size = get_tensor_model_parallel_world_size()
return compile_range.is_single_size() and compile_range.end % tp_size == 0
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is this not already bool haha

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with skip follow_imports mode, it gets treated as Any by type checker - I know it is a bit silly, but imo the benefits of having strict type checking outweighs the extra syntax


@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
Expand Down Expand Up @@ -516,7 +512,7 @@ def call_trtllm_fused_allreduce_norm(
# Get one shot input size limit for the current world size
# for the current device capability
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
device_capability, # type: ignore[arg-type]
device_capability, # type: ignore[arg-type, unused-ignore]
{},
).get(world_size, None)
# Use one shot if no max size is specified
Expand Down Expand Up @@ -666,6 +662,7 @@ def replacement(
) -> tuple[torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
rms_result = torch.empty_like(input)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
Expand Down Expand Up @@ -722,6 +719,7 @@ def pattern(
def replacement(
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
Expand Down Expand Up @@ -800,6 +798,7 @@ def replacement(
residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
Expand Down Expand Up @@ -875,6 +874,7 @@ def replacement(
scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
Expand Down Expand Up @@ -960,6 +960,7 @@ def replacement(
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
residual = torch.zeros_like(input)
result_rms = torch.empty_like(input)
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
Expand Down Expand Up @@ -1055,6 +1056,7 @@ def replacement(
weight: torch.Tensor,
input_global_scale: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert flashinfer_comm is not None, "FlashInfer must be enabled"
allreduce = auto_functionalized(
flashinfer_trtllm_fused_allreduce_norm,
allreduce_in=input,
Expand Down Expand Up @@ -1131,7 +1133,7 @@ def __init__(self, config: VllmConfig) -> None:
)

self.ipc_handles, workspace_tensor = (
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( # type: ignore[misc]
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
tp_rank=rank,
tp_size=self.tp_size,
max_token_num=self.max_token_num,
Expand Down Expand Up @@ -1204,7 +1206,7 @@ def is_applicable_for_range(self, compile_range: Range) -> bool:
if self.disabled:
logger.warning_once("AllReduce fusion pass is disabled.")
return False
return compile_range.end <= self.max_token_num
return bool(compile_range.end <= self.max_token_num)

@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
Expand Down
12 changes: 6 additions & 6 deletions vllm/compilation/compiler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:

def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
:10
]
hash_str: str = safe_hash(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
Comment on lines +204 to +206
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just use the utils for hashing here that we use for other compile_hash functions (I think sha256)

Copy link
Copy Markdown
Contributor Author

@Lucaskabela Lucaskabela Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be the same util I believe (safe_hash)- we are just making the type checker happy by saying it is str type (since we run hexdigest()) and we reformat it due changing the line length

return hash_str

def initialize_cache(
Expand Down Expand Up @@ -319,9 +319,9 @@ class InductorAdaptor(CompilerInterface):

def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
:10
]
hash_str: str = safe_hash(
str(factors).encode(), usedforsecurity=False
).hexdigest()[:10]
return hash_str

def initialize_cache(
Expand Down
Loading