Skip to content

Commit

Permalink
FEAT: add a config option to turn on pin_memory in dataloader (defaul…
Browse files Browse the repository at this point in the history
…t to True)
  • Loading branch information
Ming Du committed Oct 11, 2024
1 parent ea615f3 commit 76a2940
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
3 changes: 3 additions & 0 deletions generic_trainer/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ class TrainingConfig(Config):

loss_tracker_params: LossTrackerParameters = dataclasses.field(default_factory=LossTrackerParameters)
"""Arguments of the loss tracker."""

pin_memory_for_dataloader: bool = True
"""If True, dataloader will put fetched data tensor in pinned memory, which accelerates training."""

automatic_mixed_precision: bool = False
"""Automatic mixed precision and gradient scaling are enabled if True."""
Expand Down
6 changes: 3 additions & 3 deletions generic_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,18 +512,18 @@ def build_dataloaders(self):
batch_size=self.all_proc_batch_size,
collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(),
generator=self.get_dataloader_generator(), num_workers=0,
drop_last=False)
drop_last=False, pin_memory=self.configs.pin_memory_for_dataloader)
self.validation_dataloader = DataLoader(self.validation_dataset, shuffle=True,
batch_size=self.all_proc_batch_size,
collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(),
generator=self.get_dataloader_generator(), num_workers=0,
drop_last=False)
drop_last=False, pin_memory=self.configs.pin_memory_for_dataloader)
if self.test_dataset is not None:
self.test_dataloader = DataLoader(self.test_dataset, shuffle=True,
batch_size=self.all_proc_batch_size,
collate_fn=lambda x: x, worker_init_fn=self.get_worker_seed_func(),
generator=self.get_dataloader_generator(), num_workers=0,
drop_last=False)
drop_last=False, pin_memory=self.configs.pin_memory_for_dataloader)

def run_training(self):
for self.current_epoch in range(self.current_epoch, self.num_epochs):
Expand Down

0 comments on commit 76a2940

Please sign in to comment.