Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1602,6 +1602,18 @@ def __init__(
if self.dist_index.get_outer_fsdp_group() is not None:
# Outer/Inter-FSDP group when using hybrid FSDP
self.ubr_groups.append(self.dist_index.get_outer_fsdp_group())
if (
self.dist_index.get_fsdp_group(
is_expert_parallel=False, independent_all_gather=True
)
is not None
):
# All-gather group used when overlapping all-gather and gradient reduction.
self.ubr_groups.append(
self.dist_index.get_fsdp_group(
is_expert_parallel=False, independent_all_gather=True
)
)

if torch.distributed.get_rank() == 0:
logging.info(
Expand Down Expand Up @@ -1888,6 +1900,18 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params):
is_expert_parallel=group.is_expert_param
)

# When --create-all-gather-group is enabled, use a separate process group for
# all-gather operations (model_weight_buffer) to enable overlap with gradient reduction
# operations (main_grad_buffer). This avoids head-of-line blocking between forward
# all-gather and backward reduce-scatter on the same communicator.
model_wbuf_dp_group = main_buf_dp_group
if not group.is_expert_param and not should_create_hfsdp_wbuf_and_gbuf:
ag_group = self.dist_index.get_fsdp_group(
is_expert_parallel=False, independent_all_gather=True
)
if ag_group is not None:
model_wbuf_dp_group = ag_group

gradient_scaling_factor = (
self.gradient_scaling_factor
if not group.is_expert_param
Expand Down Expand Up @@ -1928,10 +1952,10 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params):
self.ddp_config,
group.params,
is_data_distributed=is_model_weight_buffer_distributed
and main_buf_dp_group.size() > 1,
and model_wbuf_dp_group.size() > 1,
dtype=param_dtype,
device=self.device,
data_parallel_group=main_buf_dp_group,
data_parallel_group=model_wbuf_dp_group,
is_transpose_buffer=False,
temporary_bucket_allocator=self.weight_alloc,
bucket_id=group_id,
Expand Down
20 changes: 19 additions & 1 deletion megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
from importlib.metadata import version
from typing import Callable, Optional, Sequence, Union

try:
import megatron.core.parallel_state as parallel_state

HAVE_MEGATRON_CORE = True
except (ImportError, ModuleNotFoundError):
HAVE_MEGATRON_CORE = False

try:
import einops

Expand Down Expand Up @@ -486,6 +493,13 @@ def __init__(
if contains_submesh(self.device_mesh, self.dp_shard_dim)
else None
)
# AG group comes from parallel_state, not the mesh
# the purpose of this independent group is to overlap all-gather and gradient reduction.
self.fsdp_group_ag = None
if HAVE_MEGATRON_CORE and parallel_state.has_separate_all_gather_group():
self.fsdp_group_ag = parallel_state.get_data_parallel_group(
with_context_parallel=True, independent_all_gather=True
)
# Retrieve the outer-FSDP process group from the DeviceMesh.
self.outer_fsdp_group = (
self.device_mesh[self.dp_outer_dim].get_group()
Expand Down Expand Up @@ -620,10 +634,14 @@ def get_dp_group(self, is_expert_parallel: bool = False) -> ProcessGroup:
return self.hybrid_fsdp_group
return self.fsdp_group

def get_fsdp_group(self, is_expert_parallel: bool = False) -> ProcessGroup:
def get_fsdp_group(
self, is_expert_parallel: bool = False, independent_all_gather: bool = False
) -> ProcessGroup:
"""Get the FSDP process group."""
if is_expert_parallel:
return self.expt_fsdp_group
if independent_all_gather:
return self.fsdp_group_ag
return self.fsdp_group

def get_outer_fsdp_group(self) -> ProcessGroup:
Expand Down
41 changes: 40 additions & 1 deletion megatron/core/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@

# Data parallel group information with context parallel combined.
_DATA_PARALLEL_GROUP_WITH_CP = None
_DATA_PARALLEL_GROUP_WITH_CP_AG = None
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = None
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None

Expand Down Expand Up @@ -566,6 +567,7 @@ def initialize_model_parallel(
create_gloo_process_groups: bool = True,
high_priority_stream_groups: Optional[List[str]] = None,
sharp_enabled_group: Optional[str] = None,
create_all_gather_group: Optional[bool] = False,
) -> None:
"""Initialize model data parallel groups.

Expand Down Expand Up @@ -680,6 +682,13 @@ def initialize_model_parallel(
By default (None), it is enabled from dp group.
Available options (choose one): [dp, dp_replica]

create_all_gather_group (bool, default = False):
Create a separate process group for all-gather operations to avoid
head-of-line blocking with reduce-scatter operations. When enabled,
creates an additional NCCL communicator with identical ranks as the
dp-cp group but with independent progress engines for better communication
overlap.

Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
Expand Down Expand Up @@ -816,6 +825,7 @@ def initialize_model_parallel(
global _DATA_PARALLEL_GROUP_GLOO
global _DATA_PARALLEL_GLOBAL_RANKS
global _DATA_PARALLEL_GROUP_WITH_CP
global _DATA_PARALLEL_GROUP_WITH_CP_AG
global _DATA_PARALLEL_GROUP_WITH_CP_GLOO
global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
global _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP
Expand Down Expand Up @@ -847,6 +857,15 @@ def initialize_model_parallel(
pg_options=get_nccl_options("dp_cp", nccl_comm_cfgs),
group_desc="DATA_PARALLEL_GROUP_WITH_CP",
)
if create_all_gather_group:
group_with_cp_ag = create_group(
ranks_with_cp,
timeout=timeout,
pg_options=get_nccl_options("dp_cp", nccl_comm_cfgs),
group_desc="DATA_PARALLEL_GROUP_WITH_CP_AG",
)
else:
group_with_cp_ag = None
if create_gloo_process_groups:
group_with_cp_gloo = create_group(
ranks_with_cp,
Expand All @@ -858,6 +877,7 @@ def initialize_model_parallel(
group_with_cp_gloo = None
if rank in ranks_with_cp:
_DATA_PARALLEL_GROUP_WITH_CP = group_with_cp
_DATA_PARALLEL_GROUP_WITH_CP_AG = group_with_cp_ag
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = group_with_cp_gloo
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks_with_cp

Expand Down Expand Up @@ -1387,14 +1407,21 @@ def get_pipeline_model_parallel_group(check_initialized=True):
return _PIPELINE_MODEL_PARALLEL_GROUP


def get_data_parallel_group(with_context_parallel=False, partial_data_parallel=False):
def get_data_parallel_group(
with_context_parallel=False, partial_data_parallel=False, independent_all_gather=False
):
"""Get the data-parallel group the caller rank belongs to."""
if with_context_parallel:
if partial_data_parallel:
assert (
_INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP is not None
), "Intra partial data parallel group is not initialized"
return _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP
if independent_all_gather:
assert (
_DATA_PARALLEL_GROUP_WITH_CP_AG is not None
), "data parallel group with context parallel AG is not initialized"
return _DATA_PARALLEL_GROUP_WITH_CP_AG
assert (
_DATA_PARALLEL_GROUP_WITH_CP is not None
), "data parallel group with context parallel combined is not initialized"
Expand All @@ -1405,6 +1432,15 @@ def get_data_parallel_group(with_context_parallel=False, partial_data_parallel=F
return _DATA_PARALLEL_GROUP


def has_separate_all_gather_group() -> bool:
"""Check if a separate all-gather process group has been created.

Returns True if a dedicated all-gather process group exists for improved
communication overlap, False otherwise.
"""
return _DATA_PARALLEL_GROUP_WITH_CP_AG is not None


def get_data_parallel_group_gloo(with_context_parallel=False, partial_data_parallel=False):
"""Get the Gloo data-parallel group the caller rank belongs to."""
if with_context_parallel:
Expand Down Expand Up @@ -2065,6 +2101,9 @@ def destroy_model_parallel():
global _DATA_PARALLEL_GROUP_WITH_CP
_DATA_PARALLEL_GROUP_WITH_CP = None

global _DATA_PARALLEL_GROUP_WITH_CP_AG
_DATA_PARALLEL_GROUP_WITH_CP_AG = None

global _CONTEXT_PARALLEL_GROUP
_CONTEXT_PARALLEL_GROUP = None

Expand Down
3 changes: 3 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2859,6 +2859,9 @@ def _add_distributed_args(parser):
help='IB SHARP can be enabled from only one communication group. '
'By default, it is enabled from dp group. '
'Available options: [dp, dp_replica]')
group.add_argument('--create-all-gather-group', action='store_true',
help='Create a separate process group for all-gather operations '
'to overlap reduce-scatter and all-gather operations.')
group.add_argument('--use-megatron-fsdp', action='store_true',
help='Use the Megatron FSDP code path in DDP.')
group.add_argument('--init-model-with-meta-device', action='store_true')
Expand Down
1 change: 1 addition & 0 deletions megatron/training/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s
create_gloo_process_groups=args.enable_gloo_process_groups,
high_priority_stream_groups=args.high_priority_stream_groups,
sharp_enabled_group=args.sharp_enabled_group,
create_all_gather_group=args.create_all_gather_group,
)
if args.rank == 0:
print(
Expand Down
28 changes: 28 additions & 0 deletions tests/unit_tests/test_parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,3 +530,31 @@ def test_hybrid_dp_cp_groups(world_size, tp_size, cp_size, dp_size):
assert group.size() == group_size

Utils.destroy_model_parallel()


def test_separate_all_gather_group():
"""Test separate all-gather group for improved communication overlap."""
# Test without creating AG group (default)
Utils.initialize_model_parallel(context_parallel_size=world_size, create_all_gather_group=False)
assert not ps.has_separate_all_gather_group()
assert ps._DATA_PARALLEL_GROUP_WITH_CP_AG is None
Utils.destroy_model_parallel()

# Test with creating AG group
Utils.initialize_model_parallel(context_parallel_size=world_size, create_all_gather_group=True)
assert ps.has_separate_all_gather_group()
assert ps._DATA_PARALLEL_GROUP_WITH_CP_AG is not None

# Verify it returns the correct group
ag_group = ps.get_data_parallel_group(with_context_parallel=True, independent_all_gather=True)
regular_group = ps.get_data_parallel_group(
with_context_parallel=True, independent_all_gather=False
)
assert ag_group is not None
assert regular_group is not None
# They should have the same ranks but different communicators
ag_ranks = torch.distributed.get_process_group_ranks(ag_group)
regular_ranks = torch.distributed.get_process_group_ranks(regular_group)
assert ag_ranks == regular_ranks

Utils.destroy_model_parallel()
Loading