diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py index 546bc0721e0..9a0ef354c26 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py @@ -1602,6 +1602,18 @@ def __init__( if self.dist_index.get_outer_fsdp_group() is not None: # Outer/Inter-FSDP group when using hybrid FSDP self.ubr_groups.append(self.dist_index.get_outer_fsdp_group()) + if ( + self.dist_index.get_fsdp_group( + is_expert_parallel=False, independent_all_gather=True + ) + is not None + ): + # All-gather group used when overlapping all-gather and gradient reduction. + self.ubr_groups.append( + self.dist_index.get_fsdp_group( + is_expert_parallel=False, independent_all_gather=True + ) + ) if torch.distributed.get_rank() == 0: logging.info( @@ -1888,6 +1900,18 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): is_expert_parallel=group.is_expert_param ) + # When --create-all-gather-group is enabled, use a separate process group for + # all-gather operations (model_weight_buffer) to enable overlap with gradient reduction + # operations (main_grad_buffer). This avoids head-of-line blocking between forward + # all-gather and backward reduce-scatter on the same communicator. + model_wbuf_dp_group = main_buf_dp_group + if not group.is_expert_param and not should_create_hfsdp_wbuf_and_gbuf: + ag_group = self.dist_index.get_fsdp_group( + is_expert_parallel=False, independent_all_gather=True + ) + if ag_group is not None: + model_wbuf_dp_group = ag_group + gradient_scaling_factor = ( self.gradient_scaling_factor if not group.is_expert_param @@ -1928,10 +1952,10 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): self.ddp_config, group.params, is_data_distributed=is_model_weight_buffer_distributed - and main_buf_dp_group.size() > 1, + and model_wbuf_dp_group.size() > 1, dtype=param_dtype, device=self.device, - data_parallel_group=main_buf_dp_group, + data_parallel_group=model_wbuf_dp_group, is_transpose_buffer=False, temporary_bucket_allocator=self.weight_alloc, bucket_id=group_id, diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py index 01523929ae1..d5fbc91fcf8 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/utils.py @@ -21,6 +21,13 @@ from importlib.metadata import version from typing import Callable, Optional, Sequence, Union +try: + import megatron.core.parallel_state as parallel_state + + HAVE_MEGATRON_CORE = True +except (ImportError, ModuleNotFoundError): + HAVE_MEGATRON_CORE = False + try: import einops @@ -486,6 +493,13 @@ def __init__( if contains_submesh(self.device_mesh, self.dp_shard_dim) else None ) + # AG group comes from parallel_state, not the mesh + # the purpose of this independent group is to overlap all-gather and gradient reduction. + self.fsdp_group_ag = None + if HAVE_MEGATRON_CORE and parallel_state.has_separate_all_gather_group(): + self.fsdp_group_ag = parallel_state.get_data_parallel_group( + with_context_parallel=True, independent_all_gather=True + ) # Retrieve the outer-FSDP process group from the DeviceMesh. self.outer_fsdp_group = ( self.device_mesh[self.dp_outer_dim].get_group() @@ -620,10 +634,14 @@ def get_dp_group(self, is_expert_parallel: bool = False) -> ProcessGroup: return self.hybrid_fsdp_group return self.fsdp_group - def get_fsdp_group(self, is_expert_parallel: bool = False) -> ProcessGroup: + def get_fsdp_group( + self, is_expert_parallel: bool = False, independent_all_gather: bool = False + ) -> ProcessGroup: """Get the FSDP process group.""" if is_expert_parallel: return self.expt_fsdp_group + if independent_all_gather: + return self.fsdp_group_ag return self.fsdp_group def get_outer_fsdp_group(self) -> ProcessGroup: diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index c5a73600ee1..7bb96407838 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -120,6 +120,7 @@ # Data parallel group information with context parallel combined. _DATA_PARALLEL_GROUP_WITH_CP = None +_DATA_PARALLEL_GROUP_WITH_CP_AG = None _DATA_PARALLEL_GROUP_WITH_CP_GLOO = None _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None @@ -566,6 +567,7 @@ def initialize_model_parallel( create_gloo_process_groups: bool = True, high_priority_stream_groups: Optional[List[str]] = None, sharp_enabled_group: Optional[str] = None, + create_all_gather_group: Optional[bool] = False, ) -> None: """Initialize model data parallel groups. @@ -680,6 +682,13 @@ def initialize_model_parallel( By default (None), it is enabled from dp group. Available options (choose one): [dp, dp_replica] + create_all_gather_group (bool, default = False): + Create a separate process group for all-gather operations to avoid + head-of-line blocking with reduce-scatter operations. When enabled, + creates an additional NCCL communicator with identical ranks as the + dp-cp group but with independent progress engines for better communication + overlap. + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize the model pipeline. The present function will @@ -816,6 +825,7 @@ def initialize_model_parallel( global _DATA_PARALLEL_GROUP_GLOO global _DATA_PARALLEL_GLOBAL_RANKS global _DATA_PARALLEL_GROUP_WITH_CP + global _DATA_PARALLEL_GROUP_WITH_CP_AG global _DATA_PARALLEL_GROUP_WITH_CP_GLOO global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP global _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP @@ -847,6 +857,15 @@ def initialize_model_parallel( pg_options=get_nccl_options("dp_cp", nccl_comm_cfgs), group_desc="DATA_PARALLEL_GROUP_WITH_CP", ) + if create_all_gather_group: + group_with_cp_ag = create_group( + ranks_with_cp, + timeout=timeout, + pg_options=get_nccl_options("dp_cp", nccl_comm_cfgs), + group_desc="DATA_PARALLEL_GROUP_WITH_CP_AG", + ) + else: + group_with_cp_ag = None if create_gloo_process_groups: group_with_cp_gloo = create_group( ranks_with_cp, @@ -858,6 +877,7 @@ def initialize_model_parallel( group_with_cp_gloo = None if rank in ranks_with_cp: _DATA_PARALLEL_GROUP_WITH_CP = group_with_cp + _DATA_PARALLEL_GROUP_WITH_CP_AG = group_with_cp_ag _DATA_PARALLEL_GROUP_WITH_CP_GLOO = group_with_cp_gloo _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks_with_cp @@ -1387,7 +1407,9 @@ def get_pipeline_model_parallel_group(check_initialized=True): return _PIPELINE_MODEL_PARALLEL_GROUP -def get_data_parallel_group(with_context_parallel=False, partial_data_parallel=False): +def get_data_parallel_group( + with_context_parallel=False, partial_data_parallel=False, independent_all_gather=False +): """Get the data-parallel group the caller rank belongs to.""" if with_context_parallel: if partial_data_parallel: @@ -1395,6 +1417,11 @@ def get_data_parallel_group(with_context_parallel=False, partial_data_parallel=F _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP is not None ), "Intra partial data parallel group is not initialized" return _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP + if independent_all_gather: + assert ( + _DATA_PARALLEL_GROUP_WITH_CP_AG is not None + ), "data parallel group with context parallel AG is not initialized" + return _DATA_PARALLEL_GROUP_WITH_CP_AG assert ( _DATA_PARALLEL_GROUP_WITH_CP is not None ), "data parallel group with context parallel combined is not initialized" @@ -1405,6 +1432,15 @@ def get_data_parallel_group(with_context_parallel=False, partial_data_parallel=F return _DATA_PARALLEL_GROUP +def has_separate_all_gather_group() -> bool: + """Check if a separate all-gather process group has been created. + + Returns True if a dedicated all-gather process group exists for improved + communication overlap, False otherwise. + """ + return _DATA_PARALLEL_GROUP_WITH_CP_AG is not None + + def get_data_parallel_group_gloo(with_context_parallel=False, partial_data_parallel=False): """Get the Gloo data-parallel group the caller rank belongs to.""" if with_context_parallel: @@ -2065,6 +2101,9 @@ def destroy_model_parallel(): global _DATA_PARALLEL_GROUP_WITH_CP _DATA_PARALLEL_GROUP_WITH_CP = None + global _DATA_PARALLEL_GROUP_WITH_CP_AG + _DATA_PARALLEL_GROUP_WITH_CP_AG = None + global _CONTEXT_PARALLEL_GROUP _CONTEXT_PARALLEL_GROUP = None diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 7dbb4ffedab..7fe6b88dcc8 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2859,6 +2859,9 @@ def _add_distributed_args(parser): help='IB SHARP can be enabled from only one communication group. ' 'By default, it is enabled from dp group. ' 'Available options: [dp, dp_replica]') + group.add_argument('--create-all-gather-group', action='store_true', + help='Create a separate process group for all-gather operations ' + 'to overlap reduce-scatter and all-gather operations.') group.add_argument('--use-megatron-fsdp', action='store_true', help='Use the Megatron FSDP code path in DDP.') group.add_argument('--init-model-with-meta-device', action='store_true') diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index 00fa9ad5088..e300c03218b 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -389,6 +389,7 @@ def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks, s create_gloo_process_groups=args.enable_gloo_process_groups, high_priority_stream_groups=args.high_priority_stream_groups, sharp_enabled_group=args.sharp_enabled_group, + create_all_gather_group=args.create_all_gather_group, ) if args.rank == 0: print( diff --git a/tests/unit_tests/test_parallel_state.py b/tests/unit_tests/test_parallel_state.py index 0c722ee0257..21dc740cdf4 100644 --- a/tests/unit_tests/test_parallel_state.py +++ b/tests/unit_tests/test_parallel_state.py @@ -530,3 +530,31 @@ def test_hybrid_dp_cp_groups(world_size, tp_size, cp_size, dp_size): assert group.size() == group_size Utils.destroy_model_parallel() + + +def test_separate_all_gather_group(): + """Test separate all-gather group for improved communication overlap.""" + # Test without creating AG group (default) + Utils.initialize_model_parallel(context_parallel_size=world_size, create_all_gather_group=False) + assert not ps.has_separate_all_gather_group() + assert ps._DATA_PARALLEL_GROUP_WITH_CP_AG is None + Utils.destroy_model_parallel() + + # Test with creating AG group + Utils.initialize_model_parallel(context_parallel_size=world_size, create_all_gather_group=True) + assert ps.has_separate_all_gather_group() + assert ps._DATA_PARALLEL_GROUP_WITH_CP_AG is not None + + # Verify it returns the correct group + ag_group = ps.get_data_parallel_group(with_context_parallel=True, independent_all_gather=True) + regular_group = ps.get_data_parallel_group( + with_context_parallel=True, independent_all_gather=False + ) + assert ag_group is not None + assert regular_group is not None + # They should have the same ranks but different communicators + ag_ranks = torch.distributed.get_process_group_ranks(ag_group) + regular_ranks = torch.distributed.get_process_group_ranks(regular_group) + assert ag_ranks == regular_ranks + + Utils.destroy_model_parallel()