diff --git a/olmo/data/__init__.py b/olmo/data/__init__.py index b81e99f79..45b94e3f5 100644 --- a/olmo/data/__init__.py +++ b/olmo/data/__init__.py @@ -81,12 +81,21 @@ def build_eval_dataloader( ) -def build_train_dataloader(train_config: TrainConfig, world_size: Optional[int] = None) -> DataLoader: +def build_train_dataloader( + train_config: TrainConfig, + *, + world_size: Optional[int] = None, + rank: Optional[int] = None, + fs_local_rank: Optional[int] = None, + include_instance_metadata: bool = False, +) -> DataLoader: assert train_config.device_train_batch_size is not None collator = DataCollator( pad_direction=train_config.data.pad_direction, pad_token_id=train_config.model.pad_token_id ) - dataset = build_memmap_dataset(train_config, train_config.data, include_instance_metadata=False) + dataset = build_memmap_dataset( + train_config, train_config.data, include_instance_metadata=include_instance_metadata + ) work_dir = Path(train_config.save_folder) / "train_data" if get_global_rank() == 0: if work_dir.is_dir() and not train_config.save_overwrite: @@ -105,6 +114,8 @@ def build_train_dataloader(train_config: TrainConfig, world_size: Optional[int] shuffle=True, drop_last=train_config.data.drop_last, world_size=world_size, + rank=rank, + fs_local_rank=fs_local_rank, work_dir=work_dir, ), batch_size=train_config.device_train_batch_size,