diff --git a/src/megatron/bridge/data/utils.py b/src/megatron/bridge/data/utils.py index 2973553c02..fb5d2b8383 100644 --- a/src/megatron/bridge/data/utils.py +++ b/src/megatron/bridge/data/utils.py @@ -69,8 +69,9 @@ def pretrain_train_valid_test_datasets_provider( print_rank_0("> building train, validation, and test datasets for GPT ...") + # Build the dataset on all ranks for TP-replicated loading train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( - dataset_type, train_val_test_num_samples, is_dataset_built_on_rank, dataset_config + dataset_type, train_val_test_num_samples, lambda: True, dataset_config ).build() print_rank_0("> finished creating GPT datasets ...") diff --git a/src/megatron/bridge/training/gpt_step.py b/src/megatron/bridge/training/gpt_step.py index 2cdacdee5c..5d0db05e48 100644 --- a/src/megatron/bridge/training/gpt_step.py +++ b/src/megatron/bridge/training/gpt_step.py @@ -269,14 +269,7 @@ def get_batch( if (not parallel_state.is_pipeline_first_stage()) and (not parallel_state.is_pipeline_last_stage()): return None, None, None, None, None, None, None, None - if isinstance(cfg.dataset, FinetuningDatasetConfig): - batch = get_batch_from_iterator(data_iterator, use_mtp) - else: - # get batches based on the TP rank you are on - batch = get_batch_on_this_tp_rank(data_iterator, cfg, use_mtp) - batch["cu_seqlens"] = None - batch["cu_seqlens_argmin"] = None - batch["max_seqlen"] = None + batch = get_batch_from_iterator(data_iterator, use_mtp) # slice batch along sequence dimension for context parallelism batch = get_batch_on_this_cp_rank(batch)