Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions megatron/core/distributed/fsdp/mcore_fsdp_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,12 @@ def _init_dist_index(self, pg_collection):
single_rank_group = dist.new_group(ranks=[dist.get_rank()])
expt_tp_group = single_rank_group

# Extract AG groups from pg_collection for explicit passing
dp_cp_ag = getattr(pg_collection, 'dp_cp_ag', None) if pg_collection is not None else None
expt_dp_ag = (
getattr(pg_collection, 'expt_dp_ag', None) if pg_collection is not None else None
)

if enable_hsdp:
if expt_dp_group is not None:
expt_mesh = _get_hsdp_tp_mesh(
Expand Down Expand Up @@ -311,6 +317,8 @@ def _init_dist_index(self, pg_collection):
hybrid_fsdp_group=hybrid_fsdp_group,
hybrid_fsdp_expt_group=hybrid_fsdp_expt_group,
expt_device_mesh=expt_device_mesh,
fsdp_group_ag=dp_cp_ag,
expt_fsdp_group_ag=expt_dp_ag,
)
else:
if ep_group is not None:
Expand All @@ -335,6 +343,8 @@ def _init_dist_index(self, pg_collection):
dp_shard_dim="dp_cp",
tp_dim="tp",
expt_device_mesh=expt_device_mesh,
fsdp_group_ag=dp_cp_ag,
expt_fsdp_group_ag=expt_dp_ag,
)

self.tp_group = tp_group
Expand Down
20 changes: 20 additions & 0 deletions megatron/core/distributed/fsdp/src/megatron_fsdp/fully_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def fully_shard_model(
hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
hybrid_fsdp_expt_group: Optional[torch.distributed.ProcessGroup] = None,
expt_device_mesh: Optional[DeviceMesh] = None,
fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
expt_fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
fsdp_unit_modules: Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]] = None,
zero_dp_strategy: str | int = 3,
outer_dp_sharding_strategy: str | int = 0,
Expand Down Expand Up @@ -142,6 +144,17 @@ class that schedules the sharding lifecycle of the model parameters and gradient
Expert parallel device mesh object defining the topology for MoE distributed training.
Utilizes the mesh dimension names specified by the *_dim arguments.

fsdp_group_ag (Optional[torch.distributed.ProcessGroup]):
Independent all-gather process group for overlapping all-gather and reduce-scatter
operations. When provided, enables AG/RS overlap optimization for regular (non-expert)
parameters. Users should create this group with the same ranks as the dp-cp group.
Defaults to None.

expt_fsdp_group_ag (Optional[torch.distributed.ProcessGroup]):
Independent all-gather process group for expert parameters in MoE models. When provided,
enables AG/RS overlap optimization for expert parameters. Users should create this group
with the same ranks as the expert data parallel group. Defaults to None.

fsdp_unit_modules (Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]]):
List of (sub-)module classes or (sub-)module class import paths that are "units",
which are torch.nn.Module(s) that are sharded and scheduled by Megatron-FSDP.
Expand Down Expand Up @@ -362,6 +375,9 @@ class that schedules the sharding lifecycle of the model parameters and gradient
hsdp_outer_dp_shard=_outer_fsdp_sharding,
# Only required for Megatron-FSDP + EP.
expt_device_mesh=expt_device_mesh,
# AG groups for AG/RS overlap optimization.
fsdp_group_ag=fsdp_group_ag,
expt_fsdp_group_ag=expt_fsdp_group_ag,
)

# Wrap model in Megatron FSDP.
Expand Down Expand Up @@ -621,6 +637,8 @@ def fully_shard(
hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
hybrid_fsdp_expt_group: Optional[torch.distributed.ProcessGroup] = None,
expt_device_mesh: Optional[DeviceMesh] = None,
fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
expt_fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
fsdp_unit_modules: Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]] = None,
zero_dp_strategy: str | int = 3,
outer_dp_sharding_strategy: str | int = 0,
Expand Down Expand Up @@ -669,6 +687,8 @@ def fully_shard(
hybrid_fsdp_group=hybrid_fsdp_group,
hybrid_fsdp_expt_group=hybrid_fsdp_expt_group,
expt_device_mesh=expt_device_mesh,
fsdp_group_ag=fsdp_group_ag,
expt_fsdp_group_ag=expt_fsdp_group_ag,
fsdp_unit_modules=fsdp_unit_modules,
zero_dp_strategy=zero_dp_strategy,
outer_dp_sharding_strategy=outer_dp_sharding_strategy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1691,9 +1691,6 @@ def __init__(
if self.dist_index.get_fsdp_group(is_expert_parallel=True) is not None:
# Expert-DP group when using EP
self.ubr_groups.append(self.dist_index.get_fsdp_group(is_expert_parallel=True))
if self.dist_index.get_outer_fsdp_group() is not None:
# Outer/Inter-FSDP group when using hybrid FSDP
self.ubr_groups.append(self.dist_index.get_outer_fsdp_group())
if (
self.dist_index.get_fsdp_group(
is_expert_parallel=False, independent_all_gather=True
Expand All @@ -1706,6 +1703,19 @@ def __init__(
is_expert_parallel=False, independent_all_gather=True
)
)
if (
self.dist_index.get_fsdp_group(is_expert_parallel=True, independent_all_gather=True)
is not None
):
# Expert all-gather group used when overlapping all-gather and gradient reduction.
self.ubr_groups.append(
self.dist_index.get_fsdp_group(
is_expert_parallel=True, independent_all_gather=True
)
)
if self.dist_index.get_outer_fsdp_group() is not None:
# Outer/Inter-FSDP group when using hybrid FSDP (IB domain, registered last).
self.ubr_groups.append(self.dist_index.get_outer_fsdp_group())

log_single_rank(
logger,
Expand Down Expand Up @@ -2036,14 +2046,14 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params):
is_expert_parallel=group.is_expert_param
)

# When --create-all-gather-group is enabled, use a separate process group for
# all-gather operations (model_weight_buffer) to enable overlap with gradient reduction
# operations (main_grad_buffer). This avoids head-of-line blocking between forward
# all-gather and backward reduce-scatter on the same communicator.
# Use separate process group for all-gather operations (model_weight_buffer)
# to enable overlap with gradient reduction operations (main_grad_buffer).
# This avoids head-of-line blocking between forward all-gather and backward
# reduce-scatter on the same communicator.
model_wbuf_dp_group = main_buf_dp_group
if not group.is_expert_param and not should_create_hfsdp_wbuf_and_gbuf:
if not should_create_hfsdp_wbuf_and_gbuf:
ag_group = self.dist_index.get_fsdp_group(
is_expert_parallel=False, independent_all_gather=True
is_expert_parallel=group.is_expert_param, independent_all_gather=True
)
if ag_group is not None:
model_wbuf_dp_group = ag_group
Expand Down
22 changes: 18 additions & 4 deletions megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,8 @@ def __init__(
hybrid_fsdp_expt_group: Optional[torch.distributed.ProcessGroup] = None,
hsdp_outer_dp_shard: bool = False,
expt_device_mesh: Optional[DeviceMesh] = None,
fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
expt_fsdp_group_ag: Optional[torch.distributed.ProcessGroup] = None,
):
"""
Args:
Expand All @@ -502,6 +504,13 @@ def __init__(
just sharding across dp_shard ranks and replicating across dp_outer ranks.
expt_device_mesh (Optional[DeviceMesh]): The expert parallel device mesh
to use for the DistributedIndex.
fsdp_group_ag (Optional[torch.distributed.ProcessGroup]): Independent all-gather
process group for overlapping all-gather and reduce-scatter operations.
When provided, enables AG/RS overlap optimization for regular (non-expert)
parameters.
expt_fsdp_group_ag (Optional[torch.distributed.ProcessGroup]): Independent all-gather
process group for expert parameters in MoE models. When provided, enables AG/RS
overlap optimization for expert parameters.
"""
# Device mesh arguments.
self.device_mesh = device_mesh
Expand All @@ -525,13 +534,16 @@ def __init__(
if contains_submesh(self.device_mesh, self.dp_shard_dim)
else None
)
# AG group comes from parallel_state, not the mesh
# the purpose of this independent group is to overlap all-gather and gradient reduction.
self.fsdp_group_ag = None
if HAVE_MEGATRON_CORE and parallel_state.has_separate_all_gather_group():
# AG groups: use explicit arguments if provided, otherwise fall back to parallel_state.
if fsdp_group_ag is not None:
self.fsdp_group_ag = fsdp_group_ag
elif HAVE_MEGATRON_CORE and parallel_state.has_separate_all_gather_group():
self.fsdp_group_ag = parallel_state.get_data_parallel_group(
with_context_parallel=True, independent_all_gather=True
)
else:
self.fsdp_group_ag = None
self.expt_fsdp_group_ag = expt_fsdp_group_ag
# Retrieve the outer-FSDP process group from the DeviceMesh.
self.outer_fsdp_group = (
self.device_mesh[self.dp_outer_dim].get_group()
Expand Down Expand Up @@ -683,6 +695,8 @@ def get_fsdp_group(
) -> ProcessGroup:
"""Get the FSDP process group."""
if is_expert_parallel:
if independent_all_gather:
return self.expt_fsdp_group_ag
return self.expt_fsdp_group
if independent_all_gather:
return self.fsdp_group_ag
Expand Down
85 changes: 85 additions & 0 deletions megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,91 @@ def initialize_model_parallel(
_set_global_memory_buffer()


def create_all_gather_groups(for_expert_parallelism=False, timeout=None, nccl_comm_cfgs=None):
"""
Helper function to create all-gather process groups for AG/RS overlap.

Creates separate communicators with the same ranks as data parallel groups
to enable overlapping all-gather operations with reduce-scatter operations.

Args:
for_expert_parallelism (bool): If True, also creates AG group for expert parameters.
timeout (timedelta): Timeout for distributed collectives.
nccl_comm_cfgs (dict): NCCL communicator configurations.

Returns:
tuple: (dp_cp_ag_group, expt_dp_ag_group) where expt_dp_ag_group is None
if for_expert_parallelism=False.

Example:
# After initialize_model_parallel():
dp_cp_ag, expt_dp_ag = parallel_state.create_all_gather_groups(
for_expert_parallelism=True
)

# Add to ProcessGroupCollection:
pg_collection = ProcessGroupCollection.use_mpu_process_groups()
pg_collection.dp_cp_ag = dp_cp_ag
pg_collection.expt_dp_ag = expt_dp_ag
"""
if not is_initialized():
raise RuntimeError(
"create_all_gather_groups() requires parallel state to be initialized. "
"Call initialize_model_parallel() first."
)

rank = torch.distributed.get_rank()
pp_size = get_pipeline_model_parallel_world_size()
cp_size = get_context_parallel_world_size()
tp_size = get_tensor_model_parallel_world_size()
ep_size = get_expert_model_parallel_world_size()
dp_size = get_data_parallel_world_size()

# Create regular DP all-gather group
dp_cp_ag_group = None
decoder_rank_gen = RankGenerator(
tp=tp_size, ep=1, dp=dp_size, pp=pp_size, cp=cp_size, order='tp-cp-ep-dp-pp', rank_offset=0
)
Comment on lines +1403 to +1405
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: order and rank_offset are hardcoded here, but initialize_model_parallel() accepts them as parameters. If a user passes a different order (e.g., 'tp-dp-pp-cp-ep') or non-zero rank_offset, this RankGenerator will produce different rank lists than the actual DP groups, silently creating AG groups with wrong membership.

The same issue applies to the expert RankGenerator below (line 1423-1431).

A simpler and more robust approach would be to get the ranks directly from the already-created groups (which is what the tests in this PR already do):

# Regular DP AG group
dp_cp_group = get_data_parallel_group(with_context_parallel=True)
all_dp_cp_ranks = get_data_parallel_group_ranks(with_context_parallel=True)
dp_cp_ag_group = None
for ranks_with_cp in all_dp_cp_ranks:  # or iterate all groups
    group_with_cp_ag = create_group(
        ranks_with_cp,
        timeout=timeout,
        pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs or {}),
        group_desc='DATA_PARALLEL_GROUP_WITH_CP_AG',
    )
    if rank in ranks_with_cp:
        dp_cp_ag_group = group_with_cp_ag

Or even simpler — since create_group is a collective, you could collect the ranks from _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP (the global that stores all dp-cp rank lists). This avoids re-deriving ranks entirely and guarantees consistency with the initialized state.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeffnvidia WDYT?


for ranks_with_cp in decoder_rank_gen.get_ranks('dp-cp'):
group_with_cp_ag = create_group(
ranks_with_cp,
timeout=timeout,
pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs or {}),
group_desc='DATA_PARALLEL_GROUP_WITH_CP_AG',
)
if rank in ranks_with_cp:
dp_cp_ag_group = group_with_cp_ag

# Create expert DP all-gather group if requested
expt_dp_ag_group = None
if for_expert_parallelism and ep_size > 1:
expert_tp_size = get_expert_tensor_parallel_world_size()
expert_dp_size = get_expert_data_parallel_world_size()

expert_rank_gen = RankGenerator(
tp=expert_tp_size,
ep=ep_size,
dp=expert_dp_size,
pp=pp_size,
cp=1,
order='tp-cp-ep-dp-pp',
rank_offset=0,
)

for expert_dp_ranks in expert_rank_gen.get_ranks('dp'):
expert_dp_ag = create_group(
expert_dp_ranks,
timeout=timeout,
pg_options=get_nccl_options("ep_dp", nccl_comm_cfgs or {}),
group_desc='EXPERT_DATA_PARALLEL_GROUP_AG',
)
if rank in expert_dp_ranks:
expt_dp_ag_group = expert_dp_ag

return dp_cp_ag_group, expt_dp_ag_group


def is_initialized():
"""Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is not None
Expand Down
8 changes: 8 additions & 0 deletions megatron/core/process_groups_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,19 @@ class ProcessGroupCollection:
# _DATA_PARALLEL_GROUP_WITH_CP
dp_cp: torch.distributed.ProcessGroup = field(init=False)

# _DATA_PARALLEL_GROUP_WITH_CP_AG
dp_cp_ag: torch.distributed.ProcessGroup = field(init=False)

# MoE layers need expt_dp group for sharded state dict
# we need this workaround until distributed checkpoint is refactored
# to have sharded_state_dict can take the PG and pass it down
# TODO (Hepteract): remove this once distributed checkpoint is refactored
# _EXPERT_DATA_PARALLEL_GROUP
expt_dp: torch.distributed.ProcessGroup = field(init=False)

# _EXPERT_DATA_PARALLEL_GROUP_AG
expt_dp_ag: torch.distributed.ProcessGroup = field(init=False)

# _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP
intra_dp_cp: torch.distributed.ProcessGroup = field(init=False)

Expand Down Expand Up @@ -210,6 +216,7 @@ def use_mpu_process_groups(cls, required_pgs: Optional[List[str]] = None):
),
'dp': parallel_state.get_data_parallel_group,
'dp_cp': partial(parallel_state.get_data_parallel_group, with_context_parallel=True),
'dp_cp_ag': lambda: None,
'intra_dp_cp': partial(
parallel_state.get_data_parallel_group,
with_context_parallel=True,
Expand All @@ -232,6 +239,7 @@ def use_mpu_process_groups(cls, required_pgs: Optional[List[str]] = None):
'expt_dp': partial(
parallel_state.get_expert_data_parallel_group, check_initialized=False
),
'expt_dp_ag': lambda: None,
'tp_dp_cp': partial(
parallel_state.get_tensor_and_data_parallel_group,
check_initialized=False,
Expand Down
9 changes: 7 additions & 2 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2691,8 +2691,13 @@ def _add_distributed_args(parser):
default=False, help='Manually register the FSDP communication buffers to NCCL user buffer.'
'This option is only effective when use-megatron-fsdp and use-nccl-ub is set.')
group.add_argument('--create-all-gather-group', action='store_true',
help='Create a separate process group for all-gather operations '
'to overlap reduce-scatter and all-gather operations.')
help='Enable AG/RS overlap optimization by creating separate '
'all-gather communicators.')
group.add_argument('--megatron-fsdp-pg-collection', action='store_true',
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we exposing this knob supporting both behaviors? Is there something stopping us that I'm missing?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, this seems like an implementation detail, not something the user should have to worry about.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context here should explain the status quo: #3249 (comment)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this either.

default=False,
help='[Experimental] Pass ProcessGroupCollection explicitly to '
'Megatron-FSDP instead of using global parallel_state. '
'Only effective when --use-megatron-fsdp is set.')
group.add_argument('--data-parallel-sharding-strategy', type=str, default='no_shard',
choices=['no_shard', 'optim', 'optim_grads', 'optim_grads_params'],
help='Sharding strategy of data parallelism.')
Expand Down
24 changes: 23 additions & 1 deletion megatron/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ def set_startup_timestamps(program_start=None, main_entry=None):
from megatron.core.parallel_state import (
destroy_global_memory_buffer,
destroy_model_parallel,
update_pg_timeout
update_pg_timeout,
create_all_gather_groups,
)
from megatron.core.inference.symmetric_memory import SymmetricMemoryManager
from megatron.core.inference.unified_memory import create_unified_mempool
Expand Down Expand Up @@ -1317,6 +1318,19 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
if pg_collection is None:
pg_collection = ProcessGroupCollection.use_mpu_process_groups()

if args.create_all_gather_group:
timeout = timedelta(minutes=args.distributed_timeout_minutes) if args.distributed_timeout_minutes else None
dp_cp_ag, expt_dp_ag = create_all_gather_groups(
for_expert_parallelism=(args.expert_model_parallel_size > 1),
timeout=timeout,
)
pg_collection.dp_cp_ag = dp_cp_ag
pg_collection.expt_dp_ag = expt_dp_ag

print_rank_0("> created all-gather process groups for AG/RS overlap")
if expt_dp_ag is not None:
print_rank_0("> including expert parallelism AG group")

if has_nvidia_modelopt:
from megatron.post_training.checkpointing import has_modelopt_state
# [ModelOpt]: Check if the checkpoint is a ModelOpt checkpoint and
Expand Down Expand Up @@ -1486,6 +1500,13 @@ def build_model():
ddp_stream.wait_stream(torch.cuda.current_stream())
# Make ddp_stream start after whatever the default stream already queued
with torch.cuda.stream(ddp_stream):
# To pass kwargs unique to specific DDP classes.
ddp_init_kwargs = {}
if args.use_megatron_fsdp:
if getattr(args, 'megatron_fsdp_pg_collection', False):
# Pass PG collection distributed environment to Megatron-FSDP.
ddp_init_kwargs["pg_collection"] = pg_collection

model = [
DP(
config=config,
Expand All @@ -1494,6 +1515,7 @@ def build_model():
# Turn off bucketing for model_chunk 2 onwards, since communication
# for these model chunks is overlapped with compute anyway.
disable_bucketing=(model_chunk_idx > 0) or args.overlap_param_gather_with_optimizer_step,
**ddp_init_kwargs,
)
for (model_chunk_idx, model_chunk) in enumerate(model)
]
Expand Down
Loading
Loading