diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 168ad3c521..4080909c14 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -56,7 +56,7 @@ def __init__( self.transform = transform self.loader = loader if self.has_labels: - self.label_to_class_mapping = {v: k for k, v in enumerate(list(sorted(list(set(self.fnames)))))} + self.label_to_class_mapping = dict(map(reversed, enumerate(sorted(set(self.labels))))) @property def has_labels(self) -> bool: @@ -70,7 +70,8 @@ def __getitem__(self, index: int) -> Tuple[Any, Optional[int]]: img = self.loader(filename) label = None if self.has_labels: - label = self.label_to_class_mapping[filename] + label = self.labels[index] + label = self.label_to_class_mapping[label] return img, label