Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

add crashing test in multi-label #233

Merged
merged 4 commits into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,25 +110,30 @@ def load_data(cls, data: Any, dataset: Optional[AutoDataset] = None) -> Iterable
return cls._load_data_files_labels(data=data, dataset=dataset)

@staticmethod
def load_sample(sample) -> Union[Image.Image]:
def load_sample(sample) -> Union[Image.Image, torch.Tensor, Tuple[Image.Image, torch.Tensor]]:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
if isinstance(sample, torch.Tensor):
return sample
out: torch.Tensor = sample
return out

path: str = ""
if isinstance(sample, (tuple, list)):
path = sample[0]
sample = list(sample)
else:
path = sample

with open(path, "rb") as f, Image.open(f) as img:
img = img.convert("RGB")
img_out: Image.Image = img.convert("RGB")

if isinstance(sample, list):
sample[0] = img
return sample
# return a tuple with the PIL image and tensor with the labels.
# returning the tensor helps later to easily collate the batch
# for single/multi label at the same time.
out: Tuple[Image.Image, torch.Tensor] = (img_out, torch.as_tensor(sample[1]))
return out

return img
return img_out

@classmethod
def predict_load_data(cls, samples: Any) -> Iterable:
Expand Down
58 changes: 35 additions & 23 deletions tests/data/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,10 @@ def test_base_viz(tmpdir):
seed_everything(42)
tmpdir = Path(tmpdir)

(tmpdir / "a").mkdir()
(tmpdir / "b").mkdir()
_rand_image().save(tmpdir / "a" / "a_1.png")
_rand_image().save(tmpdir / "a" / "a_2.png")
train_images = [str(tmpdir / "a1.png"), str(tmpdir / "b1.png")]

_rand_image().save(tmpdir / "b" / "a_1.png")
_rand_image().save(tmpdir / "b" / "a_2.png")
_rand_image().save(train_images[0])
_rand_image().save(train_images[1])

class CustomBaseVisualization(BaseVisualization):

Expand Down Expand Up @@ -134,45 +131,60 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization:
return CustomBaseVisualization(*args, **kwargs)

dm = CustomImageClassificationData.from_filepaths(
train_filepaths=[tmpdir / "a", tmpdir / "b"],
train_filepaths=train_images,
train_labels=[0, 1],
val_filepaths=[tmpdir / "a", tmpdir / "b"],
val_labels=[0, 1],
test_filepaths=[tmpdir / "a", tmpdir / "b"],
test_labels=[0, 1],
predict_filepaths=[tmpdir / "a", tmpdir / "b"],
val_filepaths=train_images,
val_labels=[2, 3],
test_filepaths=train_images,
test_labels=[4, 5],
predict_filepaths=train_images,
batch_size=2,
num_workers=0,
)

for stage in _STAGES_PREFIX.values():

for _ in range(10):
getattr(dm, f"show_{stage}_batch")(reset=False)
fcn = getattr(dm, f"show_{stage}_batch")
fcn(reset=False)

is_predict = stage == "predict"

def extract_data(data):
def _extract_data(data):
if not is_predict:
return data[0][0]
return data[0]

assert isinstance(extract_data(dm.data_fetcher.batches[stage]["load_sample"]), Image.Image)
def _get_result(function_name: str):
return dm.data_fetcher.batches[stage][function_name]

res = _get_result("load_sample")
assert isinstance(_extract_data(res), Image.Image)

if not is_predict:
assert isinstance(dm.data_fetcher.batches[stage]["load_sample"][0][1], int)
res = _get_result("load_sample")
assert isinstance(res[0][1], torch.Tensor)

res = _get_result("to_tensor_transform")
assert isinstance(_extract_data(res), torch.Tensor)

assert isinstance(extract_data(dm.data_fetcher.batches[stage]["to_tensor_transform"]), torch.Tensor)
if not is_predict:
assert isinstance(dm.data_fetcher.batches[stage]["to_tensor_transform"][0][1], int)
res = _get_result("to_tensor_transform")
assert isinstance(res[0][1], torch.Tensor)

res = _get_result("collate")
assert _extract_data(res).shape == (2, 3, 196, 196)

assert extract_data(dm.data_fetcher.batches[stage]["collate"]).shape == torch.Size([2, 3, 196, 196])
if not is_predict:
assert dm.data_fetcher.batches[stage]["collate"][0][1].shape == torch.Size([2])
res = _get_result("collate")
assert res[0][1].shape == torch.Size([2])

res = _get_result("per_batch_transform")
assert _extract_data(res).shape == (2, 3, 196, 196)

generated = extract_data(dm.data_fetcher.batches[stage]["per_batch_transform"]).shape
assert generated == torch.Size([2, 3, 196, 196])
if not is_predict:
assert dm.data_fetcher.batches[stage]["per_batch_transform"][0][1].shape == torch.Size([2])
res = _get_result("per_batch_transform")
assert res[0][1].shape == (2, )

assert dm.data_fetcher.show_load_sample_called
assert dm.data_fetcher.show_pre_tensor_transform_called
Expand Down
41 changes: 41 additions & 0 deletions tests/vision/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,44 @@ def test_from_folders(tmpdir):
imgs, labels = data
assert imgs.shape == (1, 3, 196, 196)
assert labels.shape == (1, )


def test_from_filepaths_multilabel(tmpdir):
tmpdir = Path(tmpdir)

(tmpdir / "a").mkdir()
_rand_image().save(tmpdir / "a1.png")
_rand_image().save(tmpdir / "a2.png")

train_images = [str(tmpdir / "a1.png"), str(tmpdir / "a2.png")]
train_labels = [[1, 0, 1, 0], [0, 0, 1, 1]]
valid_labels = [[1, 1, 1, 0], [1, 0, 0, 1]]
test_labels = [[1, 0, 1, 0], [1, 1, 0, 1]]

dm = ImageClassificationData.from_filepaths(
train_filepaths=train_images,
train_labels=train_labels,
val_filepaths=train_images,
val_labels=valid_labels,
test_filepaths=train_images,
test_labels=test_labels,
batch_size=2,
num_workers=0,
)

data = next(iter(dm.train_dataloader()))
imgs, labels = data
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, 4)

data = next(iter(dm.val_dataloader()))
imgs, labels = data
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, 4)
torch.testing.assert_allclose(labels, torch.tensor(valid_labels))

data = next(iter(dm.test_dataloader()))
imgs, labels = data
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, 4)
torch.testing.assert_allclose(labels, torch.tensor(test_labels))