Skip to content

feat: Enhance dataset loading efficiency with tensor parallelism#2405

Open
izikgo wants to merge 7 commits intomainfrom
izikg/broadcast-data-across-tp
Open

feat: Enhance dataset loading efficiency with tensor parallelism#2405
izikgo wants to merge 7 commits intomainfrom
izikg/broadcast-data-across-tp

Conversation

@izikgo
Copy link
Contributor

@izikgo izikgo commented Feb 17, 2026

What does this PR do?

Adds a broadcast_data_across_tp configuration flag that restores the TP-rank-0 broadcast data loading path removed in PR #491, allowing users on high-latency storage backends (e.g. VAST) to avoid the severe I/O contention caused by replicated loading across all TP ranks.

Changelog

  • Added broadcast_data_across_tp: bool = False to DataloaderConfig, inherited by all dataset configs.
  • Updated is_dataset_built_on_rank to support optional pg_collection, falling back to global parallel state so it can be passed directly as a zero-argument callable to BlendedMegatronDatasetBuilder.
  • Modified pretrain_train_valid_test_datasets_provider to conditionally pass is_dataset_built_on_rank (broadcast) or lambda: True (replicated) based on the flag.
  • Modified get_batch in gpt_step.py to conditionally use get_batch_on_this_tp_rank (broadcast) or get_batch_from_iterator (replicated) based on the flag.
  • Appended extra-keys broadcast to get_batch_on_this_tp_rank via broadcast_object_list to forward any non-standard batch keys (e.g. packed-sequence metadata) without modifying the existing direct tensor broadcast path.

GitHub Actions CI

See the CI section in the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

Additional Information

Context: PR #491 replaced the TP-rank-0 broadcast data loading path with replicated loading on all ranks. This works correctly on Lustre-backed clusters but introduces severe data loading stalls on VAST (network-attached storage) at multi-node scale due to redundant I/O amplification (8x with TP=8). On 64 nodes / 512 GPUs, step times alternate between 22s (compute) and 250-400s (data stall).

Approach: Rather than a hard revert, this PR adds a broadcast_data_across_tp flag (default False) so users can select the optimal strategy for their storage backend. When enabled, only TP-rank-0 loads data and broadcasts to other TP ranks, reducing VAST readers by a factor of TP.

Validation: Benchmarked on 8x B200 over NVLink, the direct tensor broadcast adds <0.7 ms even at 256K sequence length. The extra-keys path (broadcast_object_list) adds negligible overhead for the common pretraining case where no extra keys are present.

Summary by CodeRabbit

  • New Features
    • Added broadcast_data_across_tp configuration flag to enable data broadcasting across tensor-parallel ranks. When activated, data is loaded on a single tensor-parallel rank and broadcast to others, providing an alternative to per-rank data loading.
    • Improved batch synchronization to broadcast additional metadata and non-fixed keys across all tensor-parallel ranks, ensuring consistent data distribution during training.

- Added `broadcast_data_across_tp` configuration to control data loading across tensor parallel ranks, reducing I/O overhead.
- Updated `is_dataset_built_on_rank` to utilize global parallel state when no explicit process group is provided.
- Modified dataset building logic in `BlendedMegatronDatasetBuilder` to accommodate the new broadcast mechanism.
- Updated `get_batch_on_this_tp_rank` to support packed sequences keys.
@izikgo izikgo requested a review from ananthsub February 17, 2026 16:49
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 17, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 17, 2026

📝 Walkthrough

Walkthrough

Adds optional ProcessGroupCollection parameter to dataset building checks, introduces broadcast_data_across_tp configuration flag, implements conditional batch loading across tensor parallel ranks, and broadcasts non-fixed batch metadata across TP ranks via distributed synchronization.

Changes

Cohort / File(s) Summary
Configuration
src/megatron/bridge/training/config.py
Added new broadcast_data_across_tp boolean flag to DataloaderConfig to control whether TP-rank-0 loads and broadcasts data across tensor-parallel ranks.
Dataset Building
src/megatron/bridge/data/utils.py
Updated is_dataset_built_on_rank to accept optional ProcessGroupCollection parameter with fallback to global parallel_state when not provided; added conditional dataset building logic based on broadcast_data_across_tp configuration.
Batch Loading & Synchronization
src/megatron/bridge/training/gpt_step.py, src/megatron/bridge/training/utils/batch_utils.py
Implemented conditional batch loading where TP-rank-0 loads and broadcasts to other ranks when broadcast_data_across_tp is enabled; added broadcasting of non-fixed batch metadata across TP ranks via distributed synchronization.

Sequence Diagram(s)

sequenceDiagram
    participant TP0 as TP-Rank-0
    participant TP_Other as Other TP-Ranks
    participant DataLoader as Data Iterator
    participant Broadcast as Broadcast Operation

    TP0->>DataLoader: get_batch_from_iterator()
    DataLoader-->>TP0: batch (fixed-shape + extra metadata)
    
    TP0->>Broadcast: broadcast_data(fixed_tensors)
    Broadcast-->>TP_Other: receive fixed_tensors
    
    TP0->>Broadcast: broadcast_object_list(extras)
    Broadcast-->>TP_Other: receive extras
    
    TP_Other->>TP_Other: merge extras into batch
    
    TP0->>TP0: merge extras into batch
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • ananthsub
🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR introduces major data loading changes affecting performance and behavior across tensor-parallel ranks, but test results and validation information are not documented in PR description or codebase. Add performance benchmarks with before-and-after numbers, test files validating new functionality, regression testing results, and testing methodology documentation to the PR.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main change: adding a configurable tensor parallelism feature to optimize dataset loading efficiency. It is concise, clear, and directly related to the primary objective of the PR.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch izikg/broadcast-data-across-tp

Tip

Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
src/megatron/bridge/data/utils.py (1)

41-62: Dual-path logic is sound; minor type hint guideline deviation.

The optional pg_collection parameter with fallback to parallel_state globals is a clean design that lets this function serve as both an explicit-group call and a zero-argument callable for BlendedMegatronDatasetBuilder.

Per coding guidelines, prefer pg_collection: ProcessGroupCollection | None = None over Optional[ProcessGroupCollection]. The rest of the file also uses Optional, so this is consistent within the file but deviates from the guideline.

🔧 Guideline-aligned type hint
-def is_dataset_built_on_rank(pg_collection: Optional[ProcessGroupCollection] = None) -> bool:
+def is_dataset_built_on_rank(pg_collection: ProcessGroupCollection | None = None) -> bool:

As per coding guidelines: "Use 'T | None' for nullable types instead of 'Optional[T]'".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/data/utils.py` around lines 41 - 62, Update the nullable
type hint on the is_dataset_built_on_rank function to use the union syntax
instead of Optional: change the parameter annotation from
Optional[ProcessGroupCollection] to ProcessGroupCollection | None (keeping the
default = None and all logic unchanged) so it follows the "T | None" guideline
while still referencing the same ProcessGroupCollection type used throughout the
file.
src/megatron/bridge/training/utils/batch_utils.py (1)

152-162: broadcast_object_list runs unconditionally — consider skipping when extra keys are absent.

Every micro-batch pays the broadcast_object_list latency (pickle round-trip + collective) even when extra is an empty dict {}. For the common pre-training case without packed sequences, this is pure overhead on every step.

A simple guard could avoid the collective entirely when there is nothing to broadcast:

♻️ Proposed optimization
     # Broadcast any extra keys (e.g. packed-sequence metadata) that are not
     # covered by the fixed-shape direct broadcasts above.
-    extra = {k: v for k, v in data.items() if k not in batch} if is_tp_rank0 else None
-    obj_list = [extra]
-    torch.distributed.broadcast_object_list(obj_list, src=tp_ranks[0], group=tp_group)
-    extra = obj_list[0]
-    for key, val in extra.items():
-        if isinstance(val, torch.Tensor):
-            batch[key] = val.cuda(non_blocking=True)
-        else:
-            batch[key] = val
+    if is_tp_rank0:
+        extra = {k: v for k, v in data.items() if k not in batch}
+        has_extra = [bool(extra)]
+    else:
+        extra = None
+        has_extra = [False]
+    torch.distributed.broadcast_object_list(has_extra, src=tp_ranks[0], group=tp_group)
+    if has_extra[0]:
+        obj_list = [extra]
+        torch.distributed.broadcast_object_list(obj_list, src=tp_ranks[0], group=tp_group)
+        extra = obj_list[0]
+        for key, val in extra.items():
+            if isinstance(val, torch.Tensor):
+                batch[key] = val.cuda(non_blocking=True)
+            else:
+                batch[key] = val

Note: this still incurs one small broadcast_object_list for the boolean, but avoids the heavier payload broadcast when there's nothing extra. Alternatively, if extra keys are guaranteed to be present/absent for the lifetime of a run, you could cache the decision once.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/training/utils/batch_utils.py` around lines 152 - 162,
The current code always calls torch.distributed.broadcast_object_list for
`extra` which wastes latency when `extra` is empty; change it to first broadcast
a small boolean presence flag and only call broadcast_object_list for the heavy
`extra` payload when that flag is True: compute has_extra = bool(extra) on the
source rank (use the existing is_tp_rank0 and tp_ranks[0]/tp_group), broadcast
the flag via a single-object list, read back the flag on all ranks, and only
when the flag is True call torch.distributed.broadcast_object_list for the
actual `extra` object and then merge into `batch` (using the same handling for
torch.Tensor -> .cuda(non_blocking=True)); otherwise skip the heavy broadcast
entirely.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@src/megatron/bridge/data/utils.py`:
- Around line 41-62: Update the nullable type hint on the
is_dataset_built_on_rank function to use the union syntax instead of Optional:
change the parameter annotation from Optional[ProcessGroupCollection] to
ProcessGroupCollection | None (keeping the default = None and all logic
unchanged) so it follows the "T | None" guideline while still referencing the
same ProcessGroupCollection type used throughout the file.

In `@src/megatron/bridge/training/utils/batch_utils.py`:
- Around line 152-162: The current code always calls
torch.distributed.broadcast_object_list for `extra` which wastes latency when
`extra` is empty; change it to first broadcast a small boolean presence flag and
only call broadcast_object_list for the heavy `extra` payload when that flag is
True: compute has_extra = bool(extra) on the source rank (use the existing
is_tp_rank0 and tp_ranks[0]/tp_group), broadcast the flag via a single-object
list, read back the flag on all ranks, and only when the flag is True call
torch.distributed.broadcast_object_list for the actual `extra` object and then
merge into `batch` (using the same handling for torch.Tensor ->
.cuda(non_blocking=True)); otherwise skip the heavy broadcast entirely.

@@ -0,0 +1,104 @@
# Broadcast Data Across Tensor-Parallel Ranks
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for adding this documentation! I believe it'd be useful to have this be part of a general data loading page, which can later be expanded to cover different datasets supported in megatron bridge and how to plug in custom datasets, rather than a page specicially dedicated to TP broadcast vs replicated load, but I'll defer to @yaoyu-33 and @cuichenx for this

@@ -0,0 +1,173 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it expected for this benchmark to be run in the test suite? or do you want this as a standalone script?

@yaoyu-33 does it make more sense to have a top-level benchmarks / microbenchmarks directory for scripts like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right. I'll move it. I didn't intend for this to be a unittext

is_first_pp_stage=is_first,
is_last_pp_stage=is_last,
)
broadcast_data = getattr(cfg.dataset, "broadcast_data_across_tp", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

izikgo and others added 2 commits February 23, 2026 11:21
… loading

- Added `bench_broadcast_tp.py` to measure the performance of direct tensor broadcasting versus `broadcast_object_list` for tensor parallel data loading.
- Updated `get_batch_on_this_tp_rank` to support broadcasting configurations and optimized the handling of standard and extra keys.
- Enhanced unit tests to validate broadcasting behavior across pipeline parallel stages.
return (is_pp_first_stage(pg_collection.pp) or is_pp_last_stage(pg_collection.pp)) and (
pg_collection.tp.rank() == 0
if pg_collection is not None:
return (is_pp_first_stage(pg_collection.pp) or is_pp_last_stage(pg_collection.pp)) and (

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check if rank contains MTP layer, required to support placing MTP layers into standalone stages (Not the last PP stage)

https://github.com/NVIDIA/Megatron-LM/blob/3d1a4ba71ecc49f1a0c9480c90f819d2b00f9915/pretrain_gpt.py#L209

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants