diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index b613d4424ee3..b6a1314af9ef 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -729,14 +729,26 @@ def __init__(self, config: VllmConfig) -> None: scope="global", ) - self.workspace = flashinfer_comm.create_allreduce_fusion_workspace( - backend="trtllm", - world_size=self.tp_size, - rank=rank, - max_token_num=self.max_token_num, - hidden_dim=self.hidden_dim, - dtype=self.model_dtype, - ) + try: + self.workspace = flashinfer_comm.create_allreduce_fusion_workspace( + backend="trtllm", + world_size=self.tp_size, + rank=rank, + max_token_num=self.max_token_num, + hidden_dim=self.hidden_dim, + dtype=self.model_dtype, + ) + except RuntimeError as e: + if "multicast" not in str(e).lower(): + raise + logger.warning_once( + "AllReduce fusion pass is disabled: flashinfer workspace " + "creation failed: %s. This is expected on GPUs without " + "NVSwitch (e.g., NVLink bridge-only or PCIe topologies). " + "Falling back to non-fused allreduce.", + str(e), + ) + return global _FI_WORKSPACE _FI_WORKSPACE = self.workspace