diff --git a/megatron/core/pipeline_parallel/bridge_communicator.py b/megatron/core/pipeline_parallel/bridge_communicator.py index 20bb8a5bb73..515ddf1743a 100644 --- a/megatron/core/pipeline_parallel/bridge_communicator.py +++ b/megatron/core/pipeline_parallel/bridge_communicator.py @@ -65,6 +65,7 @@ def __init__( comm_dtype: Optional[torch.dtype] = None, src_module_name: Optional[str] = None, dest_module_name: Optional[str] = None, + tensor_ndim: int = 3, ): """Initialize the bridge communicator between source and destination grids. @@ -76,6 +77,11 @@ def __init__( dim_mapping: Dictionary mapping logical dimensions to tensor axes. Expected keys: 's' (sequence), 'b' (batch), 'h' (hidden). Defaults to {'s': 1, 'b': 0, 'h': 2} if None. + tensor_ndim: Number of dimensions in tensors communicated through this + bridge. For 3D tensors (e.g. [S, B, H]), fan-in/fan-out + operates on dim_mapping['b']. For 2D tensors (e.g. [B*S, H] + where batch is folded into dim 0), fan-in/fan-out operates + on dim 0. Default: 3. """ self.src_grid = src_grid self.dest_grid = dest_grid @@ -83,6 +89,9 @@ def __init__( self.dest_module_name = dest_module_name self.comm_dtype = comm_dtype + assert tensor_ndim in (2, 3), f"tensor_ndim must be 2 or 3, got {tensor_ndim}" + self.tensor_ndim = tensor_ndim + # TODO (ykarnati, pthombre) - CP support will be added in follow up PR. if 'cp' in self.src_grid.dim_names: assert self.src_grid.shape[self.src_grid.dim_names.index('cp')] == 1, ( @@ -157,6 +166,18 @@ def __init__( self.build_comm_map(self.src_tp_leaders, self.dest_tp_leaders) dist.barrier() + @property + def _batch_dim(self) -> int: + """Get the tensor dimension used for fan-in/fan-out (cat/split). + + For 3D tensors (e.g. [S, B, H]), this is dim_mapping['b']. + For 2D tensors (e.g. [B*S, H] where batch is folded into the first + dimension), this is 0. + """ + if self.tensor_ndim == 2: + return 0 + return self.dim_mapping['b'] + @classmethod def _get_or_create_broadcast_pg(cls, ranks_list: List[List[int]]): """Get or create a broadcast PG, caching to avoid duplicate NCCL communicators.""" @@ -385,7 +406,7 @@ def recv_forward(self) -> torch.Tensor: f"shape {tensor_to_recv.shape} sum {tensor_to_recv.sum()}" ) received_tensors_list.append(tensor_to_recv) - aggregated_tensor = torch.cat(received_tensors_list, dim=self.dim_mapping['b']) + aggregated_tensor = torch.cat(received_tensors_list, dim=self._batch_dim) logging.debug( f"[Bridge Communicator] [receive_forward] Rank {self.current_rank} " f"broadcasting tensor {aggregated_tensor.shape} sum {aggregated_tensor.sum()}" @@ -409,7 +430,9 @@ def recv_forward(self) -> torch.Tensor: and self.current_rank in self.dest_grid_broadcast_ranks ): # Non-leader rank - participate in broadcast - shape_tensor = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64) + shape_tensor = torch.empty( + (self.tensor_ndim,), device=torch.cuda.current_device(), dtype=torch.int64 + ) dist.broadcast( shape_tensor, src=self.dest_local_leader_rank, group=self.dest_grid_broadcast_pg ) @@ -514,7 +537,7 @@ def recv_backward(self) -> torch.Tensor: received_gradients_list.append(grad_tensor) # Concatenate received gradients - aggregated_gradient = torch.cat(received_gradients_list, dim=self.dim_mapping['b']) + aggregated_gradient = torch.cat(received_gradients_list, dim=self._batch_dim) logging.debug( f"[Bridge Communicator] [receive_backward] Rank {self.current_rank} " f"agg grad shape {aggregated_gradient.shape} sum {aggregated_gradient.sum()}" @@ -536,7 +559,9 @@ def recv_backward(self) -> torch.Tensor: ): # Non-leader rank - participate in gather for gradients # Receive broadcasted tensor shape from leader rank - shape_tensor = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64) + shape_tensor = torch.empty( + (self.tensor_ndim,), device=torch.cuda.current_device(), dtype=torch.int64 + ) dist.broadcast( shape_tensor, src=self.src_local_leader_rank, group=self.src_grid_broadcast_pg ) @@ -635,7 +660,7 @@ def send_forward_recv_backward( req.wait() # Concatenate received gradients - aggregated_gradient = torch.cat(received_gradients_list, dim=self.dim_mapping['b']) + aggregated_gradient = torch.cat(received_gradients_list, dim=self._batch_dim) logging.debug( f"[Bridge Communicator] [send_forward_recv_backward] Rank {self.current_rank} " f"agg grad shape {aggregated_gradient.shape} sum {aggregated_gradient.sum()}" @@ -661,7 +686,9 @@ def send_forward_recv_backward( ): # participate in both gather for gradients # Receive gradient from leader using broadcast - shape_tensor = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64) + shape_tensor = torch.empty( + (self.tensor_ndim,), device=torch.cuda.current_device(), dtype=torch.int64 + ) dist.broadcast( shape_tensor, src=self.src_local_leader_rank, group=self.src_grid_broadcast_pg ) @@ -757,9 +784,7 @@ def send_backward_recv_forward( req.wait() # Concatenate received activations - aggregated_activation = torch.cat( - received_activations_list, dim=self.dim_mapping['b'] - ) + aggregated_activation = torch.cat(received_activations_list, dim=self._batch_dim) logging.debug( f"[Bridge Communicator] [send_backward_recv_forward] Rank {self.current_rank} " f"agg act shape {aggregated_activation.shape} sum {aggregated_activation.sum()}" @@ -784,7 +809,9 @@ def send_backward_recv_forward( rank_info.role == CommRole.MEMBER and self.current_rank in self.dest_grid_broadcast_ranks ): - shape_tensor = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64) + shape_tensor = torch.empty( + (self.tensor_ndim,), device=torch.cuda.current_device(), dtype=torch.int64 + ) dist.broadcast( shape_tensor, src=self.dest_local_leader_rank, group=self.dest_grid_broadcast_pg ) @@ -865,7 +892,7 @@ def _communicate_shapes( if recv_next: for dest_rank in rank_info.send_to_ranks: grad_shape_tensor = torch.empty( - (3), device=torch.cuda.current_device(), dtype=torch.int64 + (self.tensor_ndim,), device=torch.cuda.current_device(), dtype=torch.int64 ) recv_grad_shape_tensors.append(grad_shape_tensor) ops.append( @@ -879,7 +906,7 @@ def _communicate_shapes( if recv_prev: for src_rank in rank_info.recv_from_ranks: forward_shape_tensor = torch.empty( - (3), device=torch.cuda.current_device(), dtype=torch.int64 + (self.tensor_ndim,), device=torch.cuda.current_device(), dtype=torch.int64 ) recv_forward_shape_tensors.append(forward_shape_tensor) ops.append( @@ -935,7 +962,6 @@ def _split_tensor_at_batch_dim( if num_splits <= 0: raise ValueError(f"num_splits must be positive, got {num_splits}") - batch_dim = self.dim_mapping['b'] - splits = torch.tensor_split(aggregated_tensor, num_splits, dim=batch_dim) + splits = torch.tensor_split(aggregated_tensor, num_splits, dim=self._batch_dim) # PyTorch p2p requires the tensors to be contiguous return [split.contiguous() for split in splits] diff --git a/megatron/core/pipeline_parallel/multimodule_communicator.py b/megatron/core/pipeline_parallel/multimodule_communicator.py index fd1276bcc1c..b2e5682a29d 100644 --- a/megatron/core/pipeline_parallel/multimodule_communicator.py +++ b/megatron/core/pipeline_parallel/multimodule_communicator.py @@ -48,9 +48,12 @@ class RankModuleInfo: def _prepare_tensor_for_comm( tensor: Union[torch.Tensor, List[torch.Tensor], None] ) -> Union[torch.Tensor, List[torch.Tensor], None]: - """Prepare tensor for P2P/bridge communication by expanding to 3D if needed. + """Prepare tensor for P2P communication by expanding to 3D if needed. - P2P and bridge communicators expect 3D tensors. 2D tensors are unsqueezed by adding + Only used for intra-module P2P paths. Bridge communicators handle 2D/3D + tensors natively via tensor_ndim and do not need this adapter. + + P2P communicators expect 3D tensors. 2D tensors are unsqueezed by adding a singleton last dimension, and _restore_tensor_from_comm will squeeze it back. 3D tensors are passed through unchanged. @@ -81,7 +84,10 @@ def _prepare_tensor_for_comm( def _restore_tensor_from_comm( tensor: Union[torch.Tensor, List[torch.Tensor], None] ) -> Union[torch.Tensor, List[torch.Tensor], None]: - """Restore tensor shape after P2P/bridge communication by squeezing singleton dim. + """Restore tensor shape after P2P communication by squeezing singleton dim. + + Only used for intra-module P2P paths. Bridge communicators handle 2D/3D + tensors natively via tensor_ndim and do not need this adapter. Removes the extra dimension added by _prepare_tensor_for_comm if it was singleton. Handles both single tensors and lists of tensors (for VPP). @@ -110,6 +116,7 @@ def __init__( topology: Dict[str, List[str]], config: ModelParallelConfig, dim_mapping: Dict[str, List[int]] = None, + module_output_ndim: Optional[Dict[str, int]] = None, ): """ Initialize the MultiModulePipelineCommunicator. @@ -136,11 +143,19 @@ def __init__( Example: dim_mapping = {'s': 0, 'h': 2, 'b': 1} Default: None + module_output_ndim (Dict[str, int]): Number of dimensions for each module's + output tensor. Used by bridge communicators for cross-module fan-in/fan-out. + Modules producing 2D tensors [B*S, H] (e.g. vision encoders) should be 2. + Modules not listed default to 3. + Example: + module_output_ndim = {'image_encoder': 2, 'llm': 3} + Default: None (all modules assumed 3D) """ self.module_to_grid_map = module_to_grid_map self.topology = topology self.config = config self.dim_mapping = dim_mapping + self.module_output_ndim = module_output_ndim or {} self.current_rank = dist.get_rank() # Build bridge communicators for all modules @@ -164,6 +179,7 @@ def _build_bridge_comms(self): comm_dtype=self.config.pipeline_dtype, src_module_name=src_module_name, dest_module_name=dest_module_name, + tensor_ndim=self.module_output_ndim.get(src_module_name, 3), ) self.bridge_comms.append(bridge_comm) @@ -327,11 +343,10 @@ def recv_forward( # from incoming modules. for bridge_comm in rank_module_info.bridge_comms_as_dest_module: received_tensor = bridge_comm.recv_forward() - input_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm( - received_tensor - ) + input_dict[bridge_comm.src_module_name] = received_tensor else: # If not first stage, receive forward activation tensor from P2P communicator. + # P2P hardcodes 3D shape buffers, so use adapter for 2D tensors. received_tensor = rank_module_info.p2p_communicator.recv_forward( tensor_shapes=tensor_shape, is_first_stage=False ) @@ -349,8 +364,7 @@ def send_forward(self, output_dict: Dict[str, torch.Tensor], is_last_stage: bool # If last stage, and has outgoing modules, send forward activation # by using bridge communicator. for bridge_comm in rank_module_info.bridge_comms_as_src_module: - tensor_to_send = _prepare_tensor_for_comm(output_dict[module_name]) - bridge_comm.send_forward(tensor_to_send) + bridge_comm.send_forward(output_dict[module_name]) else: # If not last stage, send forward activation by using P2P communicator. tensor_to_send = _prepare_tensor_for_comm(output_dict[module_name]) @@ -377,9 +391,8 @@ def send_forward_recv_backward( # If last stage, and has outgoing modules, send forward activation and # receive backward gradient by using bridge communicator. for bridge_comm in rank_module_info.bridge_comms_as_src_module: - tensor_to_send = _prepare_tensor_for_comm(output_dict[module_name]) - grad = bridge_comm.send_forward_recv_backward(tensor_to_send) - grad_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm(grad) + grad = bridge_comm.send_forward_recv_backward(output_dict[module_name]) + grad_dict[bridge_comm.src_module_name] = grad else: # If not last stage, send forward activation and receive backward gradient # by using P2P communicator. @@ -411,11 +424,10 @@ def send_backward_recv_forward( for bridge_comm in rank_module_info.bridge_comms_as_dest_module: # If first stage, and has incoming modules, send backward gradient and # receive forward activation by using bridge communicator. - grad_to_send = _prepare_tensor_for_comm(grad_dict[bridge_comm.src_module_name]) - received_tensor = bridge_comm.send_backward_recv_forward(grad_to_send) - input_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm( - received_tensor + received_tensor = bridge_comm.send_backward_recv_forward( + grad_dict[bridge_comm.src_module_name] ) + input_dict[bridge_comm.src_module_name] = received_tensor else: # If not first stage, send backward gradient and receive forward activation # by using P2P communicator. @@ -448,7 +460,7 @@ def recv_backward( # by using bridge communicator. for bridge_comm in rank_module_info.bridge_comms_as_src_module: grad = bridge_comm.recv_backward() - grad_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm(grad) + grad_dict[bridge_comm.src_module_name] = grad else: # If not last stage, receive backward gradient by using P2P communicator. grad = rank_module_info.p2p_communicator.recv_backward( @@ -468,8 +480,7 @@ def send_backward(self, grad_dict: Dict[str, torch.Tensor], is_first_stage: bool # If first stage, and has incoming modules, send backward activation # by using bridge communicator. for bridge_comm in rank_module_info.bridge_comms_as_dest_module: - grad_to_send = _prepare_tensor_for_comm(grad_dict[bridge_comm.src_module_name]) - bridge_comm.send_backward(grad_to_send) + bridge_comm.send_backward(grad_dict[bridge_comm.src_module_name]) else: # If not first stage, send backward activation by using P2P communicator. grad_to_send = _prepare_tensor_for_comm(grad_dict[module_name]) diff --git a/tests/unit_tests/models/test_mimo_1f1b_schedule.py b/tests/unit_tests/models/test_mimo_1f1b_schedule.py index 1a6b29cc58a..68e43b43443 100644 --- a/tests/unit_tests/models/test_mimo_1f1b_schedule.py +++ b/tests/unit_tests/models/test_mimo_1f1b_schedule.py @@ -506,10 +506,21 @@ def finalize_grads_func(*args, **kwargs): ) communicator = MultiModulePipelineCommunicator( - module_to_grid_map, topology, mimo_model.config, dim_mapping={'s': 0, 'h': 2, 'b': 1} + module_to_grid_map, + topology, + mimo_model.config, + dim_mapping={'s': 0, 'h': 2, 'b': 1}, + module_output_ndim={encoder_name: 2}, ) - # Create data iterator on ranks that need it + # Compute per-rank micro-batch size for asymmetric DP. + # The LLM's MBS is the schedule-level MBS. The encoder's MBS is adjusted + # by the DP ratio so that total work is conserved across the bridge. + llm_mbs = micro_batch_size + encoder_mbs = micro_batch_size * llm_dp // encoder_dp + + # Create data iterator on ranks that need it, with per-role micro-batch size. + # Encoder ranks use encoder_mbs, LLM ranks use llm_mbs. data_iterator = None encoder_needs_data = is_rank_in_grid(encoder_grid) and is_pp_first_stage( encoder_grid.get_pg("pp") @@ -517,7 +528,13 @@ def finalize_grads_func(*args, **kwargs): llm_needs_data = is_rank_in_grid(llm_grid) and ( is_pp_first_stage(llm_grid.get_pg("pp")) or is_pp_last_stage(llm_grid.get_pg("pp")) ) - if encoder_needs_data or llm_needs_data: + if encoder_needs_data and not llm_needs_data: + data_iterator = DataIterator(hidden_size, seq_length, encoder_mbs, vocab_size, encoder_name) + elif llm_needs_data and not encoder_needs_data: + data_iterator = DataIterator(hidden_size, seq_length, llm_mbs, vocab_size, encoder_name) + elif encoder_needs_data and llm_needs_data: + # Colocated: both encoder and LLM on same rank. Use LLM's MBS since + # the LLM drives the schedule. (encoder_dp == llm_dp when colocated) data_iterator = DataIterator( hidden_size, seq_length, micro_batch_size, vocab_size, encoder_name ) @@ -689,3 +706,81 @@ def test_full_pp_8gpu(self): micro_batch_size=2, num_microbatches=4, ) + + def test_fan_in_dp4_to_dp1_llm_tp2_pp2_8gpu(self): + """Fan-in 4→1: Encoder DP=4 → LLM TP=2 PP=2 DP=1, on 8 GPUs. + + High fan-in ratio. Each encoder rank processes MBS=1, bridge concatenates + 4 × [img_seq, H] → [4*img_seq, H]. LLM has both TP and PP. + """ + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + + run_mimo_1f1b_test( + encoder_tp=1, + encoder_pp=1, + encoder_dp=4, + encoder_offset=0, + llm_tp=2, + llm_pp=2, + llm_dp=1, + llm_offset=4, + hidden_size=256, + num_layers=2, + vocab_size=1000, + seq_length=64, + micro_batch_size=4, + num_microbatches=4, + ) + + def test_fan_out_dp1_to_dp4_enc_tp2_pp2_8gpu(self): + """Fan-out 1→4: Encoder TP=2 PP=2 DP=1 → LLM DP=4, on 8 GPUs. + + Encoder has PP and TP. Bridge fan-out splits encoder output into + 4 parts for 4 LLM DP ranks each with MBS=1. + """ + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + + run_mimo_1f1b_test( + encoder_tp=2, + encoder_pp=2, + encoder_dp=1, + encoder_offset=0, + llm_tp=1, + llm_pp=1, + llm_dp=4, + llm_offset=4, + hidden_size=256, + num_layers=2, + vocab_size=1000, + seq_length=64, + micro_batch_size=1, + num_microbatches=4, + ) + + def test_fan_in_dp2_to_dp1_llm_pp3_8gpu(self): + """Fan-in 2→1: Encoder DP=2 → LLM TP=2 PP=3, on 8 GPUs. + + Tests fan-in with deep LLM pipeline (PP=3). The 2D tensor goes through + bridge fan-in then P2P across 3 LLM PP stages. + """ + if self.world_size != 8: + pytest.skip(f"Requires 8 GPUs, got {self.world_size}") + + run_mimo_1f1b_test( + encoder_tp=1, + encoder_pp=1, + encoder_dp=2, + encoder_offset=0, + llm_tp=2, + llm_pp=3, + llm_dp=1, + llm_offset=2, + hidden_size=256, + num_layers=3, + vocab_size=1000, + seq_length=64, + micro_batch_size=2, + num_microbatches=4, + ) diff --git a/tests/unit_tests/pipeline_parallel/test_bridge_communicator.py b/tests/unit_tests/pipeline_parallel/test_bridge_communicator.py index 8c01f59eb29..326ac8b5890 100644 --- a/tests/unit_tests/pipeline_parallel/test_bridge_communicator.py +++ b/tests/unit_tests/pipeline_parallel/test_bridge_communicator.py @@ -485,3 +485,47 @@ def test_get_leader_rank(self, tp, cp, pp, dp, expected_src_leaders, expected_de assert ( dest_leaders == expected_dest_leaders ), f"Dest leaders: Expected {expected_dest_leaders}, got {dest_leaders}" + + def test_2d_fan_in_fwd_bwd(self): + """Fan-in with 2D tensors: 4 src DP ranks → 1 dest DP group, forward + backward.""" + src_grid = create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=4) + dest_grid = create_hypercomm_grid(offset=4, tp=4, cp=1, pp=1, dp=1) + bridge = BridgeCommunicator( + src_grid, + dest_grid, + dim_mapping={'s': 0, 'h': 2, 'b': 1}, + comm_dtype=torch.float32, + tensor_ndim=2, + ) + + rank = dist.get_rank() + if bridge.is_current_rank_in_grid(src_grid): + tensor = torch.full((577, 128), float(rank + 1), device='cuda') + grad = bridge.send_forward_recv_backward(tensor) + assert grad.shape == (577, 128) + else: + grad = torch.randn(577 * 4, 128, device='cuda') + activation = bridge.send_backward_recv_forward(grad) + assert activation.shape == (577 * 4, 128) + + def test_2d_fan_out_fwd_bwd(self): + """Fan-out with 2D tensors: 1 src DP group → 4 dest DP ranks, forward + backward.""" + src_grid = create_hypercomm_grid(offset=0, tp=4, cp=1, pp=1, dp=1) + dest_grid = create_hypercomm_grid(offset=4, tp=1, cp=1, pp=1, dp=4) + bridge = BridgeCommunicator( + src_grid, + dest_grid, + dim_mapping={'s': 0, 'h': 2, 'b': 1}, + comm_dtype=torch.float32, + tensor_ndim=2, + ) + + rank = dist.get_rank() + if bridge.is_current_rank_in_grid(src_grid): + tensor = torch.randn(577 * 4, 128, device='cuda') + grad = bridge.send_forward_recv_backward(tensor) + assert grad.shape == (577 * 4, 128) + else: + grad = torch.full((577, 128), float(rank), device='cuda') + activation = bridge.send_backward_recv_forward(grad) + assert activation.shape == (577, 128)