Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in "on_advance_start" when data-loader's sampler is a NumPy array #13320

Closed
LucaButera opened this issue Jun 17, 2022 · 5 comments
Closed
Assignees
Labels
bug Something isn't working data handling Generic data-related topic good first issue Good for newcomers
Milestone

Comments

@LucaButera
Copy link

LucaButera commented Jun 17, 2022

🐛 Bug

When using a NumPy array as sampler for a PyTorch data loader the check

if (
    dataloader is not None
    and getattr(dataloader, "sampler", None)
    and callable(getattr(dataloader.sampler, "set_epoch", None))
    ):

in "on_advance_start", raises the following exception:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

To Reproduce

import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
import numpy as np

class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def run():
    train_data = DataLoader(RandomDataset(32, 64, 1000), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64, 1000), batch_size=2, sampler=np.array([1, 2, 3, 4]))
    test_data = DataLoader(RandomDataset(32, 64, 1000), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)

run()

Expected behavior

The error is not raised.

Environment

  • CUDA:
    • GPU:
      • Tesla T4
    • available: True
    • version: 11.3
  • Packages:
    • numpy: 1.21.6
    • pyTorch_debug: False
    • pyTorch_version: 1.11.0+cu113
    • pytorch-lightning: 1.6.4
    • tqdm: 4.64.0
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.7.13
    • version: Proposal for help #1 SMP Sun Apr 24 10:03:06 PDT 2022

Additional context

An easy solution is to change the code that generates the error to

if (
    dataloader is not None
    and getattr(dataloader, "sampler", None) is not None
    and callable(getattr(dataloader.sampler, "set_epoch", None))
    ):

if the only thing to check is that the sampler exists and is different from None.

cc @Borda @justusschock @awaelchli @ninginthecloud @rohitgr7 @otaj

@LucaButera LucaButera added the needs triage Waiting to be triaged by maintainers label Jun 17, 2022
@carmocca carmocca added bug Something isn't working data handling Generic data-related topic and removed needs triage Waiting to be triaged by maintainers labels Jun 21, 2022
@carmocca
Copy link
Contributor

Would you like to open a PR with your proposed fix?

@carmocca carmocca added this to the pl:1.6.x milestone Jun 21, 2022
@LucaButera
Copy link
Author

Would you like to open a PR with your proposed fix?

Sure, I think I will have time by the end of next week

@carmocca carmocca added the good first issue Good for newcomers label Jun 21, 2022
@BaruchG
Copy link

BaruchG commented Jul 12, 2022

@LucaButera Is this something you are still working on? I'd be happy to take over if you didn't have the bandwidth.

@LucaButera
Copy link
Author

@BaruchG Thanks for asking, however I think this is already fixed in 1.6.5, as I am not able to reproduce the bug anymore.
@carmocca can you confirm this?
If that is the case I think the issue can be marked as solved, maybe linking the PR that solved it.

@carmocca
Copy link
Contributor

Thanks for the heads up! This was fixed by #13396

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working data handling Generic data-related topic good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

3 participants