Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions vllm/env_override.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os

import torch
from packaging import version

from vllm.logger import init_logger

Expand All @@ -21,3 +22,120 @@
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
# see https://github.com/vllm-project/vllm/issues/10619
torch._inductor.config.compile_threads = 1


# ========================================
# torch 2.9 Inductor Scheduler monkeypatch
# ========================================
# This change monkeypatches a function in Inductor to work around the following
# bug: https://github.com/vllm-project/vllm/issues/26678
#
# The bug occurs when `use_inductor_graph_partition` is turned on and there
# exists operators inside of `splitting_ops` that have an in-place mutation. In
# vllm, this specifically occurs on the operator
# vllm.unified_attention_with_output. In this case, inductor does not populate
# the inductor IR's `origin_node` field, causing an assertion error when trying
# to access the node's `origin_node` field.
#
# So, we will monkeypatch torch._inductor.scheduler.Scheduler.should_partition
# so that it does not access the inductor IR node's `origin_node` field and just
# returns True if a node is registered as having a custom partition function.
# This is ok for now since vllm's implementation of the custom partition
# functions just return True.
# ========================================


def should_partition_patched(self, node, should_log: bool = False) -> bool:
# This is a patched version of
# torch._inductor.scheduler.Scheduler.should_partition that modifies
# the following piece of code so that we always return True:
# https://github.com/pytorch/pytorch/blob/ecb53078faf86ca1b33277df33b82985675bb011/torch/_inductor/scheduler.py#L4712-L4724
"""Return True if we should partition the inductor graph on this node"""

import torch._inductor.ir as ir
from torch._inductor.scheduler import (
BaseSchedulerNode,
FusedSchedulerNode,
_custom_should_partition_fns,
)
from torch._inductor.utils import (
_unstable_customized_partition_wrapper,
is_cudagraph_unsafe_op,
maybe_log_cudagraph_partition,
)

# Allow users to manually specify if a node should be partitioned
# Can only do this for FallbackKernels
ir_node = node.node
if isinstance(ir_node, ir.FallbackKernel):
operator = ir_node.op_overload
if operator is not None and operator in _custom_should_partition_fns:
return True

# When not using cudagraphs, keep all kernels in the `call` function
# instead of graph partition functions, since graph partition only brings
# benefit to cudagraph
if (
not torch._inductor.config.triton.cudagraphs
and _unstable_customized_partition_wrapper.wrapper is None
):
return True

# avoid duplicating logs when should_partition is called multiple times
# on the same node
def noop_log(msg: str, node: BaseSchedulerNode | None) -> None:
return

log_partition_reason = maybe_log_cudagraph_partition if should_log else noop_log

if isinstance(node, FusedSchedulerNode):
return any(self.should_partition(snode) for snode in node.snodes)

assert node.node is not None

if not node.is_gpu():
log_partition_reason("non gpu ops", node=node)

return True

if isinstance(node.node, ir.DeviceCopy):
log_partition_reason("DeviceCopy ops", node=node)
return True

if isinstance(node.node, ir.Conditional):
log_partition_reason("Conditional ops", node=node)
return True

if getattr(node.node, "unbacked_bindings", None):
log_partition_reason("unbacked binding ops", node=node)
return True

if is_cudagraph_unsafe_op(node.node):
log_partition_reason("CUDAGraph-unsafe custom ops", node=node)
return True

return False


def _update_scheduler_patched(self) -> None:
# Copied from torch._inductor.graph.GrahLowering._update_scheduler. Patches
# this method so that we can patch Scheduler.should_partition with the
# function above
"""
(Re)initializes the scheduler member. When initializing the scheduler, no CUBIN
files should be generated (to avoid biasing any benchmarks and pessimizing
fusion decisions).
"""
import torch._inductor.config as config
from torch._inductor.scheduler import Scheduler

Scheduler.should_partition = should_partition_patched

with config.patch("triton.store_cubin", False):
self.scheduler = Scheduler(self.operations)


if version.parse(str(torch.__version__)) == version.parse("2.9.0"):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not hold..
image

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can fix in your PR, for main this is a no-op, in #26738 I manually made sure to use your approach

from torch._inductor.graph import GraphLowering

GraphLowering._update_scheduler = _update_scheduler_patched
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it work if we just use Scheduler.should_partition = should_partition_patched instead of _update_scheduler_patched?