Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
added .csv image loading utils (#116)
Browse files Browse the repository at this point in the history
* added .csv image loading utils

* added .csv image loading utils

* added .csv image loading utils

* added .csv image loading utils

* added .csv image loading utils

* added .csv image loading utils

* added .csv image loading utils

* added .csv image loading utils

* added .csv image loading utils

* added .csv image loading utils

* added .csv image loading utils

* added .csv image loading utils

* added .csv image loading utils

* added .csv image loading utils
  • Loading branch information
williamFalcon authored Feb 14, 2021
1 parent 43b1aa0 commit 592b580
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 9 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,5 @@ docs/api/
titanic.csv
.vscode
data_folder
data
*.pt
*.zip
1 change: 1 addition & 0 deletions docs/source/reference/image_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ Available backbones:
* densenet121
* densenet169
* densenet161
* swav-imagenet

------

Expand Down
1 change: 1 addition & 0 deletions flash/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from flash.data.data_utils import labels_from_categorical_csv
28 changes: 28 additions & 0 deletions flash/data/data_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Dict, List, Union

import pandas as pd


def labels_from_categorical_csv(csv: str, index_col: str, return_dict: bool = True) -> Union[Dict, List]:
"""
Returns a dictionary with {index_col: label} for each entry in the csv.
Expects a csv of this form:
index_col, b, c, d
some_name, 0 0 1
some_name_b, 1 0 0
"""
df = pd.read_csv(csv)
# get names
names = df[index_col].to_list()
del df[index_col]

# everything else is binary
labels = df.to_numpy().argmax(1).tolist()

if return_dict:
labels = {name: label for name, label in zip(names, labels)}

return labels
50 changes: 42 additions & 8 deletions flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pathlib
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union

import pandas as pd
import torch
from PIL import Image
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -55,9 +56,13 @@ def __init__(
self.labels = labels or []
self.transform = transform
self.loader = loader
if self.has_labels:
if not self.has_dict_labels and self.has_labels:
self.label_to_class_mapping = dict(map(reversed, enumerate(sorted(set(self.labels)))))

@property
def has_dict_labels(self) -> bool:
return isinstance(self.labels, dict)

@property
def has_labels(self) -> bool:
return self.labels is not None
Expand All @@ -71,7 +76,11 @@ def __getitem__(self, index: int) -> Tuple[Any, Optional[int]]:
if self.transform is not None:
img = self.transform(img)
label = None
if self.has_labels:
if self.has_dict_labels:
name = os.path.basename(filename)
label = self.labels[name]

elif self.has_labels:
label = self.labels[index]
label = self.label_to_class_mapping[label]
return img, label
Expand Down Expand Up @@ -244,13 +253,13 @@ class ImageClassificationData(DataModule):
@classmethod
def from_filepaths(
cls,
train_filepaths: Optional[Sequence[Union[str, pathlib.Path]]] = None,
train_filepaths: Union[str, Optional[Sequence[Union[str, pathlib.Path]]]] = None,
train_labels: Optional[Sequence] = None,
train_transform: Optional[Callable] = _default_train_transforms,
valid_filepaths: Optional[Sequence[Union[str, pathlib.Path]]] = None,
valid_filepaths: Union[str, Optional[Sequence[Union[str, pathlib.Path]]]] = None,
valid_labels: Optional[Sequence] = None,
valid_transform: Optional[Callable] = _default_valid_transforms,
test_filepaths: Optional[Sequence[Union[str, pathlib.Path]]] = None,
test_filepaths: Union[str, Optional[Sequence[Union[str, pathlib.Path]]]] = None,
test_labels: Optional[Sequence] = None,
loader: Callable = _pil_loader,
batch_size: int = 64,
Expand All @@ -260,13 +269,13 @@ def from_filepaths(
"""Creates a ImageClassificationData object from lists of image filepaths and labels
Args:
train_filepaths: sequence of file paths for training dataset. Defaults to ``None``.
train_filepaths: string or sequence of file paths for training dataset. Defaults to ``None``.
train_labels: sequence of labels for training dataset. Defaults to ``None``.
train_transform: transforms for training dataset. Defaults to ``None``.
valid_filepaths: sequence of file paths for validation dataset. Defaults to ``None``.
valid_filepaths: string or sequence of file paths for validation dataset. Defaults to ``None``.
valid_labels: sequence of labels for validation dataset. Defaults to ``None``.
valid_transform: transforms for validation and testing dataset. Defaults to ``None``.
test_filepaths: sequence of file paths for test dataset. Defaults to ``None``.
test_filepaths: string or sequence of file paths for test dataset. Defaults to ``None``.
test_labels: sequence of labels for test dataset. Defaults to ``None``.
loader: function to load an image file. Defaults to ``None``.
batch_size: the batchsize to use for parallel loading. Defaults to ``64``.
Expand All @@ -278,7 +287,32 @@ def from_filepaths(
Examples:
>>> img_data = ImageClassificationData.from_filepaths(["a.png", "b.png"], [0, 1]) # doctest: +SKIP
Example when labels are in .csv file::
train_labels = labels_from_categorical_csv('path/to/train.csv', 'my_id')
valid_labels = labels_from_categorical_csv(path/to/valid.csv', 'my_id')
test_labels = labels_from_categorical_csv(path/to/tests.csv', 'my_id')
data = ImageClassificationData.from_filepaths(
batch_size=2,
train_filepaths='path/to/train',
train_labels=train_labels,
valid_filepaths='path/to/valid',
valid_labels=valid_labels,
test_filepaths='path/to/test',
test_labels=test_labels,
)
"""
# enable passing in a string which loads all files in that folder as a list
if isinstance(train_filepaths, str):
train_filepaths = [os.path.join(train_filepaths, x) for x in os.listdir(train_filepaths)]
if isinstance(valid_filepaths, str):
valid_filepaths = [os.path.join(valid_filepaths, x) for x in os.listdir(valid_filepaths)]
if isinstance(test_filepaths, str):
test_filepaths = [os.path.join(test_filepaths, x) for x in os.listdir(test_filepaths)]

train_ds = FilepathDataset(
filepaths=train_filepaths,
labels=train_labels,
Expand Down
63 changes: 63 additions & 0 deletions tests/vision/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path

import numpy as np
import torch
from PIL import Image
from torchvision import transforms as T

from flash.data.data_utils import labels_from_categorical_csv
from flash.vision import ImageClassificationData


Expand Down Expand Up @@ -72,6 +74,67 @@ def test_from_filepaths(tmpdir):
assert labels.shape == (1, )


def test_categorical_csv_labels(tmpdir):
train_dir = Path(tmpdir / "some_dataset")
train_dir.mkdir()

(train_dir / "train").mkdir()
_rand_image().save(train_dir / "train" / "train_1.png")
_rand_image().save(train_dir / "train" / "train_2.png")

(train_dir / "valid").mkdir()
_rand_image().save(train_dir / "valid" / "valid_1.png")
_rand_image().save(train_dir / "valid" / "valid_2.png")

(train_dir / "test").mkdir()
_rand_image().save(train_dir / "test" / "test_1.png")
_rand_image().save(train_dir / "test" / "test_2.png")

train_csv = os.path.join(tmpdir, 'some_dataset', 'train.csv')
text_file = open(train_csv, 'w')
text_file.write(
'my_id, label_a, label_b, label_c\n"train_1.png", 0, 1, 0\n"train_2.png", 0, 0, 1\n"train_2.png", 1, 0, 0\n'
)
text_file.close()

valid_csv = os.path.join(tmpdir, 'some_dataset', 'valid.csv')
text_file = open(valid_csv, 'w')
text_file.write(
'my_id, label_a, label_b, label_c\n"valid_1.png", 0, 1, 0\n"valid_2.png", 0, 0, 1\n"valid_3.png", 1, 0, 0\n'
)
text_file.close()

test_csv = os.path.join(tmpdir, 'some_dataset', 'test.csv')
text_file = open(test_csv, 'w')
text_file.write(
'my_id, label_a, label_b, label_c\n"test_1.png", 0, 1, 0\n"test_2.png", 0, 0, 1\n"test_3.png", 1, 0, 0\n'
)
text_file.close()

train_labels = labels_from_categorical_csv(train_csv, 'my_id')
valid_labels = labels_from_categorical_csv(valid_csv, 'my_id')
test_labels = labels_from_categorical_csv(test_csv, 'my_id')

data = ImageClassificationData.from_filepaths(
batch_size=2,
train_filepaths=os.path.join(tmpdir, 'some_dataset', 'train'),
train_labels=train_labels,
valid_filepaths=os.path.join(tmpdir, 'some_dataset', 'valid'),
valid_labels=valid_labels,
test_filepaths=os.path.join(tmpdir, 'some_dataset', 'test'),
test_labels=test_labels,
)

for (x, y) in data.train_dataloader():
assert len(x) == 2

for (x, y) in data.val_dataloader():
assert len(x) == 2

for (x, y) in data.test_dataloader():
assert len(x) == 2


def test_from_folders(tmpdir):
train_dir = Path(tmpdir / "train")
train_dir.mkdir()
Expand Down

0 comments on commit 592b580

Please sign in to comment.