diff --git a/vllm/compilation/passes/pass_manager.py b/vllm/compilation/passes/pass_manager.py index 5d4355a5b2b4..cd9a6bcfc370 100644 --- a/vllm/compilation/passes/pass_manager.py +++ b/vllm/compilation/passes/pass_manager.py @@ -19,9 +19,6 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass if rocm_aiter_ops.is_enabled(): - from .fusion.allreduce_rms_fusion import ( - RocmAiterAllReduceFusionPass, - ) from .fusion.rocm_aiter_fusion import ( MLADualRMSNormFusionPass, RocmAiterRMSNormQuantFusionPass, @@ -142,11 +139,8 @@ def configure(self, config: VllmConfig) -> None: if self.pass_config.fuse_gemm_comms: self.passes += [AsyncTPPass(config)] - if self.pass_config.fuse_allreduce_rms: - if rocm_aiter_ops.is_enabled(): - self.passes += [RocmAiterAllReduceFusionPass(config)] - else: - self.passes += [AllReduceFusionPass(config)] + if self.pass_config.fuse_allreduce_rms and current_platform.is_cuda(): + self.passes += [AllReduceFusionPass(config)] if self.pass_config.fuse_minimax_qk_norm: self.passes += [MiniMaxQKNormPass(config)]