feat: Enhance dataset loading efficiency with tensor parallelism#2405
feat: Enhance dataset loading efficiency with tensor parallelism#2405
Conversation
- Added `broadcast_data_across_tp` configuration to control data loading across tensor parallel ranks, reducing I/O overhead. - Updated `is_dataset_built_on_rank` to utilize global parallel state when no explicit process group is provided. - Modified dataset building logic in `BlendedMegatronDatasetBuilder` to accommodate the new broadcast mechanism. - Updated `get_batch_on_this_tp_rank` to support packed sequences keys.
📝 WalkthroughWalkthroughAdds optional ProcessGroupCollection parameter to dataset building checks, introduces broadcast_data_across_tp configuration flag, implements conditional batch loading across tensor parallel ranks, and broadcasts non-fixed batch metadata across TP ranks via distributed synchronization. Changes
Sequence Diagram(s)sequenceDiagram
participant TP0 as TP-Rank-0
participant TP_Other as Other TP-Ranks
participant DataLoader as Data Iterator
participant Broadcast as Broadcast Operation
TP0->>DataLoader: get_batch_from_iterator()
DataLoader-->>TP0: batch (fixed-shape + extra metadata)
TP0->>Broadcast: broadcast_data(fixed_tensors)
Broadcast-->>TP_Other: receive fixed_tensors
TP0->>Broadcast: broadcast_object_list(extras)
Broadcast-->>TP_Other: receive extras
TP_Other->>TP_Other: merge extras into batch
TP0->>TP0: merge extras into batch
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Tip Issue Planner is now in beta. Read the docs and try it out! Share your feedback on Discord. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
src/megatron/bridge/data/utils.py (1)
41-62: Dual-path logic is sound; minor type hint guideline deviation.The optional
pg_collectionparameter with fallback toparallel_stateglobals is a clean design that lets this function serve as both an explicit-group call and a zero-argument callable forBlendedMegatronDatasetBuilder.Per coding guidelines, prefer
pg_collection: ProcessGroupCollection | None = NoneoverOptional[ProcessGroupCollection]. The rest of the file also usesOptional, so this is consistent within the file but deviates from the guideline.🔧 Guideline-aligned type hint
-def is_dataset_built_on_rank(pg_collection: Optional[ProcessGroupCollection] = None) -> bool: +def is_dataset_built_on_rank(pg_collection: ProcessGroupCollection | None = None) -> bool:As per coding guidelines: "Use 'T | None' for nullable types instead of 'Optional[T]'".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/data/utils.py` around lines 41 - 62, Update the nullable type hint on the is_dataset_built_on_rank function to use the union syntax instead of Optional: change the parameter annotation from Optional[ProcessGroupCollection] to ProcessGroupCollection | None (keeping the default = None and all logic unchanged) so it follows the "T | None" guideline while still referencing the same ProcessGroupCollection type used throughout the file.src/megatron/bridge/training/utils/batch_utils.py (1)
152-162:broadcast_object_listruns unconditionally — consider skipping when extra keys are absent.Every micro-batch pays the
broadcast_object_listlatency (pickle round-trip + collective) even whenextrais an empty dict{}. For the common pre-training case without packed sequences, this is pure overhead on every step.A simple guard could avoid the collective entirely when there is nothing to broadcast:
♻️ Proposed optimization
# Broadcast any extra keys (e.g. packed-sequence metadata) that are not # covered by the fixed-shape direct broadcasts above. - extra = {k: v for k, v in data.items() if k not in batch} if is_tp_rank0 else None - obj_list = [extra] - torch.distributed.broadcast_object_list(obj_list, src=tp_ranks[0], group=tp_group) - extra = obj_list[0] - for key, val in extra.items(): - if isinstance(val, torch.Tensor): - batch[key] = val.cuda(non_blocking=True) - else: - batch[key] = val + if is_tp_rank0: + extra = {k: v for k, v in data.items() if k not in batch} + has_extra = [bool(extra)] + else: + extra = None + has_extra = [False] + torch.distributed.broadcast_object_list(has_extra, src=tp_ranks[0], group=tp_group) + if has_extra[0]: + obj_list = [extra] + torch.distributed.broadcast_object_list(obj_list, src=tp_ranks[0], group=tp_group) + extra = obj_list[0] + for key, val in extra.items(): + if isinstance(val, torch.Tensor): + batch[key] = val.cuda(non_blocking=True) + else: + batch[key] = valNote: this still incurs one small
broadcast_object_listfor the boolean, but avoids the heavier payload broadcast when there's nothing extra. Alternatively, if extra keys are guaranteed to be present/absent for the lifetime of a run, you could cache the decision once.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/megatron/bridge/training/utils/batch_utils.py` around lines 152 - 162, The current code always calls torch.distributed.broadcast_object_list for `extra` which wastes latency when `extra` is empty; change it to first broadcast a small boolean presence flag and only call broadcast_object_list for the heavy `extra` payload when that flag is True: compute has_extra = bool(extra) on the source rank (use the existing is_tp_rank0 and tp_ranks[0]/tp_group), broadcast the flag via a single-object list, read back the flag on all ranks, and only when the flag is True call torch.distributed.broadcast_object_list for the actual `extra` object and then merge into `batch` (using the same handling for torch.Tensor -> .cuda(non_blocking=True)); otherwise skip the heavy broadcast entirely.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@src/megatron/bridge/data/utils.py`:
- Around line 41-62: Update the nullable type hint on the
is_dataset_built_on_rank function to use the union syntax instead of Optional:
change the parameter annotation from Optional[ProcessGroupCollection] to
ProcessGroupCollection | None (keeping the default = None and all logic
unchanged) so it follows the "T | None" guideline while still referencing the
same ProcessGroupCollection type used throughout the file.
In `@src/megatron/bridge/training/utils/batch_utils.py`:
- Around line 152-162: The current code always calls
torch.distributed.broadcast_object_list for `extra` which wastes latency when
`extra` is empty; change it to first broadcast a small boolean presence flag and
only call broadcast_object_list for the heavy `extra` payload when that flag is
True: compute has_extra = bool(extra) on the source rank (use the existing
is_tp_rank0 and tp_ranks[0]/tp_group), broadcast the flag via a single-object
list, read back the flag on all ranks, and only when the flag is True call
torch.distributed.broadcast_object_list for the actual `extra` object and then
merge into `batch` (using the same handling for torch.Tensor ->
.cuda(non_blocking=True)); otherwise skip the heavy broadcast entirely.
| @@ -0,0 +1,104 @@ | |||
| # Broadcast Data Across Tensor-Parallel Ranks | |||
There was a problem hiding this comment.
thank you for adding this documentation! I believe it'd be useful to have this be part of a general data loading page, which can later be expanded to cover different datasets supported in megatron bridge and how to plug in custom datasets, rather than a page specicially dedicated to TP broadcast vs replicated load, but I'll defer to @yaoyu-33 and @cuichenx for this
| @@ -0,0 +1,173 @@ | |||
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |||
There was a problem hiding this comment.
is it expected for this benchmark to be run in the test suite? or do you want this as a standalone script?
@yaoyu-33 does it make more sense to have a top-level benchmarks / microbenchmarks directory for scripts like this?
There was a problem hiding this comment.
you're right. I'll move it. I didn't intend for this to be a unittext
| is_first_pp_stage=is_first, | ||
| is_last_pp_stage=is_last, | ||
| ) | ||
| broadcast_data = getattr(cfg.dataset, "broadcast_data_across_tp", False) |
There was a problem hiding this comment.
since the dataloader config is shared, this setting is also exposed for vlm datasets. there needs to be handling here for broadcast_data_across_tp too: https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/main/src/megatron/bridge/training/vlm_step.py
… loading - Added `bench_broadcast_tp.py` to measure the performance of direct tensor broadcasting versus `broadcast_object_list` for tensor parallel data loading. - Updated `get_batch_on_this_tp_rank` to support broadcasting configurations and optimized the handling of standard and extra keys. - Enhanced unit tests to validate broadcasting behavior across pipeline parallel stages.
src/megatron/bridge/data/utils.py
Outdated
| return (is_pp_first_stage(pg_collection.pp) or is_pp_last_stage(pg_collection.pp)) and ( | ||
| pg_collection.tp.rank() == 0 | ||
| if pg_collection is not None: | ||
| return (is_pp_first_stage(pg_collection.pp) or is_pp_last_stage(pg_collection.pp)) and ( |
There was a problem hiding this comment.
Check if rank contains MTP layer, required to support placing MTP layers into standalone stages (Not the last PP stage)
What does this PR do?
Adds a
broadcast_data_across_tpconfiguration flag that restores the TP-rank-0 broadcast data loading path removed in PR #491, allowing users on high-latency storage backends (e.g. VAST) to avoid the severe I/O contention caused by replicated loading across all TP ranks.Changelog
broadcast_data_across_tp: bool = FalsetoDataloaderConfig, inherited by all dataset configs.is_dataset_built_on_rankto support optionalpg_collection, falling back to global parallel state so it can be passed directly as a zero-argument callable toBlendedMegatronDatasetBuilder.pretrain_train_valid_test_datasets_providerto conditionally passis_dataset_built_on_rank(broadcast) orlambda: True(replicated) based on the flag.get_batchingpt_step.pyto conditionally useget_batch_on_this_tp_rank(broadcast) orget_batch_from_iterator(replicated) based on the flag.get_batch_on_this_tp_rankviabroadcast_object_listto forward any non-standard batch keys (e.g. packed-sequence metadata) without modifying the existing direct tensor broadcast path.GitHub Actions CI
See the CI section in the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
Additional Information
Context: PR #491 replaced the TP-rank-0 broadcast data loading path with replicated loading on all ranks. This works correctly on Lustre-backed clusters but introduces severe data loading stalls on VAST (network-attached storage) at multi-node scale due to redundant I/O amplification (8x with TP=8). On 64 nodes / 512 GPUs, step times alternate between 22s (compute) and 250-400s (data stall).
Approach: Rather than a hard revert, this PR adds a
broadcast_data_across_tpflag (defaultFalse) so users can select the optimal strategy for their storage backend. When enabled, only TP-rank-0 loads data and broadcasts to other TP ranks, reducing VAST readers by a factor of TP.Validation: Benchmarked on 8x B200 over NVLink, the direct tensor broadcast adds <0.7 ms even at 256K sequence length. The extra-keys path (
broadcast_object_list) adds negligible overhead for the common pretraining case where no extra keys are present.Summary by CodeRabbit
broadcast_data_across_tpconfiguration flag to enable data broadcasting across tensor-parallel ranks. When activated, data is loaded on a single tensor-parallel rank and broadcast to others, providing an alternative to per-rank data loading.