Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions vllm/env_override.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,3 +634,120 @@ def _patch_fxgraphcache_pickle_if_needed():


_patch_fxgraphcache_pickle_if_needed()

# ===================================================
# torch 2.11 Inductor cpp codegen indirect_assert scalar-mask fix
# ===================================================
# CppVecKernel.indirect_assert wraps a scalar mask with
# `VecMask<...>(scalar)`, which is not a valid constructor and triggers a
# C++ compile error during torch.compile of any model that does indirect
# indexing inside a tail-vectorized loop (e.g. Qwen3-VL).
# Failure looks like:
# no matching function for call to 'VecMask<int64_t,2>::VecMask(int&)'
# Upstream fix in PyTorch mainline replaces the call with
# `VecMask<...>::from(scalar)`, see pytorch/pytorch#178148 (lands in 2.12).
# This is a thin backport for torch >= 2.11 and < 2.12; remove once the
# minimum supported torch is 2.12.


def _apply_cpp_indirect_assert_patch():
"""Replace CppVecKernel.indirect_assert with a fixed copy that uses
`VecMask<...>::from(scalar)` for scalar masks.

Idempotent: marks the class with `_vllm_indirect_assert_patched` after
the first apply.
"""
from torch._inductor.codegen.cpp import CppVecKernel

if getattr(CppVecKernel, "_vllm_indirect_assert_patched", False):
return

from torch._inductor.codegen.cpp import CppCSEVariable, cexpr_index

def patched_indirect_assert(self, var, lower, upper, mask=None):
assert isinstance(var, CppCSEVariable)
assert var.dtype is not None
if not var.is_vec:
if isinstance(mask, CppCSEVariable) and mask.is_vec:
mask = f"({mask}).all_masked()"
return super(CppVecKernel, self).indirect_assert(var, lower, upper, mask)
lower_scalar = lower
upper_scalar = upper
if lower:
lower = f"{self._get_vec_type(var.dtype)}({lower})"
if upper:
upper = f"{self._get_vec_type(var.dtype)}({upper})"
if lower and upper:
cond = f"({lower} <= {var}) & ({var} < {upper})"
cond_print = f"{lower_scalar} <= {var} < {upper_scalar}"
elif lower:
cond = f"{lower} <= {var}"
cond_print = f"{lower_scalar} <= {var}"
else:
assert upper
cond = f"{var} < {upper}"
cond_print = f"{var} < {upper_scalar}"
cond = f"{self._get_mask_type(var.dtype)}({cond})"
if mask:
if not mask.is_vec:
# Backport of pytorch/pytorch#178148 -- use ::from for
# scalar masks so g++ picks the correct overload.
mask = f"{self._get_mask_type(var.dtype)}::from({mask})"
cond = f"({cond}) | ~({mask})"
if self.tail_size:
cond = (
f"{self._get_mask_type(var.dtype)}::set("
f"{self._get_mask_type(var.dtype)}::from(1)"
f", ({cond}), {cexpr_index(self.tail_size)})"
)
cond = f"({cond}).all_masked()"
return f'{self.assert_function}({cond}, "index out of bounds: {cond_print}")'

CppVecKernel.indirect_assert = patched_indirect_assert
CppVecKernel._vllm_indirect_assert_patched = True # type: ignore[attr-defined]


def _patch_cpp_indirect_assert_if_needed():
"""Apply cpp codegen indirect_assert backport when on torch 2.11.x.

Defers application until torch._inductor.codegen.cpp is naturally
imported by Inductor. Importing it eagerly during vllm.__init__ pulls
in torch._inductor.scheduler, whose top-level
`import torch._inductor.async_compile` can fail with
`ModuleNotFoundError: import of torch._inductor.async_compile halted;
None in sys.modules` depending on the import order on the runner
(observed in vLLM CPU CI).
"""
if not is_torch_equal_or_newer("2.11.0") or is_torch_equal_or_newer("2.12.0.dev"):
Comment thread
tlrmchlsmth marked this conversation as resolved.
return

import sys

target_name = "torch._inductor.codegen.cpp"
if target_name in sys.modules:
_apply_cpp_indirect_assert_patch()
return

import importlib.abc

class _CppCodegenPatchFinder(importlib.abc.MetaPathFinder):
def find_spec(self, fullname, path, target=None):
if fullname != target_name:
return None
sys.meta_path.remove(self)
spec = importlib.util.find_spec(fullname)
if spec is None or spec.loader is None:
return None
original_exec = spec.loader.exec_module

def _exec_then_patch(module):
original_exec(module)
_apply_cpp_indirect_assert_patch()

spec.loader.exec_module = _exec_then_patch # type: ignore[method-assign]
return spec

sys.meta_path.insert(0, _CppCodegenPatchFinder())


_patch_cpp_indirect_assert_if_needed()
Loading