diff --git a/src/megatron/bridge/training/pretrain.py b/src/megatron/bridge/training/pretrain.py index 56294e3060..147eb3a7ac 100644 --- a/src/megatron/bridge/training/pretrain.py +++ b/src/megatron/bridge/training/pretrain.py @@ -126,6 +126,7 @@ def _pretrain( valid_data_iterator = setup_output.valid_data_iterator test_data_iterator = setup_output.test_data_iterator ckpt_context = setup_output.checkpointing_context + pg_collection = setup_output.pg_collection # TRAINING if not config.train.skip_train: @@ -140,6 +141,7 @@ def _pretrain( valid_data_iterator, state, ckpt_context, + pg_collection, ) barrier_and_log("after training is done") diff --git a/src/megatron/bridge/training/setup.py b/src/megatron/bridge/training/setup.py index 754c737e70..6c7ba5b355 100644 --- a/src/megatron/bridge/training/setup.py +++ b/src/megatron/bridge/training/setup.py @@ -26,6 +26,7 @@ from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.rerun_state_machine import RerunDataIterator from megatron.core.transformer import MegatronModule +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.bridge.data.loaders import setup_data_iterators from megatron.bridge.models import GPTModelProvider, T5ModelProvider @@ -65,6 +66,7 @@ class SetupOutput(NamedTuple): test_data_iterator: The data iterator for the testing dataset, if applicable. checkpointing_context: A dictionary holding context for checkpointing operations, especially for non-persistent local checkpointing. + pg_collection: The process group collection initialized for this run. """ state: GlobalState @@ -75,6 +77,7 @@ class SetupOutput(NamedTuple): valid_data_iterator: Optional[RerunDataIterator | list[RerunDataIterator]] test_data_iterator: Optional[RerunDataIterator | list[RerunDataIterator]] checkpointing_context: dict[str, Any] + pg_collection: ProcessGroupCollection def setup( state: GlobalState, @@ -149,6 +152,9 @@ def setup( print_rank_0("time to initialize megatron (seconds): {:.3f}".format(time.time() - state.start_time)) barrier_and_log("after megatron is initialized") + # Initialize process group collection once and pass through + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + # Context used for persisting some state between checkpoint saves. checkpointing_context = init_checkpointing_context(cfg.checkpoint) @@ -233,6 +239,7 @@ def setup( cfg.ddp, optimizer, align_grad_reduce=cfg.dist.align_grad_reduce, + pg_collection=pg_collection, ) # Data stuff. @@ -272,6 +279,7 @@ def setup( valid_data_iterator, test_data_iterator, checkpointing_context, + pg_collection, ) @@ -282,6 +290,7 @@ def _update_model_config_funcs( optimizer: Optional[MegatronOptimizer], *, align_grad_reduce: bool = True, + pg_collection: Optional[ProcessGroupCollection] = None, ) -> None: """Update model config sync funcs based on initialized model.""" if isinstance(model[0], (DistributedDataParallel, megatron_FSDP)) and ddp_config.overlap_grad_reduce: @@ -301,7 +310,9 @@ def _update_model_config_funcs( if len(model) == 1: model_config.param_sync_func = model_config.param_sync_func[0] if optimizer is not None: - model_config.finalize_model_grads_func = finalize_model_grads + model_config.finalize_model_grads_func = partial( + finalize_model_grads, pg_collection=pg_collection + ) model_config.grad_scale_func = optimizer.scale_loss diff --git a/src/megatron/bridge/training/train.py b/src/megatron/bridge/training/train.py index c80a26347f..3af70fe42a 100644 --- a/src/megatron/bridge/training/train.py +++ b/src/megatron/bridge/training/train.py @@ -34,6 +34,7 @@ from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.rerun_state_machine import RerunDataIterator, get_rerun_state_machine from megatron.core.transformer import MegatronModule from megatron.core.transformer.cuda_graphs import TECudaGraphHelper @@ -82,6 +83,7 @@ def train( valid_data_iterator: Optional[Union[RerunDataIterator, list[RerunDataIterator]]], global_state: GlobalState, checkpointing_context: dict[str, Any], + pg_collection: ProcessGroupCollection, process_non_loss_data_func: Optional[Callable] = None, non_loss_data_func: Optional[Callable] = None, ) -> None: @@ -280,13 +282,13 @@ def train( update_num_microbatches(global_state.train_state.consumed_train_samples, consistency_check=True, verbose=True) # Completely skip iteration if needed. - if _should_skip_and_handle_iteration(global_state, train_data_iterator): + if _should_skip_and_handle_iteration(global_state, train_data_iterator, pg_collection): continue # Run training step. fault_tolerance.on_training_step_start(global_state) loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = train_step( - wrapped_forward_step_func, train_data_iterator, model, optimizer, scheduler, global_state + wrapped_forward_step_func, train_data_iterator, model, optimizer, scheduler, global_state, pg_collection ) fault_tolerance.on_training_step_end(global_state) @@ -332,9 +334,8 @@ def train( cuda_graph_helper.cuda_graph_set_manual_hooks() global_state.train_state.step += 1 - batch_size = ( - parallel_state.get_data_parallel_world_size() * train_config.micro_batch_size * get_num_microbatches() - ) + dp_size = pg_collection.dp.size() + batch_size = dp_size * train_config.micro_batch_size * get_num_microbatches() global_state.train_state.consumed_train_samples += batch_size num_skipped_samples_in_batch = get_current_global_batch_size() - get_current_running_global_batch_size() if train_config.decrease_batch_size_if_needed: @@ -496,6 +497,7 @@ def train_step( optimizer: MegatronOptimizer, scheduler: OptimizerParamScheduler, global_state: GlobalState, + pg_collection: ProcessGroupCollection, ) -> tuple[dict[str, torch.Tensor], int, bool, bool, int, Optional[float], Optional[int]]: """Single training step. @@ -607,9 +609,8 @@ def train_step( # there is one dict per microbatch. in new reporting, we average # over the total number of tokens across the global batch. val = torch.vstack(val).sum(dim=0) - torch.distributed.all_reduce( - val, group=parallel_state.get_data_parallel_group(with_context_parallel=True) - ) + dp_cp_group = pg_collection.dp_cp + torch.distributed.all_reduce(val, group=dp_cp_group) loss_reduced[key] = val[0] / val[1] elif val[0].numel() == 1: # legacy behavior, we average over the number of microbatches @@ -1059,7 +1060,9 @@ def _finish_train(global_state: GlobalState): def _should_skip_and_handle_iteration( - global_state: GlobalState, train_data_iterator: Optional[Union[RerunDataIterator, list[RerunDataIterator]]] + global_state: GlobalState, + train_data_iterator: Optional[Union[RerunDataIterator, list[RerunDataIterator]]], + pg_collection: ProcessGroupCollection, ) -> bool: """Check if the current iteration should be skipped and handle it if so. @@ -1082,7 +1085,8 @@ def _should_skip_and_handle_iteration( # Update step and sample counters global_state.train_state.step += 1 - batch_size = parallel_state.get_data_parallel_world_size() * cfg.train.micro_batch_size * get_num_microbatches() + dp_size = pg_collection.dp.size() + batch_size = dp_size * cfg.train.micro_batch_size * get_num_microbatches() global_state.train_state.consumed_train_samples += batch_size global_state.train_state.skipped_train_samples += batch_size diff --git a/tests/unit_tests/training/test_pg_collection_wiring.py b/tests/unit_tests/training/test_pg_collection_wiring.py new file mode 100644 index 0000000000..f92462556d --- /dev/null +++ b/tests/unit_tests/training/test_pg_collection_wiring.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import signal +from types import SimpleNamespace + + +def test_train_step_accepts_pg_collection_argument(): + # Import locally to avoid import-time side effects in unrelated modules + from megatron.bridge.training import train as train_module + + sig = inspect.signature(train_module.train_step) + assert "pg_collection" in sig.parameters, "train_step must accept pg_collection param" + + +def test_should_skip_iteration_uses_passed_pg_collection(monkeypatch): + # Arrange minimal GlobalState with only the fields that are used + from megatron.bridge.training import train as train_module + from megatron.bridge.training.state import GlobalState + + state = GlobalState() + + # Set up a minimal config needed by _should_skip_and_handle_iteration + # We skip step 0 so that the function executes the skip path. + state.cfg = SimpleNamespace( + train=SimpleNamespace( + iterations_to_skip={0}, + micro_batch_size=4, + exit_signal=signal.SIGTERM, + ) + ) + + # Fake pg_collection with a DP size + class _DP: + def size(self): + return 3 + + class _PG: + def __init__(self): + self.dp = _DP() + + fake_pg = _PG() + + # Ensure deterministic microbatch count without touching global calculators + monkeypatch.setattr(train_module, "get_num_microbatches", lambda: 2) + + # Avoid any distributed or pipeline logic inside the dummy step + monkeypatch.setattr(train_module, "_dummy_train_step", lambda *args, **kwargs: None) + + # Pre-check counters + assert state.train_state.step == 0 + assert state.train_state.consumed_train_samples == 0 + assert state.train_state.skipped_train_samples == 0 + + # Act + did_skip = train_module._should_skip_and_handle_iteration(state, None, fake_pg) + + # Assert + assert did_skip is True + # One iteration skipped + assert state.train_state.step == 1 + # Batch size = dp.size * micro_batch_size * num_microbatches = 3 * 4 * 2 = 24 + expected_batch = 3 * 4 * 2 + assert state.train_state.consumed_train_samples == expected_batch + assert state.train_state.skipped_train_samples == expected_batch diff --git a/tests/unit_tests/training/test_train.py b/tests/unit_tests/training/test_train.py index 3a6494c65d..6ea6069c76 100644 --- a/tests/unit_tests/training/test_train.py +++ b/tests/unit_tests/training/test_train.py @@ -805,6 +805,20 @@ def _create_mock_global_state(self, step=0, iterations_to_skip=None, micro_batch return mock_state + def _make_fake_pg(self, dp_size: int): + class _DP: + def __init__(self, size: int) -> None: + self._size = size + + def size(self) -> int: + return self._size + + class _PG: + def __init__(self, size: int) -> None: + self.dp = _DP(size) + + return _PG(dp_size) + @patch("megatron.bridge.training.train._dummy_train_step") @patch("megatron.bridge.training.train.parallel_state.get_data_parallel_world_size", return_value=2) @patch("megatron.bridge.training.train.get_num_microbatches", return_value=4) @@ -817,7 +831,8 @@ def test_should_skip_iteration_when_step_in_skip_list( train_data_iterator = Mock() # Call function - result = _should_skip_and_handle_iteration(global_state, train_data_iterator) + fake_pg = self._make_fake_pg(2) + result = _should_skip_and_handle_iteration(global_state, train_data_iterator, fake_pg) # Verify assert result is True @@ -837,7 +852,8 @@ def test_should_not_skip_iteration_when_step_not_in_skip_list(self, mock_dummy_s train_data_iterator = Mock() # Call function - result = _should_skip_and_handle_iteration(global_state, train_data_iterator) + fake_pg = self._make_fake_pg(1) + result = _should_skip_and_handle_iteration(global_state, train_data_iterator, fake_pg) # Verify assert result is False @@ -856,7 +872,8 @@ def test_should_not_skip_when_skip_list_empty(self, mock_dummy_step): train_data_iterator = Mock() # Call function - result = _should_skip_and_handle_iteration(global_state, train_data_iterator) + fake_pg = self._make_fake_pg(1) + result = _should_skip_and_handle_iteration(global_state, train_data_iterator, fake_pg) # Verify assert result is False @@ -874,7 +891,8 @@ def test_batch_size_calculation_with_different_parallelism( train_data_iterator = Mock() # Call function - result = _should_skip_and_handle_iteration(global_state, train_data_iterator) + fake_pg = self._make_fake_pg(8) + result = _should_skip_and_handle_iteration(global_state, train_data_iterator, fake_pg) # Verify assert result is True