From 95c5562882e76a587e916b8fc458d35e4a257818 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Wed, 21 Apr 2021 13:28:40 +0200 Subject: [PATCH] code cleanup and add fixes to multilabel tests --- flash/vision/classification/data.py | 18 +++++--- tests/data/test_callbacks.py | 58 ++++++++++++++---------- tests/vision/classification/test_data.py | 10 ++-- 3 files changed, 52 insertions(+), 34 deletions(-) diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index ea0bf203ae..db912b4aa6 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -110,11 +110,13 @@ def load_data(cls, data: Any, dataset: Optional[AutoDataset] = None) -> Iterable return cls._load_data_files_labels(data=data, dataset=dataset) @staticmethod - def load_sample(sample) -> Union[Image.Image, Tuple[Image.Image, torch.Tensor]]: + def load_sample(sample) -> Union[Image.Image, torch.Tensor, Tuple[Image.Image, torch.Tensor]]: # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) if isinstance(sample, torch.Tensor): - return sample + out: torch.Tensor = sample + return out + path: str = "" if isinstance(sample, (tuple, list)): path = sample[0] sample = list(sample) @@ -122,14 +124,16 @@ def load_sample(sample) -> Union[Image.Image, Tuple[Image.Image, torch.Tensor]]: path = sample with open(path, "rb") as f, Image.open(f) as img: - img = img.convert("RGB") + img_out: Image.Image = img.convert("RGB") if isinstance(sample, list): - sample[0] = img - sample[1] = torch.tensor(sample[1]) - return sample + # return a tuple with the PIL image and tensor with the labels. + # returning the tensor helps later to easily collate the batch + # for single/multi label at the same time. + out: Tuple[Image.Image, torch.Tensor] = (img_out, torch.as_tensor(sample[1])) + return out - return img + return img_out @classmethod def predict_load_data(cls, samples: Any) -> Iterable: diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py index 43988ff2d6..df1bba9939 100644 --- a/tests/data/test_callbacks.py +++ b/tests/data/test_callbacks.py @@ -84,13 +84,10 @@ def test_base_viz(tmpdir): seed_everything(42) tmpdir = Path(tmpdir) - (tmpdir / "a").mkdir() - (tmpdir / "b").mkdir() - _rand_image().save(tmpdir / "a" / "a_1.png") - _rand_image().save(tmpdir / "a" / "a_2.png") + train_images = [str(tmpdir / "a1.png"), str(tmpdir / "b1.png")] - _rand_image().save(tmpdir / "b" / "a_1.png") - _rand_image().save(tmpdir / "b" / "a_2.png") + _rand_image().save(train_images[0]) + _rand_image().save(train_images[1]) class CustomBaseVisualization(BaseVisualization): @@ -134,13 +131,13 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: return CustomBaseVisualization(*args, **kwargs) dm = CustomImageClassificationData.from_filepaths( - train_filepaths=[tmpdir / "a", tmpdir / "b"], + train_filepaths=train_images, train_labels=[0, 1], - val_filepaths=[tmpdir / "a", tmpdir / "b"], - val_labels=[0, 1], - test_filepaths=[tmpdir / "a", tmpdir / "b"], - test_labels=[0, 1], - predict_filepaths=[tmpdir / "a", tmpdir / "b"], + val_filepaths=train_images, + val_labels=[2, 3], + test_filepaths=train_images, + test_labels=[4, 5], + predict_filepaths=train_images, batch_size=2, num_workers=0, ) @@ -148,31 +145,46 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: for stage in _STAGES_PREFIX.values(): for _ in range(10): - getattr(dm, f"show_{stage}_batch")(reset=False) + fcn = getattr(dm, f"show_{stage}_batch") + fcn(reset=False) is_predict = stage == "predict" - def extract_data(data): + def _extract_data(data): if not is_predict: return data[0][0] return data[0] - assert isinstance(extract_data(dm.data_fetcher.batches[stage]["load_sample"]), Image.Image) + def _get_result(function_name: str): + return dm.data_fetcher.batches[stage][function_name] + + res = _get_result("load_sample") + assert isinstance(_extract_data(res), Image.Image) + if not is_predict: - assert isinstance(dm.data_fetcher.batches[stage]["load_sample"][0][1], int) + res = _get_result("load_sample") + assert isinstance(res[0][1], torch.Tensor) + + res = _get_result("to_tensor_transform") + assert isinstance(_extract_data(res), torch.Tensor) - assert isinstance(extract_data(dm.data_fetcher.batches[stage]["to_tensor_transform"]), torch.Tensor) if not is_predict: - assert isinstance(dm.data_fetcher.batches[stage]["to_tensor_transform"][0][1], int) + res = _get_result("to_tensor_transform") + assert isinstance(res[0][1], torch.Tensor) + + res = _get_result("collate") + assert _extract_data(res).shape == (2, 3, 196, 196) - assert extract_data(dm.data_fetcher.batches[stage]["collate"]).shape == torch.Size([2, 3, 196, 196]) if not is_predict: - assert dm.data_fetcher.batches[stage]["collate"][0][1].shape == torch.Size([2]) + res = _get_result("collate") + assert res[0][1].shape == torch.Size([2]) + + res = _get_result("per_batch_transform") + assert _extract_data(res).shape == (2, 3, 196, 196) - generated = extract_data(dm.data_fetcher.batches[stage]["per_batch_transform"]).shape - assert generated == torch.Size([2, 3, 196, 196]) if not is_predict: - assert dm.data_fetcher.batches[stage]["per_batch_transform"][0][1].shape == torch.Size([2]) + res = _get_result("per_batch_transform") + assert res[0][1].shape == (2, ) assert dm.data_fetcher.show_load_sample_called assert dm.data_fetcher.show_pre_tensor_transform_called diff --git a/tests/vision/classification/test_data.py b/tests/vision/classification/test_data.py index bf3bc40476..24d30bfd8a 100644 --- a/tests/vision/classification/test_data.py +++ b/tests/vision/classification/test_data.py @@ -215,14 +215,16 @@ def test_from_filepaths_multilabel(tmpdir): train_images = [str(tmpdir / "a1.png"), str(tmpdir / "a2.png")] train_labels = [[1, 0, 1, 0], [0, 0, 1, 1]] + valid_labels = [[1, 1, 1, 0], [1, 0, 0, 1]] + test_labels = [[1, 0, 1, 0], [1, 1, 0, 1]] dm = ImageClassificationData.from_filepaths( train_filepaths=train_images, train_labels=train_labels, val_filepaths=train_images, - val_labels=train_labels, + val_labels=valid_labels, test_filepaths=train_images, - test_labels=train_labels, + test_labels=test_labels, batch_size=2, num_workers=0, ) @@ -236,10 +238,10 @@ def test_from_filepaths_multilabel(tmpdir): imgs, labels = data assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) - torch.testing.assert_allclose(labels, torch.tensor(train_labels)) + torch.testing.assert_allclose(labels, torch.tensor(valid_labels)) data = next(iter(dm.test_dataloader())) imgs, labels = data assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 4) - torch.testing.assert_allclose(labels, torch.tensor(train_labels)) + torch.testing.assert_allclose(labels, torch.tensor(test_labels))