@@ -81,12 +81,21 @@ def build_eval_dataloader(
81
81
)
82
82
83
83
84
- def build_train_dataloader (train_config : TrainConfig , world_size : Optional [int ] = None ) -> DataLoader :
84
+ def build_train_dataloader (
85
+ train_config : TrainConfig ,
86
+ * ,
87
+ world_size : Optional [int ] = None ,
88
+ rank : Optional [int ] = None ,
89
+ fs_local_rank : Optional [int ] = None ,
90
+ include_instance_metadata : bool = False ,
91
+ ) -> DataLoader :
85
92
assert train_config .device_train_batch_size is not None
86
93
collator = DataCollator (
87
94
pad_direction = train_config .data .pad_direction , pad_token_id = train_config .model .pad_token_id
88
95
)
89
- dataset = build_memmap_dataset (train_config , train_config .data , include_instance_metadata = False )
96
+ dataset = build_memmap_dataset (
97
+ train_config , train_config .data , include_instance_metadata = include_instance_metadata
98
+ )
90
99
work_dir = Path (train_config .save_folder ) / "train_data"
91
100
if get_global_rank () == 0 :
92
101
if work_dir .is_dir () and not train_config .save_overwrite :
@@ -105,6 +114,8 @@ def build_train_dataloader(train_config: TrainConfig, world_size: Optional[int]
105
114
shuffle = True ,
106
115
drop_last = train_config .data .drop_last ,
107
116
world_size = world_size ,
117
+ rank = rank ,
118
+ fs_local_rank = fs_local_rank ,
108
119
work_dir = work_dir ,
109
120
),
110
121
batch_size = train_config .device_train_batch_size ,
0 commit comments