Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
code cleanup and add fixes to multilabel tests
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarriba committed Apr 21, 2021
1 parent d529b78 commit 95c5562
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 34 deletions.
18 changes: 11 additions & 7 deletions flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,26 +110,30 @@ def load_data(cls, data: Any, dataset: Optional[AutoDataset] = None) -> Iterable
return cls._load_data_files_labels(data=data, dataset=dataset)

@staticmethod
def load_sample(sample) -> Union[Image.Image, Tuple[Image.Image, torch.Tensor]]:
def load_sample(sample) -> Union[Image.Image, torch.Tensor, Tuple[Image.Image, torch.Tensor]]:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
if isinstance(sample, torch.Tensor):
return sample
out: torch.Tensor = sample
return out

path: str = ""
if isinstance(sample, (tuple, list)):
path = sample[0]
sample = list(sample)
else:
path = sample

with open(path, "rb") as f, Image.open(f) as img:
img = img.convert("RGB")
img_out: Image.Image = img.convert("RGB")

if isinstance(sample, list):
sample[0] = img
sample[1] = torch.tensor(sample[1])
return sample
# return a tuple with the PIL image and tensor with the labels.
# returning the tensor helps later to easily collate the batch
# for single/multi label at the same time.
out: Tuple[Image.Image, torch.Tensor] = (img_out, torch.as_tensor(sample[1]))
return out

return img
return img_out

@classmethod
def predict_load_data(cls, samples: Any) -> Iterable:
Expand Down
58 changes: 35 additions & 23 deletions tests/data/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,10 @@ def test_base_viz(tmpdir):
seed_everything(42)
tmpdir = Path(tmpdir)

(tmpdir / "a").mkdir()
(tmpdir / "b").mkdir()
_rand_image().save(tmpdir / "a" / "a_1.png")
_rand_image().save(tmpdir / "a" / "a_2.png")
train_images = [str(tmpdir / "a1.png"), str(tmpdir / "b1.png")]

_rand_image().save(tmpdir / "b" / "a_1.png")
_rand_image().save(tmpdir / "b" / "a_2.png")
_rand_image().save(train_images[0])
_rand_image().save(train_images[1])

class CustomBaseVisualization(BaseVisualization):

Expand Down Expand Up @@ -134,45 +131,60 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization:
return CustomBaseVisualization(*args, **kwargs)

dm = CustomImageClassificationData.from_filepaths(
train_filepaths=[tmpdir / "a", tmpdir / "b"],
train_filepaths=train_images,
train_labels=[0, 1],
val_filepaths=[tmpdir / "a", tmpdir / "b"],
val_labels=[0, 1],
test_filepaths=[tmpdir / "a", tmpdir / "b"],
test_labels=[0, 1],
predict_filepaths=[tmpdir / "a", tmpdir / "b"],
val_filepaths=train_images,
val_labels=[2, 3],
test_filepaths=train_images,
test_labels=[4, 5],
predict_filepaths=train_images,
batch_size=2,
num_workers=0,
)

for stage in _STAGES_PREFIX.values():

for _ in range(10):
getattr(dm, f"show_{stage}_batch")(reset=False)
fcn = getattr(dm, f"show_{stage}_batch")
fcn(reset=False)

is_predict = stage == "predict"

def extract_data(data):
def _extract_data(data):
if not is_predict:
return data[0][0]
return data[0]

assert isinstance(extract_data(dm.data_fetcher.batches[stage]["load_sample"]), Image.Image)
def _get_result(function_name: str):
return dm.data_fetcher.batches[stage][function_name]

res = _get_result("load_sample")
assert isinstance(_extract_data(res), Image.Image)

if not is_predict:
assert isinstance(dm.data_fetcher.batches[stage]["load_sample"][0][1], int)
res = _get_result("load_sample")
assert isinstance(res[0][1], torch.Tensor)

res = _get_result("to_tensor_transform")
assert isinstance(_extract_data(res), torch.Tensor)

assert isinstance(extract_data(dm.data_fetcher.batches[stage]["to_tensor_transform"]), torch.Tensor)
if not is_predict:
assert isinstance(dm.data_fetcher.batches[stage]["to_tensor_transform"][0][1], int)
res = _get_result("to_tensor_transform")
assert isinstance(res[0][1], torch.Tensor)

res = _get_result("collate")
assert _extract_data(res).shape == (2, 3, 196, 196)

assert extract_data(dm.data_fetcher.batches[stage]["collate"]).shape == torch.Size([2, 3, 196, 196])
if not is_predict:
assert dm.data_fetcher.batches[stage]["collate"][0][1].shape == torch.Size([2])
res = _get_result("collate")
assert res[0][1].shape == torch.Size([2])

res = _get_result("per_batch_transform")
assert _extract_data(res).shape == (2, 3, 196, 196)

generated = extract_data(dm.data_fetcher.batches[stage]["per_batch_transform"]).shape
assert generated == torch.Size([2, 3, 196, 196])
if not is_predict:
assert dm.data_fetcher.batches[stage]["per_batch_transform"][0][1].shape == torch.Size([2])
res = _get_result("per_batch_transform")
assert res[0][1].shape == (2, )

assert dm.data_fetcher.show_load_sample_called
assert dm.data_fetcher.show_pre_tensor_transform_called
Expand Down
10 changes: 6 additions & 4 deletions tests/vision/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,16 @@ def test_from_filepaths_multilabel(tmpdir):

train_images = [str(tmpdir / "a1.png"), str(tmpdir / "a2.png")]
train_labels = [[1, 0, 1, 0], [0, 0, 1, 1]]
valid_labels = [[1, 1, 1, 0], [1, 0, 0, 1]]
test_labels = [[1, 0, 1, 0], [1, 1, 0, 1]]

dm = ImageClassificationData.from_filepaths(
train_filepaths=train_images,
train_labels=train_labels,
val_filepaths=train_images,
val_labels=train_labels,
val_labels=valid_labels,
test_filepaths=train_images,
test_labels=train_labels,
test_labels=test_labels,
batch_size=2,
num_workers=0,
)
Expand All @@ -236,10 +238,10 @@ def test_from_filepaths_multilabel(tmpdir):
imgs, labels = data
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, 4)
torch.testing.assert_allclose(labels, torch.tensor(train_labels))
torch.testing.assert_allclose(labels, torch.tensor(valid_labels))

data = next(iter(dm.test_dataloader()))
imgs, labels = data
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, 4)
torch.testing.assert_allclose(labels, torch.tensor(train_labels))
torch.testing.assert_allclose(labels, torch.tensor(test_labels))

0 comments on commit 95c5562

Please sign in to comment.