Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
71821c2
retrieve PGCollection from legacy globals via parallel_state in setup
yaoyu-33 Oct 23, 2025
911ec14
Merge branch 'main' into m4/0_prepare
yaoyu-33 Oct 23, 2025
3f7ff31
fix setup
yaoyu-33 Oct 23, 2025
57c971a
pass pg_collection directly not leverage global state
yaoyu-33 Oct 29, 2025
b6a2b59
add unit test
yaoyu-33 Oct 30, 2025
70ae249
license
yaoyu-33 Oct 30, 2025
14607ba
lint
yaoyu-33 Oct 30, 2025
e5acfd9
Merge branch 'main' into m4/0_prepare
yaoyu-33 Oct 30, 2025
224e1a3
fix unit tests
yaoyu-33 Oct 30, 2025
7ca7dee
fix pretrain api
yaoyu-33 Oct 31, 2025
05939dc
remove parallel_state from train.py
yaoyu-33 Nov 3, 2025
bac52e2
update gpt_step and vlm_step to not rely on parallel_state
yaoyu-33 Nov 3, 2025
0a2e29f
add util to get pg collection from model
yaoyu-33 Nov 3, 2025
384488b
remove parallel state from train utils
yaoyu-33 Nov 3, 2025
aa82d5e
unit test update
yaoyu-33 Nov 3, 2025
1b54119
unit tests fixes
yaoyu-33 Nov 3, 2025
2f57a7b
Merge branch 'main' into m4/1_train_loops_and_steps
yaoyu-33 Nov 3, 2025
3df91bf
update get_pg_collection to use get_attr_wrapped_model
yaoyu-33 Nov 3, 2025
10acaad
Merge branch 'main' into m4/1_train_loops_and_steps
yaoyu-33 Nov 4, 2025
c8e6636
update model provider to m4
yaoyu-33 Nov 4, 2025
44c9bb4
update model providers for m4
yaoyu-33 Nov 4, 2025
6a5a16b
fix model provider unit tests
yaoyu-33 Nov 5, 2025
ca797f8
fix unit tests
yaoyu-33 Nov 5, 2025
0639dfb
update data part to use m4
yaoyu-33 Nov 6, 2025
adec736
Merge branch 'main' into m4/1_train_loops_and_steps
yaoyu-33 Nov 6, 2025
534781c
Merge branch 'm4/1_train_loops_and_steps' into m4/4_data
yaoyu-33 Nov 6, 2025
3abca31
update unit tests and functional tests
yaoyu-33 Nov 6, 2025
1745d2f
address comments
yaoyu-33 Nov 6, 2025
71f43cb
Merge branch 'main' into m4/3_model_provider
yaoyu-33 Nov 12, 2025
522acef
lint
yaoyu-33 Nov 12, 2025
a1cdf4c
fix unit test
yaoyu-33 Nov 21, 2025
a26232b
Merge branch 'main' into m4/3_model_provider
yaoyu-33 Dec 5, 2025
5ea19fb
add pg_collection in model providers
yaoyu-33 Dec 5, 2025
5e6eedf
Merge branch 'main' into m4/3_model_provider
yaoyu-33 Dec 5, 2025
8ca3d4c
Merge branch 'main' into m4/3_model_provider
yaoyu-33 Dec 6, 2025
cbd5490
update mlm and provider
yaoyu-33 Dec 6, 2025
7012412
Merge branch 'main' into m4/3_model_provider
yaoyu-33 Dec 15, 2025
2dc3e34
merge main
yaoyu-33 Dec 15, 2025
1a2ed53
update to use `_pg_collection`
yaoyu-33 Dec 15, 2025
a4c956e
update to use `_pg_collection`
yaoyu-33 Dec 15, 2025
ae9aa19
Revert "update to use `_pg_collection`"
yaoyu-33 Dec 15, 2025
69de3f4
fix unit test
yaoyu-33 Dec 16, 2025
766a397
fix tests
yaoyu-33 Dec 17, 2025
c416fe1
fix tests
yaoyu-33 Dec 17, 2025
c2b090c
Merge branch 'main' into m4/3_model_provider
yaoyu-33 Dec 17, 2025
32ebce4
try fix
yaoyu-33 Dec 18, 2025
abb1dc1
Simplify process group removal
yaoyu-33 Dec 18, 2025
6d227ed
Merge remote-tracking branch 'origin/main' into m4/3_model_provider
yaoyu-33 Dec 31, 2025
9fe6ebc
remove pg_collection from transformer config when pass in to mcore
yaoyu-33 Dec 31, 2025
b978f7d
Revert "remove pg_collection from transformer config when pass in to …
yaoyu-33 Jan 2, 2026
7c2e39d
override __deepcopy__ for transformer config to get pass the deepcopy…
yaoyu-33 Jan 2, 2026
fd0d1d3
wip draft
yaoyu-33 Jan 5, 2026
eaf93e0
Merge branch 'main' into m4/3_model_provider
yaoyu-33 Jan 5, 2026
9284b84
Revert "wip draft"
yaoyu-33 Jan 5, 2026
4cdcc74
fix import
yaoyu-33 Jan 5, 2026
4a415e9
lint
yaoyu-33 Jan 5, 2026
20df6f7
fix asdict
yaoyu-33 Jan 6, 2026
7a18d79
Merge branch 'refs/heads/main' into m4/3_model_provider
yaoyu-33 Jan 8, 2026
cac6227
revert global_memory_buffer get_tensor change
yaoyu-33 Jan 8, 2026
a955844
Merge branch 'main' into m4/4_data
yaoyu-33 Jan 8, 2026
25a2b58
Merge branch 'm4/3_model_provider' into m4/4_data
yaoyu-33 Jan 8, 2026
1784aa1
Merge branch 'main' into m4/4_data
yaoyu-33 Jan 10, 2026
fb032aa
update from pg_collection to dp_group
yaoyu-33 Jan 10, 2026
4f04a0c
Merge branch 'main' into m4/4_data
yaoyu-33 Jan 14, 2026
e9dd4df
fix unit test
yaoyu-33 Jan 14, 2026
5af94c1
update 3rd party
yaoyu-33 Jan 14, 2026
8691e65
fix
yaoyu-33 Jan 16, 2026
7b3e7ab
Merge branch 'm4/4_data' into m4/5_initialize
yaoyu-33 Jan 20, 2026
c5182f9
feat(config): add use_local_parallel_groups flag to DistributedInitCo…
yaoyu-33 Jan 21, 2026
716a268
feat(initialize): add HyperCommGrid-based local parallel group creation
yaoyu-33 Jan 21, 2026
4c989ab
feat(training): propagate ProcessGroupCollection through training mod…
yaoyu-33 Jan 21, 2026
4b5786a
refactor(checkpointing): use ProcessGroupCollection instead of mpu gl…
yaoyu-33 Jan 21, 2026
08fa42c
feat(model_provider): add pg_collection parameter to get_model()
yaoyu-33 Jan 21, 2026
c610e1a
test(unit): add unit tests for local parallel groups feature
yaoyu-33 Jan 21, 2026
8440087
test(functional): add functional tests for local parallel groups
yaoyu-33 Jan 21, 2026
0b2682d
Add local parallel groups examples for Qwen3 pretraining
yaoyu-33 Jan 21, 2026
bd92363
refactor: rename use_local_parallel_groups to use_decentralized_pg
yaoyu-33 Jan 21, 2026
98dea89
Merge branch 'main' into m4/5_initialize
yaoyu-33 Jan 21, 2026
c3927a8
unit test fix
yaoyu-33 Jan 21, 2026
ce99243
fix comments
yaoyu-33 Jan 21, 2026
fa5af7e
fix tests
yaoyu-33 Jan 21, 2026
982ec64
rename tests and examples
yaoyu-33 Jan 21, 2026
321556f
fix dist not exist case in checkpoint conversion
yaoyu-33 Jan 21, 2026
645ff38
test fix
yaoyu-33 Jan 21, 2026
7161f84
Merge origin/main into m4/5_initialize
yaoyu-33 Jan 23, 2026
bafab8f
Merge branch 'main' into m4/5_initialize
yaoyu-33 Jan 27, 2026
de02e58
fix comments
yaoyu-33 Jan 27, 2026
d293ff8
Update src/megatron/bridge/training/initialize.py
yaoyu-33 Jan 27, 2026
3367db8
Update examples/recipes/decentralized_pg/README.md
yaoyu-33 Jan 27, 2026
b6912a7
fix _create_pg_collection
yaoyu-33 Jan 27, 2026
b6a3c14
[test] fix: Add pg_collection parameter to _set_random_seed in SFT un…
yaoyu-33 Jan 27, 2026
f5495b6
Merge branch 'main' into m4/5_initialize
yaoyu-33 Jan 27, 2026
6e11df4
Fix get_rng_state missing pg_collection argument in low-memory save
yaoyu-33 Jan 28, 2026
d0cb1bf
Pass pg_collection to save_checkpoint in low-memory save mode
yaoyu-33 Jan 28, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions examples/recipes/decentralized_pg/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Decentralized Process Groups Examples

This directory contains examples demonstrating how to use **decentralized process groups** (`use_decentralized_pg=True`) in Megatron-Bridge for distributed training.

## Overview

Instead of relying on Megatron-Core's global parallel state (mpu) module, you can use a `ProcessGroupCollection` that is explicitly passed to all components. This gives you full control over the parallelism topology and is useful for:

1. **Reinforcement Learning**: Multiple model instances (policy, value, reference) with different parallelism
2. **Multi-Model Pipelines**: Complex workflows requiring explicit control over communication
3. **Testing/Debugging**: Isolated process groups without global state side effects

## Files

| File | Description |
|------|-------------|
| `pretrain_qwen3_simple.py` | **Simple**: Use a recipe and enable `use_decentralized_pg=True` |
| `pretrain_qwen3_with_decentralized_pg.py` | **Advanced**: Manually create process groups with `HyperCommGrid` |

## Quick Start

### Simple Approach (Recommended)

Just use an existing recipe and enable decentralized process groups:

```bash
# 8 GPUs: TP2 x PP2 x DP2
uv run python -m torch.distributed.run --nproc_per_node=8 examples/recipes/decentralized_pg/pretrain_qwen3_simple.py

# 4 GPUs: TP2 x PP2 x DP1
uv run python -m torch.distributed.run --nproc_per_node=4 examples/recipes/decentralized_pg/pretrain_qwen3_simple.py
```

The key is just two lines:

```python
from megatron.bridge.recipes.qwen.qwen3 import qwen3_4b_pretrain_config

cfg = qwen3_4b_pretrain_config(
tensor_model_parallel_size=2,
pipeline_model_parallel_size=2,
# ... other settings
)

# Enable decentralized process groups
cfg.dist.use_decentralized_pg = True
cfg.dist.use_gloo_process_groups = False # Gloo not supported
```

### Advanced Approach (Manual Process Group Creation)

For full control over process groups:

```bash
# 8 GPUs: TP2 x PP2 x DP2
uv run python -m torch.distributed.run --nproc_per_node=8 examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py

# 4 GPUs: TP2 x PP2 x DP1
uv run python -m torch.distributed.run --nproc_per_node=4 examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py \
--tp-size 2 --pp-size 2

# 2 GPUs: TP2 x PP1 x DP1
uv run python -m torch.distributed.run --nproc_per_node=2 examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py \
--tp-size 2 --pp-size 1
```

## Manual Process Group Creation (Advanced)

### Step 1: Initialize torch.distributed

```python
torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank)
```

### Step 2: Create ProcessGroupCollection with HyperCommGrid

```python
from megatron.core.hyper_comm_grid import HyperCommGrid
from megatron.core.process_groups_config import ProcessGroupCollection

# Create a grid with shape [TP, CP, DP, PP]
grid = HyperCommGrid(
shape=[tp_size, cp_size, dp_size, pp_size],
dim_names=["tp", "cp", "dp", "pp"],
rank_offset=0,
backend="nccl",
)

# Create process groups by selecting dimensions
tp_pg = grid.create_pg(["tp"]) # Ranks differ only in TP dimension
pp_pg = grid.create_pg(["pp"]) # Ranks differ only in PP dimension
dp_pg = grid.create_pg(["dp"]) # Ranks differ only in DP dimension
mp_pg = grid.create_pg(["tp", "pp"]) # Model parallel = TP + PP

# Bundle into ProcessGroupCollection
pg_collection = ProcessGroupCollection(
tp=tp_pg,
pp=pp_pg,
dp=dp_pg,
mp=mp_pg,
# ... more groups
)
```

### Step 3: Set Random Seeds (REQUIRED)

```python
from megatron.core import tensor_parallel
from megatron.core.utils import get_pg_rank

# Get TP rank from our process group
tp_rank = get_pg_rank(pg_collection.tp)

# Initialize CUDA RNG tracker - REQUIRED before model creation!
tensor_parallel.model_parallel_cuda_manual_seed(
seed=1234,
te_rng_tracker=False,
inference_rng_tracker=False,
use_cudagraphable_rng=False,
tp_rank=tp_rank,
ep_rank=0,
etp_rank=tp_rank,
)
```

### Step 4: Pass pg_collection Explicitly to Components

```python
# Model creation
model = cfg.model.provide_distributed_model(
pg_collection=pg_collection, # <-- Pass here!
...
)

# Optimizer setup
optimizer, scheduler = setup_optimizer(
pg_collection=pg_collection, # <-- Pass here!
...
)

# Data loaders use the DP group
train_data_iterator = setup_data_iterators(
dp_group=pg_collection.dp, # <-- Use DP group for data sharding!
...
)

# Training loop
train(
pg_collection=pg_collection, # <-- Pass here!
...
)
```

## HyperCommGrid Explained

`HyperCommGrid` creates a multi-dimensional grid of ranks. The grid shape `[TP, CP, DP, PP]` defines how ranks are organized:

When you call `grid.create_pg(["tp"])`, it creates groups of ranks that share the same DP and PP coordinates but differ in TP:
- Group 1: [rank 0, rank 1] (DP=0, PP=0)
- Group 2: [rank 2, rank 3] (DP=0, PP=1)
- Group 3: [rank 4, rank 5] (DP=1, PP=0)
- Group 4: [rank 6, rank 7] (DP=1, PP=1)

## Limitations

- Gloo process groups are not supported (NCCL only)
- ModelOpt sharded checkpointing is disabled
- Distillation tensor shape adjustment is disabled
79 changes: 79 additions & 0 deletions examples/recipes/decentralized_pg/pretrain_qwen3_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#!/usr/bin/env python3
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
==============================================================================
Example: Qwen3 Pretraining with Decentralized Process Groups (Simple)
==============================================================================

This example demonstrates the simplest way to enable decentralized process groups:
just use an existing recipe and set `cfg.dist.use_decentralized_pg = True`.

The setup() function inside pretrain() will automatically create the
ProcessGroupCollection using HyperCommGrid based on the parallelism settings.

How to Run
----------
# 8 GPUs: TP2 x PP2 x DP2
uv run python -m torch.distributed.run --nproc_per_node=8 examples/recipes/decentralized_pg/pretrain_qwen3_simple.py

# 4 GPUs: TP2 x PP2 x DP1
uv run python -m torch.distributed.run --nproc_per_node=4 examples/recipes/decentralized_pg/pretrain_qwen3_simple.py
"""

import torch

from megatron.bridge.recipes.qwen.qwen3 import qwen3_4b_pretrain_config
from megatron.bridge.training.gpt_step import forward_step
from megatron.bridge.training.pretrain import pretrain


def main() -> None:
"""Run Qwen3 pretraining with decentralized process groups enabled."""
# Get the standard Qwen3 4B pretrain config with overrides
cfg = qwen3_4b_pretrain_config(
# Use mock data for demo
mock=True,
# Parallelism
tensor_model_parallel_size=2,
pipeline_model_parallel_size=2,
# Training settings (small for demo)
train_iters=100,
seq_length=1024,
global_batch_size=32,
micro_batch_size=1,
# LR schedule (must fit within train_iters)
lr_warmup_iters=10,
lr_decay_iters=100,
)
# known issue with share_embeddings_and_output_weights
cfg.model.share_embeddings_and_output_weights = False

# =========================================================================
# KEY: Enable decentralized process groups
# =========================================================================
cfg.dist.use_decentralized_pg = True
cfg.dist.use_gloo_process_groups = False # Gloo not supported with decentralized PG

pretrain(config=cfg, forward_step_func=forward_step)

# Cleanup
if torch.distributed.is_initialized():
torch.distributed.barrier()
torch.distributed.destroy_process_group()


if __name__ == "__main__":
main()
Loading
Loading