Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,13 @@ def has_blocked_weights():
self.compilation_config.pass_config.enable_sp = False
self.compilation_config.pass_config.fuse_gemm_comms = False

if self.compilation_config.fast_moe_cold_start is None:
from vllm.utils.torch_utils import HAS_OPAQUE_TYPE

if HAS_OPAQUE_TYPE:
# On torch >= 2.11 the hoisted OpaqueObject approach supersedes
# fast_moe_cold_start, so force it off.
self.compilation_config.fast_moe_cold_start = False
elif self.compilation_config.fast_moe_cold_start is None:
# resolve default behavior: try to be as safe as possible
# this config is unsafe if any spec decoding draft model has a MOE.
# We'll conservatively turn it off if we see spec decoding.
Expand Down
41 changes: 41 additions & 0 deletions vllm/env_override.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,44 @@ def _patch_get_raw_stream_if_needed():

PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
GraphLowering._update_scheduler = _update_scheduler_patched

# ===================================================
# torch 2.11 Inductor constrain_to_fx_strides monkeypatch
# ===================================================
# Patch the inductor's `constrain_to_fx_strides` to handle opaque
# (non-tensor) arguments. The original calls `.stride()` on every FX
# arg's meta value, which crashes on FakeScriptObject (the compile-time
# proxy for hoisted opaque types). The patched version skips args
# whose meta value is not a torch.Tensor.
# Upstream issue: https://github.com/pytorch/pytorch/issues/175973

from vllm.utils.torch_utils import is_torch_equal_or_newer

if is_torch_equal_or_newer("2.11.0.dev"):
import torch._inductor.ir as _ir
import torch._inductor.lowering as _lowering
from torch._inductor.virtualized import V as _V

_orig_constrain = _lowering.constrain_to_fx_strides

def _patched_constrain_to_fx_strides(fx_node, *args, **kwargs):
def apply_constraint(arg, fx_arg):
if isinstance(arg, _ir.IRNode):
meta_val = fx_arg.meta.get("val")
if isinstance(meta_val, torch.Tensor):
stride_order = _ir.get_stride_order(
meta_val.stride(), _V.graph.sizevars.shape_env
)
return _ir.ExternKernel.require_stride_order(arg, stride_order)
return arg
if isinstance(arg, dict):
return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg}
return arg

args = tuple(
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
)
kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
return args, kwargs

_lowering.constrain_to_fx_strides = _patched_constrain_to_fx_strides
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import nullcontext
from typing import TYPE_CHECKING

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -30,6 +31,8 @@
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import (
HAS_OPAQUE_TYPE,
ModuleName,
aux_stream,
current_stream,
direct_register_custom_op,
Expand All @@ -56,13 +59,27 @@ def get_layer_from_name(layer_name: str) -> torch.nn.Module:
return forward_context.no_compile_layers[layer_name]


# On torch >= 2.11, layer_name is a hoisted ModuleName opaque object;
# on older versions it remains a plain str.
if TYPE_CHECKING:
from typing import TypeAlias

_layer_name_type: TypeAlias = str | ModuleName
else:
_layer_name_type = ModuleName if HAS_OPAQUE_TYPE else str


def _resolve_layer_name(layer_name: str | ModuleName) -> str:
return layer_name.value if isinstance(layer_name, ModuleName) else layer_name


def _moe_forward(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: str,
layer_name: _layer_name_type,
) -> torch.Tensor:
layer = get_layer_from_name(layer_name)
layer = get_layer_from_name(_resolve_layer_name(layer_name))
# TODO(bnell): this can be removed after MK migration is complete.
layer.ensure_moe_quant_config_init()
return layer.runner.forward_impl(
Expand All @@ -74,7 +91,7 @@ def _moe_forward_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: str,
layer_name: _layer_name_type,
) -> torch.Tensor:
return torch.empty_like(hidden_states)

Expand All @@ -83,9 +100,9 @@ def _moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: str,
layer_name: _layer_name_type,
) -> tuple[torch.Tensor, torch.Tensor]:
layer = get_layer_from_name(layer_name)
layer = get_layer_from_name(_resolve_layer_name(layer_name))
# TODO(bnell): this can be removed after MK migration is complete.
layer.ensure_moe_quant_config_init()
return layer.runner.forward_impl(
Expand All @@ -97,20 +114,18 @@ def _moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
layer_name: str,
layer_name: _layer_name_type,
) -> tuple[torch.Tensor, torch.Tensor]:
# Output shapes:
# - fused_out: same as hidden_states (routed experts use transformed size)
# - shared_out: same as shared_experts_input if provided, else same as
# hidden_states
# (For latent MoE: shared experts use original hidden_size, not latent size)
fused_out = torch.empty_like(hidden_states)

if shared_experts_input is not None:
shared_out = torch.empty_like(shared_experts_input)
else:
shared_out = torch.empty_like(hidden_states)

return shared_out, fused_out


Expand Down Expand Up @@ -367,7 +382,9 @@ def reduce_and_trunc(x: torch.Tensor, trunc_size: int) -> torch.Tensor:
assert len(trunc_sizes) == 1
return func(states, trunc_sizes[0])

def _encode_layer_name(self) -> str:
def _encode_layer_name(self) -> str | ModuleName:
if HAS_OPAQUE_TYPE:
return ModuleName(self.layer_name)
# Can be unavailable or None in unittests
if (
is_forward_context_available()
Expand Down
35 changes: 35 additions & 0 deletions vllm/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,41 @@ def is_torch_equal(target: str) -> bool:
return Version(importlib.metadata.version("torch")) == Version(target)


HAS_OPAQUE_TYPE = is_torch_equal_or_newer("2.11.0.dev")

if HAS_OPAQUE_TYPE:
from torch._opaque_base import OpaqueBase
else:
OpaqueBase = object # type: ignore[misc, assignment]


class ModuleName(OpaqueBase): # type: ignore[misc]
"""Wraps a module name string for use as a torch opaque type.

When torch >= 2.11, this is registered as a hoisted value-type opaque
object so that torch.compile lifts it as a graph input instead of baking
it as a constant. This avoids per-layer recompilation for MOE ops.
"""

def __init__(self, value: str):
self.value = value

def __eq__(self, other):
return isinstance(other, ModuleName) and self.value == other.value

def __hash__(self):
return hash(self.value)

def __fx_repr__(self):
return (f"ModuleName({self.value!r})", {ModuleName})


if HAS_OPAQUE_TYPE:
from torch._library.opaque_object import register_opaque_type

register_opaque_type(ModuleName, typ="value", hoist=True)


# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform
def supports_xccl() -> bool:
return torch.distributed.is_xccl_available()
Expand Down