Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for IterableDatasets everywhere #1104

Merged

Conversation

ethanwharris
Copy link
Member

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you write any new necessary tests?

What does this PR do?

Fixes #948

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@ethanwharris ethanwharris requested a review from a team March 9, 2020 16:14
pytorch_lightning/trainer/data_loading.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/data_loading.py Outdated Show resolved Hide resolved
tests/models/mixins.py Show resolved Hide resolved
@@ -213,6 +213,48 @@ def test_dataloader(self):
return self._dataloader(train=False)


class CustomInfDataloader:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather create complet Dataloader so it is easier to undestand... what about?

class CustomInfDataloader:

    def __init__(self, dataset, batch_size, shuffle):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __iter__(self):
        idxs = []
        while True:
            if len(idxs) < self.batch_size:
                idxs = range(len(self.dataset))
                if self.shuffle:
                    np.random.shuffle(idxs)
            batch = [self.dataset[idx] for idx in idxs[:self.batch_size]]
            yield batch
            idxs = idxs[len(batch):]

Copy link
Member Author

@ethanwharris ethanwharris Mar 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.DataLoader does quite a bit more than this (e.g. collate functions, samplers, etc.) so it is probably better to wrap it rather than rewrite it - also we don't really have access to the dataset when this is created, only the dataloader

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Borda we generally want to avoid duplicating torch functionality. Otherwise the project scope will blow up quickly,

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do agree, I just found this construction quite difficult to follow...

pytorch_lightning/trainer/data_loading.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/data_loading.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/data_loading.py Outdated Show resolved Hide resolved
@Borda Borda added this to the 0.7.2 milestone Mar 9, 2020
@Borda Borda added the feature Is an improvement or enhancement label Mar 9, 2020
@Borda
Copy link
Member

Borda commented Mar 11, 2020

hey there, we have added GPU CI test, so could we kindly ask to rebase/merge master which will trigger these tests so we do not need to test it manually... Thx for your understanding 🤖

@ethanwharris
Copy link
Member Author

@Borda Done :)

@ethanwharris ethanwharris added the ready PRs ready to be merged label Mar 12, 2020
@williamFalcon williamFalcon merged commit 2b3f443 into Lightning-AI:master Mar 12, 2020
@ethanwharris ethanwharris deleted the feature/iterable_datasets branch March 12, 2020 17:40
tullie pushed a commit to tullie/pytorch-lightning that referenced this pull request Apr 3, 2020
* Add support for IterableDatasets everywhere

* Added type hints, simplified code and improved coverage in data_loading.py

* Update CHANGELOG.md
@Borda Borda modified the milestones: v0.7., v0.7.x Apr 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support IterableDatasets for validation and test, not just train set [blocked by #953]
3 participants