Skip to content

M4: add pg_collection into setup and wire into train.py#1062

Merged
yaoyu-33 merged 10 commits intomainfrom
m4/0_prepare
Nov 6, 2025
Merged

M4: add pg_collection into setup and wire into train.py#1062
yaoyu-33 merged 10 commits intomainfrom
m4/0_prepare

Conversation

@yaoyu-33
Copy link
Contributor

@yaoyu-33 yaoyu-33 commented Oct 23, 2025

Summary by CodeRabbit

  • Refactor
    • Updated distributed training infrastructure to use centralized process group management for batch size computation.
    • Modified distributed communication group configuration for loss aggregation during training.

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 23, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 23, 2025

Walkthrough

The changes introduce a ProcessGroupCollection structure to replace direct parallel_state calls for data-parallel operations. A new pg_collection attribute is added to GlobalState and initialized during setup from ProcessGroupCollection.use_mpu_process_groups(). Training code now retrieves DP size and communication groups from this centralized collection.

Changes

Cohort / File(s) Summary
State management
src/megatron/bridge/training/state.py
Added new pg_collection attribute to GlobalState class, initialized to None.
Setup initialization
src/megatron/bridge/training/setup.py
Introduced pg_collection retrieval from ProcessGroupCollection.use_mpu_process_groups() and exposed it via state.pg_collection. Extended _update_model_config_funcs to accept optional pg_collection parameter. Updated finalize_model_grads_func wiring to use partial() injection of pg_collection.
Training execution
src/megatron/bridge/training/train.py
Replaced parallel_state.get_data_parallel_world_size() calls with global_state.pg_collection.dp.size(). Replaced parallel_state.get_data_parallel_group(with_context_parallel=True) with global_state.pg_collection.dp_cp for distributed all-reduce operations in loss aggregation and iteration skip logic.

Sequence Diagram

sequenceDiagram
    participant Setup as setup()
    participant PGC as ProcessGroupCollection
    participant State as GlobalState
    participant UpdateConfig as _update_model_config_funcs()
    participant Train as train()
    
    Setup->>PGC: use_mpu_process_groups()
    PGC-->>Setup: pg_collection instance
    Setup->>State: state.pg_collection = pg_collection
    Setup->>UpdateConfig: call with pg_collection param
    UpdateConfig->>UpdateConfig: inject pg_collection into finalize_model_grads_func
    Train->>State: retrieve state.pg_collection
    Train->>State: dp_size = pg_collection.dp.size()
    Train->>State: dp_cp_group = pg_collection.dp_cp
    Train->>Train: compute batch_size with dp_size
    Train->>Train: all_reduce(loss, group=dp_cp_group)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

The changes span three files with mixed complexity: trivial attribute addition, moderate setup refactoring with partial injection, and higher-density distributed training logic updates. While the pattern is repetitive (parallel_state → pg_collection), the critical nature of DP group operations and batch size calculations warrants careful verification.

Poem

🐰 From parallel_state calls we hop away,
To pg_collection's brighter day,
Process groups now centralized and neat,
Distributed training made complete! ✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title "M4 PR 0: add pg_collection into global state and setup call" accurately describes the core changes in the changeset. The modifications add a new pg_collection attribute to GlobalState, extend the setup infrastructure to retrieve and pass this collection through _update_model_config_funcs, and wire it into the model configuration via finalize_model_grads_func. While train.py also contains changes that replace parallel_state calls with pg_collection accessors, these represent the practical application of the infrastructure changes rather than the primary focus of the PR. The title is clear, concise, and specific enough for teammates to understand the primary infrastructure change without being vague or misleading.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch m4/0_prepare

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@yaoyu-33 yaoyu-33 changed the title M4 PR0: add pg_collection into global state and setup call. M4 PR 0: add pg_collection into global state and setup call. Oct 23, 2025
@yaoyu-33
Copy link
Contributor Author

@coderabbitai review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 23, 2025

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/megatron/bridge/training/train.py (1)

574-576: Use centralized DP size for LR scheduler increment (consistency).

Mixing cfg.data_parallel_size with pg_collection risks drift. Use pg_collection.dp.size().

-        increment = get_num_microbatches() * train_config.micro_batch_size * cfg.data_parallel_size
+        increment = get_num_microbatches() * train_config.micro_batch_size * global_state.pg_collection.dp.size()
🧹 Nitpick comments (7)
src/megatron/bridge/training/state.py (1)

138-138: Type the new attribute and keep imports light.

Annotate pg_collection as Optional[ProcessGroupCollection] using TYPE_CHECKING to avoid runtime deps; improves IDE/type safety.

Apply:

-from typing import Any, Optional
+from typing import Any, Optional, TYPE_CHECKING
+if TYPE_CHECKING:
+    from megatron.core.process_groups_config import ProcessGroupCollection
...
-        self.pg_collection = None
+        self.pg_collection: Optional["ProcessGroupCollection"] = None
src/megatron/bridge/training/train.py (3)

326-328: Guard pg_collection and avoid per-iter lookup.

Add an upfront assert and cache dp_size once to reduce overhead and fail fast if setup() wasn’t called.

@@
-    start_iteration = global_state.train_state.step
+    start_iteration = global_state.train_state.step
+    # Ensure process groups are wired via setup().
+    assert global_state.pg_collection is not None, "GlobalState.pg_collection is None; call setup() before train()."
+    dp_size_cached = global_state.pg_collection.dp.size()
@@
-        global_state.train_state.step += 1
-        dp_size = global_state.pg_collection.dp.size()
-        batch_size = dp_size * train_config.micro_batch_size * get_num_microbatches()
+        global_state.train_state.step += 1
+        batch_size = dp_size_cached * train_config.micro_batch_size * get_num_microbatches()

595-597: Verify dp_cp group equivalence and add a safe fallback.

Assuming pg_collection.dp_cp maps to prior get_data_parallel_group(with_context_parallel=True). Add a fallback to parallel_state for robustness.

-                dp_cp_group = global_state.pg_collection.dp_cp
-                torch.distributed.all_reduce(val, group=dp_cp_group)
+                dp_cp_group = (
+                    global_state.pg_collection.dp_cp
+                    if global_state.pg_collection is not None
+                    else parallel_state.get_data_parallel_group(with_context_parallel=True)
+                )
+                torch.distributed.all_reduce(val, group=dp_cp_group)

1068-1071: Add the same guard and reuse cached dp_size for skip path.

Mirror the train-loop guard for correctness and consistency.

-    global_state.train_state.step += 1
-    dp_size = global_state.pg_collection.dp.size()
-    batch_size = dp_size * cfg.train.micro_batch_size * get_num_microbatches()
+    global_state.train_state.step += 1
+    assert global_state.pg_collection is not None, "GlobalState.pg_collection is None; call setup() before train()."
+    batch_size = global_state.pg_collection.dp.size() * cfg.train.micro_batch_size * get_num_microbatches()
src/megatron/bridge/training/setup.py (3)

29-29: Import is fine; keep type-only imports elsewhere.

This runtime import is needed here. In state.py, prefer TYPE_CHECKING to avoid hard dep during state construction.


149-153: Good: PGs fetched after initialize_megatron(). Add brief debug log.

Fetching via ProcessGroupCollection.use_mpu_process_groups() post-init is correct. Consider logging ranks/sizes to aid debugging.

Example:

print_rank_0(f"PGCollection: dp={pg_collection.dp.size()}, tp={pg_collection.tp.size()}, pp={pg_collection.pp.size()}")

276-277: New param default looks safe. Document expectation.

Add a short docstring note that pg_collection should come from ProcessGroupCollection.use_mpu_process_groups().

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between abf5a02 and 3f7ff31.

📒 Files selected for processing (3)
  • src/megatron/bridge/training/setup.py (5 hunks)
  • src/megatron/bridge/training/state.py (1 hunks)
  • src/megatron/bridge/training/train.py (3 hunks)
🔇 Additional comments (2)
src/megatron/bridge/training/setup.py (2)

220-227: Threading pg_collection into model-config wiring is consistent.

Passing pg_collection down keeps a single source of truth.

Please confirm no remaining call sites rely on parallel_state DP size/group in model-config functions.


296-299: finalize_model_grads supports pg_collection parameter—code is correct.

The finalize_model_grads function in megatron-core accepts pg_collection as an optional parameter, which allows specifying explicit process groups for distributed collectives. The parameter type is megatron.core.process_groups_config.ProcessGroupCollection, and if not provided, Megatron uses its global initialized process groups. The code at lines 296-299 correctly passes pg_collection via partial(), and no runtime error will occur. The current dependency constraint (megatron-core>=0.14.0a0,<0.16.0) supports this API.

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33
Copy link
Contributor Author

/ok to test 70ae249

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33
Copy link
Contributor Author

/ok to test 97bcd25

@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 30, 2025

/ok to test 97bcd25

@yaoyu-33, there was an error processing your request: E2

See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/

@yaoyu-33 yaoyu-33 changed the title M4 PR 0: add pg_collection into global state and setup call. M4: add pg_collection into setup and wire into train.py Oct 30, 2025
@yaoyu-33
Copy link
Contributor Author

/ok to test 14607ba

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33
Copy link
Contributor Author

/ok to test 224e1a3

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33
Copy link
Contributor Author

/ok to test 7ca7dee

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants