From a38ea321c7c94bcae4c5703b40b3cfaee345b7ef Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 2 Dec 2021 13:24:57 +0100 Subject: [PATCH] update --- tests/image/classification/test_data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index 3670e586bd..9328908c91 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -19,6 +19,7 @@ import pytest import torch import torch.nn as nn +from pytorch_lightning import seed_everything from flash.core.data.io.input import DataKeys from flash.core.data.transforms import ApplyToKeys, merge_transforms @@ -274,6 +275,7 @@ def test_from_folders_only_train(tmpdir): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_from_folders_train_val(tmpdir): + seed_everything(42) train_dir = Path(tmpdir / "train") train_dir.mkdir() @@ -297,6 +299,7 @@ def test_from_folders_train_val(tmpdir): imgs, labels = data["input"], data["target"] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2,) + assert list(labels.numpy()) == [0, 1] data = next(iter(img_data.val_dataloader())) imgs, labels = data["input"], data["target"]