Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions megatron/core/pipeline_parallel/bridge_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
comm_dtype: Optional[torch.dtype] = None,
src_module_name: Optional[str] = None,
dest_module_name: Optional[str] = None,
tensor_ndim: int = 3,
):
"""Initialize the bridge communicator between source and destination grids.

Expand All @@ -76,13 +77,21 @@ def __init__(
dim_mapping: Dictionary mapping logical dimensions to tensor axes.
Expected keys: 's' (sequence), 'b' (batch), 'h' (hidden).
Defaults to {'s': 1, 'b': 0, 'h': 2} if None.
tensor_ndim: Number of dimensions in tensors communicated through this
bridge. For 3D tensors (e.g. [S, B, H]), fan-in/fan-out
operates on dim_mapping['b']. For 2D tensors (e.g. [B*S, H]
where batch is folded into dim 0), fan-in/fan-out operates
on dim 0. Default: 3.
"""
self.src_grid = src_grid
self.dest_grid = dest_grid
self.src_module_name = src_module_name
self.dest_module_name = dest_module_name
self.comm_dtype = comm_dtype

assert tensor_ndim in (2, 3), f"tensor_ndim must be 2 or 3, got {tensor_ndim}"
self.tensor_ndim = tensor_ndim

# TODO (ykarnati, pthombre) - CP support will be added in follow up PR.
if 'cp' in self.src_grid.dim_names:
assert self.src_grid.shape[self.src_grid.dim_names.index('cp')] == 1, (
Expand Down Expand Up @@ -157,6 +166,18 @@ def __init__(
self.build_comm_map(self.src_tp_leaders, self.dest_tp_leaders)
dist.barrier()

@property
def _batch_dim(self) -> int:
"""Get the tensor dimension used for fan-in/fan-out (cat/split).

For 3D tensors (e.g. [S, B, H]), this is dim_mapping['b'].
For 2D tensors (e.g. [B*S, H] where batch is folded into the first
dimension), this is 0.
"""
if self.tensor_ndim == 2:
return 0
return self.dim_mapping['b']

@classmethod
def _get_or_create_broadcast_pg(cls, ranks_list: List[List[int]]):
"""Get or create a broadcast PG, caching to avoid duplicate NCCL communicators."""
Expand Down Expand Up @@ -385,7 +406,7 @@ def recv_forward(self) -> torch.Tensor:
f"shape {tensor_to_recv.shape} sum {tensor_to_recv.sum()}"
)
received_tensors_list.append(tensor_to_recv)
aggregated_tensor = torch.cat(received_tensors_list, dim=self.dim_mapping['b'])
aggregated_tensor = torch.cat(received_tensors_list, dim=self._batch_dim)
logging.debug(
f"[Bridge Communicator] [receive_forward] Rank {self.current_rank} "
f"broadcasting tensor {aggregated_tensor.shape} sum {aggregated_tensor.sum()}"
Expand All @@ -409,7 +430,9 @@ def recv_forward(self) -> torch.Tensor:
and self.current_rank in self.dest_grid_broadcast_ranks
):
# Non-leader rank - participate in broadcast
shape_tensor = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
shape_tensor = torch.empty(
(self.tensor_ndim,), device=torch.cuda.current_device(), dtype=torch.int64
)
dist.broadcast(
shape_tensor, src=self.dest_local_leader_rank, group=self.dest_grid_broadcast_pg
)
Expand Down Expand Up @@ -514,7 +537,7 @@ def recv_backward(self) -> torch.Tensor:
received_gradients_list.append(grad_tensor)

# Concatenate received gradients
aggregated_gradient = torch.cat(received_gradients_list, dim=self.dim_mapping['b'])
aggregated_gradient = torch.cat(received_gradients_list, dim=self._batch_dim)
logging.debug(
f"[Bridge Communicator] [receive_backward] Rank {self.current_rank} "
f"agg grad shape {aggregated_gradient.shape} sum {aggregated_gradient.sum()}"
Expand All @@ -536,7 +559,9 @@ def recv_backward(self) -> torch.Tensor:
):
# Non-leader rank - participate in gather for gradients
# Receive broadcasted tensor shape from leader rank
shape_tensor = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
shape_tensor = torch.empty(
(self.tensor_ndim,), device=torch.cuda.current_device(), dtype=torch.int64
)
dist.broadcast(
shape_tensor, src=self.src_local_leader_rank, group=self.src_grid_broadcast_pg
)
Expand Down Expand Up @@ -635,7 +660,7 @@ def send_forward_recv_backward(
req.wait()

# Concatenate received gradients
aggregated_gradient = torch.cat(received_gradients_list, dim=self.dim_mapping['b'])
aggregated_gradient = torch.cat(received_gradients_list, dim=self._batch_dim)
logging.debug(
f"[Bridge Communicator] [send_forward_recv_backward] Rank {self.current_rank} "
f"agg grad shape {aggregated_gradient.shape} sum {aggregated_gradient.sum()}"
Expand All @@ -661,7 +686,9 @@ def send_forward_recv_backward(
):
# participate in both gather for gradients
# Receive gradient from leader using broadcast
shape_tensor = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
shape_tensor = torch.empty(
(self.tensor_ndim,), device=torch.cuda.current_device(), dtype=torch.int64
)
dist.broadcast(
shape_tensor, src=self.src_local_leader_rank, group=self.src_grid_broadcast_pg
)
Expand Down Expand Up @@ -757,9 +784,7 @@ def send_backward_recv_forward(
req.wait()

# Concatenate received activations
aggregated_activation = torch.cat(
received_activations_list, dim=self.dim_mapping['b']
)
aggregated_activation = torch.cat(received_activations_list, dim=self._batch_dim)
logging.debug(
f"[Bridge Communicator] [send_backward_recv_forward] Rank {self.current_rank} "
f"agg act shape {aggregated_activation.shape} sum {aggregated_activation.sum()}"
Expand All @@ -784,7 +809,9 @@ def send_backward_recv_forward(
rank_info.role == CommRole.MEMBER
and self.current_rank in self.dest_grid_broadcast_ranks
):
shape_tensor = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
shape_tensor = torch.empty(
(self.tensor_ndim,), device=torch.cuda.current_device(), dtype=torch.int64
)
dist.broadcast(
shape_tensor, src=self.dest_local_leader_rank, group=self.dest_grid_broadcast_pg
)
Expand Down Expand Up @@ -865,7 +892,7 @@ def _communicate_shapes(
if recv_next:
for dest_rank in rank_info.send_to_ranks:
grad_shape_tensor = torch.empty(
(3), device=torch.cuda.current_device(), dtype=torch.int64
(self.tensor_ndim,), device=torch.cuda.current_device(), dtype=torch.int64
)
recv_grad_shape_tensors.append(grad_shape_tensor)
ops.append(
Expand All @@ -879,7 +906,7 @@ def _communicate_shapes(
if recv_prev:
for src_rank in rank_info.recv_from_ranks:
forward_shape_tensor = torch.empty(
(3), device=torch.cuda.current_device(), dtype=torch.int64
(self.tensor_ndim,), device=torch.cuda.current_device(), dtype=torch.int64
)
recv_forward_shape_tensors.append(forward_shape_tensor)
ops.append(
Expand Down Expand Up @@ -935,7 +962,6 @@ def _split_tensor_at_batch_dim(
if num_splits <= 0:
raise ValueError(f"num_splits must be positive, got {num_splits}")

batch_dim = self.dim_mapping['b']
splits = torch.tensor_split(aggregated_tensor, num_splits, dim=batch_dim)
splits = torch.tensor_split(aggregated_tensor, num_splits, dim=self._batch_dim)
# PyTorch p2p requires the tensors to be contiguous
return [split.contiguous() for split in splits]
47 changes: 29 additions & 18 deletions megatron/core/pipeline_parallel/multimodule_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,12 @@ class RankModuleInfo:
def _prepare_tensor_for_comm(
tensor: Union[torch.Tensor, List[torch.Tensor], None]
) -> Union[torch.Tensor, List[torch.Tensor], None]:
"""Prepare tensor for P2P/bridge communication by expanding to 3D if needed.
"""Prepare tensor for P2P communication by expanding to 3D if needed.

P2P and bridge communicators expect 3D tensors. 2D tensors are unsqueezed by adding
Only used for intra-module P2P paths. Bridge communicators handle 2D/3D
tensors natively via tensor_ndim and do not need this adapter.

P2P communicators expect 3D tensors. 2D tensors are unsqueezed by adding
a singleton last dimension, and _restore_tensor_from_comm will squeeze it back. 3D
tensors are passed through unchanged.

Expand Down Expand Up @@ -81,7 +84,10 @@ def _prepare_tensor_for_comm(
def _restore_tensor_from_comm(
tensor: Union[torch.Tensor, List[torch.Tensor], None]
) -> Union[torch.Tensor, List[torch.Tensor], None]:
"""Restore tensor shape after P2P/bridge communication by squeezing singleton dim.
"""Restore tensor shape after P2P communication by squeezing singleton dim.

Only used for intra-module P2P paths. Bridge communicators handle 2D/3D
tensors natively via tensor_ndim and do not need this adapter.

Removes the extra dimension added by _prepare_tensor_for_comm if it was singleton.
Handles both single tensors and lists of tensors (for VPP).
Expand Down Expand Up @@ -110,6 +116,7 @@ def __init__(
topology: Dict[str, List[str]],
config: ModelParallelConfig,
dim_mapping: Dict[str, List[int]] = None,
module_output_ndim: Optional[Dict[str, int]] = None,
):
"""
Initialize the MultiModulePipelineCommunicator.
Expand All @@ -136,11 +143,19 @@ def __init__(
Example:
dim_mapping = {'s': 0, 'h': 2, 'b': 1}
Default: None
module_output_ndim (Dict[str, int]): Number of dimensions for each module's
output tensor. Used by bridge communicators for cross-module fan-in/fan-out.
Modules producing 2D tensors [B*S, H] (e.g. vision encoders) should be 2.
Modules not listed default to 3.
Example:
module_output_ndim = {'image_encoder': 2, 'llm': 3}
Default: None (all modules assumed 3D)
"""
self.module_to_grid_map = module_to_grid_map
self.topology = topology
self.config = config
self.dim_mapping = dim_mapping
self.module_output_ndim = module_output_ndim or {}
self.current_rank = dist.get_rank()

# Build bridge communicators for all modules
Expand All @@ -164,6 +179,7 @@ def _build_bridge_comms(self):
comm_dtype=self.config.pipeline_dtype,
src_module_name=src_module_name,
dest_module_name=dest_module_name,
tensor_ndim=self.module_output_ndim.get(src_module_name, 3),
)
self.bridge_comms.append(bridge_comm)

Expand Down Expand Up @@ -327,11 +343,10 @@ def recv_forward(
# from incoming modules.
for bridge_comm in rank_module_info.bridge_comms_as_dest_module:
received_tensor = bridge_comm.recv_forward()
input_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm(
received_tensor
)
input_dict[bridge_comm.src_module_name] = received_tensor
else:
# If not first stage, receive forward activation tensor from P2P communicator.
# P2P hardcodes 3D shape buffers, so use adapter for 2D tensors.
received_tensor = rank_module_info.p2p_communicator.recv_forward(
tensor_shapes=tensor_shape, is_first_stage=False
)
Expand All @@ -349,8 +364,7 @@ def send_forward(self, output_dict: Dict[str, torch.Tensor], is_last_stage: bool
# If last stage, and has outgoing modules, send forward activation
# by using bridge communicator.
for bridge_comm in rank_module_info.bridge_comms_as_src_module:
tensor_to_send = _prepare_tensor_for_comm(output_dict[module_name])
bridge_comm.send_forward(tensor_to_send)
bridge_comm.send_forward(output_dict[module_name])
else:
# If not last stage, send forward activation by using P2P communicator.
tensor_to_send = _prepare_tensor_for_comm(output_dict[module_name])
Expand All @@ -377,9 +391,8 @@ def send_forward_recv_backward(
# If last stage, and has outgoing modules, send forward activation and
# receive backward gradient by using bridge communicator.
for bridge_comm in rank_module_info.bridge_comms_as_src_module:
tensor_to_send = _prepare_tensor_for_comm(output_dict[module_name])
grad = bridge_comm.send_forward_recv_backward(tensor_to_send)
grad_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm(grad)
grad = bridge_comm.send_forward_recv_backward(output_dict[module_name])
grad_dict[bridge_comm.src_module_name] = grad
else:
# If not last stage, send forward activation and receive backward gradient
# by using P2P communicator.
Expand Down Expand Up @@ -411,11 +424,10 @@ def send_backward_recv_forward(
for bridge_comm in rank_module_info.bridge_comms_as_dest_module:
# If first stage, and has incoming modules, send backward gradient and
# receive forward activation by using bridge communicator.
grad_to_send = _prepare_tensor_for_comm(grad_dict[bridge_comm.src_module_name])
received_tensor = bridge_comm.send_backward_recv_forward(grad_to_send)
input_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm(
received_tensor
received_tensor = bridge_comm.send_backward_recv_forward(
grad_dict[bridge_comm.src_module_name]
)
input_dict[bridge_comm.src_module_name] = received_tensor
else:
# If not first stage, send backward gradient and receive forward activation
# by using P2P communicator.
Expand Down Expand Up @@ -448,7 +460,7 @@ def recv_backward(
# by using bridge communicator.
for bridge_comm in rank_module_info.bridge_comms_as_src_module:
grad = bridge_comm.recv_backward()
grad_dict[bridge_comm.src_module_name] = _restore_tensor_from_comm(grad)
grad_dict[bridge_comm.src_module_name] = grad
else:
# If not last stage, receive backward gradient by using P2P communicator.
grad = rank_module_info.p2p_communicator.recv_backward(
Expand All @@ -468,8 +480,7 @@ def send_backward(self, grad_dict: Dict[str, torch.Tensor], is_first_stage: bool
# If first stage, and has incoming modules, send backward activation
# by using bridge communicator.
for bridge_comm in rank_module_info.bridge_comms_as_dest_module:
grad_to_send = _prepare_tensor_for_comm(grad_dict[bridge_comm.src_module_name])
bridge_comm.send_backward(grad_to_send)
bridge_comm.send_backward(grad_dict[bridge_comm.src_module_name])
else:
# If not first stage, send backward activation by using P2P communicator.
grad_to_send = _prepare_tensor_for_comm(grad_dict[module_name])
Expand Down
Loading
Loading