diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 5a67415f1030..cb8b54311a0b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -23,6 +23,7 @@ from torch.fx._lazy_graph_module import _use_lazy_graph_module import vllm.envs as envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.compilation.codegen import ( compile_execution_fn, generate_execution_code, @@ -929,20 +930,23 @@ def collect_standalone_compile_artifacts( def configure_post_pass(self) -> None: # TODO proper PassManager? pre_grad_pass_key = "pre_grad_custom_pass" - assert self.pass_key != pre_grad_pass_key - assert pre_grad_pass_key not in self.inductor_config - self.inductor_config[pre_grad_pass_key] = VllmIRInplaceFunctionalizationPass( - self.vllm_config - ) + # Keep the regular pre-grad pass pipeline for other backends. ROCm + # AITER skips this pass to avoid HIP graph replay corruption. + if not rocm_aiter_ops.is_enabled(): + assert self.pass_key != pre_grad_pass_key + assert pre_grad_pass_key not in self.inductor_config + self.inductor_config[pre_grad_pass_key] = VllmIRInplaceFunctionalizationPass( + self.vllm_config + ) - # Make sure pre_grad_custom_pass is not pickled - # as part of AOTAutograd built-in cache key - # TODO(luka) is there a cleaner way to do this - import torch._inductor.config as inductor_config + # Make sure pre_grad_custom_pass is not pickled + # as part of AOTAutograd built-in cache key + # TODO(luka) is there a cleaner way to do this + import torch._inductor.config as inductor_config - ignore = inductor_config._cache_config_ignore_prefix + [pre_grad_pass_key] - assert "_cache_config_ignore_prefix" not in self.inductor_config - self.inductor_config["_cache_config_ignore_prefix"] = ignore + ignore = inductor_config._cache_config_ignore_prefix + [pre_grad_pass_key] + assert "_cache_config_ignore_prefix" not in self.inductor_config + self.inductor_config["_cache_config_ignore_prefix"] = ignore # Configure the (nominally post-grad) pass manager self.pass_manager.configure(self.vllm_config) diff --git a/vllm/compilation/passes/pass_manager.py b/vllm/compilation/passes/pass_manager.py index 5d4355a5b2b4..159a51e7e1db 100644 --- a/vllm/compilation/passes/pass_manager.py +++ b/vllm/compilation/passes/pass_manager.py @@ -18,10 +18,10 @@ from .ir.lowering_pass import VllmIRLoweringPass from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass +if rocm_aiter_ops.is_enabled() or current_platform.is_cuda(): + from .fusion.allreduce_rms_fusion import AllReduceFusionPass + if rocm_aiter_ops.is_enabled(): - from .fusion.allreduce_rms_fusion import ( - RocmAiterAllReduceFusionPass, - ) from .fusion.rocm_aiter_fusion import ( MLADualRMSNormFusionPass, RocmAiterRMSNormQuantFusionPass, @@ -41,7 +41,6 @@ from .utility.split_coalescing import SplitCoalescingPass if current_platform.is_cuda(): - from .fusion.allreduce_rms_fusion import AllReduceFusionPass from .fusion.collective_fusion import AsyncTPPass from .fusion.minimax_qk_norm_fusion import MiniMaxQKNormPass @@ -116,8 +115,11 @@ def __call__(self, graph: fx.Graph) -> None: # DCE handles mutating ops correctly as well. self.ir_lowering(graph) VllmInductorPass.dump_prefix += 1 - self.clone_elimination(graph) - VllmInductorPass.dump_prefix += 1 + # ROCm AITER relies on HIP graph replay; this unsafe pass can alter + # aliases in ways that corrupt replayed decode graphs. + if not rocm_aiter_ops.is_enabled(): + self.clone_elimination(graph) + VllmInductorPass.dump_prefix += 1 # clean up after lowering again self.post_cleanup(graph) @@ -143,10 +145,9 @@ def configure(self, config: VllmConfig) -> None: self.passes += [AsyncTPPass(config)] if self.pass_config.fuse_allreduce_rms: - if rocm_aiter_ops.is_enabled(): - self.passes += [RocmAiterAllReduceFusionPass(config)] - else: - self.passes += [AllReduceFusionPass(config)] + # The ROCm AITER allreduce fusion path can corrupt HIP graph + # replay; use the standard pass when AITER is enabled. + self.passes += [AllReduceFusionPass(config)] if self.pass_config.fuse_minimax_qk_norm: self.passes += [MiniMaxQKNormPass(config)] @@ -205,7 +206,8 @@ def uuid(self) -> str: passes.append(self.post_cleanup.uuid()) passes.append(self.ir_lowering.uuid()) - passes.append(self.clone_elimination.uuid()) + if not rocm_aiter_ops.is_enabled(): + passes.append(self.clone_elimination.uuid()) passes.append(self.post_cleanup.uuid()) passes.append(self.fix_functionalization.uuid()) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 58c49c09dc54..473acb908b28 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -472,7 +472,6 @@ def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None # only cuda uses this function, # so we don't abstract it into the base class maybe_ca_context = nullcontext() - maybe_aiter_context = nullcontext() from vllm.distributed.device_communicators.cuda_communicator import ( CudaCommunicator, ) @@ -483,20 +482,13 @@ def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None if ca_comm is not None: maybe_ca_context = ca_comm.capture() # type: ignore - from vllm._aiter_ops import rocm_aiter_ops - - if rocm_aiter_ops.is_enabled(): - aiter_ar = rocm_aiter_ops.get_aiter_allreduce() - if aiter_ar is not None: - maybe_aiter_context = aiter_ar.capture() # type: ignore - # ensure all initialization operations complete before attempting to # capture the graph on another stream curr_stream = torch.cuda.current_stream() if curr_stream != stream: stream.wait_stream(curr_stream) - with torch.cuda.stream(stream), maybe_ca_context, maybe_aiter_context: + with torch.cuda.stream(stream), maybe_ca_context: yield graph_capture_context def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: