M4: add pg_collection into setup and wire into train.py#1062
Conversation
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
WalkthroughThe changes introduce a ProcessGroupCollection structure to replace direct parallel_state calls for data-parallel operations. A new pg_collection attribute is added to GlobalState and initialized during setup from ProcessGroupCollection.use_mpu_process_groups(). Training code now retrieves DP size and communication groups from this centralized collection. Changes
Sequence DiagramsequenceDiagram
participant Setup as setup()
participant PGC as ProcessGroupCollection
participant State as GlobalState
participant UpdateConfig as _update_model_config_funcs()
participant Train as train()
Setup->>PGC: use_mpu_process_groups()
PGC-->>Setup: pg_collection instance
Setup->>State: state.pg_collection = pg_collection
Setup->>UpdateConfig: call with pg_collection param
UpdateConfig->>UpdateConfig: inject pg_collection into finalize_model_grads_func
Train->>State: retrieve state.pg_collection
Train->>State: dp_size = pg_collection.dp.size()
Train->>State: dp_cp_group = pg_collection.dp_cp
Train->>Train: compute batch_size with dp_size
Train->>Train: all_reduce(loss, group=dp_cp_group)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes The changes span three files with mixed complexity: trivial attribute addition, moderate setup refactoring with partial injection, and higher-density distributed training logic updates. While the pattern is repetitive (parallel_state → pg_collection), the critical nature of DP group operations and batch size calculations warrants careful verification. Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/megatron/bridge/training/train.py (1)
574-576: Use centralized DP size for LR scheduler increment (consistency).Mixing cfg.data_parallel_size with pg_collection risks drift. Use pg_collection.dp.size().
- increment = get_num_microbatches() * train_config.micro_batch_size * cfg.data_parallel_size + increment = get_num_microbatches() * train_config.micro_batch_size * global_state.pg_collection.dp.size()
🧹 Nitpick comments (7)
src/megatron/bridge/training/state.py (1)
138-138: Type the new attribute and keep imports light.Annotate pg_collection as Optional[ProcessGroupCollection] using TYPE_CHECKING to avoid runtime deps; improves IDE/type safety.
Apply:
-from typing import Any, Optional +from typing import Any, Optional, TYPE_CHECKING +if TYPE_CHECKING: + from megatron.core.process_groups_config import ProcessGroupCollection ... - self.pg_collection = None + self.pg_collection: Optional["ProcessGroupCollection"] = Nonesrc/megatron/bridge/training/train.py (3)
326-328: Guard pg_collection and avoid per-iter lookup.Add an upfront assert and cache dp_size once to reduce overhead and fail fast if setup() wasn’t called.
@@ - start_iteration = global_state.train_state.step + start_iteration = global_state.train_state.step + # Ensure process groups are wired via setup(). + assert global_state.pg_collection is not None, "GlobalState.pg_collection is None; call setup() before train()." + dp_size_cached = global_state.pg_collection.dp.size() @@ - global_state.train_state.step += 1 - dp_size = global_state.pg_collection.dp.size() - batch_size = dp_size * train_config.micro_batch_size * get_num_microbatches() + global_state.train_state.step += 1 + batch_size = dp_size_cached * train_config.micro_batch_size * get_num_microbatches()
595-597: Verify dp_cp group equivalence and add a safe fallback.Assuming pg_collection.dp_cp maps to prior get_data_parallel_group(with_context_parallel=True). Add a fallback to parallel_state for robustness.
- dp_cp_group = global_state.pg_collection.dp_cp - torch.distributed.all_reduce(val, group=dp_cp_group) + dp_cp_group = ( + global_state.pg_collection.dp_cp + if global_state.pg_collection is not None + else parallel_state.get_data_parallel_group(with_context_parallel=True) + ) + torch.distributed.all_reduce(val, group=dp_cp_group)
1068-1071: Add the same guard and reuse cached dp_size for skip path.Mirror the train-loop guard for correctness and consistency.
- global_state.train_state.step += 1 - dp_size = global_state.pg_collection.dp.size() - batch_size = dp_size * cfg.train.micro_batch_size * get_num_microbatches() + global_state.train_state.step += 1 + assert global_state.pg_collection is not None, "GlobalState.pg_collection is None; call setup() before train()." + batch_size = global_state.pg_collection.dp.size() * cfg.train.micro_batch_size * get_num_microbatches()src/megatron/bridge/training/setup.py (3)
29-29: Import is fine; keep type-only imports elsewhere.This runtime import is needed here. In state.py, prefer TYPE_CHECKING to avoid hard dep during state construction.
149-153: Good: PGs fetched after initialize_megatron(). Add brief debug log.Fetching via ProcessGroupCollection.use_mpu_process_groups() post-init is correct. Consider logging ranks/sizes to aid debugging.
Example:
print_rank_0(f"PGCollection: dp={pg_collection.dp.size()}, tp={pg_collection.tp.size()}, pp={pg_collection.pp.size()}")
276-277: New param default looks safe. Document expectation.Add a short docstring note that pg_collection should come from ProcessGroupCollection.use_mpu_process_groups().
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/megatron/bridge/training/setup.py(5 hunks)src/megatron/bridge/training/state.py(1 hunks)src/megatron/bridge/training/train.py(3 hunks)
🔇 Additional comments (2)
src/megatron/bridge/training/setup.py (2)
220-227: Threading pg_collection into model-config wiring is consistent.Passing pg_collection down keeps a single source of truth.
Please confirm no remaining call sites rely on parallel_state DP size/group in model-config functions.
296-299: finalize_model_grads supports pg_collection parameter—code is correct.The finalize_model_grads function in megatron-core accepts pg_collection as an optional parameter, which allows specifying explicit process groups for distributed collectives. The parameter type is megatron.core.process_groups_config.ProcessGroupCollection, and if not provided, Megatron uses its global initialized process groups. The code at lines 296-299 correctly passes pg_collection via
partial(), and no runtime error will occur. The current dependency constraint (megatron-core>=0.14.0a0,<0.16.0) supports this API.
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test 70ae249 |
|
/ok to test 97bcd25 |
@yaoyu-33, there was an error processing your request: See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/ |
|
/ok to test 14607ba |
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test 224e1a3 |
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test 7ca7dee |
Summary by CodeRabbit