Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.fx._lazy_graph_module import _use_lazy_graph_module

import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.codegen import (
compile_execution_fn,
generate_execution_code,
Expand Down Expand Up @@ -929,20 +930,23 @@ def collect_standalone_compile_artifacts(
def configure_post_pass(self) -> None:
# TODO proper PassManager?
pre_grad_pass_key = "pre_grad_custom_pass"
assert self.pass_key != pre_grad_pass_key
assert pre_grad_pass_key not in self.inductor_config
self.inductor_config[pre_grad_pass_key] = VllmIRInplaceFunctionalizationPass(
self.vllm_config
)
# Keep the regular pre-grad pass pipeline for other backends. ROCm
# AITER skips this pass to avoid HIP graph replay corruption.
if not rocm_aiter_ops.is_enabled():
assert self.pass_key != pre_grad_pass_key
assert pre_grad_pass_key not in self.inductor_config
self.inductor_config[pre_grad_pass_key] = VllmIRInplaceFunctionalizationPass(
self.vllm_config
)

# Make sure pre_grad_custom_pass is not pickled
# as part of AOTAutograd built-in cache key
# TODO(luka) is there a cleaner way to do this
import torch._inductor.config as inductor_config
# Make sure pre_grad_custom_pass is not pickled
# as part of AOTAutograd built-in cache key
# TODO(luka) is there a cleaner way to do this
import torch._inductor.config as inductor_config

ignore = inductor_config._cache_config_ignore_prefix + [pre_grad_pass_key]
assert "_cache_config_ignore_prefix" not in self.inductor_config
self.inductor_config["_cache_config_ignore_prefix"] = ignore
ignore = inductor_config._cache_config_ignore_prefix + [pre_grad_pass_key]
assert "_cache_config_ignore_prefix" not in self.inductor_config
self.inductor_config["_cache_config_ignore_prefix"] = ignore

# Configure the (nominally post-grad) pass manager
self.pass_manager.configure(self.vllm_config)
Expand Down
24 changes: 13 additions & 11 deletions vllm/compilation/passes/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from .ir.lowering_pass import VllmIRLoweringPass
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass

if rocm_aiter_ops.is_enabled() or current_platform.is_cuda():
from .fusion.allreduce_rms_fusion import AllReduceFusionPass

if rocm_aiter_ops.is_enabled():
from .fusion.allreduce_rms_fusion import (
RocmAiterAllReduceFusionPass,
)
from .fusion.rocm_aiter_fusion import (
MLADualRMSNormFusionPass,
RocmAiterRMSNormQuantFusionPass,
Expand All @@ -41,7 +41,6 @@
from .utility.split_coalescing import SplitCoalescingPass

if current_platform.is_cuda():
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
from .fusion.collective_fusion import AsyncTPPass
from .fusion.minimax_qk_norm_fusion import MiniMaxQKNormPass

Expand Down Expand Up @@ -116,8 +115,11 @@ def __call__(self, graph: fx.Graph) -> None:
# DCE handles mutating ops correctly as well.
self.ir_lowering(graph)
VllmInductorPass.dump_prefix += 1
self.clone_elimination(graph)
VllmInductorPass.dump_prefix += 1
# ROCm AITER relies on HIP graph replay; this unsafe pass can alter
# aliases in ways that corrupt replayed decode graphs.
if not rocm_aiter_ops.is_enabled():
self.clone_elimination(graph)
VllmInductorPass.dump_prefix += 1

# clean up after lowering again
self.post_cleanup(graph)
Expand All @@ -143,10 +145,9 @@ def configure(self, config: VllmConfig) -> None:
self.passes += [AsyncTPPass(config)]

if self.pass_config.fuse_allreduce_rms:
if rocm_aiter_ops.is_enabled():
self.passes += [RocmAiterAllReduceFusionPass(config)]
else:
self.passes += [AllReduceFusionPass(config)]
# The ROCm AITER allreduce fusion path can corrupt HIP graph
# replay; use the standard pass when AITER is enabled.
self.passes += [AllReduceFusionPass(config)]

if self.pass_config.fuse_minimax_qk_norm:
self.passes += [MiniMaxQKNormPass(config)]
Expand Down Expand Up @@ -205,7 +206,8 @@ def uuid(self) -> str:

passes.append(self.post_cleanup.uuid())
passes.append(self.ir_lowering.uuid())
passes.append(self.clone_elimination.uuid())
if not rocm_aiter_ops.is_enabled():
passes.append(self.clone_elimination.uuid())
passes.append(self.post_cleanup.uuid())
passes.append(self.fix_functionalization.uuid())

Expand Down
10 changes: 1 addition & 9 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,6 @@ def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None
# only cuda uses this function,
# so we don't abstract it into the base class
maybe_ca_context = nullcontext()
maybe_aiter_context = nullcontext()
from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator,
)
Expand All @@ -483,20 +482,13 @@ def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None
if ca_comm is not None:
maybe_ca_context = ca_comm.capture() # type: ignore

from vllm._aiter_ops import rocm_aiter_ops

if rocm_aiter_ops.is_enabled():
aiter_ar = rocm_aiter_ops.get_aiter_allreduce()
if aiter_ar is not None:
maybe_aiter_context = aiter_ar.capture() # type: ignore

# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch.cuda.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)

with torch.cuda.stream(stream), maybe_ca_context, maybe_aiter_context:
with torch.cuda.stream(stream), maybe_ca_context:
yield graph_capture_context

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
Expand Down
Loading