diff --git a/tests/compile/passes/distributed/test_fusion_all_reduce.py b/tests/compile/passes/distributed/test_fusion_all_reduce.py index d48f22970313..73bad7941c10 100644 --- a/tests/compile/passes/distributed/test_fusion_all_reduce.py +++ b/tests/compile/passes/distributed/test_fusion_all_reduce.py @@ -323,3 +323,102 @@ def all_reduce_fusion_pass_on_test_model( backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) backend.check_after_ops(model.ops_in_model_after()) del all_reduce_fusion_pass + + +@multi_gpu_test(num_gpus=4) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") +@pytest.mark.skip( + reason="Disabled until flashinfer fixes device_idx=tp_rank " + "in SymmDeviceMemory (wrong GPU in DP+TP configurations)", +) +def test_all_reduce_fusion_pass_dp_tp(): + """Test AllReduceFusionPass with DP=2, TP=2 (4 GPUs total). + + Regression test for https://github.com/vllm-project/vllm/issues/34401 + where workspace creation used the global process group instead of the + TP-scoped group, causing NCCL errors in DP+TP configurations. + """ + torch.multiprocessing.spawn( + all_reduce_fusion_pass_on_test_model_dp_tp, + args=(4,), + nprocs=4, + ) + + +def all_reduce_fusion_pass_on_test_model_dp_tp( + local_rank: int, + world_size: int, +): + tp_size = 2 + dtype = torch.bfloat16 + hidden_size = 64 + batch_size = 8 + seq_len = 8 + + set_random_seed(0) + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12346", + } + ) + + init_distributed_environment() + + # Create vllm_config with dp_size=2 BEFORE initialize_model_parallel, + # because initialize_model_parallel reads data_parallel_size from + # the current vllm config. + vllm_config = VllmConfig( + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, custom_ops=[] + ) + ) + vllm_config.compilation_config.pass_config = PassConfig( + fuse_allreduce_rms=True, eliminate_noops=True + ) + vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) + vllm_config.parallel_config.rank = local_rank + vllm_config.parallel_config.data_parallel_size = 2 + + model_name = "RedHatAI/Llama-3.2-1B-Instruct-FP8" + vllm_config.model_config = ModelConfig( + model=model_name, trust_remote_code=True, dtype=dtype, seed=42 + ) + with set_current_vllm_config(vllm_config): + # With dp=2, tp=2 on 4 ranks, this creates: + # TP groups: [0,1], [2,3] + # DP groups: [0,2], [1,3] + initialize_model_parallel(tensor_model_parallel_size=tp_size) + + all_reduce_fusion_pass = AllReduceFusionPass(vllm_config) + noop_pass = NoOpEliminationPass(vllm_config) + func_pass = FixFunctionalizationPass(vllm_config) + cleanup_pass = PostCleanupPass(vllm_config) + + backend = TestBackend( + noop_pass, all_reduce_fusion_pass, func_pass, cleanup_pass + ) + + token_num = batch_size * seq_len + model = TestAllReduceRMSNormModel(hidden_size, token_num) + + hidden_states = torch.randn((token_num, hidden_size), requires_grad=False) + + compiled_model = torch.compile(model, backend=backend) + compiled_model(hidden_states) + + assert all_reduce_fusion_pass.matched_count == 4, ( + f"{all_reduce_fusion_pass.matched_count=}" + ) + backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False) + backend.check_after_ops(model.ops_in_model_after()) + del all_reduce_fusion_pass diff --git a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py index b613d4424ee3..5d21c56505f6 100644 --- a/vllm/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/vllm/compilation/passes/fusion/allreduce_rms_fusion.py @@ -41,11 +41,41 @@ _flashinfer_comm, "create_allreduce_fusion_workspace" ): flashinfer_comm = _flashinfer_comm + from flashinfer.comm.mnnvl import TorchDistBackend + + class _TPCommBackend(TorchDistBackend): + """CommBackend scoped to the TP process group. + + Fixes two flashinfer issues: + 1. TorchDistBackend.bcast passes a group-local root as a + global rank to broadcast_object_list. We use group_src + instead. + 2. IPC socket opIds collide across TP groups because + random.randint produces identical values under vllm's + deterministic seeding. We offset the opId by the + global rank of the group root. + """ + + def __init__(self, group): + super().__init__(group=group) + self._global_root = self._dist.get_global_rank(group, 0) + + def bcast(self, data, root=0): + object_list = [data] + self._dist.broadcast_object_list( + object_list, group_src=root, group=self._group + ) + result = object_list[0] + # Offset opId by global root rank so each TP group + # gets a unique IPC socket path. Only opIds (int) + # flow through bcast in the TRTLLM backend path. + if isinstance(result, int): + result += self._global_root + return result + except ImportError: pass -logger = init_logger(__name__) - if hasattr(torch.ops._C, "scaled_fp4_quant"): STATIC_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant.default @@ -687,6 +717,15 @@ def __init__(self, config: VllmConfig) -> None: if self.tp_size <= 1: logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.") return + if config.parallel_config.data_parallel_size > 1: + # flashinfer uses device_idx=tp_rank in SymmDeviceMemory, + # which maps to the wrong GPU for DP groups > 0. + # See: https://github.com/vllm-project/vllm/issues/34401 + logger.warning_once( + "AllReduce fusion pass is disabled for DP+TP due to " + "a flashinfer device assignment limitation." + ) + return self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="all_reduce_fusion_pass" ) @@ -736,6 +775,7 @@ def __init__(self, config: VllmConfig) -> None: max_token_num=self.max_token_num, hidden_dim=self.hidden_dim, dtype=self.model_dtype, + comm_backend=_TPCommBackend(group=self.group), ) global _FI_WORKSPACE