Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When using TorchRec version 0.8.0 or later, we cannot train for more than one epoch when set persistent_workers=true in the dataloader. #2327

Open
tiankongdeguiji opened this issue Aug 21, 2024 · 2 comments

Comments

@tiankongdeguiji
Copy link
Contributor

In the _next_batch method of TrainPipelineSparseDist, we check whether the new dataloader_iter is the same as the original dataloader_iter. We proceed to fetch the next batch only if they are different. However, when we set persistent_workers=true in the dataloader, the dataloader_iter remains the same instance for each epoch. As a result, we can not get data when the epoch exceeds 1.

https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/train_pipeline/train_pipelines.py#L578

    def _next_batch(self, dataloader_iter: Iterator[In]) -> Optional[In]:
        if dataloader_iter is not self._dataloader_iter:
            self._dataloader_iter = dataloader_iter
            self._dataloader_exhausted = False

        if self._dataloader_exhausted:
            batch = None
        else:
            with record_function("## next_batch ##"):
                batch = next(dataloader_iter, None)
            if batch is None:
                self._dataloader_exhausted = True
        return batch
@tiankongdeguiji
Copy link
Contributor Author

Hi, @henrylhtsang @IvanKobzarev @joshuadeng @PaulZhang12 can you see this problem?

@gouchangjiang
Copy link

Try setting the num_workers = 0, and see if it solves your problem. In my case, it works.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants