Skip to content

Commit 72159ae

Browse files
authored
Merge pull request #614 from allenai/shanea/pass-include-instance-metadata
Make include_instance_metadata a kwarg of build_train_dataloader
2 parents c2cedbc + 6aad427 commit 72159ae

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

olmo/data/__init__.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,21 @@ def build_eval_dataloader(
8181
)
8282

8383

84-
def build_train_dataloader(train_config: TrainConfig, world_size: Optional[int] = None) -> DataLoader:
84+
def build_train_dataloader(
85+
train_config: TrainConfig,
86+
*,
87+
world_size: Optional[int] = None,
88+
rank: Optional[int] = None,
89+
fs_local_rank: Optional[int] = None,
90+
include_instance_metadata: bool = False,
91+
) -> DataLoader:
8592
assert train_config.device_train_batch_size is not None
8693
collator = DataCollator(
8794
pad_direction=train_config.data.pad_direction, pad_token_id=train_config.model.pad_token_id
8895
)
89-
dataset = build_memmap_dataset(train_config, train_config.data, include_instance_metadata=False)
96+
dataset = build_memmap_dataset(
97+
train_config, train_config.data, include_instance_metadata=include_instance_metadata
98+
)
9099
work_dir = Path(train_config.save_folder) / "train_data"
91100
if get_global_rank() == 0:
92101
if work_dir.is_dir() and not train_config.save_overwrite:
@@ -105,6 +114,8 @@ def build_train_dataloader(train_config: TrainConfig, world_size: Optional[int]
105114
shuffle=True,
106115
drop_last=train_config.data.drop_last,
107116
world_size=world_size,
117+
rank=rank,
118+
fs_local_rank=fs_local_rank,
108119
work_dir=work_dir,
109120
),
110121
batch_size=train_config.device_train_batch_size,

0 commit comments

Comments
 (0)