From 538400370f720fcc82d897f112585a3ba5704a8c Mon Sep 17 00:00:00 2001 From: ananyahjha93 Date: Thu, 30 Sep 2021 10:38:34 -0400 Subject: [PATCH] . --- tests/image/face_detection/test_model.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/tests/image/face_detection/test_model.py b/tests/image/face_detection/test_model.py index e048cd353b..29f7206ac5 100644 --- a/tests/image/face_detection/test_model.py +++ b/tests/image/face_detection/test_model.py @@ -13,7 +13,9 @@ # limitations under the License. import pytest +import torch import flash + from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _FASTFACE_AVAILABLE from flash.image import FaceDetectionData, FaceDetector @@ -29,22 +31,31 @@ @pytest.mark.skipif(not _FASTFACE_AVAILABLE, reason="fastface not installed.") -@pytest.mark.parametrize("model_name", ["lffd_slim", "lffd_original"]) -def test_fastface_training(tmpdir, model_name): +def test_fastface_training(): dataset = ff.dataset.FDDBDataset(source_dir="data/", phase="val") - datamodule = FaceDetectionData.from_datasets(train_dataset=dataset, batch_size=1) + datamodule = FaceDetectionData.from_datasets(train_dataset=dataset, batch_size=2) - model = FaceDetector(model=model_name) + model = FaceDetector(model="lffd_slim") - trainer = flash.Trainer(max_steps=1, num_sanity_val_steps=0) + # test fit + trainer = flash.Trainer(max_steps=2, num_sanity_val_steps=0) trainer.finetune(model, datamodule=datamodule, strategy="freeze") +@pytest.mark.skipif(not _FASTFACE_AVAILABLE, reason="fastface not installed.") +def test_fastface_forward(): + model = FaceDetector(model="lffd_slim") + mock_batch = torch.randn(2, 3, 256, 256) + + # test model forward (tests: _prepare_batch, logits_to_preds, _postprocess from ff) + model(mock_batch) + + @pytest.mark.skipif(not _FASTFACE_AVAILABLE, reason="fastface not installed.") def test_fastface_backbones_registry(): backbones = FACE_DETECTION_BACKBONES.available_keys() assert "lffd_slim" in backbones assert "lffd_original" in backbones - backbone, _ = FACE_DETECTION_BACKBONES.get("lffd_original")(pretrained=True) + backbone, _ = FACE_DETECTION_BACKBONES.get("lffd_original")(pretrained=False) assert isinstance(backbone, LFFD)