Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge f435f8f into 592b580
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored Feb 14, 2021
2 parents 592b580 + f435f8f commit 41b4a5d
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 19 deletions.
19 changes: 15 additions & 4 deletions flash/data/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from typing import Dict, List, Union
from typing import Any, Dict, List, Union

import pandas as pd


def labels_from_categorical_csv(csv: str, index_col: str, return_dict: bool = True) -> Union[Dict, List]:
def labels_from_categorical_csv(
csv: str,
index_col: str,
feature_cols: List,
return_dict: bool = True,
index_col_collate_fn: Any = None
) -> Union[Dict, List]:
"""
Returns a dictionary with {index_col: label} for each entry in the csv.
Expand All @@ -17,10 +23,15 @@ def labels_from_categorical_csv(csv: str, index_col: str, return_dict: bool = Tr
df = pd.read_csv(csv)
# get names
names = df[index_col].to_list()
del df[index_col]

# apply colate fn to index_col
if index_col_collate_fn:
for i in range(len(names)):
names[i] = index_col_collate_fn(names[i])

# everything else is binary
labels = df.to_numpy().argmax(1).tolist()
feature_df = df[feature_cols]
labels = feature_df.to_numpy().argmax(1).tolist()

if return_dict:
labels = {name: label for name, label in zip(names, labels)}
Expand Down
34 changes: 25 additions & 9 deletions flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def __getitem__(self, index: int) -> Tuple[Any, Optional[int]]:
img = self.transform(img)
label = None
if self.has_dict_labels:
name = os.path.basename(filename)
name = os.path.splitext(filename)[0]
name = os.path.basename(name)
label = self.labels[name]

elif self.has_labels:
Expand Down Expand Up @@ -256,6 +257,7 @@ def from_filepaths(
train_filepaths: Union[str, Optional[Sequence[Union[str, pathlib.Path]]]] = None,
train_labels: Optional[Sequence] = None,
train_transform: Optional[Callable] = _default_train_transforms,
valid_split: Union[None, float] = None,
valid_filepaths: Union[str, Optional[Sequence[Union[str, pathlib.Path]]]] = None,
valid_labels: Optional[Sequence] = None,
valid_transform: Optional[Callable] = _default_valid_transforms,
Expand All @@ -264,6 +266,7 @@ def from_filepaths(
loader: Callable = _pil_loader,
batch_size: int = 64,
num_workers: Optional[int] = None,
seed: int = 1234,
**kwargs
):
"""Creates a ImageClassificationData object from lists of image filepaths and labels
Expand All @@ -272,6 +275,7 @@ def from_filepaths(
train_filepaths: string or sequence of file paths for training dataset. Defaults to ``None``.
train_labels: sequence of labels for training dataset. Defaults to ``None``.
train_transform: transforms for training dataset. Defaults to ``None``.
valid_split: if not None, generates val split from train dataloader using this value.
valid_filepaths: string or sequence of file paths for validation dataset. Defaults to ``None``.
valid_labels: sequence of labels for validation dataset. Defaults to ``None``.
valid_transform: transforms for validation and testing dataset. Defaults to ``None``.
Expand All @@ -281,6 +285,7 @@ def from_filepaths(
batch_size: the batchsize to use for parallel loading. Defaults to ``64``.
num_workers: The number of workers to use for parallelized loading.
Defaults to ``None`` which equals the number of available CPU threads.
seed: Used for the train/val splits when valid_split is not None
Returns:
ImageClassificationData: The constructed data module.
Expand Down Expand Up @@ -319,14 +324,25 @@ def from_filepaths(
loader=loader,
transform=train_transform,
)
valid_ds = (
FilepathDataset(
filepaths=valid_filepaths,
labels=valid_labels,
loader=loader,
transform=valid_transform,
) if valid_filepaths is not None else None
)

if valid_split:
full_length = len(train_ds)
train_split = int((1.0 - valid_split) * full_length)
valid_split = full_length - train_split
train_ds, valid_ds = torch.utils.data.random_split(
train_ds,
[train_split, valid_split],
generator=torch.Generator().manual_seed(seed)
)
else:
valid_ds = (
FilepathDataset(
filepaths=valid_filepaths,
labels=valid_labels,
loader=loader,
transform=valid_transform,
) if valid_filepaths is not None else None
)

test_ds = (
FilepathDataset(
Expand Down
28 changes: 22 additions & 6 deletions tests/vision/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,27 +93,33 @@ def test_categorical_csv_labels(tmpdir):
train_csv = os.path.join(tmpdir, 'some_dataset', 'train.csv')
text_file = open(train_csv, 'w')
text_file.write(
'my_id, label_a, label_b, label_c\n"train_1.png", 0, 1, 0\n"train_2.png", 0, 0, 1\n"train_2.png", 1, 0, 0\n'
'my_id,label_a,label_b,label_c\n"train_1.png", 0, 1, 0\n"train_2.png", 0, 0, 1\n"train_2.png", 1, 0, 0\n'
)
text_file.close()

valid_csv = os.path.join(tmpdir, 'some_dataset', 'valid.csv')
text_file = open(valid_csv, 'w')
text_file.write(
'my_id, label_a, label_b, label_c\n"valid_1.png", 0, 1, 0\n"valid_2.png", 0, 0, 1\n"valid_3.png", 1, 0, 0\n'
'my_id,label_a,label_b,label_c\n"valid_1.png", 0, 1, 0\n"valid_2.png", 0, 0, 1\n"valid_3.png", 1, 0, 0\n'
)
text_file.close()

test_csv = os.path.join(tmpdir, 'some_dataset', 'test.csv')
text_file = open(test_csv, 'w')
text_file.write(
'my_id, label_a, label_b, label_c\n"test_1.png", 0, 1, 0\n"test_2.png", 0, 0, 1\n"test_3.png", 1, 0, 0\n'
'my_id,label_a,label_b,label_c\n"test_1.png", 0, 1, 0\n"test_2.png", 0, 0, 1\n"test_3.png", 1, 0, 0\n'
)
text_file.close()

train_labels = labels_from_categorical_csv(train_csv, 'my_id')
valid_labels = labels_from_categorical_csv(valid_csv, 'my_id')
test_labels = labels_from_categorical_csv(test_csv, 'my_id')
def index_col_collate_fn(x):
return os.path.splitext(x)[0]

train_labels = labels_from_categorical_csv(
train_csv, 'my_id', feature_cols=['label_a', 'label_b', 'label_c'], index_col_collate_fn=index_col_collate_fn)
valid_labels = labels_from_categorical_csv(
valid_csv, 'my_id', feature_cols=['label_a', 'label_b', 'label_c'], index_col_collate_fn=index_col_collate_fn)
test_labels = labels_from_categorical_csv(
test_csv, 'my_id', feature_cols=['label_a', 'label_b', 'label_c'], index_col_collate_fn=index_col_collate_fn)

data = ImageClassificationData.from_filepaths(
batch_size=2,
Expand All @@ -134,6 +140,16 @@ def test_categorical_csv_labels(tmpdir):
for (x, y) in data.test_dataloader():
assert len(x) == 2

data = ImageClassificationData.from_filepaths(
batch_size=2,
train_filepaths=os.path.join(tmpdir, 'some_dataset', 'train'),
train_labels=train_labels,
valid_split=0.5
)

for (x, y) in data.val_dataloader():
assert len(x) == 1


def test_from_folders(tmpdir):
train_dir = Path(tmpdir / "train")
Expand Down

0 comments on commit 41b4a5d

Please sign in to comment.