Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/megatron/bridge/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,13 @@ class DistributedInitConfig:
global parallel state (mpu) variables. When True, parallel groups are obtained from
the pg_collection object rather than the global megatron.core.parallel_state module."""

create_all_gather_group: bool = False
"""Create separate all-gather process groups for AG/RS overlap optimization.
When True, creates separate process groups for all-gather operations, enabling
overlap between all-gather (forward pass) and reduce-scatter (backward pass)
communications in FSDP training. This improves training performance by hiding
communication latency."""


@dataclass
class RerunStateMachineConfig:
Expand Down
42 changes: 42 additions & 0 deletions src/megatron/bridge/training/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def _create_pg_collection(
num_distributed_optimizer_instances: int,
get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
create_all_gather_group: bool = False,
) -> ProcessGroupCollection:
"""Create all process groups via HyperCommGrid and return a ProcessGroupCollection."""
world_size = torch.distributed.get_world_size()
Expand Down Expand Up @@ -499,6 +500,23 @@ def _create_pg_collection(
# combine tp-ep-pp ranks across the intra-partial DP slice.
intra_dist_opt_pg = expert_grid.create_pg(["tp", "ep", inner_dp_dim, "pp"])

# Create all-gather groups for AG/RS overlap if requested
dp_cp_ag_pg = None
expt_dp_ag_pg = None
if create_all_gather_group:
# Create regular DP all-gather group with same ranks as dp_cp_pg
# Use HyperCommGrid to enumerate ranks for dp-cp groups
dp_cp_rank_lists = grid._gen_rank_enum(["dp", "cp"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why "grid.create_pg(...)" not working? ideally shouldn't use internal api here, bit risky.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can tell, grid.create_pg(["dp", "cp"]) can't be used here because it was already called on line 415 to create dp_cp_pg.

If I understand correctly, calling it again would raise a KeyError — create_pg keys by dimension names only (line 151 of hyper_comm_grid.py), so any second call with ["dp", "cp"] would collide with the existing "dp-cp" entry, regardless of group_desc or pg_options.

The AG group needs the same ranks but as an independent NCCL communicator, so I used _gen_rank_enum to get the rank lists and passed them to new_subgroups_by_enumeration directly. Same situation for the expert grid on line 516.

That said, I'm not super familiar with the HyperCommGrid internals — happy to refactor if there's a preferred way to create a second PG with the same rank topology ?

if dp_cp_rank_lists:
dp_cp_ag_pg, _ = torch.distributed.new_subgroups_by_enumeration(dp_cp_rank_lists, backend="nccl")

# Create expert DP all-gather group if expert parallelism is enabled
if ep_size > 1:
# Use expert grid to enumerate ranks for expert dp groups
expt_dp_rank_lists = expert_grid._gen_rank_enum(dp_group_dims)
if expt_dp_rank_lists:
expt_dp_ag_pg, _ = torch.distributed.new_subgroups_by_enumeration(expt_dp_rank_lists, backend="nccl")

# Build ProcessGroupCollection with available groups.
pg_collection = ProcessGroupCollection(
tp=tp_pg,
Expand All @@ -522,6 +540,13 @@ def _create_pg_collection(
inter_dist_opt=inter_dist_opt_pg,
intra_dist_opt=intra_dist_opt_pg,
)

# Add AG groups to ProcessGroupCollection if created
if create_all_gather_group:
pg_collection.dp_cp_ag = dp_cp_ag_pg
if expt_dp_ag_pg is not None:
pg_collection.expt_dp_ag = expt_dp_ag_pg

return pg_collection


Expand Down Expand Up @@ -602,6 +627,7 @@ def _initialize_distributed(
num_distributed_optimizer_instances,
get_embedding_ranks=get_embedding_ranks,
get_position_embedding_ranks=get_position_embedding_ranks,
create_all_gather_group=dist_config.create_all_gather_group,
)
if get_rank_safe() == 0:
tp = int(model_config.tensor_model_parallel_size)
Expand Down Expand Up @@ -644,6 +670,22 @@ def _initialize_distributed(
f"> initialized pipeline model parallel with size "
f"{parallel_state.get_pipeline_model_parallel_world_size()}"
)

# Create AG groups if requested
if dist_config.create_all_gather_group:
for_expert_parallelism = (getattr(model_config, "expert_model_parallel_size", 1) or 1) > 1
dp_cp_ag, expt_dp_ag = parallel_state.create_all_gather_groups(
for_expert_parallelism=for_expert_parallelism,
timeout=datetime.timedelta(minutes=dist_config.distributed_timeout_minutes),
nccl_comm_cfgs=None, # Could use dist_config.nccl_communicator_config_path if needed
)
# Get ProcessGroupCollection and populate with AG groups
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
pg_collection.dp_cp_ag = dp_cp_ag
if expt_dp_ag is not None:
pg_collection.expt_dp_ag = expt_dp_ag
return pg_collection
Comment on lines +674 to +687
Copy link
Copy Markdown
Contributor

@cspades cspades Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: Adding this to the PG collection will be passed to the Megatron-FSDP DDP FullyShardedDataParallel adapter which then passes it to the FSDPDistIndex / MegatronFSDP class API.


# Return a ProcessGroupCollection using mpu process groups
return ProcessGroupCollection.use_mpu_process_groups()

Expand Down
Loading