Skip to content

Commit

Permalink
Fix: Enforce dataloader params to have shuffle=True (#259)
Browse files Browse the repository at this point in the history
### Description

- **What**: It seems that in the `CAREamics` `TrainDataModule` the
dataloader does not have shuffle set to `True`.
- **Why**: Not shuffling the data during training can result in worse
training, e.g. overfitting.
- **How**: Allow users to explicitly pass shuffle=False with a warning,
otherwise `{"shuffle": True}` is added to the param dictionary, if the
dataset is not a subclass of `IterableDataset`.`

### Changes Made

- **Modified**: `TrainDataModule.train_dataloader`

### Additional Notes and Examples

See the discussion in #258 for details.

---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [x] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: Joran Deschamps <[email protected]>
  • Loading branch information
melisande-c and jdeschamps authored Oct 29, 2024
1 parent 5fde99f commit a416c37
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion src/careamics/lightning/train_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union
from warnings import warn

import numpy as np
import pytorch_lightning as L
from numpy.typing import NDArray
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, IterableDataset

from careamics.config import DataConfig
from careamics.config.support import SupportedData
Expand Down Expand Up @@ -446,6 +447,19 @@ def train_dataloader(self) -> Any:
Any
Training dataloader.
"""
# check because iterable dataset cannot be shuffled
if not isinstance(self.train_dataset, IterableDataset):
if ("shuffle" in self.dataloader_params) and (
not self.dataloader_params["shuffle"]
):
warn(
"Dataloader parameters include `shuffle=False`, this will be "
"passed to the training dataloader and may result in bad results.",
stacklevel=1,
)
else:
self.dataloader_params["shuffle"] = True

return DataLoader(
self.train_dataset, batch_size=self.batch_size, **self.dataloader_params
)
Expand Down

0 comments on commit a416c37

Please sign in to comment.