Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set_epoch not called for BatchSampler #13316

Closed
samgelman opened this issue Jun 16, 2022 · 0 comments · Fixed by #13396
Closed

set_epoch not called for BatchSampler #13316

samgelman opened this issue Jun 16, 2022 · 0 comments · Fixed by #13396
Labels
bug Something isn't working data handling Generic data-related topic
Milestone

Comments

@samgelman
Copy link

samgelman commented Jun 16, 2022

🐛 Bug

Lightning takes care of calling set_epoch for custom DataLoader Samplers here. However, the DataLoader might use a custom BatchSampler instead of a Sampler. Lightning does not call set_epoch for custom BatchSamplers. Calling set_epoch is important for proper seeding in distributed environments.

Expected behavior

Lightning should call set_epoch for BatchSamplers to match its behavior for Samplers.
Can use the DataLoader's index_sampler property to retrieve the Sampler or BatchSampler that is actually being used by the DataLoader, or more simply, call set_epoch for both Sampler and BatchSampler.

Additional context

A workaround is to use a Callback such as

class SetBatchSamplerEpoch(Callback):
    """ sets the epoch for batch sampler before dataloader iterator is initialized every training epoch """
    def __init__(self):
        super().__init__()

    @staticmethod
    def set_batch_sampler_epoch(dataloader, epoch):
        if callable(getattr(dataloader.batch_sampler, "set_epoch", None)):
            # print("[RANK {}] Setting batch_sampler epoch to {}".format(os.getenv("LOCAL_RANK", '0'), epoch))
            dataloader.batch_sampler.set_epoch(epoch)

    def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if trainer.train_dataloader is not None:
            if isinstance(trainer.train_dataloader, CombinedLoader):
                # is the train_dataloader always wrapped in a CombinedLoader at this point?
                apply_to_collection(data=trainer.train_dataloader.loaders,
                                    dtype=DataLoader,
                                    function=self.set_batch_sampler_epoch,
                                    epoch=trainer.current_epoch)

            elif isinstance(trainer.train_dataloader, DataLoader):
                # just in case the train_dataloader is not wrapped in a CombinedLoader
                self.set_batch_sampler_epoch(trainer.train_dataloader, trainer.current_epoch)

            else:
                raise TypeError(f"Unexpected type of trainer.train_dataloader: {type(trainer.train_dataloader)}")

cc @justusschock @awaelchli @ninginthecloud @rohitgr7 @otaj

@samgelman samgelman added the needs triage Waiting to be triaged by maintainers label Jun 16, 2022
@carmocca carmocca added bug Something isn't working data handling Generic data-related topic and removed needs triage Waiting to be triaged by maintainers labels Jun 21, 2022
@carmocca carmocca added this to the pl:1.6.x milestone Jun 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working data handling Generic data-related topic
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants