Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fixes wrong label in FilePathDataset #94

Merged
merged 3 commits into from
Feb 9, 2021
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
self.transform = transform
self.loader = loader
if self.has_labels:
self.label_to_class_mapping = {v: k for k, v in enumerate(list(sorted(list(set(self.fnames)))))}
self.label_to_class_mapping = {v: k for k, v in enumerate(list(sorted(list(set(self.labels)))))}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.label_to_class_mapping = {v: k for k, v in enumerate(list(sorted(list(set(self.labels)))))}
self.label_to_class_mapping = {v: k for k, v in enumerate(sorted(set(self.labels)))}

maybe we can just zip

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the code. I think it now looks cleaner.


@property
def has_labels(self) -> bool:
Expand All @@ -70,7 +70,8 @@ def __getitem__(self, index: int) -> Tuple[Any, Optional[int]]:
img = self.loader(filename)
label = None
if self.has_labels:
label = self.label_to_class_mapping[filename]
label = self.labels[index]
label = self.label_to_class_mapping[label]
return img, label


Expand Down