You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Lightning takes care of calling set_epoch for custom DataLoader Samplers here. However, the DataLoader might use a custom BatchSampler instead of a Sampler. Lightning does not call set_epoch for custom BatchSamplers. Calling set_epoch is important for proper seeding in distributed environments.
Expected behavior
Lightning should call set_epoch for BatchSamplers to match its behavior for Samplers.
Can use the DataLoader's index_sampler property to retrieve the Sampler or BatchSampler that is actually being used by the DataLoader, or more simply, call set_epoch for both Sampler and BatchSampler.
Additional context
A workaround is to use a Callback such as
class SetBatchSamplerEpoch(Callback):
""" sets the epoch for batch sampler before dataloader iterator is initialized every training epoch """
def __init__(self):
super().__init__()
@staticmethod
def set_batch_sampler_epoch(dataloader, epoch):
if callable(getattr(dataloader.batch_sampler, "set_epoch", None)):
# print("[RANK {}] Setting batch_sampler epoch to {}".format(os.getenv("LOCAL_RANK", '0'), epoch))
dataloader.batch_sampler.set_epoch(epoch)
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if trainer.train_dataloader is not None:
if isinstance(trainer.train_dataloader, CombinedLoader):
# is the train_dataloader always wrapped in a CombinedLoader at this point?
apply_to_collection(data=trainer.train_dataloader.loaders,
dtype=DataLoader,
function=self.set_batch_sampler_epoch,
epoch=trainer.current_epoch)
elif isinstance(trainer.train_dataloader, DataLoader):
# just in case the train_dataloader is not wrapped in a CombinedLoader
self.set_batch_sampler_epoch(trainer.train_dataloader, trainer.current_epoch)
else:
raise TypeError(f"Unexpected type of trainer.train_dataloader: {type(trainer.train_dataloader)}")
🐛 Bug
Lightning takes care of calling set_epoch for custom DataLoader Samplers here. However, the DataLoader might use a custom BatchSampler instead of a Sampler. Lightning does not call set_epoch for custom BatchSamplers. Calling set_epoch is important for proper seeding in distributed environments.
Expected behavior
Lightning should call set_epoch for BatchSamplers to match its behavior for Samplers.
Can use the DataLoader's index_sampler property to retrieve the Sampler or BatchSampler that is actually being used by the DataLoader, or more simply, call set_epoch for both Sampler and BatchSampler.
Additional context
A workaround is to use a Callback such as
cc @justusschock @awaelchli @ninginthecloud @rohitgr7 @otaj
The text was updated successfully, but these errors were encountered: