Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/megatron/bridge/training/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def _pretrain(
valid_data_iterator = setup_output.valid_data_iterator
test_data_iterator = setup_output.test_data_iterator
ckpt_context = setup_output.checkpointing_context
pg_collection = setup_output.pg_collection

# TRAINING
if not config.train.skip_train:
Expand All @@ -140,6 +141,7 @@ def _pretrain(
valid_data_iterator,
state,
ckpt_context,
pg_collection,
)

barrier_and_log("after training is done")
Expand Down
13 changes: 12 additions & 1 deletion src/megatron/bridge/training/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.core.rerun_state_machine import RerunDataIterator
from megatron.core.transformer import MegatronModule
from megatron.core.process_groups_config import ProcessGroupCollection

from megatron.bridge.data.loaders import setup_data_iterators
from megatron.bridge.models import GPTModelProvider, T5ModelProvider
Expand Down Expand Up @@ -65,6 +66,7 @@ class SetupOutput(NamedTuple):
test_data_iterator: The data iterator for the testing dataset, if applicable.
checkpointing_context: A dictionary holding context for checkpointing operations,
especially for non-persistent local checkpointing.
pg_collection: The process group collection initialized for this run.
"""

state: GlobalState
Expand All @@ -75,6 +77,7 @@ class SetupOutput(NamedTuple):
valid_data_iterator: Optional[RerunDataIterator | list[RerunDataIterator]]
test_data_iterator: Optional[RerunDataIterator | list[RerunDataIterator]]
checkpointing_context: dict[str, Any]
pg_collection: ProcessGroupCollection

def setup(
state: GlobalState,
Expand Down Expand Up @@ -149,6 +152,9 @@ def setup(
print_rank_0("time to initialize megatron (seconds): {:.3f}".format(time.time() - state.start_time))
barrier_and_log("after megatron is initialized")

# Initialize process group collection once and pass through
pg_collection = ProcessGroupCollection.use_mpu_process_groups()

# Context used for persisting some state between checkpoint saves.
checkpointing_context = init_checkpointing_context(cfg.checkpoint)

Expand Down Expand Up @@ -233,6 +239,7 @@ def setup(
cfg.ddp,
optimizer,
align_grad_reduce=cfg.dist.align_grad_reduce,
pg_collection=pg_collection,
)

# Data stuff.
Expand Down Expand Up @@ -272,6 +279,7 @@ def setup(
valid_data_iterator,
test_data_iterator,
checkpointing_context,
pg_collection,
)


Expand All @@ -282,6 +290,7 @@ def _update_model_config_funcs(
optimizer: Optional[MegatronOptimizer],
*,
align_grad_reduce: bool = True,
pg_collection: Optional[ProcessGroupCollection] = None,
) -> None:
"""Update model config sync funcs based on initialized model."""
if isinstance(model[0], (DistributedDataParallel, megatron_FSDP)) and ddp_config.overlap_grad_reduce:
Expand All @@ -301,7 +310,9 @@ def _update_model_config_funcs(
if len(model) == 1:
model_config.param_sync_func = model_config.param_sync_func[0]
if optimizer is not None:
model_config.finalize_model_grads_func = finalize_model_grads
model_config.finalize_model_grads_func = partial(
finalize_model_grads, pg_collection=pg_collection
)
model_config.grad_scale_func = optimizer.scale_loss


Expand Down
24 changes: 14 additions & 10 deletions src/megatron/bridge/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from megatron.core.optimizer.distrib_optimizer import DistributedOptimizer
from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.rerun_state_machine import RerunDataIterator, get_rerun_state_machine
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.cuda_graphs import TECudaGraphHelper
Expand Down Expand Up @@ -82,6 +83,7 @@ def train(
valid_data_iterator: Optional[Union[RerunDataIterator, list[RerunDataIterator]]],
global_state: GlobalState,
checkpointing_context: dict[str, Any],
pg_collection: ProcessGroupCollection,
process_non_loss_data_func: Optional[Callable] = None,
non_loss_data_func: Optional[Callable] = None,
) -> None:
Expand Down Expand Up @@ -280,13 +282,13 @@ def train(
update_num_microbatches(global_state.train_state.consumed_train_samples, consistency_check=True, verbose=True)

# Completely skip iteration if needed.
if _should_skip_and_handle_iteration(global_state, train_data_iterator):
if _should_skip_and_handle_iteration(global_state, train_data_iterator, pg_collection):
continue

# Run training step.
fault_tolerance.on_training_step_start(global_state)
loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = train_step(
wrapped_forward_step_func, train_data_iterator, model, optimizer, scheduler, global_state
wrapped_forward_step_func, train_data_iterator, model, optimizer, scheduler, global_state, pg_collection
)
fault_tolerance.on_training_step_end(global_state)

Expand Down Expand Up @@ -332,9 +334,8 @@ def train(
cuda_graph_helper.cuda_graph_set_manual_hooks()

global_state.train_state.step += 1
batch_size = (
parallel_state.get_data_parallel_world_size() * train_config.micro_batch_size * get_num_microbatches()
)
dp_size = pg_collection.dp.size()
batch_size = dp_size * train_config.micro_batch_size * get_num_microbatches()
global_state.train_state.consumed_train_samples += batch_size
num_skipped_samples_in_batch = get_current_global_batch_size() - get_current_running_global_batch_size()
if train_config.decrease_batch_size_if_needed:
Expand Down Expand Up @@ -496,6 +497,7 @@ def train_step(
optimizer: MegatronOptimizer,
scheduler: OptimizerParamScheduler,
global_state: GlobalState,
pg_collection: ProcessGroupCollection,
) -> tuple[dict[str, torch.Tensor], int, bool, bool, int, Optional[float], Optional[int]]:
"""Single training step.

Expand Down Expand Up @@ -607,9 +609,8 @@ def train_step(
# there is one dict per microbatch. in new reporting, we average
# over the total number of tokens across the global batch.
val = torch.vstack(val).sum(dim=0)
torch.distributed.all_reduce(
val, group=parallel_state.get_data_parallel_group(with_context_parallel=True)
)
dp_cp_group = pg_collection.dp_cp
torch.distributed.all_reduce(val, group=dp_cp_group)
loss_reduced[key] = val[0] / val[1]
elif val[0].numel() == 1:
# legacy behavior, we average over the number of microbatches
Expand Down Expand Up @@ -1059,7 +1060,9 @@ def _finish_train(global_state: GlobalState):


def _should_skip_and_handle_iteration(
global_state: GlobalState, train_data_iterator: Optional[Union[RerunDataIterator, list[RerunDataIterator]]]
global_state: GlobalState,
train_data_iterator: Optional[Union[RerunDataIterator, list[RerunDataIterator]]],
pg_collection: ProcessGroupCollection,
) -> bool:
"""Check if the current iteration should be skipped and handle it if so.

Expand All @@ -1082,7 +1085,8 @@ def _should_skip_and_handle_iteration(

# Update step and sample counters
global_state.train_state.step += 1
batch_size = parallel_state.get_data_parallel_world_size() * cfg.train.micro_batch_size * get_num_microbatches()
dp_size = pg_collection.dp.size()
batch_size = dp_size * cfg.train.micro_batch_size * get_num_microbatches()
global_state.train_state.consumed_train_samples += batch_size
global_state.train_state.skipped_train_samples += batch_size

Expand Down
77 changes: 77 additions & 0 deletions tests/unit_tests/training/test_pg_collection_wiring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import signal
from types import SimpleNamespace


def test_train_step_accepts_pg_collection_argument():
# Import locally to avoid import-time side effects in unrelated modules
from megatron.bridge.training import train as train_module

sig = inspect.signature(train_module.train_step)
assert "pg_collection" in sig.parameters, "train_step must accept pg_collection param"


def test_should_skip_iteration_uses_passed_pg_collection(monkeypatch):
# Arrange minimal GlobalState with only the fields that are used
from megatron.bridge.training import train as train_module
from megatron.bridge.training.state import GlobalState

state = GlobalState()

# Set up a minimal config needed by _should_skip_and_handle_iteration
# We skip step 0 so that the function executes the skip path.
state.cfg = SimpleNamespace(
train=SimpleNamespace(
iterations_to_skip={0},
micro_batch_size=4,
exit_signal=signal.SIGTERM,
)
)

# Fake pg_collection with a DP size
class _DP:
def size(self):
return 3

class _PG:
def __init__(self):
self.dp = _DP()

fake_pg = _PG()

# Ensure deterministic microbatch count without touching global calculators
monkeypatch.setattr(train_module, "get_num_microbatches", lambda: 2)

# Avoid any distributed or pipeline logic inside the dummy step
monkeypatch.setattr(train_module, "_dummy_train_step", lambda *args, **kwargs: None)

# Pre-check counters
assert state.train_state.step == 0
assert state.train_state.consumed_train_samples == 0
assert state.train_state.skipped_train_samples == 0

# Act
did_skip = train_module._should_skip_and_handle_iteration(state, None, fake_pg)

# Assert
assert did_skip is True
# One iteration skipped
assert state.train_state.step == 1
# Batch size = dp.size * micro_batch_size * num_microbatches = 3 * 4 * 2 = 24
expected_batch = 3 * 4 * 2
assert state.train_state.consumed_train_samples == expected_batch
assert state.train_state.skipped_train_samples == expected_batch
26 changes: 22 additions & 4 deletions tests/unit_tests/training/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,20 @@ def _create_mock_global_state(self, step=0, iterations_to_skip=None, micro_batch

return mock_state

def _make_fake_pg(self, dp_size: int):
class _DP:
def __init__(self, size: int) -> None:
self._size = size

def size(self) -> int:
return self._size

class _PG:
def __init__(self, size: int) -> None:
self.dp = _DP(size)

return _PG(dp_size)

@patch("megatron.bridge.training.train._dummy_train_step")
@patch("megatron.bridge.training.train.parallel_state.get_data_parallel_world_size", return_value=2)
@patch("megatron.bridge.training.train.get_num_microbatches", return_value=4)
Expand All @@ -817,7 +831,8 @@ def test_should_skip_iteration_when_step_in_skip_list(
train_data_iterator = Mock()

# Call function
result = _should_skip_and_handle_iteration(global_state, train_data_iterator)
fake_pg = self._make_fake_pg(2)
result = _should_skip_and_handle_iteration(global_state, train_data_iterator, fake_pg)

# Verify
assert result is True
Expand All @@ -837,7 +852,8 @@ def test_should_not_skip_iteration_when_step_not_in_skip_list(self, mock_dummy_s
train_data_iterator = Mock()

# Call function
result = _should_skip_and_handle_iteration(global_state, train_data_iterator)
fake_pg = self._make_fake_pg(1)
result = _should_skip_and_handle_iteration(global_state, train_data_iterator, fake_pg)

# Verify
assert result is False
Expand All @@ -856,7 +872,8 @@ def test_should_not_skip_when_skip_list_empty(self, mock_dummy_step):
train_data_iterator = Mock()

# Call function
result = _should_skip_and_handle_iteration(global_state, train_data_iterator)
fake_pg = self._make_fake_pg(1)
result = _should_skip_and_handle_iteration(global_state, train_data_iterator, fake_pg)

# Verify
assert result is False
Expand All @@ -874,7 +891,8 @@ def test_batch_size_calculation_with_different_parallelism(
train_data_iterator = Mock()

# Call function
result = _should_skip_and_handle_iteration(global_state, train_data_iterator)
fake_pg = self._make_fake_pg(8)
result = _should_skip_and_handle_iteration(global_state, train_data_iterator, fake_pg)

# Verify
assert result is True
Expand Down