Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Dec 2, 2021
1 parent fa8688c commit a38ea32
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions tests/image/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest
import torch
import torch.nn as nn
from pytorch_lightning import seed_everything

from flash.core.data.io.input import DataKeys
from flash.core.data.transforms import ApplyToKeys, merge_transforms
Expand Down Expand Up @@ -274,6 +275,7 @@ def test_from_folders_only_train(tmpdir):

@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_from_folders_train_val(tmpdir):
seed_everything(42)

train_dir = Path(tmpdir / "train")
train_dir.mkdir()
Expand All @@ -297,6 +299,7 @@ def test_from_folders_train_val(tmpdir):
imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2,)
assert list(labels.numpy()) == [0, 1]

data = next(iter(img_data.val_dataloader()))
imgs, labels = data["input"], data["target"]
Expand Down

0 comments on commit a38ea32

Please sign in to comment.