diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 4080909c14..d5bde69ef1 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -68,6 +68,8 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> Tuple[Any, Optional[int]]: filename = self.fnames[index] img = self.loader(filename) + if self.transform is not None: + img = self.transform(img) label = None if self.has_labels: label = self.labels[index]