diff --git a/examples/recipes/decentralized_pg/README.md b/examples/recipes/decentralized_pg/README.md new file mode 100755 index 0000000000..7849e1ee5f --- /dev/null +++ b/examples/recipes/decentralized_pg/README.md @@ -0,0 +1,168 @@ +# Decentralized Process Groups Examples + +This directory contains examples demonstrating how to use **decentralized process groups** (`use_decentralized_pg=True`) in Megatron-Bridge for distributed training. + +## Overview + +Instead of relying on Megatron-Core's global parallel state (mpu) module, you can use a `ProcessGroupCollection` that is explicitly passed to all components. This gives you full control over the parallelism topology and is useful for: + +1. **Reinforcement Learning**: Multiple model instances (policy, value, reference) with different parallelism +2. **Multi-Model Pipelines**: Complex workflows requiring explicit control over communication +3. **Testing/Debugging**: Isolated process groups without global state side effects + +## Files + +| File | Description | +|------|-------------| +| `pretrain_qwen3_simple.py` | **Simple**: Use a recipe and enable `use_decentralized_pg=True` | +| `pretrain_qwen3_with_decentralized_pg.py` | **Advanced**: Manually create process groups with `HyperCommGrid` | + +## Quick Start + +### Simple Approach (Recommended) + +Just use an existing recipe and enable decentralized process groups: + +```bash +# 8 GPUs: TP2 x PP2 x DP2 +uv run python -m torch.distributed.run --nproc_per_node=8 examples/recipes/decentralized_pg/pretrain_qwen3_simple.py + +# 4 GPUs: TP2 x PP2 x DP1 +uv run python -m torch.distributed.run --nproc_per_node=4 examples/recipes/decentralized_pg/pretrain_qwen3_simple.py +``` + +The key is just two lines: + +```python +from megatron.bridge.recipes.qwen.qwen3 import qwen3_4b_pretrain_config + +cfg = qwen3_4b_pretrain_config( + tensor_model_parallel_size=2, + pipeline_model_parallel_size=2, + # ... other settings +) + +# Enable decentralized process groups +cfg.dist.use_decentralized_pg = True +cfg.dist.use_gloo_process_groups = False # Gloo not supported +``` + +### Advanced Approach (Manual Process Group Creation) + +For full control over process groups: + +```bash +# 8 GPUs: TP2 x PP2 x DP2 +uv run python -m torch.distributed.run --nproc_per_node=8 examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py + +# 4 GPUs: TP2 x PP2 x DP1 +uv run python -m torch.distributed.run --nproc_per_node=4 examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py \ + --tp-size 2 --pp-size 2 + +# 2 GPUs: TP2 x PP1 x DP1 +uv run python -m torch.distributed.run --nproc_per_node=2 examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py \ + --tp-size 2 --pp-size 1 +``` + +## Manual Process Group Creation (Advanced) + +### Step 1: Initialize torch.distributed + +```python +torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank) +``` + +### Step 2: Create ProcessGroupCollection with HyperCommGrid + +```python +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.process_groups_config import ProcessGroupCollection + +# Create a grid with shape [TP, CP, DP, PP] +grid = HyperCommGrid( + shape=[tp_size, cp_size, dp_size, pp_size], + dim_names=["tp", "cp", "dp", "pp"], + rank_offset=0, + backend="nccl", +) + +# Create process groups by selecting dimensions +tp_pg = grid.create_pg(["tp"]) # Ranks differ only in TP dimension +pp_pg = grid.create_pg(["pp"]) # Ranks differ only in PP dimension +dp_pg = grid.create_pg(["dp"]) # Ranks differ only in DP dimension +mp_pg = grid.create_pg(["tp", "pp"]) # Model parallel = TP + PP + +# Bundle into ProcessGroupCollection +pg_collection = ProcessGroupCollection( + tp=tp_pg, + pp=pp_pg, + dp=dp_pg, + mp=mp_pg, + # ... more groups +) +``` + +### Step 3: Set Random Seeds (REQUIRED) + +```python +from megatron.core import tensor_parallel +from megatron.core.utils import get_pg_rank + +# Get TP rank from our process group +tp_rank = get_pg_rank(pg_collection.tp) + +# Initialize CUDA RNG tracker - REQUIRED before model creation! +tensor_parallel.model_parallel_cuda_manual_seed( + seed=1234, + te_rng_tracker=False, + inference_rng_tracker=False, + use_cudagraphable_rng=False, + tp_rank=tp_rank, + ep_rank=0, + etp_rank=tp_rank, +) +``` + +### Step 4: Pass pg_collection Explicitly to Components + +```python +# Model creation +model = cfg.model.provide_distributed_model( + pg_collection=pg_collection, # <-- Pass here! + ... +) + +# Optimizer setup +optimizer, scheduler = setup_optimizer( + pg_collection=pg_collection, # <-- Pass here! + ... +) + +# Data loaders use the DP group +train_data_iterator = setup_data_iterators( + dp_group=pg_collection.dp, # <-- Use DP group for data sharding! + ... +) + +# Training loop +train( + pg_collection=pg_collection, # <-- Pass here! + ... +) +``` + +## HyperCommGrid Explained + +`HyperCommGrid` creates a multi-dimensional grid of ranks. The grid shape `[TP, CP, DP, PP]` defines how ranks are organized: + +When you call `grid.create_pg(["tp"])`, it creates groups of ranks that share the same DP and PP coordinates but differ in TP: +- Group 1: [rank 0, rank 1] (DP=0, PP=0) +- Group 2: [rank 2, rank 3] (DP=0, PP=1) +- Group 3: [rank 4, rank 5] (DP=1, PP=0) +- Group 4: [rank 6, rank 7] (DP=1, PP=1) + +## Limitations + +- Gloo process groups are not supported (NCCL only) +- ModelOpt sharded checkpointing is disabled +- Distillation tensor shape adjustment is disabled diff --git a/examples/recipes/decentralized_pg/pretrain_qwen3_simple.py b/examples/recipes/decentralized_pg/pretrain_qwen3_simple.py new file mode 100644 index 0000000000..5a62680e91 --- /dev/null +++ b/examples/recipes/decentralized_pg/pretrain_qwen3_simple.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +============================================================================== +Example: Qwen3 Pretraining with Decentralized Process Groups (Simple) +============================================================================== + +This example demonstrates the simplest way to enable decentralized process groups: +just use an existing recipe and set `cfg.dist.use_decentralized_pg = True`. + +The setup() function inside pretrain() will automatically create the +ProcessGroupCollection using HyperCommGrid based on the parallelism settings. + +How to Run +---------- +# 8 GPUs: TP2 x PP2 x DP2 +uv run python -m torch.distributed.run --nproc_per_node=8 examples/recipes/decentralized_pg/pretrain_qwen3_simple.py + +# 4 GPUs: TP2 x PP2 x DP1 +uv run python -m torch.distributed.run --nproc_per_node=4 examples/recipes/decentralized_pg/pretrain_qwen3_simple.py +""" + +import torch + +from megatron.bridge.recipes.qwen.qwen3 import qwen3_4b_pretrain_config +from megatron.bridge.training.gpt_step import forward_step +from megatron.bridge.training.pretrain import pretrain + + +def main() -> None: + """Run Qwen3 pretraining with decentralized process groups enabled.""" + # Get the standard Qwen3 4B pretrain config with overrides + cfg = qwen3_4b_pretrain_config( + # Use mock data for demo + mock=True, + # Parallelism + tensor_model_parallel_size=2, + pipeline_model_parallel_size=2, + # Training settings (small for demo) + train_iters=100, + seq_length=1024, + global_batch_size=32, + micro_batch_size=1, + # LR schedule (must fit within train_iters) + lr_warmup_iters=10, + lr_decay_iters=100, + ) + # known issue with share_embeddings_and_output_weights + cfg.model.share_embeddings_and_output_weights = False + + # ========================================================================= + # KEY: Enable decentralized process groups + # ========================================================================= + cfg.dist.use_decentralized_pg = True + cfg.dist.use_gloo_process_groups = False # Gloo not supported with decentralized PG + + pretrain(config=cfg, forward_step_func=forward_step) + + # Cleanup + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py b/examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py new file mode 100644 index 0000000000..1f63955d0a --- /dev/null +++ b/examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py @@ -0,0 +1,694 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +============================================================================== +Example: Qwen3 Pretraining with Decentralized Process Groups (Advanced/Manual) +============================================================================== + +This example demonstrates how to MANUALLY create process groups using +HyperCommGrid and ProcessGroupCollection for distributed training. + +Instead of relying on the automatic setup in `pretrain()`, this example shows +the explicit steps to: + 1. Initialize torch.distributed + 2. Create HyperCommGrid with desired topology + 3. Create all required process groups from the grid + 4. Build ProcessGroupCollection + 5. Pass pg_collection explicitly to model, optimizer, and training loop + +This gives you full control over the parallelism topology. + +For a simpler approach that uses a recipe with automatic pg_collection creation, +see `pretrain_qwen3_simple.py`. + +How to Run +---------- +# 8 GPUs: TP2 x PP2 x DP2 +uv run python -m torch.distributed.run --nproc_per_node=8 examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py + +# 4 GPUs: TP2 x PP2 x DP1 +uv run python -m torch.distributed.run --nproc_per_node=4 examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py \ + --tp-size 2 --pp-size 2 + +# 2 GPUs: TP2 x PP1 x DP1 +uv run python -m torch.distributed.run --nproc_per_node=2 examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py \ + --tp-size 2 --pp-size 1 +""" + +import argparse +import os +import tempfile + +import torch +import torch.distributed + +# ============================================================================== +# Core Megatron imports for manual process group creation +# ============================================================================== +from megatron.core import parallel_state, tensor_parallel +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.num_microbatches_calculator import init_num_microbatches_calculator +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.utils import get_pg_rank + +# ============================================================================== +# Megatron-Bridge imports +# ============================================================================== +from megatron.bridge.data.loaders import setup_data_iterators +from megatron.bridge.data.utils import get_dataset_provider +from megatron.bridge.models.qwen import Qwen3ModelProvider4B +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + DistributedDataParallelConfig, + DistributedInitConfig, + LoggerConfig, + MockGPTDatasetConfig, + OptimizerConfig, + RNGConfig, + SchedulerConfig, + TokenizerConfig, + TrainingConfig, +) +from megatron.bridge.training.gpt_step import forward_step +from megatron.bridge.training.optim import setup_optimizer +from megatron.bridge.training.state import GlobalState +from megatron.bridge.training.tokenizers.tokenizer import build_tokenizer +from megatron.bridge.training.train import train +from megatron.bridge.utils.common_utils import get_rank_safe, print_rank_0 + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Qwen3 Pretraining with Manual Decentralized Process Groups") + + # Parallelism settings + parser.add_argument("--tp-size", type=int, default=2, help="Tensor parallel size (default: 2)") + parser.add_argument("--pp-size", type=int, default=2, help="Pipeline parallel size (default: 2)") + parser.add_argument("--cp-size", type=int, default=1, help="Context parallel size (default: 1)") + + # Training settings + parser.add_argument("--num-layers", type=int, default=4, help="Number of layers (default: 4)") + parser.add_argument("--seq-length", type=int, default=1024, help="Sequence length (default: 1024)") + parser.add_argument("--train-iters", type=int, default=100, help="Training iterations (default: 100)") + parser.add_argument("--global-batch-size", type=int, default=32, help="Global batch size (default: 32)") + parser.add_argument("--micro-batch-size", type=int, default=1, help="Micro batch size (default: 1)") + parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate (default: 1e-4)") + + return parser.parse_args() + + +# ============================================================================== +# STEP 1: Initialize torch.distributed +# ============================================================================== +def initialize_torch_distributed() -> None: + """ + Initialize torch.distributed process group. + + This must be called before creating any process groups. + In production, this is typically handled by torchrun. + """ + if torch.distributed.is_initialized(): + print_rank_0("torch.distributed already initialized, skipping...") + return + + # Get rank/world_size from environment (set by torchrun) + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + + # Set CUDA device before distributed init + torch.cuda.set_device(local_rank) + + print_rank_0(f"> Initializing torch.distributed with world_size={world_size}") + + torch.distributed.init_process_group( + backend="nccl", + world_size=world_size, + rank=rank, + ) + + # Barrier to ensure all ranks are ready + torch.distributed.barrier() + + +# ============================================================================== +# STEP 2: Create ProcessGroupCollection using HyperCommGrid +# ============================================================================== +def create_process_group_collection( + tp_size: int, + pp_size: int, + cp_size: int = 1, +) -> ProcessGroupCollection: + """ + Manually create all process groups using HyperCommGrid. + + This is the CORE of this example - showing explicit process group creation + instead of relying on mpu's global state. + + Args: + tp_size: Tensor parallel size + pp_size: Pipeline parallel size + cp_size: Context parallel size (default: 1) + + Returns: + ProcessGroupCollection containing all required process groups + + The HyperCommGrid creates a multi-dimensional grid of ranks: + shape = [TP, CP, DP, PP] + + From this grid, we create various process groups by selecting dimensions: + - tp_pg: select ["tp"] -> ranks within same TP group + - pp_pg: select ["pp"] -> ranks within same PP group + - dp_pg: select ["dp"] -> ranks within same DP group + - mp_pg: select ["tp", "pp"] -> model parallel (TP + PP) + - tp_cp_pg: select ["tp", "cp"] -> tensor + context parallel + - etc. + """ + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # =========================================================================== + # Calculate data parallel size from available world size + # =========================================================================== + model_parallel_size = tp_size * pp_size * cp_size + if world_size % model_parallel_size != 0: + raise RuntimeError(f"world_size ({world_size}) must be divisible by TP*PP*CP ({model_parallel_size})") + dp_size = world_size // model_parallel_size + + if rank == 0: + print(f"\n{'=' * 60}") + print("Creating ProcessGroupCollection with HyperCommGrid") + print(f"{'=' * 60}") + print(f" World Size: {world_size}") + print(f" Tensor Parallel (TP): {tp_size}") + print(f" Pipeline Parallel (PP): {pp_size}") + print(f" Context Parallel (CP): {cp_size}") + print(f" Data Parallel (DP): {dp_size}") + print(f" Grid Shape [TP, CP, DP, PP]: [{tp_size}, {cp_size}, {dp_size}, {pp_size}]") + print(f"{'=' * 60}\n") + + # =========================================================================== + # Create HyperCommGrid with the parallelism topology + # =========================================================================== + # The grid arranges all ranks in a multi-dimensional structure. + # Dimension order: [TP, CP, DP, PP] + grid = HyperCommGrid( + shape=[tp_size, cp_size, dp_size, pp_size], + dim_names=["tp", "cp", "dp", "pp"], + rank_offset=0, # Start from global rank 0 + backend="nccl", # Use NCCL for GPU communication + ) + + # =========================================================================== + # Create CORE process groups from the grid + # =========================================================================== + # Each create_pg() call creates a process group containing ranks that share + # the SAME coordinates on all dimensions NOT listed in the argument. + + # Tensor Parallel group: ranks that differ only in TP dimension + # Used for: column/row parallel linear layers, all-reduce in attention + tp_pg = grid.create_pg(["tp"]) + + # Context Parallel group: ranks that differ only in CP dimension + # Used for: ring attention, sequence splitting + cp_pg = grid.create_pg(["cp"]) + + # Pipeline Parallel group: ranks that differ only in PP dimension + # Used for: send/recv between pipeline stages + pp_pg = grid.create_pg(["pp"]) + + # Data Parallel group: ranks that differ only in DP dimension + # Used for: gradient all-reduce, optimizer state sharding + dp_pg = grid.create_pg(["dp"]) + + # =========================================================================== + # Create COMPOUND process groups + # =========================================================================== + # Model Parallel: combines TP and PP (all ranks in same model replica) + mp_pg = grid.create_pg(["tp", "pp"]) + + # Tensor + Context Parallel: used for some attention computations + tp_cp_pg = grid.create_pg(["tp", "cp"]) + + # TP + DP + CP: used for distributed optimizer across non-PP dimensions + tp_dp_cp_pg = grid.create_pg(["tp", "dp", "cp"]) + + # DP + CP: data parallel including context parallel + dp_cp_pg = grid.create_pg(["dp", "cp"]) + + # =========================================================================== + # Create embedding/position embedding groups + # =========================================================================== + # Embedding group connects first and last PP stages (for tied embeddings) + # Position embedding group is just the first PP stage + pp_rank_lists = grid._gen_rank_enum(["pp"]) + + embedding_rank_lists = [] + pos_embedding_rank_lists = [] + for ranks in pp_rank_lists: + if not ranks: + continue + # Embedding: first and last stage (or just first if pp_size==1) + embedding_rank_lists.append([ranks[0]] if len(ranks) == 1 else [ranks[0], ranks[-1]]) + # Position embedding: only first stage + pos_embedding_rank_lists.append([ranks[0]]) + + embd_pg, _ = torch.distributed.new_subgroups_by_enumeration(embedding_rank_lists, backend="nccl") + pos_embd_pg, _ = torch.distributed.new_subgroups_by_enumeration(pos_embedding_rank_lists, backend="nccl") + + # =========================================================================== + # Create Expert/MoE groups (simplified - no expert parallelism here) + # =========================================================================== + # For MoE models, you would create additional expert-specific groups. + # Here we reuse TP groups since we're not using expert parallelism. + ep_pg = None # No expert parallelism in this example + expt_tp_pg = tp_pg # Expert TP same as regular TP + tp_ep_pg = tp_pg # TP + EP = just TP when EP=1 + tp_ep_pp_pg = mp_pg # TP + EP + PP = MP when EP=1 + expt_dp_pg = dp_pg # Expert DP same as regular DP + + # =========================================================================== + # Initialize global memory buffer (required by Megatron-Core) + # =========================================================================== + parallel_state._set_global_memory_buffer() + + # =========================================================================== + # Build the ProcessGroupCollection + # =========================================================================== + # This is the single object that contains ALL process groups and gets + # passed through function calls in decentralized process groups mode. + pg_collection = ProcessGroupCollection( + # Core parallelism groups + tp=tp_pg, + pp=pp_pg, + mp=mp_pg, + cp=cp_pg, + dp=dp_pg, + dp_cp=dp_cp_pg, + tp_cp=tp_cp_pg, + tp_dp_cp=tp_dp_cp_pg, + # Embedding groups + embd=embd_pg, + pos_embd=pos_embd_pg, + # Expert/MoE groups (simplified) + ep=ep_pg, + expt_tp=expt_tp_pg, + tp_ep=tp_ep_pg, + tp_ep_pp=tp_ep_pp_pg, + expt_dp=expt_dp_pg, + intra_dp_cp=dp_cp_pg, + intra_expt_dp=expt_dp_pg, + # Hierarchical context parallel (not used) + hcp=None, + # Distributed optimizer groups (not using partial optimizer here) + inter_dist_opt=None, + intra_dist_opt=None, + ) + + if rank == 0: + print("ProcessGroupCollection created successfully!") + print(f" tp_pg world size: {torch.distributed.get_world_size(tp_pg)}") + print(f" pp_pg world size: {torch.distributed.get_world_size(pp_pg)}") + print(f" dp_pg world size: {torch.distributed.get_world_size(dp_pg)}") + print() + + return pg_collection + + +# ============================================================================== +# STEP 3: Set random seeds (required for model initialization) +# ============================================================================== +def set_random_seeds( + seed: int, + pg_collection: ProcessGroupCollection, + data_parallel_random_init: bool = False, +) -> None: + """ + Set random seeds for reproducibility. + + This is REQUIRED before creating the model because Megatron-Core's + tensor parallel layers use a CUDA RNG tracker for weight initialization. + + The RNG tracker ensures that: + - Different TP ranks initialize different weight partitions correctly + - Different PP stages get different seeds for reproducibility + - (Optionally) different DP ranks can have different initialization + + Args: + seed: Base random seed + pg_collection: ProcessGroupCollection containing all process groups + data_parallel_random_init: If True, vary seed by DP rank + """ + import random + + import numpy as np + + current_rank = torch.distributed.get_rank() + + # Different PP stages get different seeds (for reproducibility across stages) + pp_rank = torch.distributed.get_group_rank(pg_collection.pp, current_rank) + adjusted_seed = seed + (100 * pp_rank) + + # Optionally vary by DP rank (for different random init per replica) + if data_parallel_random_init: + dp_rank = torch.distributed.get_group_rank(pg_collection.dp, current_rank) + adjusted_seed = adjusted_seed + (10 * dp_rank) + + # Set seeds for Python, NumPy, PyTorch + random.seed(adjusted_seed) + np.random.seed(adjusted_seed) + torch.manual_seed(adjusted_seed) + + # =========================================================================== + # CRITICAL: Initialize CUDA RNG tracker for tensor parallelism + # =========================================================================== + # This sets up the "model-parallel-rng" state used by ColumnParallelLinear, + # RowParallelLinear, and other TP layers during weight initialization. + if torch.cuda.device_count() > 0: + # Get TP rank from our process group + tp_rank = get_pg_rank(pg_collection.tp) + # EP rank (no expert parallelism in this example) + ep_rank = get_pg_rank(pg_collection.ep) if pg_collection.ep is not None else 0 + # Expert TP rank + etp_rank = get_pg_rank(pg_collection.expt_tp) + + # This function creates the CUDA RNG tracker with "model-parallel-rng" state + tensor_parallel.model_parallel_cuda_manual_seed( + adjusted_seed, + te_rng_tracker=False, # Transformer Engine RNG tracker + inference_rng_tracker=False, # Inference-specific RNG + use_cudagraphable_rng=False, # CUDA graph compatible RNG + tp_rank=tp_rank, + ep_rank=ep_rank, + etp_rank=etp_rank, + ) + + print_rank_0(f"Random seeds set (base={seed}, adjusted={adjusted_seed})") + + +# ============================================================================== +# STEP 4: Create model, optimizer, and run training +# ============================================================================== +def run_training(args: argparse.Namespace, pg_collection: ProcessGroupCollection) -> None: + """ + Create model, optimizer, dataloaders, and run the training loop. + + This shows how to pass pg_collection explicitly to all components. + """ + rank = get_rank_safe() + world_size = torch.distributed.get_world_size() + + # Calculate DP size + dp_size = world_size // (args.tp_size * args.pp_size * args.cp_size) + + # =========================================================================== + # Create output directories + # =========================================================================== + base_dir = tempfile.mkdtemp(prefix="mbridge_decentralized_pg_") + checkpoint_dir = os.path.join(base_dir, "checkpoints") + tensorboard_dir = os.path.join(base_dir, "tensorboard") + + if rank == 0: + os.makedirs(checkpoint_dir, exist_ok=True) + os.makedirs(tensorboard_dir, exist_ok=True) + print(f"Output directory: {base_dir}\n") + + torch.distributed.barrier() + + # =========================================================================== + # Create ConfigContainer with use_decentralized_pg=True + # =========================================================================== + # IMPORTANT: When use_decentralized_pg=True, the setup functions + # expect pg_collection to be passed explicitly rather than reading from mpu. + + model_cfg = Qwen3ModelProvider4B( + # Parallelism - must match what we used to create pg_collection + tensor_model_parallel_size=args.tp_size, + pipeline_model_parallel_size=args.pp_size, + context_parallel_size=args.cp_size, + sequence_parallel=(args.tp_size > 1), + # Model architecture (scaled down for demo) + num_layers=args.num_layers, + seq_length=args.seq_length, + share_embeddings_and_output_weights=False, + # Precision + pipeline_dtype=torch.bfloat16, + bf16=True, + attention_softmax_in_fp32=True, + make_vocab_size_divisible_by=128, + vocab_size=None, + ) + + train_cfg = TrainingConfig( + train_iters=args.train_iters, + eval_interval=args.train_iters, + eval_iters=0, + global_batch_size=args.global_batch_size, + micro_batch_size=args.micro_batch_size, + exit_signal_handler=True, + ) + + optimizer_cfg = OptimizerConfig( + optimizer="adam", + bf16=True, + use_distributed_optimizer=True, + clip_grad=1.0, + lr=args.lr, + weight_decay=0.01, + min_lr=args.lr / 10, + ) + + scheduler_cfg = SchedulerConfig( + lr_decay_style="cosine", + lr_warmup_iters=10, + lr_warmup_init=0.0, + lr_decay_iters=args.train_iters, + override_opt_param_scheduler=True, + start_weight_decay=0.01, + end_weight_decay=0.01, + weight_decay_incr_style="constant", + ) + + ddp_cfg = DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + # Disable overlap features for simplicity in this manual setup example + overlap_grad_reduce=False, + overlap_param_gather=False, + use_distributed_optimizer=True, + ) + + # KEY: use_decentralized_pg=True tells Megatron-Bridge that we're + # managing process groups ourselves via pg_collection + dist_cfg = DistributedInitConfig( + use_decentralized_pg=True, + use_gloo_process_groups=False, # Gloo not supported with decentralized PG + ) + + dataset_cfg = MockGPTDatasetConfig( + random_seed=1234, + seq_length=args.seq_length, + dataloader_type="single", + num_workers=1, + reset_position_ids=False, + reset_attention_mask=False, + eod_mask_loss=False, + ) + + tokenizer_cfg = TokenizerConfig(tokenizer_type="NullTokenizer", vocab_size=10000) + logger_cfg = LoggerConfig(log_interval=10, tensorboard_dir=tensorboard_dir) + checkpoint_cfg = CheckpointConfig(save_interval=args.train_iters, save=checkpoint_dir) + rng_cfg = RNGConfig(seed=1234) + + cfg = ConfigContainer( + model=model_cfg, + train=train_cfg, + optimizer=optimizer_cfg, + scheduler=scheduler_cfg, + ddp=ddp_cfg, + dist=dist_cfg, + dataset=dataset_cfg, + logger=logger_cfg, + tokenizer=tokenizer_cfg, + checkpoint=checkpoint_cfg, + rng=rng_cfg, + ) + + # =========================================================================== + # Initialize microbatch calculator + # =========================================================================== + init_num_microbatches_calculator( + rank=rank, + rampup_batch_size=None, + global_batch_size=args.global_batch_size, + micro_batch_size=args.micro_batch_size, + data_parallel_size=dp_size, + ) + + # =========================================================================== + # Build tokenizer and set vocab size + # =========================================================================== + tokenizer = build_tokenizer(tokenizer_cfg) + cfg.model.vocab_size = tokenizer.vocab_size + cfg.dataset.tokenizer = tokenizer + cfg.validate() + + # =========================================================================== + # Create model - PASS pg_collection explicitly + # =========================================================================== + print_rank_0("Creating model with pg_collection...") + + model = cfg.model.provide_distributed_model( + ddp_config=ddp_cfg, + use_megatron_fsdp=False, + use_torch_fsdp2=False, + overlap_param_gather_with_optimizer_step=False, + data_parallel_random_init=False, + pg_collection=pg_collection, # <-- Explicitly pass our pg_collection! + ) + + print_rank_0(f"Model created: {len(model)} chunks") + + # =========================================================================== + # Create optimizer - PASS pg_collection explicitly + # =========================================================================== + print_rank_0("Creating optimizer with pg_collection...") + + optimizer, scheduler = setup_optimizer( + optimizer_config=optimizer_cfg, + scheduler_config=scheduler_cfg, + model=model, + use_gloo_process_groups=False, + pg_collection=pg_collection, # <-- Explicitly pass our pg_collection! + ) + + print_rank_0("Optimizer created") + + # =========================================================================== + # Create GlobalState (singleton pattern - no args, then set cfg) + # =========================================================================== + state = GlobalState() + state.cfg = cfg + + # =========================================================================== + # Create data iterators - use dp_group from pg_collection + # =========================================================================== + print_rank_0("Creating data iterators...") + + # Get the dataset provider based on the dataset config type + # MockGPTDatasetConfig will create mock datasets for testing/demo + dataset_provider = get_dataset_provider(cfg.dataset) + + # The data iterators need the DP group for sharding data across DP ranks + train_data_iterator, valid_data_iterator, test_data_iterator = setup_data_iterators( + cfg=cfg, + train_state=state.train_state, + model_length=len(model), + train_valid_test_datasets_provider=dataset_provider, + dp_group=pg_collection.dp, # <-- Use DP group from our pg_collection! + ) + + print_rank_0("Data iterators created\n") + + print_rank_0("=" * 60) + print_rank_0("Starting training with manually created process groups") + print_rank_0("=" * 60) + print_rank_0(f" pg_collection.tp world size: {torch.distributed.get_world_size(pg_collection.tp)}") + print_rank_0(f" pg_collection.pp world size: {torch.distributed.get_world_size(pg_collection.pp)}") + print_rank_0(f" pg_collection.dp world size: {torch.distributed.get_world_size(pg_collection.dp)}") + print_rank_0("") + + # Run the training loop + train( + forward_step_func=forward_step, + model=model, + optimizer=optimizer, + scheduler=scheduler, + train_data_iterator=train_data_iterator, + valid_data_iterator=valid_data_iterator, + global_state=state, + checkpointing_context={}, + pg_collection=pg_collection, # <-- Pass to training loop! + ) + + print_rank_0("\nTraining complete!") + + +def main() -> None: + """Main entry point demonstrating manual process group creation.""" + args = parse_args() + + print_rank_0("=" * 70) + print_rank_0("Qwen3 Pretraining with MANUALLY Created Decentralized Process Groups") + print_rank_0("=" * 70) + print_rank_0("") + print_rank_0("This example shows how to:") + print_rank_0(" 1. Initialize torch.distributed") + print_rank_0(" 2. Create HyperCommGrid with your parallelism topology") + print_rank_0(" 3. Create ProcessGroupCollection from the grid") + print_rank_0(" 4. Set random seeds (required for model weight initialization)") + print_rank_0(" 5. Pass pg_collection explicitly to model, optimizer, training") + print_rank_0("") + + # ========================================================================= + # STEP 1: Initialize torch.distributed + # ========================================================================= + print_rank_0("STEP 1: Initializing torch.distributed...") + initialize_torch_distributed() + + # Validate parallelism settings + world_size = torch.distributed.get_world_size() + required = args.tp_size * args.pp_size * args.cp_size + if world_size < required: + raise RuntimeError( + f"Need at least {required} GPUs for TP={args.tp_size}, PP={args.pp_size}, CP={args.cp_size}" + ) + if args.num_layers % args.pp_size != 0: + raise RuntimeError(f"num_layers ({args.num_layers}) must be divisible by PP ({args.pp_size})") + + # ========================================================================= + # STEP 2: Create ProcessGroupCollection manually + # ========================================================================= + print_rank_0("\nSTEP 2: Creating ProcessGroupCollection with HyperCommGrid...") + pg_collection = create_process_group_collection( + tp_size=args.tp_size, + pp_size=args.pp_size, + cp_size=args.cp_size, + ) + + # ========================================================================= + # STEP 3: Set random seeds (REQUIRED before model creation) + # ========================================================================= + print_rank_0("STEP 3: Setting random seeds for CUDA RNG tracker...") + set_random_seeds(seed=1234, pg_collection=pg_collection) + + # ========================================================================= + # STEP 4: Run training with our pg_collection + # ========================================================================= + print_rank_0("\nSTEP 4: Creating model/optimizer and running training...") + run_training(args, pg_collection) + + # ========================================================================= + # Cleanup + # ========================================================================= + torch.distributed.barrier() + torch.distributed.destroy_process_group() + print_rank_0("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/src/megatron/bridge/models/model_provider.py b/src/megatron/bridge/models/model_provider.py index 97b92705fb..c4c024f0a2 100644 --- a/src/megatron/bridge/models/model_provider.py +++ b/src/megatron/bridge/models/model_provider.py @@ -124,6 +124,7 @@ def provide_distributed_model( | None = None, post_wrap_hook: Callable[[list[MegatronModule]], list[MegatronModule]] | None = None, mixed_precision_wrapper: Callable[[Any, MegatronModule], MegatronModule] | None = Float16Module, + pg_collection: ProcessGroupCollection | None = None, ) -> list[ModelT]: """Instantiate and wrap the model for distributed training. @@ -151,6 +152,9 @@ def provide_distributed_model( this will override all hooks registered via `register_post_wrap_hook`. mixed_precision_wrapper: A module wrapper (e.g., `Float16Module`) applied when fp16/bf16 is enabled. If None, no mixed precision wrapper is applied. + pg_collection: Optional pre-initialized ProcessGroupCollection. If provided, skips + model parallel initialization and uses the provided collection directly. + This is used when `use_decentralized_pg=True` in the distributed config. Returns: A list containing the wrapped model instance. @@ -166,10 +170,13 @@ def provide_distributed_model( torch.cuda.set_device(get_local_rank_preinit()) torch.distributed.init_process_group("nccl") - if not parallel_state.is_initialized(): - print("Model parallel not initialized, initializing...") - self.initialize_model_parallel(seed=0) - pg_collection = ProcessGroupCollection.use_mpu_process_groups() + # If pg_collection is provided (e.g., from use_decentralized_pg=True), + # use it directly. Otherwise, initialize model parallel state and get pg_collection from MPU. + if pg_collection is None: + if not parallel_state.is_initialized(): + print("Model parallel not initialized, initializing...") + self.initialize_model_parallel(seed=0) + pg_collection = ProcessGroupCollection.use_mpu_process_groups() # Providers (GPT, Mamba, Gemma, etc.) expect pg_collection on self for PP/TP role checks. setattr(self, "_pg_collection", pg_collection) diff --git a/src/megatron/bridge/training/checkpointing.py b/src/megatron/bridge/training/checkpointing.py index 48163c4c1e..1d821b851e 100644 --- a/src/megatron/bridge/training/checkpointing.py +++ b/src/megatron/bridge/training/checkpointing.py @@ -29,7 +29,7 @@ import numpy as np import torch import torch.nn.functional as F -from megatron.core import dist_checkpointing, mpu, tensor_parallel +from megatron.core import dist_checkpointing, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedObject, ShardedStateDict from megatron.core.dist_checkpointing.serialization import ( StateDict, @@ -46,6 +46,7 @@ from megatron.core.num_microbatches_calculator import update_num_microbatches from megatron.core.optimizer import DistributedOptimizer, MegatronOptimizer from megatron.core.optimizer.layer_wise_optimizer import LayerWiseDistributedOptimizer +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.transformer import MegatronModule from megatron.core.utils import unwrap_model @@ -74,6 +75,7 @@ read_train_state, ) from megatron.bridge.training.utils.log_utils import append_to_progress_log +from megatron.bridge.training.utils.pg_utils import get_pg_collection from megatron.bridge.utils.common_utils import ( get_rank_safe, is_last_rank, @@ -372,7 +374,12 @@ def is_empty_async_queue(global_state: GlobalState) -> bool: return async_queue.get_num_unfinalized_calls() == 0 -def get_rng_state(data_parallel_random_init: bool, ckpt_format: str = "torch_dist") -> Union[ShardedObject, dict]: +def get_rng_state( + data_parallel_random_init: bool, + ckpt_format: str = "torch_dist", + *, + pg_collection: ProcessGroupCollection, +) -> ShardedObject | dict: """Get the random number generator states for all necessary libraries. Collects states from random, numpy, torch, cuda, and the Megatron RNG tracker. @@ -396,27 +403,27 @@ def get_rng_state(data_parallel_random_init: bool, ckpt_format: str = "torch_dis } rng_state_list = None - if torch.distributed.is_initialized() and mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init: - rng_state_list = [None for i in range(mpu.get_data_parallel_world_size())] - torch.distributed.all_gather_object(rng_state_list, rng_state, group=mpu.get_data_parallel_group()) + if torch.distributed.is_initialized() and pg_collection.dp_cp.size() > 1 and data_parallel_random_init: + rng_state_list = [None for i in range(pg_collection.dp_cp.size())] + torch.distributed.all_gather_object(rng_state_list, rng_state, group=pg_collection.dp_cp) else: rng_state_list = [rng_state] if ckpt_format == "torch_dist": - pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() - tp_rank = mpu.get_tensor_model_parallel_rank() - tp_size = mpu.get_tensor_model_parallel_world_size() + pp_rank = pg_collection.pp.rank() + pp_size = pg_collection.pp.size() + tp_rank = pg_collection.tp.rank() + tp_size = pg_collection.tp.size() rng_state_list = ShardedObject( "rng_state", rng_state_list, (pp_size, tp_size), (pp_rank, tp_rank), - replica_id=mpu.get_data_parallel_rank(with_context_parallel=True), + replica_id=pg_collection.dp_cp.rank(), ) elif ckpt_format == "fsdp_dtensor": - pp_rank = mpu.get_pipeline_model_parallel_rank() - tp_rank = mpu.get_tensor_model_parallel_rank() + pp_rank = pg_collection.pp.rank() + tp_rank = pg_collection.tp.rank() rng_state_list = {f"({pp_rank}, {tp_rank})": rng_state_list} return rng_state_list @@ -443,6 +450,7 @@ def save_checkpoint( train_data_iterator: Optional[Any] = None, preprocess_common_state_dict_fn: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, prebuilt_state_dict: Optional[dict[str, Any]] = None, + pg_collection: Optional[ProcessGroupCollection] = None, ) -> None: """Save a model checkpoint. @@ -466,6 +474,8 @@ def save_checkpoint( prebuilt_state_dict: Optional pre-built state dict. When provided, skips state dict generation and uses this directly. Used for low-memory save mode where factories are expanded and model deleted before save. + pg_collection: Optional ProcessGroupCollection. When provided, uses this instead of + extracting from model. Required when model is empty (e.g., low-memory save). """ train_state = state.train_state @@ -512,8 +522,12 @@ def save_checkpoint( print_rank_0(f"saving checkpoint at iteration {train_state.step:7d} to {save_dir} in {ckpt_format} format") # Collect rng state across data parallel ranks. + if pg_collection is None: + pg_collection = get_pg_collection(model) rng_state = get_rng_state( - data_parallel_random_init=cfg.rng.data_parallel_random_init, ckpt_format=ckpt_cfg.ckpt_format + data_parallel_random_init=cfg.rng.data_parallel_random_init, + ckpt_format=ckpt_cfg.ckpt_format, + pg_collection=pg_collection, ) # Collect rerun state across all ranks @@ -527,13 +541,17 @@ def save_checkpoint( checkpoint_name = get_checkpoint_name(save_dir, train_state.step, release=False) # Save dataloader state if the dataloader supports it (currently only Megatron Energon). - maybe_save_dataloader_state(train_data_iterator, train_state.step, getattr(cfg.dataset, "dataloader_save", None)) + maybe_save_dataloader_state( + model, + train_data_iterator, + train_state.step, + getattr(cfg.dataset, "dataloader_save", None), + pg_collection=pg_collection, + ) # Save LayerWiseDistributedOptimizer - if isinstance( - optimizer, LayerWiseDistributedOptimizer - ): # replacement of getattr(args, "optimizer", "adam").startswith("dist_") - dp_rank = mpu.get_data_parallel_rank() + if isinstance(optimizer, LayerWiseDistributedOptimizer): + dp_rank = pg_collection.dp.rank() optim_checkpoint_name = os.path.join(os.path.dirname(checkpoint_name), f"layer_wise_optimizer_{dp_rank}.pt") ensure_directory_exists(optim_checkpoint_name) if not optimizer.is_stub_optimizer: @@ -550,6 +568,7 @@ def save_checkpoint( # Collect cfg, model, RNG. sharded_sd_metadata = _build_sharded_state_dict_metadata(cfg.optimizer.use_distributed_optimizer, ckpt_cfg) + sharded_sd_metadata["dp_cp_group"] = pg_collection.dp_cp if cfg.optimizer.use_distributed_optimizer: print_rank_0( f"Storing distributed optimizer sharded state of type {sharded_sd_metadata['distrib_optim_sharding_type']}" @@ -571,6 +590,7 @@ def save_checkpoint( optim_sd_kwargs=dict(metadata=sharded_sd_metadata), model_sd_kwargs=dict(metadata=sharded_sd_metadata), rerun_state=rerun_state, + pg_collection=pg_collection, ) # Apply PEFT filtering to save adapter-only checkpoints @@ -620,7 +640,7 @@ def save_checkpoint( if ckpt_cfg.fully_parallel_save: save_strategy = FullyParallelSaveStrategyWrapper( save_strategy, - mpu.get_data_parallel_group(with_context_parallel=True), + pg_collection.dp_cp, ckpt_cfg.ckpt_assume_constant_structure, ) # Store save strategy for future checkpoint saves @@ -639,7 +659,9 @@ def save_checkpoint( ) # [ModelOpt]: save sharded modelopt_state (skip if model is empty, e.g., low-memory save mode) if model: - save_sharded_modelopt_state(model, checkpoint_name, (ckpt_cfg.ckpt_format, 1)) + # cfg.dist can be None during checkpoint conversion (save_megatron_model) + if not (cfg.dist and cfg.dist.use_decentralized_pg): + save_sharded_modelopt_state(model, checkpoint_name, (ckpt_cfg.ckpt_format, 1)) else: # [ModelOpt]: Inject modelopt_state into state_dict (skip if model is empty) if ckpt_type == CheckpointType.LOCAL: @@ -663,7 +685,7 @@ def save_checkpoint( state_dict, algo=algo, cached_metadata=cached_metadata, - parallelization_group=mpu.get_data_parallel_group(with_context_parallel=True), + parallelization_group=pg_collection.dp_cp, ) async_save_request = checkpointing_context["local_checkpoint_manager"].save( state_dict_for_save, train_state.step, is_async=bool(ckpt_cfg.async_save) @@ -722,10 +744,10 @@ def train_state_finalize_fn() -> None: if tokenizer_instance is not None: save_tokenizer_assets(tokenizer_instance, cfg.tokenizer, checkpoint_name) - tp_rank = (tensor_rank if tensor_rank is not None else mpu.get_tensor_model_parallel_rank()) + 1 - tp_world_size = mpu.get_tensor_model_parallel_world_size() - pp_rank = (pipeline_rank if pipeline_rank is not None else mpu.get_pipeline_model_parallel_rank()) + 1 - pp_world_size = mpu.get_pipeline_model_parallel_world_size() + tp_rank = (tensor_rank if tensor_rank is not None else pg_collection.tp.rank()) + 1 + tp_world_size = pg_collection.tp.size() + pp_rank = (pipeline_rank if pipeline_rank is not None else pg_collection.pp.rank()) + 1 + pp_world_size = pg_collection.pp.size() print_rank_0( f" successfully saved checkpoint from iteration {train_state_dict['step'].item():7d} " f"to {ckpt_cfg.save} [ t {tp_rank}/{tp_world_size}, p {pp_rank}/{pp_world_size} ]" @@ -830,7 +852,14 @@ def remove_iter_ckpts(_iter_ckpts): remove_iter_ckpts(rm_iter_ckpts) -def maybe_save_dataloader_state(train_iterator: Any, iteration: int, dataloader_save_path: Optional[str]) -> None: +def maybe_save_dataloader_state( + model: list[MegatronModule] | MegatronModule, + train_iterator: Any, + iteration: int, + dataloader_save_path: str | None = None, + *, + pg_collection: ProcessGroupCollection | None = None, +) -> None: """Save the dataloader state if the iterator supports it. Checks if the train_iterator has a `save_state` method and calls it. @@ -848,12 +877,13 @@ def maybe_save_dataloader_state(train_iterator: Any, iteration: int, dataloader_ if not hasattr(train_iterator.iterable, "save_state"): raise RuntimeError(f"Could not find a save_state for the train_iterator of type {type(train_iterator)}") - # Save dataloader state for each data parallel rank only once. - first_rank = mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0 - if not first_rank: + # Resolve process groups and save dataloader state for each DP rank only once. + pg_collection = pg_collection or get_pg_collection(model) + is_first_rank = (pg_collection.pp.rank() == 0) and (pg_collection.tp.rank() == 0) + if not is_first_rank: return - dp_rank = mpu.get_data_parallel_rank() + dp_rank = pg_collection.dp.rank() print_rank_0(f"saving dataloader checkpoint at iteration {iteration} to {dataloader_save_path}") train_dataloader_state_dict = train_iterator.iterable.save_state() # Get the base directory for the current iteration @@ -861,12 +891,12 @@ def maybe_save_dataloader_state(train_iterator: Any, iteration: int, dataloader_ # Construct the specific filename within that iteration directory data_state_save_path = os.path.join(iter_dir, f"train_dataloader_dprank{dp_rank:03d}.pt") - torch.distributed.barrier(group=mpu.get_data_parallel_group()) + torch.distributed.barrier(group=pg_collection.dp) - if mpu.get_data_parallel_rank() == 0: + if pg_collection.dp.rank() == 0: ensure_directory_exists(data_state_save_path) - torch.distributed.barrier(group=mpu.get_data_parallel_group()) + torch.distributed.barrier(group=pg_collection.dp) dataloader_save_dict = {} dataloader_save_dict["dataloader_state_dict"] = train_dataloader_state_dict @@ -1012,6 +1042,8 @@ def _generate_model_state_dict( model: list[MegatronModule], model_sd_kwargs: Optional[dict[str, Any]] = None, ckpt_format: str = "torch_dist", + *, + pg_collection: ProcessGroupCollection | None = None, ) -> dict[str, ShardedStateDict]: """Generate the model subset of the state dictionary to be saved in a checkpoint. @@ -1052,6 +1084,8 @@ def generate_state_dict( optim_sd_kwargs: Optional[dict[str, Any]] = None, model_sd_kwargs: Optional[dict[str, Any]] = None, rerun_state: Optional[dict[str, Any]] = None, + *, + pg_collection: ProcessGroupCollection | None = None, ) -> dict[str, Any]: """Generate the state dictionary to be saved in a checkpoint. @@ -1075,7 +1109,9 @@ def generate_state_dict( if iteration is not None: state_dict["iteration"] = iteration - state_dict.update(_generate_model_state_dict(model, model_sd_kwargs, ckpt_cfg.ckpt_format)) + state_dict.update( + _generate_model_state_dict(model, model_sd_kwargs, ckpt_cfg.ckpt_format, pg_collection=pg_collection) + ) # Optimizer stuff. if ckpt_cfg.save_optim: @@ -1198,13 +1234,13 @@ def _load_model_weights_from_checkpoint( restore_modelopt_state(model, state_dict) model = unwrap_model(model) - sharded_state_dict = _generate_model_state_dict(model, model_sd_kwargs) + pg_collection = get_pg_collection(model) + sharded_state_dict = _generate_model_state_dict(model, model_sd_kwargs, pg_collection=pg_collection) load_strategy = get_default_load_sharded_strategy(checkpoint_path) if fully_parallel_load: - load_strategy = FullyParallelLoadStrategyWrapper( - load_strategy, mpu.get_data_parallel_group(with_context_parallel=True) - ) + pg_collection = get_pg_collection(model) + load_strategy = FullyParallelLoadStrategyWrapper(load_strategy, pg_collection.dp_cp) state_dict = dist_checkpointing.load( sharded_state_dict, checkpoint_path, load_strategy, strict=dist_ckpt_strictness ) @@ -1327,6 +1363,7 @@ def _load_checkpoint_from_path( """ cfg = state.cfg model = unwrap_model(model) + pg_collection = get_pg_collection(model) ckpt_format = cfg.checkpoint.ckpt_format # Step 1: Load base checkpoint with rank0=True (torch_dist only) @@ -1338,6 +1375,7 @@ def _load_checkpoint_from_path( checkpointing_context=checkpointing_context, ignore_ckpt_step=ignore_ckpt_step, cfg=cfg, + pg_collection=pg_collection, ) # Step 2: Initialize scaffolding @@ -1376,7 +1414,9 @@ def _load_checkpoint_from_path( and cfg.checkpoint.load_rng and run_config["checkpoint"]["save_rng"] ): - gen_sd_rng_state = get_rng_state(cfg.rng.data_parallel_random_init, ckpt_format) + gen_sd_rng_state = get_rng_state( + cfg.rng.data_parallel_random_init, ckpt_format, pg_collection=pg_collection + ) else: ignore_rng_state = True gen_sd_rng_state = None @@ -1436,6 +1476,7 @@ def _load_checkpoint_from_path( if ckpt_tp_pp != run_tp_pp: print_rank_0("{}: Rerun state will be ignored".format(mismatch_msg)) + sharded_sd_metadata["dp_cp_group"] = pg_collection.dp_cp optim_sd_kwargs = dict(metadata=sharded_sd_metadata, is_loading=True) model_sd_kwargs = dict(metadata=sharded_sd_metadata) @@ -1453,6 +1494,7 @@ def _load_checkpoint_from_path( optim_sd_kwargs=optim_sd_kwargs, model_sd_kwargs=model_sd_kwargs, rerun_state=gen_sd_rerun_state, + pg_collection=pg_collection, ) elif ckpt_format == "fsdp_dtensor": @@ -1492,7 +1534,9 @@ def _load_checkpoint_from_path( data_iterator=None, ckpt_format=ckpt_format, force=True ) if cfg.checkpoint.load_rng: - gen_sd_rng_state = get_rng_state(cfg.rng.data_parallel_random_init, ckpt_format) + gen_sd_rng_state = get_rng_state( + cfg.rng.data_parallel_random_init, ckpt_format, pg_collection=pg_collection + ) if cfg.checkpoint.load_optim: gen_sd_optim = optimizer gen_sd_opt_param_scheduler = opt_param_scheduler @@ -1511,6 +1555,7 @@ def _load_checkpoint_from_path( optim_sd_kwargs=optim_sd_kwargs, rerun_state=gen_sd_rerun_state, iteration=1, + pg_collection=pg_collection, ) # Store model reference for preprocessing during load state_dict["_model"] = model @@ -1542,6 +1587,7 @@ def _load_checkpoint_from_path( checkpointing_context=checkpointing_context, ignore_ckpt_step=ignore_ckpt_step, cfg=cfg, + pg_collection=pg_collection, **load_kwargs, ) @@ -1595,14 +1641,7 @@ def _load_checkpoint_from_path( # Load optimizer and scheduler if not release and not cfg.checkpoint.finetune and cfg.checkpoint.load_optim: try: - if isinstance(optimizer, LayerWiseDistributedOptimizer) and cfg.checkpoint.ckpt_format == "torch": - # LayerWiseDistributedOptimizer load optimizer state from file on different ranks - dp_rank = mpu.get_data_parallel_rank() - optim_checkpoint_name = os.path.join( - os.path.dirname(checkpoint_name), f"layer_wise_optimizer_{dp_rank}.pt" - ) - optimizer.load_state_dict_from_file(optim_checkpoint_name) - elif ( + if ( not skip_load_to_model_and_opt and optimizer is not None and not getattr(optimizer, "is_stub_optimizer", False) @@ -1643,8 +1682,8 @@ def _load_checkpoint_from_path( if "rng_state" in state_dict: if ckpt_format == "fsdp_dtensor": # FSDP DTensor format: {(pp_rank, tp_rank): rng_state_list} - tp_rank = mpu.get_tensor_model_parallel_rank() - pp_rank = mpu.get_pipeline_model_parallel_rank() + tp_rank = pg_collection.tp.rank() + pp_rank = pg_collection.pp.rank() key = f"({pp_rank}, {tp_rank})" if key in state_dict["rng_state"]: rng_state_list = state_dict["rng_state"][key] @@ -1652,14 +1691,14 @@ def _load_checkpoint_from_path( print_rank_0("WARNING: RNG state not found for current TP/PP rank") rng_state_list = next(iter(state_dict["rng_state"].values())) rng_state = ( - rng_state_list[mpu.get_data_parallel_rank()] + rng_state_list[pg_collection.dp.rank()] if cfg.rng.data_parallel_random_init else rng_state_list[0] ) else: # torch_dist format: ShardedObject rng_state = ( - state_dict["rng_state"][mpu.get_data_parallel_rank()] + state_dict["rng_state"][pg_collection.dp.rank()] if cfg.rng.data_parallel_random_init else state_dict["rng_state"][0] ) @@ -1693,8 +1732,8 @@ def _load_checkpoint_from_path( print_rank_0( f" successfully loaded checkpoint from {load_dir} " - f"[ t {mpu.get_tensor_model_parallel_rank() + 1}/{mpu.get_tensor_model_parallel_world_size()}, " - f"p {mpu.get_pipeline_model_parallel_rank() + 1}/{mpu.get_pipeline_model_parallel_world_size()} ] " + f"[ t {pg_collection.tp.rank()}/{pg_collection.tp.size()}, " + f"p {pg_collection.pp.rank()}/{pg_collection.pp.size()} ] " f"at iteration {state.train_state.step}" ) @@ -1932,6 +1971,8 @@ def _load_non_persistent_base_checkpoint( sharded_state_dict: Optional[dict[str, Any]], non_persistent_iteration: int, checkpointing_context: Optional[dict[str, Any]] = None, + *, + pg_collection: ProcessGroupCollection, ) -> tuple[dict[str, Any], str, bool, CheckpointType]: """Load the base state_dict from a non-persistent distributed checkpoint.""" assert ckpt_cfg.non_persistent_ckpt_type is not None @@ -1946,13 +1987,14 @@ def _load_non_persistent_base_checkpoint( non_persistent_iteration, False, checkpointing_context=checkpointing_context, + pg_collection=pg_collection, ) elif ckpt_cfg.non_persistent_ckpt_type == "local": intermediate_state_dict, checkpoint_name = checkpointing_context["local_checkpoint_manager"].load() state_dict = intermediate_state_dict.to_state_dict( sharded_state_dict, algo=ckpt_cfg.non_persistent_local_ckpt_algo, - parallelization_group=mpu.get_data_parallel_group(with_context_parallel=True), + parallelization_group=pg_collection.dp_cp, ) return state_dict, checkpoint_name, False, CheckpointType.LOCAL else: @@ -1969,6 +2011,8 @@ def _load_global_dist_base_checkpoint( iteration: int, release: bool, checkpointing_context: Optional[dict[str, Any]] = None, + *, + pg_collection: ProcessGroupCollection, ) -> tuple[dict[str, Any], str, bool, CheckpointType]: """Load the base state_dict from the given directory containing the global distributed checkpoint.""" if rank0: @@ -1982,9 +2026,7 @@ def _load_global_dist_base_checkpoint( checkpoint_name = get_checkpoint_name(load_dir, iteration, release) load_strategy = get_default_load_sharded_strategy(checkpoint_name) if ckpt_cfg.fully_parallel_load: - load_strategy = FullyParallelLoadStrategyWrapper( - load_strategy, mpu.get_data_parallel_group(with_context_parallel=True) - ) + load_strategy = FullyParallelLoadStrategyWrapper(load_strategy, pg_collection.dp_cp) if checkpointing_context is not None: checkpointing_context["load_strategy"] = load_strategy state_dict = dist_checkpointing.load( @@ -2001,6 +2043,8 @@ def _load_base_checkpoint( checkpointing_context: Optional[dict[str, Any]] = None, ignore_ckpt_step: bool = False, cfg: Optional[ConfigContainer] = None, + *, + pg_collection: ProcessGroupCollection, ) -> tuple[Optional[dict[str, Any]], str, bool, Optional[CheckpointType]]: """Load the base state_dict from the given directory. @@ -2046,6 +2090,7 @@ def _load_base_checkpoint( sharded_state_dict, non_persistent_iteration, checkpointing_context, + pg_collection=pg_collection, ) else: print_rank_0("WARNING: non-persistent checkpoints are older than persistent checkpoint") @@ -2086,6 +2131,7 @@ def _load_base_checkpoint( iteration, release, checkpointing_context=checkpointing_context, + pg_collection=pg_collection, ) elif ckpt_format == "fsdp_dtensor": return _load_fsdp_dtensor_base_checkpoint( diff --git a/src/megatron/bridge/training/config.py b/src/megatron/bridge/training/config.py index f33bd2809a..0aa19b3ff7 100644 --- a/src/megatron/bridge/training/config.py +++ b/src/megatron/bridge/training/config.py @@ -196,6 +196,11 @@ class DistributedInitConfig: disable_jit_fuser: bool = False """Disable the JIT fuser.""" + use_decentralized_pg: bool = False + """Use ProcessGroupCollection passed through functions instead of relying on mcore's + global parallel state (mpu) variables. When True, parallel groups are obtained from + the pg_collection object rather than the global megatron.core.parallel_state module.""" + @dataclass class RerunStateMachineConfig: @@ -1426,6 +1431,13 @@ def validate(self) -> None: self.model.use_cpu_initialization = self.model.use_cpu_initialization or self.dist.lazy_init + # Gloo process groups are not supported when using decentralized process groups (NCCL only). + if self.dist.use_decentralized_pg: + assert not self.dist.use_gloo_process_groups, ( + "Gloo process groups are not supported when use_decentralized_pg=True. " + "Decentralized process groups only support NCCL backend." + ) + # Make sure all functionality that requires Gloo process groups is disabled. if not self.dist.use_gloo_process_groups: if self.optimizer.use_distributed_optimizer: diff --git a/src/megatron/bridge/training/eval.py b/src/megatron/bridge/training/eval.py index a8a89537d0..5dc20bb19b 100644 --- a/src/megatron/bridge/training/eval.py +++ b/src/megatron/bridge/training/eval.py @@ -20,9 +20,11 @@ from megatron.core.full_cuda_graph import FullCudaGraphWrapper from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator from megatron.core.pipeline_parallel.utils import is_pp_last_stage from megatron.core.rerun_state_machine import RerunDataIterator, RerunMode, get_rerun_state_machine from megatron.core.transformer import MegatronModule +from megatron.core.utils import get_model_config from megatron.bridge.data.finetuning import prepare_finetuning_batch from megatron.bridge.data.iterator_utils import make_data_iterator_list @@ -74,8 +76,9 @@ def evaluate( for model_module in model: model_module.eval() - # Retrieve process group collection from the model + # Retrieve process group collection and model config from the model pg_collection = get_pg_collection(model) + model_config = get_model_config(model[0]) # Disable result validation during evaluation rerun_state_machine = get_rerun_state_machine() @@ -94,10 +97,17 @@ def evaluate( if state.cfg.model.cuda_graph_impl == "local" and "full_iteration" in state.cfg.model.cuda_graph_scope: forward_backward_func = FullCudaGraphWrapper( - get_forward_backward_func(), cuda_graph_warmup_steps=state.cfg.model.cuda_graph_warmup_steps + get_forward_backward_func( + pp_size=pg_collection.pp.size(), + vp_size=state.cfg.model.virtual_pipeline_model_parallel_size, + ), + cuda_graph_warmup_steps=state.cfg.model.cuda_graph_warmup_steps, ) else: - forward_backward_func = get_forward_backward_func() + forward_backward_func = get_forward_backward_func( + pp_size=pg_collection.pp.size(), + vp_size=state.cfg.model.virtual_pipeline_model_parallel_size, + ) iteration = 0 while iteration < state.cfg.train.eval_iters: @@ -128,6 +138,7 @@ def evaluate( # Don't care about timing during evaluation config.timers = None fault_tolerance.on_eval_step_start(state) + p2p_communicator = P2PCommunicator(pp_group=pg_collection.pp, config=model_config) loss_dicts = forward_backward_func( forward_step_func=wrapped_forward_step, data_iterator=eval_data_iterator, @@ -136,6 +147,8 @@ def evaluate( seq_length=seq_length, micro_batch_size=state.cfg.train.micro_batch_size, forward_only=True, + p2p_communicator=p2p_communicator, + pg_collection=pg_collection, ) fault_tolerance.on_eval_step_end(state) config.timers = state.timers @@ -197,6 +210,7 @@ def evaluate( data_iterator=non_loss_microbatch_iterator, ) + p2p_communicator = P2PCommunicator(pp_group=pg_collection.pp, config=model_config) collected_non_loss_data = forward_backward_func( forward_step_func=wrapped_forward_step, data_iterator=non_loss_data_iterator, @@ -206,6 +220,8 @@ def evaluate( micro_batch_size=state.cfg.train.micro_batch_size, forward_only=True, collect_non_loss_data=True, + p2p_communicator=p2p_communicator, + pg_collection=pg_collection, ) # Move model back to the train mode. diff --git a/src/megatron/bridge/training/gpt_step.py b/src/megatron/bridge/training/gpt_step.py index a0658425ae..f05981d57f 100644 --- a/src/megatron/bridge/training/gpt_step.py +++ b/src/megatron/bridge/training/gpt_step.py @@ -183,7 +183,7 @@ def get_batch( batch = _partition_packed_batch_for_cp(batch, cp_size) else: # slice batch along sequence dimension for context parallelism - batch = get_batch_on_this_cp_rank(batch) + batch = get_batch_on_this_cp_rank(batch, cp_group=pg_collection.cp) return ( batch["tokens"], diff --git a/src/megatron/bridge/training/initialize.py b/src/megatron/bridge/training/initialize.py index 7e74d292ab..eb10881251 100644 --- a/src/megatron/bridge/training/initialize.py +++ b/src/megatron/bridge/training/initialize.py @@ -24,12 +24,20 @@ from megatron.core.fusions.fused_bias_dropout import bias_dropout_add_fused_train from megatron.core.fusions.fused_bias_gelu import bias_gelu from megatron.core.fusions.fused_bias_swiglu import bias_swiglu +from megatron.core.hyper_comm_grid import HyperCommGrid from megatron.core.num_microbatches_calculator import ( destroy_num_microbatches_calculator, init_num_microbatches_calculator, ) +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler -from megatron.core.utils import configure_nvtx_profiling, get_te_version, is_te_min_version, is_torch_min_version +from megatron.core.utils import ( + configure_nvtx_profiling, + get_pg_rank, + get_te_version, + is_te_min_version, + is_torch_min_version, +) from megatron.bridge.models import GPTModelProvider, T5ModelProvider from megatron.bridge.training.config import ConfigContainer, DistributedInitConfig, RerunStateMachineConfig, RNGConfig @@ -49,7 +57,7 @@ def initialize_megatron( get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None, get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None, restart_store: Optional[torch.distributed.Store] = None, -) -> Optional[Callable[[], None]]: +) -> Callable[[], None] | ProcessGroupCollection | None: """Initialize Megatron core components and distributed setup. Sets up logging, initializes distributed environment (torch.distributed), @@ -132,7 +140,7 @@ def torch_dist_init( skip_mpu_initialization: bool, restart_store: Optional[torch.distributed.Store] = None, use_inprocess_restart: bool = False, -) -> Optional[Callable[[], None]]: +) -> Callable[[], None] | ProcessGroupCollection | None: """Initialize torch.distributed and dependent components. Handles the core distributed setup, including process group initialization, @@ -154,9 +162,9 @@ def torch_dist_init( or lazy_mpu_init is True, otherwise None. """ - def finish_mpu_init(): + def finish_mpu_init() -> ProcessGroupCollection: # Pytorch distributed. - _initialize_distributed( + pg_collection = _initialize_distributed( model_config=model_config, dist_config=dist_config, num_distributed_optimizer_instances=num_distributed_optimizer_instances, @@ -175,10 +183,12 @@ def finish_mpu_init(): rng_config.te_rng_tracker, rng_config.inference_rng_tracker, use_cudagraphable_rng=(model_config.cuda_graph_impl != "none"), + pg_collection=pg_collection, ) if model_config.num_moe_experts is not None: MoEAuxLossAutoScaler.set_loss_scale(torch.ones(1, device=torch.cuda.current_device())) + return pg_collection if skip_mpu_initialization: return None @@ -191,15 +201,13 @@ def finish_mpu_init(): # to call when it has DDP initialized parallel_state.set_tensor_model_parallel_rank(get_rank_safe()) return finish_mpu_init - else: - # Megatron's MPU is the master. Complete initialization right away. - finish_mpu_init() + # Megatron's MPU is the master. Complete initialization right away. + pg_collection = finish_mpu_init() - if model_config.tp_comm_overlap: - _initialize_tp_communicators(model_config, micro_batch_size) + if model_config.tp_comm_overlap: + _initialize_tp_communicators(model_config, micro_batch_size) - # No continuation function - return None + return pg_collection def init_rerun_state(rerun_state_machine_config: RerunStateMachineConfig) -> None: @@ -349,6 +357,150 @@ def _initialize_tp_communicators(model_config: GPTModelProvider | T5ModelProvide ) +def _create_pg_collection( + model_config: GPTModelProvider | T5ModelProvider, + num_distributed_optimizer_instances: int, + get_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None, + get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None, +) -> ProcessGroupCollection: + """Create all process groups via HyperCommGrid and return a ProcessGroupCollection.""" + world_size = torch.distributed.get_world_size() + tp_size = int(model_config.tensor_model_parallel_size) + pp_size = int(model_config.pipeline_model_parallel_size) + cp_size = int(model_config.context_parallel_size) if getattr(model_config, "context_parallel_size", 1) else 1 + model_size = tp_size * pp_size * cp_size + if world_size % model_size != 0: + raise RuntimeError(f"world_size ({world_size}) is not divisible by {model_size}") + dp_size = world_size // model_size + + grid = HyperCommGrid( + shape=[tp_size, cp_size, dp_size, pp_size], + dim_names=["tp", "cp", "dp", "pp"], + rank_offset=0, + backend="nccl", + ) + # Core groups + tp_pg = grid.create_pg(["tp"]) + cp_pg = grid.create_pg(["cp"]) + pp_pg = grid.create_pg(["pp"]) + dp_pg = grid.create_pg(["dp"]) + mp_pg = grid.create_pg(["tp", "pp"]) + tp_cp_pg = grid.create_pg(["tp", "cp"]) + tp_dp_cp_pg = grid.create_pg(["tp", "dp", "cp"]) + dp_cp_pg = grid.create_pg(["dp", "cp"]) + + # Expert/MoE related groups (refer to original parallel_state.initialize_model_parallel) + expert_tp_size = ( + int(model_config.expert_tensor_parallel_size) + if getattr(model_config, "expert_tensor_parallel_size", None) + else tp_size + ) + ep_size = ( + int(model_config.expert_model_parallel_size) if getattr(model_config, "expert_model_parallel_size", 1) else 1 + ) + # Expert data-parallel size folds CP into DP (as in original expert rank generator) + expt_model_block = expert_tp_size * ep_size * pp_size + if world_size % expt_model_block != 0: + raise RuntimeError( + f"world_size ({world_size}) is not divisible by expert_tensor_model_pipeline size ({expt_model_block})" + ) + expt_dp_size = world_size // expt_model_block + use_optimizer_instance_groups = num_distributed_optimizer_instances > 1 + inner_dp_dim: Optional[str] = None + outer_dp_dim: Optional[str] = None + if use_optimizer_instance_groups: + assert expt_dp_size % num_distributed_optimizer_instances == 0, ( + "Expert DP size must be divisible by the number of optimizer instances." + ) + inner_expt_dp_size = expt_dp_size // num_distributed_optimizer_instances + expert_grid = HyperCommGrid( + shape=[expert_tp_size, ep_size, inner_expt_dp_size, num_distributed_optimizer_instances, pp_size], + dim_names=["tp", "ep", "inner_dp", "outer_dp", "pp"], + rank_offset=0, + backend="nccl", + ) + dp_group_dims: list[str] = ["inner_dp", "outer_dp"] + inner_dp_dim = "inner_dp" + outer_dp_dim = "outer_dp" + else: + expert_grid = HyperCommGrid( + shape=[expert_tp_size, ep_size, expt_dp_size, pp_size], + dim_names=["tp", "ep", "dp", "pp"], + rank_offset=0, + backend="nccl", + ) + dp_group_dims = ["dp"] + ep_pg = expert_grid.create_pg(["ep"]) + expt_tp_pg = expert_grid.create_pg(["tp"]) + tp_ep_pg = expert_grid.create_pg(["tp", "ep"]) + tp_ep_pp_pg = expert_grid.create_pg(["tp", "ep", "pp"]) + expt_dp_pg = expert_grid.create_pg(dp_group_dims) + + # Embedding and position-embedding groups + embd_pg = None + pos_embd_pg = None + # Enumerate ranks per PP group + pp_rank_lists = grid._gen_rank_enum(["pp"]) + # Determine embedding ranks for each pp group + embedding_rank_lists: list[list[int]] = [] + pos_embedding_rank_lists: list[list[int]] = [] + for ranks in pp_rank_lists: + if not ranks: + continue + if get_embedding_ranks is not None: + # Use custom callback to determine embedding ranks + embedding_rank_lists.append(get_embedding_ranks(ranks, pp_size)) + else: + # Default: embedding_ranks are first and last pp stage (or only one if pp_size==1) + embedding_rank_lists.append([ranks[0]] if len(ranks) == 1 else [ranks[0], ranks[-1]]) + if get_position_embedding_ranks is not None: + # Use custom callback to determine position embedding ranks + pos_embedding_rank_lists.append(get_position_embedding_ranks(ranks, pp_size)) + else: + # Default: position embedding ranks are first pp stage only + pos_embedding_rank_lists.append([ranks[0]]) + if embedding_rank_lists: + embd_pg, _ = torch.distributed.new_subgroups_by_enumeration(embedding_rank_lists, backend="nccl") + if pos_embedding_rank_lists: + pos_embd_pg, _ = torch.distributed.new_subgroups_by_enumeration(pos_embedding_rank_lists, backend="nccl") + + # Build Partial-Distributed-Optimizer groups for Expert DP when multiple instances are used. + intra_expt_dp_pg = None + inter_dist_opt_pg = None + intra_dist_opt_pg = None + if inner_dp_dim is not None and outer_dp_dim is not None: + intra_expt_dp_pg = expert_grid.create_pg([inner_dp_dim]) + inter_dist_opt_pg = expert_grid.create_pg([outer_dp_dim]) + # Match distributed optimizer instance grouping from parallel_state: + # combine tp-ep-pp ranks across the intra-partial DP slice. + intra_dist_opt_pg = expert_grid.create_pg(["tp", "ep", inner_dp_dim, "pp"]) + + # Build ProcessGroupCollection with available groups. + pg_collection = ProcessGroupCollection( + tp=tp_pg, + pp=pp_pg, + mp=mp_pg, + embd=embd_pg, + pos_embd=pos_embd_pg, + cp=cp_pg, + tp_cp=tp_cp_pg, + hcp=None, + ep=ep_pg, + expt_tp=expt_tp_pg, + tp_ep=tp_ep_pg, + tp_ep_pp=tp_ep_pp_pg, + tp_dp_cp=tp_dp_cp_pg, + dp=dp_pg, + dp_cp=dp_cp_pg, + expt_dp=expt_dp_pg, + intra_dp_cp=dp_cp_pg, + intra_expt_dp=intra_expt_dp_pg if intra_expt_dp_pg is not None else expt_dp_pg, + inter_dist_opt=inter_dist_opt_pg, + intra_dist_opt=intra_dist_opt_pg, + ) + return pg_collection + + def _initialize_distributed( model_config: GPTModelProvider | T5ModelProvider, dist_config: DistributedInitConfig, @@ -357,7 +509,7 @@ def _initialize_distributed( get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]], restart_store: Optional[torch.distributed.Store] = None, use_inprocess_restart: bool = False, -) -> None: +) -> ProcessGroupCollection: """Initialize torch.distributed and core model parallel.""" device_count = torch.cuda.device_count() @@ -412,7 +564,30 @@ def _initialize_distributed( # Set the tensor model-parallel, pipeline model-parallel, and # data-parallel communicators. - if device_count > 0: + + if device_count == 0: + if dist_config.use_decentralized_pg or dist_config.distributed_backend == "nccl": + raise RuntimeError("Cannot initialize parallel groups with no CUDA devices available (device_count=0)") + + if dist_config.use_decentralized_pg: + # Use HyperCommGrid to create local parallel groups passed through functions + # instead of relying on mcore's global parallel state (mpu) variables. + parallel_state._set_global_memory_buffer() + pg_collection = _create_pg_collection( + model_config, + num_distributed_optimizer_instances, + get_embedding_ranks=get_embedding_ranks, + get_position_embedding_ranks=get_position_embedding_ranks, + ) + if get_rank_safe() == 0: + tp = int(model_config.tensor_model_parallel_size) + pp = int(model_config.pipeline_model_parallel_size) + cp = int(model_config.context_parallel_size) if getattr(model_config, "context_parallel_size", 1) else 1 + dp = torch.distributed.get_world_size() // (tp * pp * cp) + print(f"> initialized HyperCommGrid with tp={tp}, pp={pp}, cp={cp}, dp={dp}") + return pg_collection + else: + # Use the original mcore parallel_state.initialize_model_parallel approach if parallel_state.model_parallel_is_initialized(): print("model parallel is already initialized") else: @@ -445,6 +620,8 @@ def _initialize_distributed( f"> initialized pipeline model parallel with size " f"{parallel_state.get_pipeline_model_parallel_world_size()}" ) + # Return a ProcessGroupCollection using mpu process groups + return ProcessGroupCollection.use_mpu_process_groups() def _set_random_seed( @@ -453,6 +630,8 @@ def _set_random_seed( te_rng_tracker: bool = False, inference_rng_tracker: bool = False, use_cudagraphable_rng: bool = False, + *, + pg_collection: ProcessGroupCollection, ) -> None: """Set random seed for reproducability.""" assert seed_ is not None and seed_ > 0, f"Seed ({seed_}) should be a positive integer." @@ -461,17 +640,31 @@ def _set_random_seed( import numpy as np + current_rank = torch.distributed.get_rank() # Ensure that different pipeline MP stages get different seeds. - seed = seed_ + (100 * parallel_state.get_pipeline_model_parallel_rank()) + pp_rank = torch.distributed.get_group_rank(pg_collection.pp, current_rank) + seed = seed_ + (100 * pp_rank) # Ensure different data parallel ranks get different seeds if data_parallel_random_init: - seed = seed + (10 * parallel_state.get_data_parallel_rank()) + dp_rank = torch.distributed.get_group_rank(pg_collection.dp, current_rank) + seed = seed + (10 * dp_rank) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.device_count() > 0: + # Derive TP/EP/ETP ranks from provided process groups using helper utils + tp_rank = get_pg_rank(pg_collection.tp) + ep_rank = get_pg_rank(pg_collection.ep) + etp_rank = get_pg_rank(pg_collection.expt_tp) + tensor_parallel.model_parallel_cuda_manual_seed( - seed, te_rng_tracker, inference_rng_tracker, use_cudagraphable_rng + seed, + te_rng_tracker, + inference_rng_tracker, + use_cudagraphable_rng, + tp_rank=tp_rank, + ep_rank=ep_rank, + etp_rank=etp_rank, ) @@ -511,7 +704,8 @@ def _warmup_jit_function(model_config: GPTModelProvider | T5ModelProvider, micro # Warmup fused bias+dropout+add if model_config.sequence_parallel: - seq_length = model_config.seq_length // parallel_state.get_tensor_model_parallel_world_size() + tp_world_size = int(model_config.tensor_model_parallel_size) + seq_length = model_config.seq_length // tp_world_size else: seq_length = model_config.seq_length input = torch.rand( diff --git a/src/megatron/bridge/training/llava_step.py b/src/megatron/bridge/training/llava_step.py index fbcd286650..1723cfaeaf 100644 --- a/src/megatron/bridge/training/llava_step.py +++ b/src/megatron/bridge/training/llava_step.py @@ -17,8 +17,8 @@ from typing import Iterable import torch -from megatron.core import parallel_state from megatron.core.models.gpt import GPTModel +from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage from megatron.core.utils import get_batch_on_this_cp_rank, get_model_config from megatron.bridge.training.config import ConfigContainer @@ -27,6 +27,7 @@ ) from megatron.bridge.training.losses import masked_next_token_loss from megatron.bridge.training.state import GlobalState +from megatron.bridge.training.utils.pg_utils import get_pg_collection logger = logging.getLogger(__name__) @@ -35,12 +36,17 @@ def get_batch_from_iterator( data_iterator: Iterable, skip_getting_attention_mask_from_dataset: bool = True, + *, + is_first_pp_stage: bool, + is_last_pp_stage: bool, ) -> dict[str, torch.Tensor]: """Get a batch of data from the iterator. Args: data_iterator: The data iterator to get the batch from. skip_getting_attention_mask_from_dataset: If set, the dataset will pass a None attention mask. + is_first_pp_stage: Whether this is the first pipeline parallel stage. + is_last_pp_stage: Whether this is the last pipeline parallel stage. Returns: dict[str, torch.Tensor]: A dictionary containing the batch data. @@ -62,9 +68,9 @@ def get_batch_from_iterator( required_host_keys.add("cu_seqlens_argmin") required_host_keys.add("max_seqlen") - if parallel_state.is_pipeline_first_stage(): + if is_first_pp_stage: required_device_keys.update(("tokens", "input_ids", "position_ids")) - if parallel_state.is_pipeline_last_stage(): + if is_last_pp_stage: required_device_keys.update(("labels", "loss_mask")) _batch_required_keys = {} @@ -80,7 +86,7 @@ def get_batch_from_iterator( def get_batch( - data_iterator: Iterable, cfg: ConfigContainer + data_iterator: Iterable, cfg: ConfigContainer, *, pg_collection ) -> tuple[ torch.Tensor, torch.Tensor, @@ -98,23 +104,29 @@ def get_batch( Args: data_iterator: Input data iterator cfg: Configuration container + pg_collection: Process group collection for distributed training Returns: tuple of tensors containing tokens, labels, loss_mask, attention_mask, position_ids, cu_seqlens (optional), cu_seqlens_argmin (optional), max_seqlen (optional), images (optional) """ - if (not parallel_state.is_pipeline_first_stage()) and (not parallel_state.is_pipeline_last_stage()): + # Determine pipeline stage role via process group collection + is_first = is_pp_first_stage(pg_collection.pp) + is_last = is_pp_last_stage(pg_collection.pp) + if (not is_first) and (not is_last): return None, None, None, None, None, None, None, None, None, None batch = get_batch_from_iterator( data_iterator, getattr(cfg.dataset, "skip_getting_attention_mask_from_dataset", True), + is_first_pp_stage=is_first, + is_last_pp_stage=is_last, ) # Keep optional vision tensors aside to avoid being dropped by CP slicing util images = batch.get("pixel_values") # slice batch along sequence dimension for context parallelism - batch = get_batch_on_this_cp_rank(batch) + batch = get_batch_on_this_cp_rank(batch, cp_group=pg_collection.cp) if images is not None: batch["images"] = images @@ -152,6 +164,8 @@ def forward_step( config = get_model_config(model) + pg_collection = get_pg_collection(model) + timers("batch-generator", log_level=2).start() with straggler_timer(bdata=True): ( @@ -165,7 +179,7 @@ def forward_step( cu_seqlens, cu_seqlens_argmin, max_seqlen, - ) = get_batch(data_iterator, state.cfg) + ) = get_batch(data_iterator, state.cfg, pg_collection=pg_collection) timers("batch-generator").stop() diff --git a/src/megatron/bridge/training/model_load_save.py b/src/megatron/bridge/training/model_load_save.py index 85df8fc3c5..651f81e7b6 100644 --- a/src/megatron/bridge/training/model_load_save.py +++ b/src/megatron/bridge/training/model_load_save.py @@ -514,11 +514,15 @@ def save_megatron_model( generate_state_dict, get_rng_state, ) + from megatron.bridge.training.utils.pg_utils import get_pg_collection logger.info("[LOW_MEMORY_SAVE] Generating state dict...") # Get RNG state (minimal, since save_rng=False) - rng_state = get_rng_state(data_parallel_random_init=False, ckpt_format=ckpt_format) + pg_collection = get_pg_collection(model) + rng_state = get_rng_state( + data_parallel_random_init=False, ckpt_format=ckpt_format, pg_collection=pg_collection + ) # Build sharded state dict metadata sharded_sd_metadata = _build_sharded_state_dict_metadata(False, state.cfg.checkpoint) @@ -652,6 +656,7 @@ def _collect_factories(d): opt_param_scheduler=None, num_floating_point_operations_so_far=0, prebuilt_state_dict=state_dict, + pg_collection=pg_collection, ) else: # Save the checkpoint diff --git a/src/megatron/bridge/training/optim.py b/src/megatron/bridge/training/optim.py index c0e1f82517..fb22ac09f3 100644 --- a/src/megatron/bridge/training/optim.py +++ b/src/megatron/bridge/training/optim.py @@ -21,6 +21,7 @@ ) from megatron.core.optimizer.muon import get_megatron_muon_optimizer from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.module import MegatronModule from megatron.bridge.training.config import ( @@ -35,6 +36,7 @@ def setup_optimizer( scheduler_config: SchedulerConfig, model: Union[MegatronModule, list[MegatronModule]], use_gloo_process_groups: bool = False, + pg_collection: Optional[ProcessGroupCollection] = None, optimizer_config_override_provider: Optional[OptimizerConfigOverrideProvider] = None, ) -> tuple[MegatronOptimizer, OptimizerParamScheduler]: """Set up the optimizer and scheduler. @@ -44,6 +46,7 @@ def setup_optimizer( scheduler_config: Configuration for the scheduler model: The model to optimize use_gloo_process_groups: Whether to use Gloo process groups + pg_collection: Optional process group collection for distributed training Returns: tuple containing the optimizer and scheduler @@ -62,6 +65,7 @@ def setup_optimizer( model_chunks=model, config_overrides=config_overrides, use_gloo_process_groups=use_gloo_process_groups, + pg_collection=pg_collection, ) else: optimizer = get_megatron_muon_optimizer( @@ -70,6 +74,7 @@ def setup_optimizer( config_overrides=config_overrides, use_gloo_process_groups=use_gloo_process_groups, layer_wise_distributed_optimizer="dist" in optimizer_config.optimizer, + pg_collection=pg_collection, ) scheduler = _get_scheduler(optimizer_config, scheduler_config, optimizer) diff --git a/src/megatron/bridge/training/setup.py b/src/megatron/bridge/training/setup.py index ca7b4f8fe1..85a0dd136c 100644 --- a/src/megatron/bridge/training/setup.py +++ b/src/megatron/bridge/training/setup.py @@ -129,7 +129,10 @@ def setup( set_level_for_all_loggers=cfg.logger.set_level_for_all_loggers, ) - initialize_megatron( + # pg_collection is returned from initialize_megatron: + # - When use_decentralized_pg=True: uses HyperCommGrid to create local process groups + # - When use_decentralized_pg=False: uses mpu's global parallel state + pg_collection = initialize_megatron( cfg=cfg, get_embedding_ranks=get_embedding_ranks, get_position_embedding_ranks=get_position_embedding_ranks, @@ -158,9 +161,6 @@ def setup( print_rank_0("time to initialize megatron (seconds): {:.3f}".format(time.time() - state.start_time)) barrier_and_log("after megatron is initialized") - # Initialize process group collection once and pass through - pg_collection = ProcessGroupCollection.use_mpu_process_groups() - # Context used for persisting some state between checkpoint saves. checkpointing_context = init_checkpointing_context(cfg.checkpoint) @@ -217,6 +217,7 @@ def modelopt_pre_wrap_hook(model): use_torch_fsdp2=cfg.dist.use_torch_fsdp2, overlap_param_gather_with_optimizer_step=cfg.optimizer.overlap_param_gather_with_optimizer_step, data_parallel_random_init=cfg.rng.data_parallel_random_init, + pg_collection=pg_collection, ) cfg.model.timers = timers @@ -226,6 +227,9 @@ def modelopt_pre_wrap_hook(model): scheduler_config=cfg.scheduler, model=model, use_gloo_process_groups=cfg.dist.use_gloo_process_groups, + # Only pass pg_collection when use_decentralized_pg is True. + # When False, mcore's optimizer will use parallel_state directly which supports Gloo. + pg_collection=pg_collection if cfg.dist.use_decentralized_pg else None, optimizer_config_override_provider=cfg.optimizer_config_override_provider, ) timers("model-and-optimizer-setup").stop() diff --git a/src/megatron/bridge/training/train.py b/src/megatron/bridge/training/train.py index c4c24681f0..5330079637 100644 --- a/src/megatron/bridge/training/train.py +++ b/src/megatron/bridge/training/train.py @@ -36,7 +36,8 @@ from megatron.core.optimizer.qk_clip import clip_qk from megatron.core.optimizer_param_scheduler import OptimizerParamScheduler from megatron.core.parallel_state import update_pg_timeout -from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator +from megatron.core.pipeline_parallel.schedules import get_forward_backward_func from megatron.core.pipeline_parallel.utils import ( is_pp_first_stage, is_pp_last_stage, @@ -249,7 +250,10 @@ def train( history_wct = deque(maxlen=config.logger.throughput_window_size + 1) # Wrap forward_backward_func for Full iteration CUDA graph - forward_backward_func = get_forward_backward_func() + forward_backward_func = get_forward_backward_func( + pp_size=pg_collection.pp.size(), + vp_size=config.model.virtual_pipeline_model_parallel_size, + ) if config.model.cuda_graph_impl == "local" and "full_iteration" in config.model.cuda_graph_scope: forward_backward_func = FullCudaGraphWrapper( forward_backward_func, cuda_graph_warmup_steps=config.model.cuda_graph_warmup_steps @@ -651,13 +655,18 @@ def train_step( ) # [ModelOpt]: Pipeline-parallel Distillation stacks student and teacher tensors - adjust_tensor_shapes_fn = get_tensor_shapes_adjust_fn_for_distillation( - model, - seq_length=model_config.seq_length, - micro_batch_size=train_config.micro_batch_size, - decoder_seq_length=model_config.seq_length, - ) + if not cfg.dist.use_decentralized_pg: + adjust_tensor_shapes_fn = get_tensor_shapes_adjust_fn_for_distillation( + model, + seq_length=model_config.seq_length, + micro_batch_size=train_config.micro_batch_size, + decoder_seq_length=model_config.seq_length, + ) + else: + adjust_tensor_shapes_fn = None + # Forward pass. + p2p_communicator = P2PCommunicator(pp_group=pg_collection.pp, config=model_config) losses_reduced = forward_backward_func( forward_step_func=forward_step_func, data_iterator=forward_backward_data_iterator, @@ -668,6 +677,8 @@ def train_step( decoder_seq_length=seq_length, forward_only=False, adjust_tensor_shapes_fn=adjust_tensor_shapes_fn, + p2p_communicator=p2p_communicator, + pg_collection=pg_collection, ) should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit() if should_exit: @@ -710,7 +721,7 @@ def train_step( if train_config.empty_unused_memory_level >= 2: torch.cuda.empty_cache() - if pg_collection.pp.rank() == pg_collection.pp.size() - 1: + if is_pp_last_stage(pg_collection.pp): # Average loss across microbatches. loss_reduced = {} diff --git a/src/megatron/bridge/training/vlm_step.py b/src/megatron/bridge/training/vlm_step.py index fe01ad8e52..1e013eeebd 100644 --- a/src/megatron/bridge/training/vlm_step.py +++ b/src/megatron/bridge/training/vlm_step.py @@ -137,7 +137,7 @@ def get_batch( # Slice only text tensors for context parallelism cp_keys = ("tokens", "input_ids", "labels", "loss_mask", "attention_mask", "position_ids") cp_slice = {k: batch.get(k) for k in cp_keys if k in batch} - cp_slice = get_batch_on_this_cp_rank(cp_slice) + cp_slice = get_batch_on_this_cp_rank(cp_slice, cp_group=pg_collection.cp) for k, v in cp_slice.items(): batch[k] = v diff --git a/tests/functional_tests/data/datasets/test_chat_template.py b/tests/functional_tests/data/datasets/test_chat_template.py index 587a4b4c00..e223f3e279 100644 --- a/tests/functional_tests/data/datasets/test_chat_template.py +++ b/tests/functional_tests/data/datasets/test_chat_template.py @@ -142,13 +142,19 @@ def setup_and_teardown_parallel_state(self): context_parallel_size=1, ) + from megatron.core.process_groups_config import ProcessGroupCollection + from megatron.bridge.training.initialize import _set_random_seed + # Create pg_collection from initialized mpu + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + _set_random_seed( seed_=1234, data_parallel_random_init=False, te_rng_tracker=True, inference_rng_tracker=False, + pg_collection=pg_collection, ) yield @@ -497,13 +503,19 @@ def setup_and_teardown_parallel_state(self): context_parallel_size=1, ) + from megatron.core.process_groups_config import ProcessGroupCollection + from megatron.bridge.training.initialize import _set_random_seed + # Create pg_collection from initialized mpu + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + _set_random_seed( seed_=1234, data_parallel_random_init=False, te_rng_tracker=True, inference_rng_tracker=False, + pg_collection=pg_collection, ) yield diff --git a/tests/functional_tests/data/datasets/test_sft.py b/tests/functional_tests/data/datasets/test_sft.py index cccc03d42b..f0d19df8e1 100644 --- a/tests/functional_tests/data/datasets/test_sft.py +++ b/tests/functional_tests/data/datasets/test_sft.py @@ -107,13 +107,19 @@ def setup_and_teardown_parallel_state(self): ) assert parallel_state.model_parallel_is_initialized(), "Model parallel not initialized" + from megatron.core.process_groups_config import ProcessGroupCollection + from megatron.bridge.training.initialize import _set_random_seed + # Create pg_collection from initialized mpu + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + _set_random_seed( seed_=1234, data_parallel_random_init=False, te_rng_tracker=True, inference_rng_tracker=False, + pg_collection=pg_collection, ) yield diff --git a/tests/functional_tests/data/energon/test_base_energon_datamodule.py b/tests/functional_tests/data/energon/test_base_energon_datamodule.py index 6a59ff4684..1b907d1cfa 100644 --- a/tests/functional_tests/data/energon/test_base_energon_datamodule.py +++ b/tests/functional_tests/data/energon/test_base_energon_datamodule.py @@ -67,13 +67,19 @@ def setup_and_teardown_parallel_state(self): assert parallel_state.model_parallel_is_initialized(), "Model parallel not initialized" # Seed + from megatron.core.process_groups_config import ProcessGroupCollection + from megatron.bridge.training.initialize import _set_random_seed + # Create pg_collection from initialized mpu + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + _set_random_seed( seed_=1234, data_parallel_random_init=False, te_rng_tracker=True, inference_rng_tracker=False, + pg_collection=pg_collection, ) yield diff --git a/tests/functional_tests/training/test_decentralized_pg.py b/tests/functional_tests/training/test_decentralized_pg.py new file mode 100644 index 0000000000..a65b791626 --- /dev/null +++ b/tests/functional_tests/training/test_decentralized_pg.py @@ -0,0 +1,854 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Functional tests for the use_decentralized_pg feature. + +This feature enables using ProcessGroupCollection passed through functions instead +of relying on mcore's global parallel state (mpu) variables. When enabled, parallel +groups are obtained from the pg_collection object rather than the global +megatron.core.parallel_state module. +""" + +import os + +import pytest +import torch + +from megatron.bridge.models.llama import Llama32ModelProvider1B +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + DistributedDataParallelConfig, + DistributedInitConfig, + LoggerConfig, + MockGPTDatasetConfig, + OptimizerConfig, + RNGConfig, + SchedulerConfig, + TokenizerConfig, + TrainingConfig, +) +from megatron.bridge.training.gpt_step import forward_step +from megatron.bridge.training.initialize import destroy_global_state +from megatron.bridge.training.pretrain import pretrain +from tests.functional_tests.utils import ( + broadcast_path, + clear_directories, + initialize_distributed, + verify_checkpoint_files, +) + + +@pytest.fixture(autouse=True) +def cleanup_megatron_state(): + """Cleanup Megatron global state after each test. + + This fixture ensures that global state is cleaned up even if a test fails, + preventing state leakage between tests when running multiple tests in the + same pytest session. + """ + yield + # Cleanup after test (runs even if test fails) + try: + destroy_global_state() + except Exception: + # Ignore errors during cleanup - state might not have been initialized + pass + + +class TestDecentralizedPgPretrain: + """ + Functional tests for pretraining with use_decentralized_pg enabled. + """ + + @pytest.mark.run_only_on("GPU") + def test_pretrain_with_decentralized_pg(self, tmp_path): + """ + Test end to end training with use_decentralized_pg=True. + + This test verifies that training works correctly when parallel groups + are passed through functions instead of using global mpu state. + """ + initialize_distributed() + shared_base_dir = broadcast_path(tmp_path) + + checkpoint_dir = os.path.join(shared_base_dir, "checkpoints") + tensorboard_dir = os.path.join(shared_base_dir, "tensorboard") + + if torch.distributed.get_rank() == 0: + os.makedirs(checkpoint_dir, exist_ok=True) + os.makedirs(tensorboard_dir, exist_ok=True) + + torch.distributed.barrier() + + try: + global_batch_size = 8 + micro_batch_size = 1 + seq_length = 512 + total_iters = 5 + + model_cfg = Llama32ModelProvider1B( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + sequence_parallel=False, + attention_softmax_in_fp32=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + seq_length=seq_length, + make_vocab_size_divisible_by=128, + vocab_size=None, + num_layers=1, + # Disable shared embeddings - not supported with decentralized PG + share_embeddings_and_output_weights=False, + ) + + # Config Container with use_decentralized_pg=True + cfg = ConfigContainer( + model=model_cfg, + train=TrainingConfig( + train_iters=total_iters, + eval_interval=5, + eval_iters=0, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + exit_signal_handler=True, + ), + optimizer=OptimizerConfig( + optimizer="adam", + bf16=True, + fp16=False, + adam_beta1=0.9, + adam_beta2=0.95, + adam_eps=1e-5, + use_distributed_optimizer=True, + clip_grad=1.0, + lr=3e-3, + weight_decay=0.01, + min_lr=1e-6, + ), + scheduler=SchedulerConfig( + start_weight_decay=0.033, + end_weight_decay=0.033, + weight_decay_incr_style="constant", + lr_decay_style="cosine", + lr_warmup_iters=1, + lr_warmup_init=0.0, + lr_decay_iters=total_iters, + override_opt_param_scheduler=True, + ), + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + use_distributed_optimizer=True, + ), + dist=DistributedInitConfig( + use_decentralized_pg=True, # Enable the feature + use_gloo_process_groups=False, # Gloo not supported with custom pg_collection + ), + dataset=MockGPTDatasetConfig( + random_seed=1234, + reset_attention_mask=False, + reset_position_ids=False, + eod_mask_loss=False, + seq_length=seq_length, + num_dataset_builder_threads=1, + data_sharding=True, + dataloader_type="single", + num_workers=1, + ), + logger=LoggerConfig( + log_interval=5, + tensorboard_dir=tensorboard_dir, + ), + tokenizer=TokenizerConfig( + tokenizer_type="NullTokenizer", + vocab_size=10000, + ), + checkpoint=CheckpointConfig( + save_interval=total_iters, + save=checkpoint_dir, + ckpt_format="torch_dist", + fully_parallel_save=True, + async_save=True, + ), + rng=RNGConfig(seed=1234), + ) + + # Run training + pretrain(cfg, forward_step) + + # Verify training completed + torch.distributed.barrier() + verify_checkpoint_files(checkpoint_dir, total_iters) + + finally: + clear_directories(tmp_path) + + @pytest.mark.run_only_on("GPU") + def test_pretrain_with_decentralized_pg_disabled(self, tmp_path): + """ + Test end to end training with use_decentralized_pg=False (default). + + This test verifies that training works correctly with the default + behavior using global mpu state. + """ + initialize_distributed() + shared_base_dir = broadcast_path(tmp_path) + + checkpoint_dir = os.path.join(shared_base_dir, "checkpoints") + tensorboard_dir = os.path.join(shared_base_dir, "tensorboard") + + if torch.distributed.get_rank() == 0: + os.makedirs(checkpoint_dir, exist_ok=True) + os.makedirs(tensorboard_dir, exist_ok=True) + + torch.distributed.barrier() + + try: + global_batch_size = 8 + micro_batch_size = 1 + seq_length = 512 + total_iters = 5 + + model_cfg = Llama32ModelProvider1B( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + sequence_parallel=False, + attention_softmax_in_fp32=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + seq_length=seq_length, + make_vocab_size_divisible_by=128, + vocab_size=None, + num_layers=1, + ) + + # Config Container with use_decentralized_pg=False (default) + cfg = ConfigContainer( + model=model_cfg, + train=TrainingConfig( + train_iters=total_iters, + eval_interval=5, + eval_iters=0, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + exit_signal_handler=True, + ), + optimizer=OptimizerConfig( + optimizer="adam", + bf16=True, + fp16=False, + adam_beta1=0.9, + adam_beta2=0.95, + adam_eps=1e-5, + use_distributed_optimizer=True, + clip_grad=1.0, + lr=3e-3, + weight_decay=0.01, + min_lr=1e-6, + ), + scheduler=SchedulerConfig( + start_weight_decay=0.033, + end_weight_decay=0.033, + weight_decay_incr_style="constant", + lr_decay_style="cosine", + lr_warmup_iters=1, + lr_warmup_init=0.0, + lr_decay_iters=total_iters, + override_opt_param_scheduler=True, + ), + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + use_distributed_optimizer=True, + ), + dist=DistributedInitConfig( + use_decentralized_pg=False, # Explicitly disable (default) + ), + dataset=MockGPTDatasetConfig( + random_seed=1234, + reset_attention_mask=False, + reset_position_ids=False, + eod_mask_loss=False, + seq_length=seq_length, + num_dataset_builder_threads=1, + data_sharding=True, + dataloader_type="single", + num_workers=1, + ), + logger=LoggerConfig( + log_interval=5, + tensorboard_dir=tensorboard_dir, + ), + tokenizer=TokenizerConfig( + tokenizer_type="NullTokenizer", + vocab_size=10000, + ), + checkpoint=CheckpointConfig( + save_interval=total_iters, + save=checkpoint_dir, + ckpt_format="torch_dist", + fully_parallel_save=True, + async_save=True, + ), + rng=RNGConfig(seed=1234), + ) + + # Run training + pretrain(cfg, forward_step) + + # Verify training completed + torch.distributed.barrier() + verify_checkpoint_files(checkpoint_dir, total_iters) + + finally: + clear_directories(tmp_path) + + # + @pytest.mark.run_only_on("GPU") + def test_pretrain_with_decentralized_pg_and_pp(self, tmp_path): + """ + Test training with use_decentralized_pg=True and pipeline parallelism. + + This test verifies that the decentralized process groups feature works correctly + with pipeline parallelism enabled. + """ + initialize_distributed() + shared_base_dir = broadcast_path(tmp_path) + + # Skip if world size is not at least 2 for PP + world_size = torch.distributed.get_world_size() + if world_size < 2: + pytest.skip("This test requires at least 2 GPUs for PP=2") + + checkpoint_dir = os.path.join(shared_base_dir, "checkpoints") + tensorboard_dir = os.path.join(shared_base_dir, "tensorboard") + + if torch.distributed.get_rank() == 0: + os.makedirs(checkpoint_dir, exist_ok=True) + os.makedirs(tensorboard_dir, exist_ok=True) + + torch.distributed.barrier() + + try: + global_batch_size = 8 + micro_batch_size = 1 + seq_length = 512 + total_iters = 5 + + model_cfg = Llama32ModelProvider1B( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=2, # Enable PP + context_parallel_size=1, + sequence_parallel=False, + attention_softmax_in_fp32=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + seq_length=seq_length, + make_vocab_size_divisible_by=128, + vocab_size=None, + num_layers=2, # Need at least 2 layers for PP=2 + # Disable shared embeddings - not supported with decentralized PG + share_embeddings_and_output_weights=False, + ) + + # Config Container with use_decentralized_pg=True + cfg = ConfigContainer( + model=model_cfg, + train=TrainingConfig( + train_iters=total_iters, + eval_interval=5, + eval_iters=0, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + exit_signal_handler=True, + ), + optimizer=OptimizerConfig( + optimizer="adam", + bf16=True, + fp16=False, + adam_beta1=0.9, + adam_beta2=0.95, + adam_eps=1e-5, + use_distributed_optimizer=True, + clip_grad=1.0, + lr=3e-3, + weight_decay=0.01, + min_lr=1e-6, + ), + scheduler=SchedulerConfig( + start_weight_decay=0.033, + end_weight_decay=0.033, + weight_decay_incr_style="constant", + lr_decay_style="cosine", + lr_warmup_iters=1, + lr_warmup_init=0.0, + lr_decay_iters=total_iters, + override_opt_param_scheduler=True, + ), + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + use_distributed_optimizer=True, + ), + dist=DistributedInitConfig( + use_decentralized_pg=True, # Enable the feature + use_gloo_process_groups=False, # Gloo not supported with custom pg_collection + ), + dataset=MockGPTDatasetConfig( + random_seed=1234, + reset_attention_mask=False, + reset_position_ids=False, + eod_mask_loss=False, + seq_length=seq_length, + num_dataset_builder_threads=1, + data_sharding=True, + dataloader_type="single", + num_workers=1, + ), + logger=LoggerConfig( + log_interval=5, + tensorboard_dir=tensorboard_dir, + ), + tokenizer=TokenizerConfig( + tokenizer_type="NullTokenizer", + vocab_size=10000, + ), + checkpoint=CheckpointConfig( + save_interval=total_iters, + save=checkpoint_dir, + ckpt_format="torch_dist", + fully_parallel_save=True, + async_save=True, + ), + rng=RNGConfig(seed=1234), + ) + + # Run training + pretrain(cfg, forward_step) + + # Verify training completed + torch.distributed.barrier() + verify_checkpoint_files(checkpoint_dir, total_iters) + + finally: + clear_directories(tmp_path) + + @pytest.mark.run_only_on("GPU") + def test_pretrain_with_decentralized_pg_and_cp(self, tmp_path): + """ + Test training with use_decentralized_pg=True and context parallelism. + + This test verifies that the decentralized process groups feature works correctly + with context parallelism enabled. + """ + initialize_distributed() + shared_base_dir = broadcast_path(tmp_path) + + # Skip if world size is not at least 2 for CP + world_size = torch.distributed.get_world_size() + if world_size < 2: + pytest.skip("This test requires at least 2 GPUs for CP=2") + + checkpoint_dir = os.path.join(shared_base_dir, "checkpoints") + tensorboard_dir = os.path.join(shared_base_dir, "tensorboard") + + if torch.distributed.get_rank() == 0: + os.makedirs(checkpoint_dir, exist_ok=True) + os.makedirs(tensorboard_dir, exist_ok=True) + + torch.distributed.barrier() + + try: + global_batch_size = 8 + micro_batch_size = 1 + seq_length = 512 + total_iters = 5 + + model_cfg = Llama32ModelProvider1B( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=2, # Enable CP + sequence_parallel=False, + attention_softmax_in_fp32=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + seq_length=seq_length, + make_vocab_size_divisible_by=128, + vocab_size=None, + num_layers=1, + # Disable shared embeddings - not supported with decentralized PG + share_embeddings_and_output_weights=False, + ) + + # Config Container with use_decentralized_pg=True + cfg = ConfigContainer( + model=model_cfg, + train=TrainingConfig( + train_iters=total_iters, + eval_interval=5, + eval_iters=0, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + exit_signal_handler=True, + ), + optimizer=OptimizerConfig( + optimizer="adam", + bf16=True, + fp16=False, + adam_beta1=0.9, + adam_beta2=0.95, + adam_eps=1e-5, + use_distributed_optimizer=True, + clip_grad=1.0, + lr=3e-3, + weight_decay=0.01, + min_lr=1e-6, + ), + scheduler=SchedulerConfig( + start_weight_decay=0.033, + end_weight_decay=0.033, + weight_decay_incr_style="constant", + lr_decay_style="cosine", + lr_warmup_iters=1, + lr_warmup_init=0.0, + lr_decay_iters=total_iters, + override_opt_param_scheduler=True, + ), + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + use_distributed_optimizer=True, + ), + dist=DistributedInitConfig( + use_decentralized_pg=True, # Enable the feature + use_gloo_process_groups=False, # Gloo not supported with custom pg_collection + ), + dataset=MockGPTDatasetConfig( + random_seed=1234, + reset_attention_mask=False, + reset_position_ids=False, + eod_mask_loss=False, + seq_length=seq_length, + num_dataset_builder_threads=1, + data_sharding=True, + dataloader_type="single", + num_workers=1, + ), + logger=LoggerConfig( + log_interval=5, + tensorboard_dir=tensorboard_dir, + ), + tokenizer=TokenizerConfig( + tokenizer_type="NullTokenizer", + vocab_size=10000, + ), + checkpoint=CheckpointConfig( + save_interval=total_iters, + save=checkpoint_dir, + ckpt_format="torch_dist", + fully_parallel_save=True, + async_save=True, + ), + rng=RNGConfig(seed=1234), + ) + + # Run training + pretrain(cfg, forward_step) + + # Verify training completed + torch.distributed.barrier() + verify_checkpoint_files(checkpoint_dir, total_iters) + + finally: + clear_directories(tmp_path) + + @pytest.mark.run_only_on("GPU") + def test_pretrain_with_decentralized_pg_combined_parallelism(self, tmp_path): + """ + Test training with use_decentralized_pg=True and combined TP+PP. + + This test verifies that the decentralized process groups feature works correctly + with multiple forms of parallelism enabled simultaneously. + """ + initialize_distributed() + shared_base_dir = broadcast_path(tmp_path) + + # Skip if world size is not at least 4 for TP=2, PP=2 + world_size = torch.distributed.get_world_size() + if world_size < 4: + pytest.skip("This test requires at least 4 GPUs for TP=2, PP=2") + + checkpoint_dir = os.path.join(shared_base_dir, "checkpoints") + tensorboard_dir = os.path.join(shared_base_dir, "tensorboard") + + if torch.distributed.get_rank() == 0: + os.makedirs(checkpoint_dir, exist_ok=True) + os.makedirs(tensorboard_dir, exist_ok=True) + + torch.distributed.barrier() + + try: + global_batch_size = 8 + micro_batch_size = 1 + seq_length = 512 + total_iters = 5 + + model_cfg = Llama32ModelProvider1B( + tensor_model_parallel_size=2, # Enable TP + pipeline_model_parallel_size=2, # Enable PP + context_parallel_size=1, + sequence_parallel=True, # Usually used with TP + attention_softmax_in_fp32=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + seq_length=seq_length, + make_vocab_size_divisible_by=128, + vocab_size=None, + num_layers=2, # Need at least 2 layers for PP=2 + # Disable shared embeddings - not supported with decentralized PG + share_embeddings_and_output_weights=False, + ) + + # Config Container with use_decentralized_pg=True + cfg = ConfigContainer( + model=model_cfg, + train=TrainingConfig( + train_iters=total_iters, + eval_interval=5, + eval_iters=0, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + exit_signal_handler=True, + ), + optimizer=OptimizerConfig( + optimizer="adam", + bf16=True, + fp16=False, + adam_beta1=0.9, + adam_beta2=0.95, + adam_eps=1e-5, + use_distributed_optimizer=True, + clip_grad=1.0, + lr=3e-3, + weight_decay=0.01, + min_lr=1e-6, + ), + scheduler=SchedulerConfig( + start_weight_decay=0.033, + end_weight_decay=0.033, + weight_decay_incr_style="constant", + lr_decay_style="cosine", + lr_warmup_iters=1, + lr_warmup_init=0.0, + lr_decay_iters=total_iters, + override_opt_param_scheduler=True, + ), + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + use_distributed_optimizer=True, + ), + dist=DistributedInitConfig( + use_decentralized_pg=True, # Enable the feature + use_gloo_process_groups=False, # Gloo not supported with custom pg_collection + ), + dataset=MockGPTDatasetConfig( + random_seed=1234, + reset_attention_mask=False, + reset_position_ids=False, + eod_mask_loss=False, + seq_length=seq_length, + num_dataset_builder_threads=1, + data_sharding=True, + dataloader_type="single", + num_workers=1, + ), + logger=LoggerConfig( + log_interval=5, + tensorboard_dir=tensorboard_dir, + ), + tokenizer=TokenizerConfig( + tokenizer_type="NullTokenizer", + vocab_size=10000, + ), + checkpoint=CheckpointConfig( + save_interval=total_iters, + save=checkpoint_dir, + ckpt_format="torch_dist", + fully_parallel_save=True, + async_save=True, + ), + rng=RNGConfig(seed=1234), + ) + + # Run training + pretrain(cfg, forward_step) + + # Verify training completed + torch.distributed.barrier() + verify_checkpoint_files(checkpoint_dir, total_iters) + + finally: + clear_directories(tmp_path) + + @pytest.mark.run_only_on("GPU") + def test_pretrain_with_decentralized_pg_and_tp(self, tmp_path): + """ + Test training with use_decentralized_pg=True and tensor parallelism. + + This test verifies that the decentralized process groups feature works correctly + with tensor parallelism enabled. + """ + initialize_distributed() + shared_base_dir = broadcast_path(tmp_path) + + # Skip if world size is not at least 2 for TP + world_size = torch.distributed.get_world_size() + if world_size < 2: + pytest.skip("This test requires at least 2 GPUs for TP=2") + + checkpoint_dir = os.path.join(shared_base_dir, "checkpoints") + tensorboard_dir = os.path.join(shared_base_dir, "tensorboard") + + if torch.distributed.get_rank() == 0: + os.makedirs(checkpoint_dir, exist_ok=True) + os.makedirs(tensorboard_dir, exist_ok=True) + + torch.distributed.barrier() + + try: + global_batch_size = 8 + micro_batch_size = 1 + seq_length = 512 + total_iters = 5 + + model_cfg = Llama32ModelProvider1B( + tensor_model_parallel_size=2, # Enable TP + pipeline_model_parallel_size=1, + context_parallel_size=1, + sequence_parallel=True, # Usually used with TP + attention_softmax_in_fp32=True, + pipeline_dtype=torch.bfloat16, + bf16=True, + seq_length=seq_length, + make_vocab_size_divisible_by=128, + vocab_size=None, + num_layers=1, + # Disable shared embeddings - not supported with decentralized PG + share_embeddings_and_output_weights=False, + ) + + # Config Container with use_decentralized_pg=True + cfg = ConfigContainer( + model=model_cfg, + train=TrainingConfig( + train_iters=total_iters, + eval_interval=5, + eval_iters=0, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + exit_signal_handler=True, + ), + optimizer=OptimizerConfig( + optimizer="adam", + bf16=True, + fp16=False, + adam_beta1=0.9, + adam_beta2=0.95, + adam_eps=1e-5, + use_distributed_optimizer=True, + clip_grad=1.0, + lr=3e-3, + weight_decay=0.01, + min_lr=1e-6, + ), + scheduler=SchedulerConfig( + start_weight_decay=0.033, + end_weight_decay=0.033, + weight_decay_incr_style="constant", + lr_decay_style="cosine", + lr_warmup_iters=1, + lr_warmup_init=0.0, + lr_decay_iters=total_iters, + override_opt_param_scheduler=True, + ), + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + use_distributed_optimizer=True, + ), + dist=DistributedInitConfig( + use_decentralized_pg=True, # Enable the feature + use_gloo_process_groups=False, # Gloo not supported with custom pg_collection + ), + dataset=MockGPTDatasetConfig( + random_seed=1234, + reset_attention_mask=False, + reset_position_ids=False, + eod_mask_loss=False, + seq_length=seq_length, + num_dataset_builder_threads=1, + data_sharding=True, + dataloader_type="single", + num_workers=1, + ), + logger=LoggerConfig( + log_interval=5, + tensorboard_dir=tensorboard_dir, + ), + tokenizer=TokenizerConfig( + tokenizer_type="NullTokenizer", + vocab_size=10000, + ), + checkpoint=CheckpointConfig( + save_interval=total_iters, + save=checkpoint_dir, + ckpt_format="torch_dist", + fully_parallel_save=True, + async_save=True, + ), + rng=RNGConfig(seed=1234), + ) + + # Run training + pretrain(cfg, forward_step) + + # Verify training completed + torch.distributed.barrier() + verify_checkpoint_files(checkpoint_dir, total_iters) + + finally: + clear_directories(tmp_path) diff --git a/tests/unit_tests/data/datasets/test_sft.py b/tests/unit_tests/data/datasets/test_sft.py index e9db70f8c9..c819a7668a 100755 --- a/tests/unit_tests/data/datasets/test_sft.py +++ b/tests/unit_tests/data/datasets/test_sft.py @@ -135,14 +135,19 @@ def setup_and_teardown_parallel_state(self): ) assert parallel_state.model_parallel_is_initialized(), "Model parallel not initialized" + from megatron.core.process_groups_config import ProcessGroupCollection from megatron.bridge.training.initialize import _set_random_seed + # Create pg_collection from initialized mpu + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + _set_random_seed( seed_=1234, data_parallel_random_init=False, te_rng_tracker=True, inference_rng_tracker=False, + pg_collection=pg_collection, ) yield diff --git a/tests/unit_tests/peft/test_canonical_lora.py b/tests/unit_tests/peft/test_canonical_lora.py index 21279cfa3a..5b0ec0f2fc 100644 --- a/tests/unit_tests/peft/test_canonical_lora.py +++ b/tests/unit_tests/peft/test_canonical_lora.py @@ -786,13 +786,19 @@ def setup_and_teardown_parallel_state(self): ) assert parallel_state.model_parallel_is_initialized(), "Model parallel not initialized" + from megatron.core.process_groups_config import ProcessGroupCollection + from megatron.bridge.training.initialize import _set_random_seed + # Create pg_collection from initialized mpu + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + _set_random_seed( seed_=1234, data_parallel_random_init=False, te_rng_tracker=True, inference_rng_tracker=False, + pg_collection=pg_collection, ) yield diff --git a/tests/unit_tests/peft/test_dora.py b/tests/unit_tests/peft/test_dora.py index 5e8dfe87b9..4faf5a8348 100644 --- a/tests/unit_tests/peft/test_dora.py +++ b/tests/unit_tests/peft/test_dora.py @@ -543,13 +543,19 @@ def setup_and_teardown_parallel_state(self): ) assert parallel_state.model_parallel_is_initialized(), "Model parallel not initialized" + from megatron.core.process_groups_config import ProcessGroupCollection + from megatron.bridge.training.initialize import _set_random_seed + # Create pg_collection from initialized mpu + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + _set_random_seed( seed_=1234, data_parallel_random_init=False, te_rng_tracker=True, inference_rng_tracker=False, + pg_collection=pg_collection, ) yield diff --git a/tests/unit_tests/peft/test_lora.py b/tests/unit_tests/peft/test_lora.py index 5524d76350..6432b3b33c 100644 --- a/tests/unit_tests/peft/test_lora.py +++ b/tests/unit_tests/peft/test_lora.py @@ -642,13 +642,19 @@ def setup_and_teardown_parallel_state(self): ) assert parallel_state.model_parallel_is_initialized(), "Model parallel not initialized" + from megatron.core.process_groups_config import ProcessGroupCollection + from megatron.bridge.training.initialize import _set_random_seed + # Create pg_collection from initialized mpu + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + _set_random_seed( seed_=1234, data_parallel_random_init=False, te_rng_tracker=True, inference_rng_tracker=False, + pg_collection=pg_collection, ) yield diff --git a/tests/unit_tests/peft/test_lora_layers.py b/tests/unit_tests/peft/test_lora_layers.py index 576a4445b7..a426905861 100644 --- a/tests/unit_tests/peft/test_lora_layers.py +++ b/tests/unit_tests/peft/test_lora_layers.py @@ -377,13 +377,19 @@ def setup_and_teardown_parallel_state(self): assert parallel_state.model_parallel_is_initialized(), "Model parallel not initialized" + from megatron.core.process_groups_config import ProcessGroupCollection + from megatron.bridge.training.initialize import _set_random_seed + # Create pg_collection from initialized mpu + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + _set_random_seed( seed_=1234, data_parallel_random_init=False, te_rng_tracker=True, inference_rng_tracker=False, + pg_collection=pg_collection, ) yield diff --git a/tests/unit_tests/training/test_checkpointing.py b/tests/unit_tests/training/test_checkpointing.py index 00fd30c9ed..085f27fe19 100644 --- a/tests/unit_tests/training/test_checkpointing.py +++ b/tests/unit_tests/training/test_checkpointing.py @@ -249,14 +249,13 @@ def test_checkpoint_type_enum(self): class TestRNGState: """Test RNG state collection.""" - @patch("megatron.bridge.training.checkpointing.mpu") @patch("megatron.bridge.training.checkpointing.tensor_parallel") @patch("torch.distributed.is_initialized") @patch("torch.cuda.get_rng_state") @patch("torch.get_rng_state") @patch("numpy.random.get_state") @patch("random.getstate") - def test_get_rng_state(self, mock_random, mock_np, mock_torch, mock_cuda, mock_dist_init, mock_tp, mock_mpu): + def test_get_rng_state(self, mock_random, mock_np, mock_torch, mock_cuda, mock_dist_init, mock_tp): """Test RNG state collection.""" # Setup mocks mock_dist_init.return_value = False @@ -268,13 +267,18 @@ def test_get_rng_state(self, mock_random, mock_np, mock_torch, mock_cuda, mock_d mock_tracker.get_states.return_value = "tracker_states" mock_tp.get_cuda_rng_tracker.return_value = mock_tracker - mock_mpu.get_pipeline_model_parallel_rank.return_value = 0 - mock_mpu.get_pipeline_model_parallel_world_size.return_value = 1 - mock_mpu.get_tensor_model_parallel_rank.return_value = 0 - mock_mpu.get_tensor_model_parallel_world_size.return_value = 1 - mock_mpu.get_data_parallel_rank.return_value = 0 - - result = get_rng_state(data_parallel_random_init=False) + # Create mock pg_collection + mock_pg_collection = Mock() + mock_pg_collection.pp.rank.return_value = 0 + mock_pg_collection.pp.size.return_value = 1 + mock_pg_collection.tp.rank.return_value = 0 + mock_pg_collection.tp.size.return_value = 1 + mock_pg_collection.dp_cp.rank.return_value = 0 + mock_pg_collection.dp_cp.size.return_value = 1 + + result = get_rng_state( + data_parallel_random_init=False, ckpt_format="torch_dist", pg_collection=mock_pg_collection + ) # Verify the result is a ShardedObject assert result.key == "rng_state" @@ -347,6 +351,8 @@ def save_checkpoint_fixtures(): mock_cfg.to_yaml = Mock() # Mock config YAML export mock_cfg.logger = Mock() mock_cfg.logger.log_progress = False + mock_cfg.dist = Mock() + mock_cfg.dist.use_decentralized_pg = False mock_state.cfg = mock_cfg @@ -377,7 +383,7 @@ class TestSaveCheckpoint: @patch("megatron.bridge.training.checkpointing.get_rerun_state_machine") @patch("megatron.bridge.training.checkpointing.generate_state_dict") @patch("megatron.bridge.training.checkpointing.dist_checkpointing") - @patch("megatron.bridge.training.checkpointing.mpu") + @patch("megatron.bridge.training.checkpointing.get_pg_collection") @patch("megatron.bridge.training.checkpointing.fault_tolerance") @patch("megatron.bridge.training.checkpointing.is_empty_async_queue") @patch("megatron.bridge.training.checkpointing.get_rank_safe") @@ -400,7 +406,7 @@ def test_save_checkpoint_global( mock_get_rank_safe, mock_empty_queue, mock_ft, - mock_mpu, + mock_get_pg_collection, mock_dist_ckpt, mock_gen_state, mock_rerun, @@ -424,11 +430,16 @@ def test_save_checkpoint_global( mock_get_rng.return_value = Mock() mock_rerun.return_value.state_dict.return_value = {} mock_gen_state.return_value = {"model": {"param1": "value1", "param2": "value2"}} - mock_mpu.get_expert_data_parallel_rank.return_value = 0 - mock_mpu.get_tensor_model_parallel_rank.return_value = 0 - mock_mpu.get_tensor_model_parallel_world_size.return_value = 1 - mock_mpu.get_pipeline_model_parallel_rank.return_value = 0 - mock_mpu.get_pipeline_model_parallel_world_size.return_value = 1 + + # Create mock pg_collection + mock_pg_collection = Mock() + mock_pg_collection.expt_dp.rank.return_value = 0 + mock_pg_collection.tp.rank.return_value = 0 + mock_pg_collection.tp.size.return_value = 1 + mock_pg_collection.pp.rank.return_value = 0 + mock_pg_collection.pp.size.return_value = 1 + mock_get_pg_collection.return_value = mock_pg_collection + mock_get_strategy.return_value = Mock() mock_dist_ckpt.save.return_value = None # Synchronous save mock_save_modelopt.return_value = None # Mock ModelOpt save @@ -521,6 +532,8 @@ def load_checkpoint_fixtures(): mock_cfg.optimizer.use_distributed_optimizer = False mock_cfg.checkpoint.ckpt_format = "torch_dist" mock_cfg.checkpoint.non_persistent_save_interval = None + mock_cfg.dist = Mock() + mock_cfg.dist.use_decentralized_pg = False mock_state.cfg = mock_cfg @@ -588,11 +601,12 @@ def test_load_checkpoint_not_found( @patch("megatron.bridge.training.checkpointing.wandb_utils") @patch("megatron.bridge.training.checkpointing.is_last_rank") @patch("megatron.bridge.training.checkpointing.print_rank_0") - @patch("megatron.bridge.training.checkpointing.mpu") + @patch("megatron.bridge.training.checkpointing.get_pg_collection") @patch("megatron.bridge.training.checkpointing.get_rerun_state_machine") @patch("megatron.bridge.training.checkpointing.tensor_parallel") @patch("megatron.bridge.training.checkpointing.generate_state_dict") @patch("megatron.bridge.training.checkpointing.get_rng_state") + @patch("megatron.bridge.training.checkpointing.dist_checkpointing") @patch("random.setstate") @patch("numpy.random.set_state") @patch("torch.set_rng_state") @@ -611,11 +625,12 @@ def test_load_checkpoint_found( mock_torch_set_rng, mock_np_set_state, mock_random_setstate, + mock_dist_ckpt, mock_get_rng_state, mock_generate_state_dict, mock_tensor_parallel, mock_rerun_machine, - mock_mpu, + mock_get_pg_collection, mock_print_rank_0, mock_is_last_rank, mock_wandb, @@ -658,12 +673,19 @@ def test_load_checkpoint_found( mock_rng_tracker.set_states = Mock() mock_tensor_parallel.get_cuda_rng_tracker.return_value = mock_rng_tracker - # Mock MPU functions - mock_mpu.get_tensor_model_parallel_rank.return_value = 0 - mock_mpu.get_tensor_model_parallel_world_size.return_value = 1 - mock_mpu.get_pipeline_model_parallel_rank.return_value = 0 - mock_mpu.get_pipeline_model_parallel_world_size.return_value = 1 - mock_mpu.get_data_parallel_rank.return_value = 0 + # Create mock pg_collection + mock_pg_collection = Mock() + mock_pg_collection.tp.rank.return_value = 0 + mock_pg_collection.tp.size.return_value = 1 + mock_pg_collection.pp.rank.return_value = 0 + mock_pg_collection.pp.size.return_value = 1 + mock_pg_collection.dp.rank.return_value = 0 + mock_pg_collection.dp_cp.rank.return_value = 0 + mock_get_pg_collection.return_value = mock_pg_collection + + # Mock dist_checkpointing + mock_dist_ckpt.load_content_metadata.return_value = {} + mock_dist_ckpt.load.return_value = {} # Mock rerun state machine mock_rerun_machine.return_value.load_state_dict = Mock() @@ -897,17 +919,31 @@ def base_config(self): mock_cfg = Mock(spec=CheckpointConfig) mock_cfg.exit_on_missing_checkpoint = False mock_cfg.ckpt_step = None + mock_cfg.non_persistent_ckpt_type = None return mock_cfg + @pytest.fixture + def mock_pg_collection(self): + """Fixture for mock pg_collection.""" + mock_pg = Mock() + mock_pg.dp_cp.rank.return_value = 0 + mock_pg.dp_cp.size.return_value = 1 + mock_pg.pp.rank.return_value = 0 + mock_pg.pp.size.return_value = 1 + mock_pg.tp.rank.return_value = 0 + mock_pg.tp.size.return_value = 1 + return mock_pg + @patch("megatron.bridge.training.checkpointing._get_non_persistent_iteration") - @patch("megatron.bridge.training.checkpointing.read_train_state") - @patch("os.path.isfile") - def test_load_base_checkpoint_no_checkpoint(self, mock_isfile, mock_read_state, mock_get_np_iter, base_config): + @patch("megatron.bridge.training.checkpointing.file_exists") + def test_load_base_checkpoint_no_checkpoint( + self, mock_file_exists, mock_get_np_iter, base_config, mock_pg_collection + ): """Test when no checkpoint is found.""" mock_get_np_iter.return_value = -1 - mock_isfile.return_value = False + mock_file_exists.return_value = False - result = _load_base_checkpoint("/fake/dir", base_config) + result = _load_base_checkpoint("/fake/dir", base_config, pg_collection=mock_pg_collection) assert result == (None, "", False, None) @@ -917,7 +953,14 @@ def test_load_base_checkpoint_no_checkpoint(self, mock_isfile, mock_read_state, @patch("megatron.bridge.training.checkpointing.file_exists") @patch("os.path.exists") def test_load_base_checkpoint_non_distributed_error( - self, mock_os_exists, mock_file_exists, mock_dist_ckpt, mock_read_state, mock_get_np_iter, base_config + self, + mock_os_exists, + mock_file_exists, + mock_dist_ckpt, + mock_read_state, + mock_get_np_iter, + base_config, + mock_pg_collection, ): """Test error when trying to load non-distributed checkpoint.""" mock_get_np_iter.return_value = -1 @@ -931,8 +974,8 @@ def test_load_base_checkpoint_non_distributed_error( # Mock that .metadata file does NOT exist (so it's not fsdp_dtensor) mock_os_exists.return_value = False - with pytest.raises(RuntimeError) as exc_info: - _load_base_checkpoint("/fake/dir", base_config) + with pytest.raises(NotImplementedError) as exc_info: + _load_base_checkpoint("/fake/dir", base_config, pg_collection=mock_pg_collection) assert "Unknown checkpoint format" in str(exc_info.value) @@ -990,10 +1033,10 @@ def mock_metadata(self): @patch("megatron.bridge.training.checkpointing._load_model_state_dict") @patch("megatron.bridge.training.checkpointing.get_default_load_sharded_strategy") @patch("megatron.bridge.training.checkpointing.FullyParallelLoadStrategyWrapper") - @patch("megatron.bridge.training.checkpointing.mpu") + @patch("megatron.bridge.training.checkpointing.get_pg_collection") def test_load_model_weights_single_model_success( self, - mock_mpu, + mock_get_pg_collection, mock_fully_parallel_wrapper, mock_get_strategy, mock_load_state_dict, @@ -1014,6 +1057,11 @@ def test_load_model_weights_single_model_success( mock_generate_state_dict.return_value = {"model": {"weight": torch.randn(10, 10)}} mock_unwrap_model.return_value = mock_model + # Create mock pg_collection + mock_pg_collection = Mock() + mock_pg_collection.dp_cp = Mock() + mock_get_pg_collection.return_value = mock_pg_collection + # Call the function from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint @@ -1078,10 +1126,10 @@ def test_load_model_weights_calls_delete_extra_state( @patch("megatron.bridge.training.checkpointing._load_model_state_dict") @patch("megatron.bridge.training.checkpointing.get_default_load_sharded_strategy") @patch("megatron.bridge.training.checkpointing.FullyParallelLoadStrategyWrapper") - @patch("megatron.bridge.training.checkpointing.mpu") + @patch("megatron.bridge.training.checkpointing.get_pg_collection") def test_load_model_weights_multiple_models_success( self, - mock_mpu, + mock_get_pg_collection, mock_fully_parallel_wrapper, mock_get_strategy, mock_load_state_dict, @@ -1105,6 +1153,11 @@ def test_load_model_weights_multiple_models_success( } mock_unwrap_model.return_value = mock_multiple_models + # Create mock pg_collection + mock_pg_collection = Mock() + mock_pg_collection.dp_cp = Mock() + mock_get_pg_collection.return_value = mock_pg_collection + # Call the function from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint @@ -1134,10 +1187,10 @@ def test_load_model_weights_multiple_models_success( @patch("megatron.bridge.training.checkpointing._load_model_state_dict") @patch("megatron.bridge.training.checkpointing.get_default_load_sharded_strategy") @patch("megatron.bridge.training.checkpointing.FullyParallelLoadStrategyWrapper") - @patch("megatron.bridge.training.checkpointing.mpu") + @patch("megatron.bridge.training.checkpointing.get_pg_collection") def test_load_model_weights_fully_parallel_load( self, - mock_mpu, + mock_get_pg_collection, mock_fully_parallel_wrapper, mock_get_strategy, mock_load_state_dict, @@ -1157,7 +1210,12 @@ def test_load_model_weights_fully_parallel_load( mock_fully_parallel_wrapper.return_value = Mock() mock_generate_state_dict.return_value = {"model": {"weight": torch.randn(10, 10)}} mock_unwrap_model.return_value = mock_model - mock_mpu.get_data_parallel_group.return_value = Mock() + + # Create mock pg_collection + mock_pg_collection = Mock() + mock_dp_cp_group = Mock() + mock_pg_collection.dp_cp = mock_dp_cp_group + mock_get_pg_collection.return_value = mock_pg_collection # Call the function from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint @@ -1170,11 +1228,8 @@ def test_load_model_weights_fully_parallel_load( strict=True, ) - # Verify fully parallel wrapper was used - mock_fully_parallel_wrapper.assert_called_once_with( - mock_strategy, mock_mpu.get_data_parallel_group.return_value - ) - mock_mpu.get_data_parallel_group.assert_called_once_with(with_context_parallel=True) + # Verify fully parallel wrapper was used with pg_collection.dp_cp + mock_fully_parallel_wrapper.assert_called_once_with(mock_strategy, mock_dp_cp_group) @patch("megatron.bridge.training.checkpointing.dist_checkpointing") @patch("megatron.bridge.training.checkpointing.unwrap_model") @@ -1182,10 +1237,10 @@ def test_load_model_weights_fully_parallel_load( @patch("megatron.bridge.training.checkpointing._load_model_state_dict") @patch("megatron.bridge.training.checkpointing.get_default_load_sharded_strategy") @patch("megatron.bridge.training.checkpointing.FullyParallelLoadStrategyWrapper") - @patch("megatron.bridge.training.checkpointing.mpu") + @patch("megatron.bridge.training.checkpointing.get_pg_collection") def test_load_model_weights_none_state_dict( self, - mock_mpu, + mock_get_pg_collection, mock_fully_parallel_wrapper, mock_get_strategy, mock_load_state_dict, @@ -1203,6 +1258,11 @@ def test_load_model_weights_none_state_dict( mock_generate_state_dict.return_value = {"model": {"weight": torch.randn(10, 10)}} mock_unwrap_model.return_value = mock_model + # Create mock pg_collection + mock_pg_collection = Mock() + mock_pg_collection.dp_cp = Mock() + mock_get_pg_collection.return_value = mock_pg_collection + # Call the function and expect assertion error from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint @@ -1221,10 +1281,10 @@ def test_load_model_weights_none_state_dict( @patch("megatron.bridge.training.checkpointing._load_model_state_dict") @patch("megatron.bridge.training.checkpointing.get_default_load_sharded_strategy") @patch("megatron.bridge.training.checkpointing.FullyParallelLoadStrategyWrapper") - @patch("megatron.bridge.training.checkpointing.mpu") + @patch("megatron.bridge.training.checkpointing.get_pg_collection") def test_return_state_dict( self, - mock_mpu, + mock_get_pg_collection, mock_fully_parallel_wrapper, mock_get_strategy, mock_load_state_dict, @@ -1245,6 +1305,11 @@ def test_return_state_dict( mock_generate_state_dict.return_value = {"model": {"weight": torch.randn(10, 10)}} mock_unwrap_model.return_value = mock_model + # Create mock pg_collection + mock_pg_collection = Mock() + mock_pg_collection.dp_cp = Mock() + mock_get_pg_collection.return_value = mock_pg_collection + # Call the function from megatron.bridge.training.checkpointing import _load_model_weights_from_checkpoint @@ -1367,23 +1432,28 @@ def test_extract_megatron_lm_args_from_state_dict_missing_args(self): assert "Legacy checkpoint missing 'args' field" in str(exc_info.value) @patch("megatron.bridge.training.checkpointing.read_metadata") - @patch("os.path.exists") - def test_load_base_checkpoint_legacy_tracker(self, mock_isfile, mock_read_metadata): + @patch("megatron.bridge.training.checkpointing.file_exists") + def test_load_base_checkpoint_legacy_tracker(self, mock_file_exists, mock_read_metadata): """Test loading checkpoint with legacy Megatron-LM tracker file.""" mock_cfg = Mock(spec=CheckpointConfig) mock_cfg.non_persistent_ckpt_type = None mock_cfg.exit_on_missing_checkpoint = False mock_cfg.ckpt_step = None + # Create mock pg_collection + mock_pg_collection = Mock() + mock_pg_collection.dp_cp.rank.return_value = 0 + mock_pg_collection.dp_cp.size.return_value = 1 + # Mock file existence: NeMo-LM tracker doesn't exist, legacy tracker does - def mock_isfile_side_effect(path): + def mock_file_exists_side_effect(path): if "latest_train_state.pt" in path: return False elif "latest_checkpointed_iteration.txt" in path: return True return False - mock_isfile.side_effect = mock_isfile_side_effect + mock_file_exists.side_effect = mock_file_exists_side_effect mock_read_metadata.return_value = (1000, False) with patch("megatron.bridge.training.checkpointing._get_non_persistent_iteration", return_value=-1): @@ -1392,7 +1462,7 @@ def mock_isfile_side_effect(path): with patch("megatron.bridge.training.checkpointing._load_global_dist_base_checkpoint") as mock_load: mock_load.return_value = ({"test": "data"}, "/ckpt/path", False, CheckpointType.GLOBAL) - result = _load_base_checkpoint("/test/dir", mock_cfg, rank0=True) + result = _load_base_checkpoint("/test/dir", mock_cfg, rank0=True, pg_collection=mock_pg_collection) state_dict, checkpoint_name, release, ckpt_type = result assert state_dict == {"test": "data"} @@ -1440,10 +1510,11 @@ def test_load_checkpoint_legacy_config_extraction(self, mock_exists, mock_read_c @patch("megatron.bridge.training.checkpointing.wandb_utils") @patch("megatron.bridge.training.checkpointing.is_last_rank") @patch("megatron.bridge.training.checkpointing.print_rank_0") - @patch("megatron.bridge.training.checkpointing.mpu") + @patch("megatron.bridge.training.checkpointing.get_pg_collection") @patch("megatron.bridge.training.checkpointing.get_rerun_state_machine") @patch("megatron.bridge.training.checkpointing.generate_state_dict") @patch("megatron.bridge.training.checkpointing.get_rng_state") + @patch("megatron.bridge.training.checkpointing.dist_checkpointing") @patch("torch.distributed.is_initialized") @patch("torch.distributed.barrier") @patch("torch.cuda.empty_cache") @@ -1454,10 +1525,11 @@ def test_load_checkpoint_full_legacy_integration( mock_empty_cache, mock_barrier, mock_dist_init, + mock_dist_ckpt, mock_get_rng_state, mock_generate_state_dict, mock_rerun_machine, - mock_mpu, + mock_get_pg_collection, mock_print_rank_0, mock_is_last_rank, mock_wandb, @@ -1512,10 +1584,20 @@ def mock_exists_side_effect(path): # Mock other required functions mock_generate_state_dict.return_value = {"test": "state"} mock_get_rng_state.return_value = Mock() - mock_mpu.get_tensor_model_parallel_rank.return_value = 0 - mock_mpu.get_tensor_model_parallel_world_size.return_value = 2 - mock_mpu.get_pipeline_model_parallel_rank.return_value = 0 - mock_mpu.get_pipeline_model_parallel_world_size.return_value = 1 + + # Create mock pg_collection + mock_pg_collection = Mock() + mock_pg_collection.tp.rank.return_value = 0 + mock_pg_collection.tp.size.return_value = 2 + mock_pg_collection.pp.rank.return_value = 0 + mock_pg_collection.pp.size.return_value = 1 + mock_pg_collection.dp_cp.rank.return_value = 0 + mock_get_pg_collection.return_value = mock_pg_collection + + # Mock dist_checkpointing + mock_dist_ckpt.load_content_metadata.return_value = {} + mock_dist_ckpt.load.return_value = {} + mock_rerun_machine.return_value.load_state_dict = Mock() # Create test fixtures @@ -1544,6 +1626,8 @@ def mock_exists_side_effect(path): mock_cfg.optimizer = Mock() mock_cfg.optimizer.use_distributed_optimizer = False mock_cfg.peft = None # No PEFT for this test + mock_cfg.dist = Mock() + mock_cfg.dist.use_decentralized_pg = False mock_state.cfg = mock_cfg @@ -2005,66 +2089,66 @@ def test_get_checkpoint_format_unknown(self, mock_exists, mock_check_dist_ckpt): with pytest.raises(NotImplementedError, match="Unknown checkpoint format"): _get_checkpoint_format("/path/to/checkpoint") - @patch("megatron.core.mpu.get_pipeline_model_parallel_rank") - @patch("megatron.core.mpu.get_tensor_model_parallel_rank") - @patch("megatron.core.mpu.get_data_parallel_world_size") + @patch("megatron.bridge.training.checkpointing.tensor_parallel") @patch("torch.distributed.is_initialized") - def test_get_rng_state_fsdp_dtensor_format(self, mock_dist_init, mock_dp_world_size, mock_tp_rank, mock_pp_rank): + def test_get_rng_state_fsdp_dtensor_format(self, mock_dist_init, mock_tp): """Test get_rng_state returns correct format for fsdp_dtensor.""" mock_dist_init.return_value = False # Simplify - mock_dp_world_size.return_value = 1 - mock_tp_rank.return_value = 0 - mock_pp_rank.return_value = 0 + mock_tracker = Mock() + mock_tracker.get_states.return_value = "tracker_states" + mock_tp.get_cuda_rng_tracker.return_value = mock_tracker + + # Create mock pg_collection + mock_pg_collection = Mock() + mock_pg_collection.dp_cp.rank.return_value = 0 + mock_pg_collection.dp_cp.size.return_value = 1 + mock_pg_collection.tp.rank.return_value = 0 + mock_pg_collection.tp.size.return_value = 1 + mock_pg_collection.pp.rank.return_value = 0 + mock_pg_collection.pp.size.return_value = 1 with ( patch("random.getstate"), patch("numpy.random.get_state"), patch("torch.get_rng_state"), patch("torch.cuda.get_rng_state"), - patch("megatron.core.tensor_parallel.get_cuda_rng_tracker"), ): - result = get_rng_state(data_parallel_random_init=False, ckpt_format="fsdp_dtensor") + result = get_rng_state( + data_parallel_random_init=False, ckpt_format="fsdp_dtensor", pg_collection=mock_pg_collection + ) # Should return dict format for fsdp_dtensor assert isinstance(result, dict) assert "(0, 0)" in result - @patch("megatron.core.mpu.get_pipeline_model_parallel_rank") - @patch("megatron.core.mpu.get_pipeline_model_parallel_world_size") - @patch("megatron.core.mpu.get_tensor_model_parallel_rank") - @patch("megatron.core.mpu.get_tensor_model_parallel_world_size") - @patch("megatron.core.mpu.get_data_parallel_rank") - @patch("megatron.core.mpu.get_data_parallel_world_size") + @patch("megatron.bridge.training.checkpointing.tensor_parallel") @patch("torch.distributed.is_initialized") - def test_get_rng_state_torch_dist_format( - self, - mock_dist_init, - mock_dp_world_size, - mock_dp_rank, - mock_tp_world_size, - mock_tp_rank, - mock_pp_world_size, - mock_pp_rank, - ): + def test_get_rng_state_torch_dist_format(self, mock_dist_init, mock_tp): """Test get_rng_state returns ShardedObject for torch_dist.""" - # The ShardedObject is only created for torch_dist format regardless of distributed state - mock_dp_world_size.return_value = 1 - mock_dp_rank.return_value = 0 - mock_tp_rank.return_value = 0 - mock_tp_world_size.return_value = 1 - mock_pp_rank.return_value = 0 - mock_pp_world_size.return_value = 1 mock_dist_init.return_value = False + mock_tracker = Mock() + mock_tracker.get_states.return_value = "tracker_states" + mock_tp.get_cuda_rng_tracker.return_value = mock_tracker + + # Create mock pg_collection + mock_pg_collection = Mock() + mock_pg_collection.dp_cp.rank.return_value = 0 + mock_pg_collection.dp_cp.size.return_value = 1 + mock_pg_collection.tp.rank.return_value = 0 + mock_pg_collection.tp.size.return_value = 1 + mock_pg_collection.pp.rank.return_value = 0 + mock_pg_collection.pp.size.return_value = 1 with ( patch("random.getstate"), patch("numpy.random.get_state"), patch("torch.get_rng_state"), patch("torch.cuda.get_rng_state"), - patch("megatron.core.tensor_parallel.get_cuda_rng_tracker"), patch("megatron.bridge.training.checkpointing.ShardedObject") as mock_sharded_obj, ): - _ = get_rng_state(data_parallel_random_init=False, ckpt_format="torch_dist") + _ = get_rng_state( + data_parallel_random_init=False, ckpt_format="torch_dist", pg_collection=mock_pg_collection + ) # Should create ShardedObject for torch_dist format # The exact arguments depend on the RNG state, but we just verify it was called diff --git a/tests/unit_tests/training/test_decentralized_pg.py b/tests/unit_tests/training/test_decentralized_pg.py new file mode 100644 index 0000000000..1b48d360f2 --- /dev/null +++ b/tests/unit_tests/training/test_decentralized_pg.py @@ -0,0 +1,745 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for the use_decentralized_pg feature. + +This feature enables using ProcessGroupCollection passed through functions instead +of relying on mcore's global parallel state (mpu) variables. When enabled, parallel +groups are obtained from the pg_collection object rather than the global +megatron.core.parallel_state module. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from megatron.bridge.training.config import DistributedInitConfig + + +class TestDistributedInitConfigDecentralizedPg: + """Tests for DistributedInitConfig.use_decentralized_pg configuration.""" + + def test_use_decentralized_pg_default_is_false(self): + """Test that use_decentralized_pg defaults to False.""" + config = DistributedInitConfig() + assert config.use_decentralized_pg is False + + def test_use_decentralized_pg_can_be_enabled(self): + """Test that use_decentralized_pg can be set to True.""" + config = DistributedInitConfig(use_decentralized_pg=True) + assert config.use_decentralized_pg is True + + def test_use_decentralized_pg_can_be_explicitly_disabled(self): + """Test that use_decentralized_pg can be explicitly set to False.""" + config = DistributedInitConfig(use_decentralized_pg=False) + assert config.use_decentralized_pg is False + + +class TestCreatePgCollectionFunction: + """Tests for the _create_pg_collection function.""" + + @pytest.fixture + def mock_model_config(self): + """Create a mock model configuration for testing.""" + config = MagicMock() + config.tensor_model_parallel_size = 1 + config.pipeline_model_parallel_size = 1 + config.context_parallel_size = 1 + config.expert_tensor_parallel_size = None + config.expert_model_parallel_size = 1 + return config + + @patch("megatron.bridge.training.initialize.HyperCommGrid") + @patch("torch.distributed.get_world_size", return_value=1) + @patch("torch.distributed.new_subgroups_by_enumeration") + def test_create_pg_collection_returns_process_group_collection( + self, mock_subgroups, mock_world_size, mock_hyper_grid, mock_model_config + ): + """Test that _create_pg_collection returns a ProcessGroupCollection.""" + from megatron.bridge.training.initialize import _create_pg_collection + + # Setup mock + mock_grid_instance = MagicMock() + mock_grid_instance.create_pg.return_value = MagicMock() + mock_grid_instance._gen_rank_enum.return_value = [[0]] + mock_hyper_grid.return_value = mock_grid_instance + mock_subgroups.return_value = (MagicMock(), []) + + # Execute + result = _create_pg_collection(mock_model_config, num_distributed_optimizer_instances=1) + + # Verify + from megatron.core.process_groups_config import ProcessGroupCollection + + assert isinstance(result, ProcessGroupCollection) + + @patch("megatron.bridge.training.initialize.HyperCommGrid") + @patch("torch.distributed.get_world_size", return_value=8) + @patch("torch.distributed.new_subgroups_by_enumeration") + def test_create_pg_collection_with_tp(self, mock_subgroups, mock_world_size, mock_hyper_grid, mock_model_config): + """Test _create_pg_collection with tensor parallelism.""" + from megatron.bridge.training.initialize import _create_pg_collection + + # Setup with TP=2 + mock_model_config.tensor_model_parallel_size = 2 + + # Setup mock + mock_grid_instance = MagicMock() + mock_grid_instance.create_pg.return_value = MagicMock() + mock_grid_instance._gen_rank_enum.return_value = [[0, 1, 2, 3, 4, 5, 6, 7]] + mock_hyper_grid.return_value = mock_grid_instance + mock_subgroups.return_value = (MagicMock(), []) + + # Execute + _create_pg_collection(mock_model_config, num_distributed_optimizer_instances=1) + + # Verify grid was created with correct shape + mock_hyper_grid.assert_called() + call_kwargs = mock_hyper_grid.call_args[1] + assert call_kwargs["shape"][0] == 2 # TP size + + @patch("megatron.bridge.training.initialize.HyperCommGrid") + @patch("torch.distributed.get_world_size", return_value=8) + @patch("torch.distributed.new_subgroups_by_enumeration") + def test_create_pg_collection_with_pp(self, mock_subgroups, mock_world_size, mock_hyper_grid, mock_model_config): + """Test _create_pg_collection with pipeline parallelism.""" + from megatron.bridge.training.initialize import _create_pg_collection + + # Setup with PP=2 + mock_model_config.pipeline_model_parallel_size = 2 + + # Setup mock + mock_grid_instance = MagicMock() + mock_grid_instance.create_pg.return_value = MagicMock() + mock_grid_instance._gen_rank_enum.return_value = [[0, 1], [2, 3], [4, 5], [6, 7]] + mock_hyper_grid.return_value = mock_grid_instance + mock_subgroups.return_value = (MagicMock(), []) + + # Execute + _create_pg_collection(mock_model_config, num_distributed_optimizer_instances=1) + + # Verify grid was created with correct shape + mock_hyper_grid.assert_called() + call_kwargs = mock_hyper_grid.call_args[1] + assert call_kwargs["shape"][3] == 2 # PP size + + +class TestSetRandomSeedWithPgCollection: + """Tests for _set_random_seed function with pg_collection parameter.""" + + @patch("torch.distributed.get_rank", return_value=0) + @patch("torch.distributed.get_group_rank") + @patch("torch.cuda.device_count", return_value=1) + @patch("megatron.core.tensor_parallel.model_parallel_cuda_manual_seed") + @patch("megatron.core.utils.get_pg_rank", return_value=0) + @patch("torch.cuda.manual_seed") + @patch("torch.manual_seed") + def test_set_random_seed_uses_pg_collection_for_pp_rank( + self, + mock_torch_manual_seed, + mock_cuda_manual_seed, + mock_get_pg_rank, + mock_model_parallel_seed, + mock_cuda_device_count, + mock_group_rank, + mock_get_rank, + ): + """Test that _set_random_seed uses pg_collection for PP rank.""" + from megatron.bridge.training.initialize import _set_random_seed + + # Setup mock pg_collection + mock_pg_collection = MagicMock() + mock_pg_collection.pp = MagicMock() + mock_pg_collection.dp = MagicMock() + mock_pg_collection.tp = MagicMock() + mock_pg_collection.ep = MagicMock() + mock_pg_collection.expt_tp = MagicMock() + + # Mock get_group_rank to return PP rank + mock_group_rank.side_effect = lambda pg, rank: 0 + + # Execute + _set_random_seed( + seed_=42, + data_parallel_random_init=False, + te_rng_tracker=False, + inference_rng_tracker=False, + use_cudagraphable_rng=False, + pg_collection=mock_pg_collection, + ) + + # Verify get_group_rank was called with pg_collection.pp + mock_group_rank.assert_any_call(mock_pg_collection.pp, 0) + + @patch("torch.distributed.get_rank", return_value=0) + @patch("torch.distributed.get_group_rank") + @patch("torch.cuda.device_count", return_value=1) + @patch("megatron.core.tensor_parallel.model_parallel_cuda_manual_seed") + @patch("megatron.core.utils.get_pg_rank", return_value=0) + @patch("torch.cuda.manual_seed") + @patch("torch.manual_seed") + def test_set_random_seed_uses_pg_collection_for_dp_rank( + self, + mock_torch_manual_seed, + mock_cuda_manual_seed, + mock_get_pg_rank, + mock_model_parallel_seed, + mock_cuda_device_count, + mock_group_rank, + mock_get_rank, + ): + """Test that _set_random_seed uses pg_collection for DP rank with data_parallel_random_init.""" + from megatron.bridge.training.initialize import _set_random_seed + + # Setup mock pg_collection + mock_pg_collection = MagicMock() + mock_pg_collection.pp = MagicMock() + mock_pg_collection.dp = MagicMock() + mock_pg_collection.tp = MagicMock() + mock_pg_collection.ep = MagicMock() + mock_pg_collection.expt_tp = MagicMock() + + # Mock get_group_rank to return different values for PP and DP + def side_effect(pg, rank): + if pg == mock_pg_collection.dp: + return 1 + return 0 + + mock_group_rank.side_effect = side_effect + + # Execute with data_parallel_random_init=True + _set_random_seed( + seed_=42, + data_parallel_random_init=True, + te_rng_tracker=False, + inference_rng_tracker=False, + use_cudagraphable_rng=False, + pg_collection=mock_pg_collection, + ) + + # Verify get_group_rank was called with pg_collection.dp + mock_group_rank.assert_any_call(mock_pg_collection.dp, 0) + + +class TestTorchDistInitReturnValue: + """Tests for torch_dist_init function return value.""" + + def test_torch_dist_init_returns_pg_collection_when_not_lazy(self): + """Test that torch_dist_init returns ProcessGroupCollection when not using lazy init.""" + # This test verifies the function signature and return type annotation + import inspect + + from megatron.bridge.training.initialize import torch_dist_init + + sig = inspect.signature(torch_dist_init) + + # Verify the return type annotation includes ProcessGroupCollection + return_annotation = sig.return_annotation + assert "ProcessGroupCollection" in str(return_annotation) or "Callable" in str(return_annotation) + + +class TestInitializeDistributedRaisesOnNoDevices: + """Tests for _initialize_distributed error handling.""" + + @patch("torch.cuda.device_count", return_value=0) + @patch("torch.distributed.is_initialized", return_value=True) + @patch("megatron.bridge.training.initialize.get_rank_safe", return_value=0) + def test_initialize_distributed_raises_on_no_cuda_devices(self, mock_get_rank, mock_is_init, mock_device_count): + """Test that _initialize_distributed raises RuntimeError when no CUDA devices.""" + from megatron.bridge.training.initialize import _initialize_distributed + + mock_model_config = MagicMock() + mock_model_config.tensor_model_parallel_size = 1 + mock_model_config.pipeline_model_parallel_size = 1 + mock_model_config.context_parallel_size = 1 + + mock_dist_config = MagicMock() + + with pytest.raises(RuntimeError, match="Cannot initialize parallel groups with no CUDA devices"): + _initialize_distributed( + model_config=mock_model_config, + dist_config=mock_dist_config, + num_distributed_optimizer_instances=1, + get_embedding_ranks=None, + get_position_embedding_ranks=None, + ) + + +class TestInitializeDistributedBranching: + """Tests for _initialize_distributed branching based on use_decentralized_pg.""" + + @patch("megatron.bridge.training.initialize._create_pg_collection") + @patch("megatron.bridge.training.initialize.parallel_state") + @patch("torch.distributed.get_world_size", return_value=1) + @patch("torch.cuda.device_count", return_value=1) + @patch("torch.distributed.is_initialized", return_value=True) + @patch("megatron.bridge.training.initialize.get_rank_safe", return_value=0) + def test_uses_hyper_comm_grid_when_decentralized_pg_enabled( + self, + mock_get_rank, + mock_is_init, + mock_device_count, + mock_world_size, + mock_parallel_state, + mock_create_pg_collection, + ): + """Test that _initialize_distributed uses HyperCommGrid when use_decentralized_pg=True.""" + from megatron.bridge.training.initialize import _initialize_distributed + + mock_model_config = MagicMock() + mock_model_config.tensor_model_parallel_size = 1 + mock_model_config.pipeline_model_parallel_size = 1 + mock_model_config.context_parallel_size = 1 + + mock_dist_config = MagicMock() + mock_dist_config.use_decentralized_pg = True + + mock_pg_collection = MagicMock() + mock_create_pg_collection.return_value = mock_pg_collection + + result = _initialize_distributed( + model_config=mock_model_config, + dist_config=mock_dist_config, + num_distributed_optimizer_instances=1, + get_embedding_ranks=None, + get_position_embedding_ranks=None, + ) + + # Verify _create_pg_collection was called + mock_create_pg_collection.assert_called_once() + # Verify parallel_state.initialize_model_parallel was NOT called + mock_parallel_state.initialize_model_parallel.assert_not_called() + # Verify the result is the pg_collection from _create_pg_collection + assert result == mock_pg_collection + + @patch("megatron.bridge.training.initialize.ProcessGroupCollection") + @patch("megatron.bridge.training.initialize._create_pg_collection") + @patch("megatron.bridge.training.initialize.parallel_state") + @patch("torch.cuda.device_count", return_value=1) + @patch("torch.distributed.is_initialized", return_value=True) + @patch("megatron.bridge.training.initialize.get_rank_safe", return_value=0) + def test_uses_mpu_when_decentralized_pg_disabled( + self, + mock_get_rank, + mock_is_init, + mock_device_count, + mock_parallel_state, + mock_create_pg_collection, + mock_pg_collection_class, + ): + """Test that _initialize_distributed uses mpu when use_decentralized_pg=False.""" + from megatron.bridge.training.initialize import _initialize_distributed + + mock_model_config = MagicMock() + mock_model_config.tensor_model_parallel_size = 1 + mock_model_config.pipeline_model_parallel_size = 1 + mock_model_config.context_parallel_size = 1 + mock_model_config.virtual_pipeline_model_parallel_size = None + mock_model_config.pipeline_model_parallel_comm_backend = "nccl" + mock_model_config.hierarchical_context_parallel_sizes = None + mock_model_config.expert_model_parallel_size = 1 + mock_model_config.expert_tensor_parallel_size = None + + mock_dist_config = MagicMock() + mock_dist_config.use_decentralized_pg = False + mock_dist_config.distributed_timeout_minutes = 30 + mock_dist_config.nccl_communicator_config_path = None + mock_dist_config.use_tp_pp_dp_mapping = False + mock_dist_config.use_gloo_process_groups = False + mock_dist_config.use_sharp = False + mock_dist_config.high_priority_stream_groups = False + mock_dist_config.sharp_enabled_group = None + + mock_parallel_state.model_parallel_is_initialized.return_value = False + mock_pg_collection = MagicMock() + mock_pg_collection_class.use_mpu_process_groups.return_value = mock_pg_collection + + _initialize_distributed( + model_config=mock_model_config, + dist_config=mock_dist_config, + num_distributed_optimizer_instances=1, + get_embedding_ranks=None, + get_position_embedding_ranks=None, + ) + + # Verify _create_pg_collection was NOT called + mock_create_pg_collection.assert_not_called() + # Verify parallel_state.initialize_model_parallel WAS called + mock_parallel_state.initialize_model_parallel.assert_called_once() + + +class TestSetupUsesDecentralizedPg: + """Tests for setup function behavior with use_decentralized_pg.""" + + def test_config_use_decentralized_pg_enabled(self): + """Test that use_decentralized_pg can be enabled in config.""" + from megatron.bridge.training.config import DistributedInitConfig + + config = DistributedInitConfig(use_decentralized_pg=True) + + # When use_decentralized_pg=True, _initialize_distributed uses HyperCommGrid + assert config.use_decentralized_pg is True + + def test_config_use_decentralized_pg_disabled_default(self): + """Test that use_decentralized_pg defaults to False.""" + from megatron.bridge.training.config import DistributedInitConfig + + config = DistributedInitConfig(use_decentralized_pg=False) + + # When use_decentralized_pg=False (default), _initialize_distributed uses mpu + assert config.use_decentralized_pg is False + + +class TestSetupOptimizerWithPgCollection: + """Tests for setup_optimizer function with pg_collection parameter.""" + + @patch("megatron.bridge.training.optim.get_megatron_optimizer") + @patch("megatron.bridge.training.optim.OptimizerParamScheduler") + def test_setup_optimizer_passes_pg_collection_to_get_megatron_optimizer(self, mock_scheduler, mock_get_optimizer): + """Test that setup_optimizer passes pg_collection to get_megatron_optimizer.""" + from megatron.core.optimizer import OptimizerConfig + + from megatron.bridge.training.config import SchedulerConfig + from megatron.bridge.training.optim import setup_optimizer + + # Setup mocks + mock_model = MagicMock() + mock_pg_collection = MagicMock() + mock_optimizer = MagicMock() + mock_get_optimizer.return_value = mock_optimizer + + optimizer_config = OptimizerConfig(optimizer="adam", lr=1e-3) + scheduler_config = SchedulerConfig() + + # Execute + setup_optimizer( + optimizer_config=optimizer_config, + scheduler_config=scheduler_config, + model=mock_model, + use_gloo_process_groups=False, + pg_collection=mock_pg_collection, + ) + + # Verify pg_collection was passed + mock_get_optimizer.assert_called_once() + call_kwargs = mock_get_optimizer.call_args[1] + assert call_kwargs["pg_collection"] == mock_pg_collection + + @patch("megatron.bridge.training.optim.get_megatron_optimizer") + @patch("megatron.bridge.training.optim.OptimizerParamScheduler") + def test_setup_optimizer_passes_none_pg_collection_when_not_provided(self, mock_scheduler, mock_get_optimizer): + """Test that setup_optimizer passes None pg_collection when not provided.""" + from megatron.core.optimizer import OptimizerConfig + + from megatron.bridge.training.config import SchedulerConfig + from megatron.bridge.training.optim import setup_optimizer + + # Setup mocks + mock_model = MagicMock() + mock_optimizer = MagicMock() + mock_get_optimizer.return_value = mock_optimizer + + optimizer_config = OptimizerConfig(optimizer="adam", lr=1e-3) + scheduler_config = SchedulerConfig() + + # Execute without pg_collection + setup_optimizer( + optimizer_config=optimizer_config, + scheduler_config=scheduler_config, + model=mock_model, + use_gloo_process_groups=False, + ) + + # Verify pg_collection was None (default) + mock_get_optimizer.assert_called_once() + call_kwargs = mock_get_optimizer.call_args[1] + assert call_kwargs["pg_collection"] is None + + @patch("megatron.bridge.training.optim.get_megatron_muon_optimizer") + @patch("megatron.bridge.training.optim.OptimizerParamScheduler") + def test_setup_optimizer_passes_pg_collection_to_muon_optimizer(self, mock_scheduler, mock_get_muon_optimizer): + """Test that setup_optimizer passes pg_collection to muon optimizer.""" + from megatron.core.optimizer import OptimizerConfig + + from megatron.bridge.training.config import SchedulerConfig + from megatron.bridge.training.optim import setup_optimizer + + # Setup mocks + mock_model = MagicMock() + mock_pg_collection = MagicMock() + mock_optimizer = MagicMock() + mock_get_muon_optimizer.return_value = mock_optimizer + + optimizer_config = OptimizerConfig(optimizer="muon", lr=1e-3) + scheduler_config = SchedulerConfig() + + # Execute + setup_optimizer( + optimizer_config=optimizer_config, + scheduler_config=scheduler_config, + model=mock_model, + use_gloo_process_groups=False, + pg_collection=mock_pg_collection, + ) + + # Verify pg_collection was passed to muon optimizer + mock_get_muon_optimizer.assert_called_once() + call_kwargs = mock_get_muon_optimizer.call_args[1] + assert call_kwargs["pg_collection"] == mock_pg_collection + + +class TestSetupConditionalPgCollectionPassing: + """Tests for setup function's conditional pg_collection passing to optimizer.""" + + def test_setup_passes_pg_collection_when_use_decentralized_pg_true(self): + """ + Verify that when use_decentralized_pg=True, pg_collection is passed to optimizer. + + This tests the logic at setup.py line 232-234: + pg_collection=pg_collection if cfg.dist.use_decentralized_pg else None + """ + from megatron.bridge.training.config import DistributedInitConfig + + # Create config with use_decentralized_pg=True + config = DistributedInitConfig(use_decentralized_pg=True) + + # Simulate the conditional expression from setup.py + mock_pg_collection = MagicMock() + passed_pg_collection = mock_pg_collection if config.use_decentralized_pg else None + + # Verify pg_collection is passed when use_decentralized_pg=True + assert passed_pg_collection is mock_pg_collection + + def test_setup_passes_none_when_use_decentralized_pg_false(self): + """ + Verify that when use_decentralized_pg=False, None is passed to optimizer. + + This tests the logic at setup.py line 232-234: + pg_collection=pg_collection if cfg.dist.use_decentralized_pg else None + """ + from megatron.bridge.training.config import DistributedInitConfig + + # Create config with use_decentralized_pg=False + config = DistributedInitConfig(use_decentralized_pg=False) + + # Simulate the conditional expression from setup.py + mock_pg_collection = MagicMock() + passed_pg_collection = mock_pg_collection if config.use_decentralized_pg else None + + # Verify None is passed when use_decentralized_pg=False + assert passed_pg_collection is None + + +class TestCheckpointingWithDecentralizedPg: + """Tests for checkpointing behavior based on use_decentralized_pg setting.""" + + def test_modelopt_state_save_skipped_when_use_decentralized_pg_true(self): + """ + Verify that sharded modelopt_state save is skipped when use_decentralized_pg=True. + + This tests the logic at checkpointing.py line 641: + if not cfg.dist.use_decentralized_pg: + save_sharded_modelopt_state(model, checkpoint_name, (ckpt_cfg.ckpt_format, 1)) + """ + from megatron.bridge.training.config import DistributedInitConfig + + # Create config with use_decentralized_pg=True + config = DistributedInitConfig(use_decentralized_pg=True) + + # Simulate the condition from checkpointing.py + should_save_modelopt = not config.use_decentralized_pg + + # Verify modelopt save is skipped when use_decentralized_pg=True + assert should_save_modelopt is False + + def test_modelopt_state_save_executed_when_use_decentralized_pg_false(self): + """ + Verify that sharded modelopt_state save is executed when use_decentralized_pg=False. + + This tests the logic at checkpointing.py line 641: + if not cfg.dist.use_decentralized_pg: + save_sharded_modelopt_state(model, checkpoint_name, (ckpt_cfg.ckpt_format, 1)) + """ + from megatron.bridge.training.config import DistributedInitConfig + + # Create config with use_decentralized_pg=False (default) + config = DistributedInitConfig(use_decentralized_pg=False) + + # Simulate the condition from checkpointing.py + should_save_modelopt = not config.use_decentralized_pg + + # Verify modelopt save is executed when use_decentralized_pg=False + assert should_save_modelopt is True + + +class TestTrainTensorShapesAdjustWithDecentralizedPg: + """Tests for train.py tensor shapes adjust function behavior.""" + + def test_tensor_shapes_adjust_fn_is_none_when_use_decentralized_pg_true(self): + """ + Verify that adjust_tensor_shapes_fn is None when use_decentralized_pg=True. + + This tests the logic at train.py line 658-666: + if not cfg.dist.use_decentralized_pg: + adjust_tensor_shapes_fn = get_tensor_shapes_adjust_fn_for_distillation(...) + else: + adjust_tensor_shapes_fn = None + """ + from megatron.bridge.training.config import DistributedInitConfig + + # Create config with use_decentralized_pg=True + config = DistributedInitConfig(use_decentralized_pg=True) + + # Simulate the condition from train.py + if not config.use_decentralized_pg: + adjust_tensor_shapes_fn = "would_call_get_tensor_shapes_adjust_fn" + else: + adjust_tensor_shapes_fn = None + + # Verify adjust_tensor_shapes_fn is None when use_decentralized_pg=True + assert adjust_tensor_shapes_fn is None + + def test_tensor_shapes_adjust_fn_is_set_when_use_decentralized_pg_false(self): + """ + Verify that adjust_tensor_shapes_fn is set when use_decentralized_pg=False. + + This tests the logic at train.py line 658-666: + if not cfg.dist.use_decentralized_pg: + adjust_tensor_shapes_fn = get_tensor_shapes_adjust_fn_for_distillation(...) + else: + adjust_tensor_shapes_fn = None + """ + from megatron.bridge.training.config import DistributedInitConfig + + # Create config with use_decentralized_pg=False (default) + config = DistributedInitConfig(use_decentralized_pg=False) + + # Simulate the condition from train.py + if not config.use_decentralized_pg: + adjust_tensor_shapes_fn = "would_call_get_tensor_shapes_adjust_fn" + else: + adjust_tensor_shapes_fn = None + + # Verify adjust_tensor_shapes_fn is set when use_decentralized_pg=False + assert adjust_tensor_shapes_fn is not None + + +class TestCreatePgCollectionWithContextParallelism: + """Tests for _create_pg_collection with context parallelism.""" + + @patch("megatron.bridge.training.initialize.HyperCommGrid") + @patch("torch.distributed.get_world_size", return_value=8) + @patch("torch.distributed.new_subgroups_by_enumeration") + def test_create_pg_collection_with_cp(self, mock_subgroups, mock_world_size, mock_hyper_grid): + """Test _create_pg_collection with context parallelism.""" + from megatron.bridge.training.initialize import _create_pg_collection + + # Create a fresh mock config with CP=2 directly + mock_model_config = MagicMock() + mock_model_config.tensor_model_parallel_size = 1 + mock_model_config.pipeline_model_parallel_size = 1 + mock_model_config.context_parallel_size = 2 # CP=2 + mock_model_config.expert_tensor_parallel_size = None + mock_model_config.expert_model_parallel_size = 1 + + # Setup mock + mock_grid_instance = MagicMock() + mock_grid_instance.create_pg.return_value = MagicMock() + mock_grid_instance._gen_rank_enum.return_value = [[0, 1, 2, 3, 4, 5, 6, 7]] + mock_hyper_grid.return_value = mock_grid_instance + mock_subgroups.return_value = (MagicMock(), []) + + # Execute + _create_pg_collection(mock_model_config, num_distributed_optimizer_instances=1) + + # Verify grid was created with correct shape + # HyperCommGrid is called multiple times (main grid and expert grid) + # The first call is for the main grid which includes CP + # With TP=1, CP=2, PP=1, world_size=8: dp_size = 8 / (1*2*1) = 4 + # Shape should be [1, 2, 4, 1] = [TP, CP, DP, PP] + mock_hyper_grid.assert_called() + first_call_kwargs = mock_hyper_grid.call_args_list[0][1] + assert first_call_kwargs["shape"][1] == 2 # CP size at index 1 + + @patch("megatron.bridge.training.initialize.HyperCommGrid") + @patch("torch.distributed.get_world_size", return_value=8) + @patch("torch.distributed.new_subgroups_by_enumeration") + def test_create_pg_collection_with_tp_cp_pp(self, mock_subgroups, mock_world_size, mock_hyper_grid): + """Test _create_pg_collection with combined TP, CP, and PP.""" + from megatron.bridge.training.initialize import _create_pg_collection + + # Create a fresh mock config with TP=2, CP=2, PP=2 directly + mock_model_config = MagicMock() + mock_model_config.tensor_model_parallel_size = 2 + mock_model_config.pipeline_model_parallel_size = 2 + mock_model_config.context_parallel_size = 2 + mock_model_config.expert_tensor_parallel_size = None + mock_model_config.expert_model_parallel_size = 1 + + # Setup mock + mock_grid_instance = MagicMock() + mock_grid_instance.create_pg.return_value = MagicMock() + mock_grid_instance._gen_rank_enum.return_value = [[0, 1], [2, 3], [4, 5], [6, 7]] + mock_hyper_grid.return_value = mock_grid_instance + mock_subgroups.return_value = (MagicMock(), []) + + # Execute + _create_pg_collection(mock_model_config, num_distributed_optimizer_instances=1) + + # Verify grid was created with correct shape [TP, CP, DP, PP] + # With TP=2, CP=2, PP=2, world_size=8: dp_size = 8 / (2*2*2) = 1 + # Shape should be [2, 2, 1, 2] = [TP, CP, DP, PP] + mock_hyper_grid.assert_called() + first_call_kwargs = mock_hyper_grid.call_args_list[0][1] + assert first_call_kwargs["shape"] == [2, 2, 1, 2] # TP=2, CP=2, DP=1, PP=2 + + +class TestCreatePgCollectionWithDistributedOptimizerInstances: + """Tests for _create_pg_collection with multiple distributed optimizer instances.""" + + @pytest.fixture + def mock_model_config(self): + """Create a mock model configuration for testing.""" + config = MagicMock() + config.tensor_model_parallel_size = 1 + config.pipeline_model_parallel_size = 1 + config.context_parallel_size = 1 + config.expert_tensor_parallel_size = None + config.expert_model_parallel_size = 1 + return config + + @patch("megatron.bridge.training.initialize.HyperCommGrid") + @patch("torch.distributed.get_world_size", return_value=8) + @patch("torch.distributed.new_subgroups_by_enumeration") + def test_create_pg_collection_with_multiple_optimizer_instances( + self, mock_subgroups, mock_world_size, mock_hyper_grid, mock_model_config + ): + """Test _create_pg_collection with multiple distributed optimizer instances.""" + from megatron.bridge.training.initialize import _create_pg_collection + + # Setup mock + mock_grid_instance = MagicMock() + mock_grid_instance.create_pg.return_value = MagicMock() + mock_grid_instance._gen_rank_enum.return_value = [[0, 1, 2, 3, 4, 5, 6, 7]] + mock_hyper_grid.return_value = mock_grid_instance + mock_subgroups.return_value = (MagicMock(), []) + + # Execute with multiple optimizer instances + result = _create_pg_collection(mock_model_config, num_distributed_optimizer_instances=2) + + # Verify result is a ProcessGroupCollection + from megatron.core.process_groups_config import ProcessGroupCollection + + assert isinstance(result, ProcessGroupCollection) diff --git a/tests/unit_tests/training/test_peft_checkpointing.py b/tests/unit_tests/training/test_peft_checkpointing.py index da12a724c0..11979b7d9d 100644 --- a/tests/unit_tests/training/test_peft_checkpointing.py +++ b/tests/unit_tests/training/test_peft_checkpointing.py @@ -413,8 +413,9 @@ def adapter_key_filter(self, key): @patch("megatron.bridge.training.checkpointing.checkpoint_exists") @patch("megatron.bridge.training.checkpointing.apply_peft_adapter_filter_to_state_dict") @patch("megatron.bridge.training.checkpointing.generate_state_dict") + @patch("megatron.bridge.training.checkpointing.dist_checkpointing") def test_load_checkpoint_peft_resume_detection( - self, mock_generate_state_dict, mock_filter, mock_checkpoint_exists, mock_load_base + self, mock_dist_ckpt, mock_generate_state_dict, mock_filter, mock_checkpoint_exists, mock_load_base ): """Test that PEFT resume is properly detected and triggers filtering.""" # Setup mocks @@ -463,6 +464,8 @@ def test_load_checkpoint_peft_resume_detection( mock_cfg.checkpoint.auto_detect_ckpt_format = False mock_cfg.checkpoint.ckpt_format = "torch_dist" mock_cfg.checkpoint.non_persistent_save_interval = None + mock_cfg.dist = Mock() + mock_cfg.dist.use_decentralized_pg = False mock_state.cfg = mock_cfg mock_state.train_state = Mock() mock_state.train_state.consumed_train_samples = 0 @@ -471,6 +474,10 @@ def test_load_checkpoint_peft_resume_detection( mock_state.train_state.step = 1000 # Set to integer for comparisons mock_state.train_state.floating_point_operations_so_far = 50000 + # Mock dist_checkpointing + mock_dist_ckpt.load_content_metadata.return_value = {} + mock_dist_ckpt.load.return_value = {} + # Create mock model mock_model = [Mock()] mock_model[0].load_state_dict = Mock() @@ -486,16 +493,22 @@ def test_load_checkpoint_peft_resume_detection( patch("megatron.bridge.training.checkpointing.print_rank_0"), patch("megatron.bridge.training.checkpointing.read_run_config") as mock_read_run_config, patch("megatron.bridge.training.checkpointing.unwrap_model") as mock_unwrap_model, - patch("megatron.bridge.training.checkpointing.mpu.get_tensor_model_parallel_rank", return_value=0), - patch("megatron.bridge.training.checkpointing.mpu.get_tensor_model_parallel_world_size", return_value=1), - patch("megatron.bridge.training.checkpointing.mpu.get_pipeline_model_parallel_rank", return_value=0), - patch("megatron.bridge.training.checkpointing.mpu.get_pipeline_model_parallel_world_size", return_value=1), + patch("megatron.bridge.training.checkpointing.get_pg_collection") as mock_get_pg_collection, patch("os.path.exists") as mock_exists, ): mock_read_train_state.return_value = mock_state.train_state mock_get_version.return_value = 3.0 mock_unwrap_model.return_value = mock_model + # Create mock pg_collection + mock_pg_collection = Mock() + mock_pg_collection.tp.rank.return_value = 0 + mock_pg_collection.tp.size.return_value = 1 + mock_pg_collection.pp.rank.return_value = 0 + mock_pg_collection.pp.size.return_value = 1 + mock_pg_collection.dp_cp.rank.return_value = 0 + mock_get_pg_collection.return_value = mock_pg_collection + # Mock file existence - run_config.yaml exists, train_state.pt doesn't (to use read_train_state mock) def mock_exists_side_effect(path): if "run_config.yaml" in path: @@ -537,7 +550,8 @@ def mock_exists_side_effect(path): @patch("megatron.bridge.training.checkpointing._load_base_checkpoint") @patch("megatron.bridge.training.checkpointing.checkpoint_exists") - def test_load_checkpoint_non_peft_regular_loading(self, mock_checkpoint_exists, mock_load_base): + @patch("megatron.bridge.training.checkpointing.dist_checkpointing") + def test_load_checkpoint_non_peft_regular_loading(self, mock_dist_ckpt, mock_checkpoint_exists, mock_load_base): """Test that non-PEFT scenarios use regular loading without filtering.""" # Setup mocks mock_checkpoint_exists.return_value = True @@ -573,6 +587,8 @@ def test_load_checkpoint_non_peft_regular_loading(self, mock_checkpoint_exists, mock_cfg.checkpoint.auto_detect_ckpt_format = False mock_cfg.checkpoint.ckpt_format = "torch_dist" mock_cfg.checkpoint.non_persistent_save_interval = None + mock_cfg.dist = Mock() + mock_cfg.dist.use_decentralized_pg = False mock_state.cfg = mock_cfg mock_state.train_state = Mock() mock_state.train_state.consumed_train_samples = 0 @@ -581,6 +597,10 @@ def test_load_checkpoint_non_peft_regular_loading(self, mock_checkpoint_exists, mock_state.train_state.step = 1000 # Set to integer for comparisons mock_state.train_state.floating_point_operations_so_far = 50000 + # Mock dist_checkpointing + mock_dist_ckpt.load_content_metadata.return_value = {} + mock_dist_ckpt.load.return_value = {} + # Create mock model mock_model = [Mock()] mock_model[0].load_state_dict = Mock() @@ -596,16 +616,22 @@ def test_load_checkpoint_non_peft_regular_loading(self, mock_checkpoint_exists, patch("megatron.bridge.training.checkpointing.print_rank_0"), patch("megatron.bridge.training.checkpointing.read_run_config") as mock_read_run_config, patch("megatron.bridge.training.checkpointing.unwrap_model") as mock_unwrap_model, - patch("megatron.bridge.training.checkpointing.mpu.get_tensor_model_parallel_rank", return_value=0), - patch("megatron.bridge.training.checkpointing.mpu.get_tensor_model_parallel_world_size", return_value=1), - patch("megatron.bridge.training.checkpointing.mpu.get_pipeline_model_parallel_rank", return_value=0), - patch("megatron.bridge.training.checkpointing.mpu.get_pipeline_model_parallel_world_size", return_value=1), + patch("megatron.bridge.training.checkpointing.get_pg_collection") as mock_get_pg_collection, patch("os.path.exists") as mock_exists, ): mock_read_train_state.return_value = mock_state.train_state mock_get_version.return_value = 3.0 mock_unwrap_model.return_value = mock_model + # Create mock pg_collection + mock_pg_collection = Mock() + mock_pg_collection.tp.rank.return_value = 0 + mock_pg_collection.tp.size.return_value = 1 + mock_pg_collection.pp.rank.return_value = 0 + mock_pg_collection.pp.size.return_value = 1 + mock_pg_collection.dp_cp.rank.return_value = 0 + mock_get_pg_collection.return_value = mock_pg_collection + # Mock file existence - run_config.yaml exists, train_state.pt doesn't (to use read_train_state mock) def mock_exists_side_effect(path): if "run_config.yaml" in path: @@ -643,8 +669,9 @@ def mock_exists_side_effect(path): @patch("megatron.bridge.training.checkpointing.checkpoint_exists") @patch("megatron.bridge.training.checkpointing.apply_peft_adapter_filter_to_state_dict") @patch("megatron.bridge.training.checkpointing.generate_state_dict") + @patch("megatron.bridge.training.checkpointing.dist_checkpointing") def test_load_checkpoint_peft_resume_multi_model( - self, mock_generate_state_dict, mock_filter, mock_checkpoint_exists, mock_load_base + self, mock_dist_ckpt, mock_generate_state_dict, mock_filter, mock_checkpoint_exists, mock_load_base ): """Test PEFT resume with multiple model chunks (pipeline parallelism).""" # Setup mocks @@ -698,6 +725,8 @@ def test_load_checkpoint_peft_resume_multi_model( mock_cfg.checkpoint.auto_detect_ckpt_format = False mock_cfg.checkpoint.ckpt_format = "torch_dist" mock_cfg.checkpoint.non_persistent_save_interval = None + mock_cfg.dist = Mock() + mock_cfg.dist.use_decentralized_pg = False mock_state.cfg = mock_cfg mock_state.train_state = Mock() mock_state.train_state.consumed_train_samples = 0 @@ -706,6 +735,10 @@ def test_load_checkpoint_peft_resume_multi_model( mock_state.train_state.step = 1000 # Set to integer for comparisons mock_state.train_state.floating_point_operations_so_far = 50000 + # Mock dist_checkpointing + mock_dist_ckpt.load_content_metadata.return_value = {} + mock_dist_ckpt.load.return_value = {} + # Create mock models (2 chunks for pipeline parallelism) mock_model = [Mock(), Mock()] mock_model[0].load_state_dict = Mock() @@ -723,16 +756,22 @@ def test_load_checkpoint_peft_resume_multi_model( patch("megatron.bridge.training.checkpointing.print_rank_0"), patch("megatron.bridge.training.checkpointing.read_run_config") as mock_read_run_config, patch("megatron.bridge.training.checkpointing.unwrap_model") as mock_unwrap_model, - patch("megatron.bridge.training.checkpointing.mpu.get_tensor_model_parallel_rank", return_value=0), - patch("megatron.bridge.training.checkpointing.mpu.get_tensor_model_parallel_world_size", return_value=1), - patch("megatron.bridge.training.checkpointing.mpu.get_pipeline_model_parallel_rank", return_value=0), - patch("megatron.bridge.training.checkpointing.mpu.get_pipeline_model_parallel_world_size", return_value=1), + patch("megatron.bridge.training.checkpointing.get_pg_collection") as mock_get_pg_collection, patch("os.path.exists") as mock_exists, ): mock_read_train_state.return_value = mock_state.train_state mock_get_version.return_value = 3.0 mock_unwrap_model.return_value = mock_model + # Create mock pg_collection + mock_pg_collection = Mock() + mock_pg_collection.tp.rank.return_value = 0 + mock_pg_collection.tp.size.return_value = 1 + mock_pg_collection.pp.rank.return_value = 0 + mock_pg_collection.pp.size.return_value = 1 + mock_pg_collection.dp_cp.rank.return_value = 0 + mock_get_pg_collection.return_value = mock_pg_collection + # Mock file existence - run_config.yaml exists, train_state.pt doesn't (to use read_train_state mock) def mock_exists_side_effect(path): if "run_config.yaml" in path: @@ -813,13 +852,19 @@ def setup_and_teardown_parallel_state(self): assert parallel_state.model_parallel_is_initialized(), "Model parallel not initialized" + from megatron.core.process_groups_config import ProcessGroupCollection + from megatron.bridge.training.initialize import _set_random_seed + # Create pg_collection from initialized mpu + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + _set_random_seed( seed_=1234, data_parallel_random_init=False, te_rng_tracker=True, inference_rng_tracker=False, + pg_collection=pg_collection, ) yield diff --git a/tests/unit_tests/training/test_vlm_step.py b/tests/unit_tests/training/test_vlm_step.py index e71a707911..7a4cd47016 100755 --- a/tests/unit_tests/training/test_vlm_step.py +++ b/tests/unit_tests/training/test_vlm_step.py @@ -135,9 +135,18 @@ def test_get_batch_padding_paths(monkeypatch): # Iterator it = _Iterator(batch) + # Create a proper mock pg_collection with rank/size methods + class _MockProcessGroup: + def rank(self): + return 0 + + def size(self): + return 1 + class _PG: def __init__(self): - self.pp = object() + self.pp = _MockProcessGroup() + self.cp = _MockProcessGroup() tokens, labels, loss_mask, attention_mask, position_ids, *_ = get_batch( it, cfg, use_mtp=False, pg_collection=_PG() @@ -157,11 +166,25 @@ def test_forward_step_schedule_plan(monkeypatch): # No-op CUDA and CP functions monkeypatch.setattr("megatron.core.utils.get_batch_on_this_cp_rank", lambda x: x, raising=True) + # Create a proper mock process group with rank/size methods + class _MockProcessGroup: + def rank(self): + return 0 + + def size(self): + return 1 + + # Create mock pg_collection with proper process groups + class _MockPGCollection: + def __init__(self): + self.pp = _MockProcessGroup() + self.cp = _MockProcessGroup() + # Dummy model with required interface class _Model: def __init__(self): self.config = type("C", (), {"mtp_num_layers": 0, "overlap_moe_expert_parallel_comm": True})() - self._pg_collection = type("PG", (), {"pp": object()})() + self._pg_collection = _MockPGCollection() @property def pg_collection(self):