-
Notifications
You must be signed in to change notification settings - Fork 386
[M4] feat: Add M4 end2end support and qwen3 examples #2011
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 87 commits
Commits
Show all changes
94 commits
Select commit
Hold shift + click to select a range
71821c2
retrieve PGCollection from legacy globals via parallel_state in setup
yaoyu-33 911ec14
Merge branch 'main' into m4/0_prepare
yaoyu-33 3f7ff31
fix setup
yaoyu-33 57c971a
pass pg_collection directly not leverage global state
yaoyu-33 b6a2b59
add unit test
yaoyu-33 70ae249
license
yaoyu-33 14607ba
lint
yaoyu-33 e5acfd9
Merge branch 'main' into m4/0_prepare
yaoyu-33 224e1a3
fix unit tests
yaoyu-33 7ca7dee
fix pretrain api
yaoyu-33 05939dc
remove parallel_state from train.py
yaoyu-33 bac52e2
update gpt_step and vlm_step to not rely on parallel_state
yaoyu-33 0a2e29f
add util to get pg collection from model
yaoyu-33 384488b
remove parallel state from train utils
yaoyu-33 aa82d5e
unit test update
yaoyu-33 1b54119
unit tests fixes
yaoyu-33 2f57a7b
Merge branch 'main' into m4/1_train_loops_and_steps
yaoyu-33 3df91bf
update get_pg_collection to use get_attr_wrapped_model
yaoyu-33 10acaad
Merge branch 'main' into m4/1_train_loops_and_steps
yaoyu-33 c8e6636
update model provider to m4
yaoyu-33 44c9bb4
update model providers for m4
yaoyu-33 6a5a16b
fix model provider unit tests
yaoyu-33 ca797f8
fix unit tests
yaoyu-33 0639dfb
update data part to use m4
yaoyu-33 adec736
Merge branch 'main' into m4/1_train_loops_and_steps
yaoyu-33 534781c
Merge branch 'm4/1_train_loops_and_steps' into m4/4_data
yaoyu-33 3abca31
update unit tests and functional tests
yaoyu-33 1745d2f
address comments
yaoyu-33 71f43cb
Merge branch 'main' into m4/3_model_provider
yaoyu-33 522acef
lint
yaoyu-33 a1cdf4c
fix unit test
yaoyu-33 a26232b
Merge branch 'main' into m4/3_model_provider
yaoyu-33 5ea19fb
add pg_collection in model providers
yaoyu-33 5e6eedf
Merge branch 'main' into m4/3_model_provider
yaoyu-33 8ca3d4c
Merge branch 'main' into m4/3_model_provider
yaoyu-33 cbd5490
update mlm and provider
yaoyu-33 7012412
Merge branch 'main' into m4/3_model_provider
yaoyu-33 2dc3e34
merge main
yaoyu-33 1a2ed53
update to use `_pg_collection`
yaoyu-33 a4c956e
update to use `_pg_collection`
yaoyu-33 ae9aa19
Revert "update to use `_pg_collection`"
yaoyu-33 69de3f4
fix unit test
yaoyu-33 766a397
fix tests
yaoyu-33 c416fe1
fix tests
yaoyu-33 c2b090c
Merge branch 'main' into m4/3_model_provider
yaoyu-33 32ebce4
try fix
yaoyu-33 abb1dc1
Simplify process group removal
yaoyu-33 6d227ed
Merge remote-tracking branch 'origin/main' into m4/3_model_provider
yaoyu-33 9fe6ebc
remove pg_collection from transformer config when pass in to mcore
yaoyu-33 b978f7d
Revert "remove pg_collection from transformer config when pass in to …
yaoyu-33 7c2e39d
override __deepcopy__ for transformer config to get pass the deepcopy…
yaoyu-33 fd0d1d3
wip draft
yaoyu-33 eaf93e0
Merge branch 'main' into m4/3_model_provider
yaoyu-33 9284b84
Revert "wip draft"
yaoyu-33 4cdcc74
fix import
yaoyu-33 4a415e9
lint
yaoyu-33 20df6f7
fix asdict
yaoyu-33 7a18d79
Merge branch 'refs/heads/main' into m4/3_model_provider
yaoyu-33 cac6227
revert global_memory_buffer get_tensor change
yaoyu-33 a955844
Merge branch 'main' into m4/4_data
yaoyu-33 25a2b58
Merge branch 'm4/3_model_provider' into m4/4_data
yaoyu-33 1784aa1
Merge branch 'main' into m4/4_data
yaoyu-33 fb032aa
update from pg_collection to dp_group
yaoyu-33 4f04a0c
Merge branch 'main' into m4/4_data
yaoyu-33 e9dd4df
fix unit test
yaoyu-33 5af94c1
update 3rd party
yaoyu-33 8691e65
fix
yaoyu-33 7b3e7ab
Merge branch 'm4/4_data' into m4/5_initialize
yaoyu-33 c5182f9
feat(config): add use_local_parallel_groups flag to DistributedInitCo…
yaoyu-33 716a268
feat(initialize): add HyperCommGrid-based local parallel group creation
yaoyu-33 4c989ab
feat(training): propagate ProcessGroupCollection through training mod…
yaoyu-33 4b5786a
refactor(checkpointing): use ProcessGroupCollection instead of mpu gl…
yaoyu-33 08fa42c
feat(model_provider): add pg_collection parameter to get_model()
yaoyu-33 c610e1a
test(unit): add unit tests for local parallel groups feature
yaoyu-33 8440087
test(functional): add functional tests for local parallel groups
yaoyu-33 0b2682d
Add local parallel groups examples for Qwen3 pretraining
yaoyu-33 bd92363
refactor: rename use_local_parallel_groups to use_decentralized_pg
yaoyu-33 98dea89
Merge branch 'main' into m4/5_initialize
yaoyu-33 c3927a8
unit test fix
yaoyu-33 ce99243
fix comments
yaoyu-33 fa5af7e
fix tests
yaoyu-33 982ec64
rename tests and examples
yaoyu-33 321556f
fix dist not exist case in checkpoint conversion
yaoyu-33 645ff38
test fix
yaoyu-33 7161f84
Merge origin/main into m4/5_initialize
yaoyu-33 bafab8f
Merge branch 'main' into m4/5_initialize
yaoyu-33 de02e58
fix comments
yaoyu-33 d293ff8
Update src/megatron/bridge/training/initialize.py
yaoyu-33 3367db8
Update examples/recipes/decentralized_pg/README.md
yaoyu-33 b6912a7
fix _create_pg_collection
yaoyu-33 b6a3c14
[test] fix: Add pg_collection parameter to _set_random_seed in SFT un…
yaoyu-33 f5495b6
Merge branch 'main' into m4/5_initialize
yaoyu-33 6e11df4
Fix get_rng_state missing pg_collection argument in low-memory save
yaoyu-33 d0cb1bf
Pass pg_collection to save_checkpoint in low-memory save mode
yaoyu-33 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,179 @@ | ||
| # Decentralized Process Groups Examples | ||
|
|
||
| This directory contains examples demonstrating how to use **decentralized process groups** (`use_decentralized_pg=True`) in Megatron-Bridge for distributed training. | ||
|
|
||
| ## Overview | ||
|
|
||
| Instead of relying on Megatron-Core's global parallel state (mpu) module, you can use a `ProcessGroupCollection` that is explicitly passed to all components. This gives you full control over the parallelism topology and is useful for: | ||
|
|
||
| 1. **Reinforcement Learning**: Multiple model instances (policy, value, reference) with different parallelism | ||
| 2. **Multi-Model Pipelines**: Complex workflows requiring explicit control over communication | ||
| 3. **Testing/Debugging**: Isolated process groups without global state side effects | ||
|
|
||
| ## Files | ||
|
|
||
| | File | Description | | ||
| |------|-------------| | ||
| | `pretrain_qwen3_simple.py` | **Simple**: Use a recipe and enable `use_decentralized_pg=True` | | ||
| | `pretrain_qwen3_with_decentralized_pg.py` | **Advanced**: Manually create process groups with `HyperCommGrid` | | ||
|
|
||
| ## Quick Start | ||
|
|
||
| ### Simple Approach (Recommended) | ||
|
|
||
| Just use an existing recipe and enable decentralized process groups: | ||
|
|
||
| ```bash | ||
| # 8 GPUs: TP2 x PP2 x DP2 | ||
| uv run python -m torch.distributed.run --nproc_per_node=8 examples/recipes/decentralized_pg/pretrain_qwen3_simple.py | ||
|
|
||
| # 4 GPUs: TP2 x PP2 x DP1 | ||
| uv run python -m torch.distributed.run --nproc_per_node=4 examples/recipes/decentralized_pg/pretrain_qwen3_simple.py | ||
| ``` | ||
|
|
||
| The key is just two lines: | ||
|
|
||
| ```python | ||
| from megatron.bridge.recipes.qwen.qwen3 import qwen3_4b_pretrain_config | ||
|
|
||
| cfg = qwen3_4b_pretrain_config( | ||
| tensor_model_parallel_size=2, | ||
| pipeline_model_parallel_size=2, | ||
| # ... other settings | ||
| ) | ||
|
|
||
| # Enable decentralized process groups | ||
| cfg.dist.use_decentralized_pg = True | ||
| cfg.dist.use_gloo_process_groups = False # Gloo not supported | ||
| ``` | ||
|
|
||
| ### Advanced Approach (Manual Process Group Creation) | ||
|
|
||
| For full control over process groups: | ||
|
|
||
| ```bash | ||
| # 8 GPUs: TP2 x PP2 x DP2 | ||
| uv run python -m torch.distributed.run --nproc_per_node=8 examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py | ||
|
|
||
| # 4 GPUs: TP2 x PP2 x DP1 | ||
| uv run python -m torch.distributed.run --nproc_per_node=4 examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py \ | ||
| --tp-size 2 --pp-size 2 | ||
|
|
||
| # 2 GPUs: TP2 x PP1 x DP1 | ||
| uv run python -m torch.distributed.run --nproc_per_node=2 examples/recipes/decentralized_pg/pretrain_qwen3_with_decentralized_pg.py \ | ||
| --tp-size 2 --pp-size 1 | ||
| ``` | ||
|
|
||
| ## Manual Process Group Creation (Advanced) | ||
|
|
||
| ### Step 1: Initialize torch.distributed | ||
|
|
||
| ```python | ||
| torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=rank) | ||
| ``` | ||
|
|
||
| ### Step 2: Create ProcessGroupCollection with HyperCommGrid | ||
|
|
||
| ```python | ||
| from megatron.core.hyper_comm_grid import HyperCommGrid | ||
| from megatron.core.process_groups_config import ProcessGroupCollection | ||
|
|
||
| # Create a grid with shape [TP, CP, DP, PP] | ||
| grid = HyperCommGrid( | ||
| shape=[tp_size, cp_size, dp_size, pp_size], | ||
| dim_names=["tp", "cp", "dp", "pp"], | ||
| rank_offset=0, | ||
| backend="nccl", | ||
| ) | ||
|
|
||
| # Create process groups by selecting dimensions | ||
| tp_pg = grid.create_pg(["tp"]) # Ranks differ only in TP dimension | ||
| pp_pg = grid.create_pg(["pp"]) # Ranks differ only in PP dimension | ||
| dp_pg = grid.create_pg(["dp"]) # Ranks differ only in DP dimension | ||
| mp_pg = grid.create_pg(["tp", "pp"]) # Model parallel = TP + PP | ||
|
|
||
| # Bundle into ProcessGroupCollection | ||
| pg_collection = ProcessGroupCollection( | ||
| tp=tp_pg, | ||
| pp=pp_pg, | ||
| dp=dp_pg, | ||
| mp=mp_pg, | ||
| # ... more groups | ||
| ) | ||
| ``` | ||
|
|
||
| ### Step 3: Set Random Seeds (REQUIRED) | ||
|
|
||
| ```python | ||
| from megatron.core import tensor_parallel | ||
| from megatron.core.utils import get_pg_rank | ||
|
|
||
| # Get TP rank from our process group | ||
| tp_rank = get_pg_rank(pg_collection.tp) | ||
|
|
||
| # Initialize CUDA RNG tracker - REQUIRED before model creation! | ||
| tensor_parallel.model_parallel_cuda_manual_seed( | ||
| seed=1234, | ||
| te_rng_tracker=False, | ||
| inference_rng_tracker=False, | ||
| use_cudagraphable_rng=False, | ||
| tp_rank=tp_rank, | ||
| ep_rank=0, | ||
| etp_rank=tp_rank, | ||
| ) | ||
| ``` | ||
|
|
||
| ### Step 4: Pass pg_collection Explicitly to Components | ||
|
|
||
| ```python | ||
| # Model creation | ||
| model = cfg.model.provide_distributed_model( | ||
| pg_collection=pg_collection, # <-- Pass here! | ||
| ... | ||
| ) | ||
|
|
||
| # Optimizer setup | ||
| optimizer, scheduler = setup_optimizer( | ||
| pg_collection=pg_collection, # <-- Pass here! | ||
| ... | ||
| ) | ||
|
|
||
| # Data loaders use the DP group | ||
| train_data_iterator = setup_data_iterators( | ||
| dp_group=pg_collection.dp, # <-- Use DP group for data sharding! | ||
| ... | ||
| ) | ||
|
|
||
| # Training loop | ||
| train( | ||
| pg_collection=pg_collection, # <-- Pass here! | ||
| ... | ||
| ) | ||
| ``` | ||
|
|
||
| ## HyperCommGrid Explained | ||
|
|
||
| `HyperCommGrid` creates a multi-dimensional grid of ranks. The grid shape `[TP, CP, DP, PP]` defines how ranks are organized: | ||
|
|
||
| ``` | ||
| World Size = 8, Shape = [2, 1, 2, 2] means: | ||
| TP=2, CP=1, DP=2, PP=2 | ||
|
|
||
| Rank layout: | ||
| TP=0,DP=0,PP=0: rank 0 TP=1,DP=0,PP=0: rank 1 | ||
| TP=0,DP=0,PP=1: rank 2 TP=1,DP=0,PP=1: rank 3 | ||
| TP=0,DP=1,PP=0: rank 4 TP=1,DP=1,PP=0: rank 5 | ||
| TP=0,DP=1,PP=1: rank 6 TP=1,DP=1,PP=1: rank 7 | ||
| ``` | ||
|
|
||
| When you call `grid.create_pg(["tp"])`, it creates groups of ranks that share the same DP and PP coordinates but differ in TP: | ||
| - Group 1: [rank 0, rank 1] (DP=0, PP=0) | ||
| - Group 2: [rank 2, rank 3] (DP=0, PP=1) | ||
| - Group 3: [rank 4, rank 5] (DP=1, PP=0) | ||
| - Group 4: [rank 6, rank 7] (DP=1, PP=1) | ||
|
|
||
| ## Limitations | ||
|
|
||
| - Gloo process groups are not supported (NCCL only) | ||
| - ModelOpt sharded checkpointing is disabled | ||
| - Distillation tensor shape adjustment is disabled | ||
79 changes: 79 additions & 0 deletions
79
examples/recipes/decentralized_pg/pretrain_qwen3_simple.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| #!/usr/bin/env python3 | ||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """ | ||
| ============================================================================== | ||
| Example: Qwen3 Pretraining with Decentralized Process Groups (Simple) | ||
| ============================================================================== | ||
|
|
||
| This example demonstrates the simplest way to enable decentralized process groups: | ||
| just use an existing recipe and set `cfg.dist.use_decentralized_pg = True`. | ||
|
|
||
| The setup() function inside pretrain() will automatically create the | ||
| ProcessGroupCollection using HyperCommGrid based on the parallelism settings. | ||
|
|
||
| How to Run | ||
| ---------- | ||
| # 8 GPUs: TP2 x PP2 x DP2 | ||
| uv run python -m torch.distributed.run --nproc_per_node=8 examples/recipes/decentralized_pg/pretrain_qwen3_simple.py | ||
|
|
||
| # 4 GPUs: TP2 x PP2 x DP1 | ||
| uv run python -m torch.distributed.run --nproc_per_node=4 examples/recipes/decentralized_pg/pretrain_qwen3_simple.py | ||
| """ | ||
|
|
||
| import torch | ||
|
|
||
| from megatron.bridge.recipes.qwen.qwen3 import qwen3_4b_pretrain_config | ||
| from megatron.bridge.training.gpt_step import forward_step | ||
| from megatron.bridge.training.pretrain import pretrain | ||
|
|
||
|
|
||
| def main() -> None: | ||
| """Run Qwen3 pretraining with decentralized process groups enabled.""" | ||
| # Get the standard Qwen3 4B pretrain config with overrides | ||
| cfg = qwen3_4b_pretrain_config( | ||
| # Use mock data for demo | ||
| mock=True, | ||
| # Parallelism | ||
| tensor_model_parallel_size=2, | ||
| pipeline_model_parallel_size=2, | ||
| # Training settings (small for demo) | ||
| train_iters=100, | ||
| seq_length=1024, | ||
| global_batch_size=32, | ||
| micro_batch_size=1, | ||
| # LR schedule (must fit within train_iters) | ||
| lr_warmup_iters=10, | ||
| lr_decay_iters=100, | ||
| ) | ||
| # known issue with share_embeddings_and_output_weights | ||
| cfg.model.share_embeddings_and_output_weights = False | ||
|
|
||
| # ========================================================================= | ||
| # KEY: Enable decentralized process groups | ||
| # ========================================================================= | ||
| cfg.dist.use_decentralized_pg = True | ||
| cfg.dist.use_gloo_process_groups = False # Gloo not supported with decentralized PG | ||
|
|
||
| pretrain(config=cfg, forward_step_func=forward_step) | ||
|
|
||
| # Cleanup | ||
| if torch.distributed.is_initialized(): | ||
| torch.distributed.barrier() | ||
| torch.distributed.destroy_process_group() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.