diff --git a/CHANGELOG.md b/CHANGELOG.md index 3845a894fc..e0982d3946 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `from_data_frame` to `TextClassificationData` ([#785](https://github.com/PyTorchLightning/lightning-flash/pull/785)) +- Added `FastFace` integration ([#606](https://github.com/PyTorchLightning/lightning-flash/pull/606)) + - Added support for `from_lists` to `TextClassificationData` ([#805](https://github.com/PyTorchLightning/lightning-flash/pull/805)) ### Changed diff --git a/flash/core/data/utils.py b/flash/core/data/utils.py index 3779b7426e..c1bc18f698 100644 --- a/flash/core/data/utils.py +++ b/flash/core/data/utils.py @@ -13,6 +13,7 @@ # limitations under the License. import os.path +import tarfile import zipfile from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Set, Type @@ -148,10 +149,23 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None: ): fp.write(chunk) # type: ignore + def extract_tarfile(file_path: str, extract_path: str, mode: str): + if os.path.exists(file_path): + with tarfile.open(file_path, mode=mode) as tar_ref: + for member in tar_ref.getmembers(): + try: + tar_ref.extract(member, path=extract_path, set_attrs=False) + except PermissionError: + raise PermissionError(f"Could not extract tar file {file_path}") + if ".zip" in local_filename: if os.path.exists(local_filename): with zipfile.ZipFile(local_filename, "r") as zip_ref: zip_ref.extractall(path) + elif local_filename.endswith(".tar.gz") or local_filename.endswith(".tgz"): + extract_tarfile(local_filename, path, "r:gz") + elif local_filename.endswith(".tar.bz2") or local_filename.endswith(".tbz"): + extract_tarfile(local_filename, path, "r:bz2") def _contains_any_tensor(value: Any, dtype: Type = Tensor) -> bool: diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index f138eaf37e..95a6272072 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -87,6 +87,7 @@ def _compare_version(package: str, op, version) -> bool: _PIL_AVAILABLE = _module_available("PIL") _OPEN3D_AVAILABLE = _module_available("open3d") _SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch") +_FASTFACE_AVAILABLE = _module_available("fastface") _LIBROSA_AVAILABLE = _module_available("librosa") _TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter") _TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse") diff --git a/flash/image/__init__.py b/flash/image/__init__.py index b3ac7f10b6..788a15ca40 100644 --- a/flash/image/__init__.py +++ b/flash/image/__init__.py @@ -6,6 +6,7 @@ from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES # noqa: F401 from flash.image.detection import ObjectDetectionData, ObjectDetector # noqa: F401 from flash.image.embedding import ImageEmbedder # noqa: F401 +from flash.image.face_detection import FaceDetectionData, FaceDetector # noqa: F401 from flash.image.instance_segmentation import InstanceSegmentation, InstanceSegmentationData # noqa: F401 from flash.image.keypoint_detection import KeypointDetectionData, KeypointDetector # noqa: F401 from flash.image.segmentation import ( # noqa: F401 diff --git a/flash/image/face_detection/__init__.py b/flash/image/face_detection/__init__.py new file mode 100644 index 0000000000..b8a100f085 --- /dev/null +++ b/flash/image/face_detection/__init__.py @@ -0,0 +1,2 @@ +from flash.image.face_detection.data import FaceDetectionData # noqa: F401 +from flash.image.face_detection.model import FaceDetector # noqa: F401 diff --git a/flash/image/face_detection/backbones/__init__.py b/flash/image/face_detection/backbones/__init__.py new file mode 100644 index 0000000000..fdf1185754 --- /dev/null +++ b/flash/image/face_detection/backbones/__init__.py @@ -0,0 +1,5 @@ +from flash.core.registry import FlashRegistry # noqa: F401 +from flash.image.face_detection.backbones.fastface_backbones import register_ff_backbones # noqa: F401 + +FACE_DETECTION_BACKBONES = FlashRegistry("face_detection_backbones") +register_ff_backbones(FACE_DETECTION_BACKBONES) diff --git a/flash/image/face_detection/backbones/fastface_backbones.py b/flash/image/face_detection/backbones/fastface_backbones.py new file mode 100644 index 0000000000..a8829d516b --- /dev/null +++ b/flash/image/face_detection/backbones/fastface_backbones.py @@ -0,0 +1,44 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _FASTFACE_AVAILABLE + +if _FASTFACE_AVAILABLE: + import fastface as ff + + _MODEL_NAMES = ff.list_pretrained_models() +else: + _MODEL_NAMES = [] + + +def fastface_backbone(model_name: str, pretrained: bool, **kwargs): + if pretrained: + pl_model = ff.FaceDetector.from_pretrained(model_name, **kwargs) + else: + arch, config = model_name.split("_") + pl_model = ff.FaceDetector.build(arch, config, **kwargs) + + backbone = getattr(pl_model, "arch") + + return backbone, pl_model + + +def register_ff_backbones(register: FlashRegistry): + if _FASTFACE_AVAILABLE: + backbones = [partial(fastface_backbone, model_name=name) for name in _MODEL_NAMES] + + for idx, backbone in enumerate(backbones): + register(backbone, name=_MODEL_NAMES[idx]) diff --git a/flash/image/face_detection/data.py b/flash/image/face_detection/data.py new file mode 100644 index 0000000000..f926c24538 --- /dev/null +++ b/flash/image/face_detection/data.py @@ -0,0 +1,172 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +from torch.utils.data import Dataset + +from flash.core.data.data_source import DatasetDataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.process import Postprocess, Preprocess +from flash.core.data.transforms import ApplyToKeys +from flash.core.utilities.imports import _FASTFACE_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.image.data import ImagePathsDataSource +from flash.image.detection import ObjectDetectionData + +if _TORCHVISION_AVAILABLE: + import torchvision + from torchvision.datasets.folder import default_loader + +if _FASTFACE_AVAILABLE: + import fastface as ff + + +def fastface_collate_fn(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence[Any]]: + """Collate function from fastface. + + Organizes individual elements in a batch, calls prepare_batch from fastface and prepares the targets. + """ + samples = {key: [sample[key] for sample in samples] for key in samples[0]} + + images, scales, paddings = ff.utils.preprocess.prepare_batch( + samples[DefaultDataKeys.INPUT], None, adaptive_batch=True + ) + + samples["scales"] = scales + samples["paddings"] = paddings + + if DefaultDataKeys.TARGET in samples.keys(): + targets = samples[DefaultDataKeys.TARGET] + targets = [{"target_boxes": target["boxes"]} for target in targets] + + for i, (target, scale, padding) in enumerate(zip(targets, scales, paddings)): + target["target_boxes"] *= scale + target["target_boxes"][:, [0, 2]] += padding[0] + target["target_boxes"][:, [1, 3]] += padding[1] + targets[i]["target_boxes"] = target["target_boxes"] + + samples[DefaultDataKeys.TARGET] = targets + samples[DefaultDataKeys.INPUT] = images + + return samples + + +class FastFaceDataSource(DatasetDataSource): + """Logic for loading from FDDBDataset.""" + + def load_data(self, data: Dataset, dataset: Any = None) -> Dataset: + new_data = [] + for img_file_path, targets in zip(data.ids, data.targets): + new_data.append( + super().load_sample( + ( + img_file_path, + dict( + boxes=targets["target_boxes"], + # label `1` indicates positive sample + labels=[1 for _ in range(targets["target_boxes"].shape[0])], + ), + ) + ) + ) + + return new_data + + def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str, Any]: + filepath = sample[DefaultDataKeys.INPUT] + img = default_loader(filepath) + sample[DefaultDataKeys.INPUT] = img + + w, h = img.size # WxH + sample[DefaultDataKeys.METADATA] = { + "filepath": filepath, + "size": (h, w), + } + + return sample + + +class FaceDetectionPreprocess(Preprocess): + """Applies default transform and collate_fn for fastface on FastFaceDataSource.""" + + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (128, 128), + ): + self.image_size = image_size + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.FILES: ImagePathsDataSource(), + DefaultDataSources.FOLDERS: ImagePathsDataSource(), + DefaultDataSources.DATASETS: FastFaceDataSource(), + }, + default_data_source=DefaultDataSources.FILES, + ) + + def get_state_dict(self) -> Dict[str, Any]: + return {**self.transforms} + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) + + def default_transforms(self) -> Dict[str, Callable]: + return { + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys( + DefaultDataKeys.TARGET, + nn.Sequential( + ApplyToKeys("boxes", torch.as_tensor), + ApplyToKeys("labels", torch.as_tensor), + ), + ), + ), + "collate": fastface_collate_fn, + } + + +class FaceDetectionPostProcess(Postprocess): + """Generates preds from model output.""" + + @staticmethod + def per_batch_transform(batch: Any) -> Any: + scales = batch["scales"] + paddings = batch["paddings"] + + batch.pop("scales", None) + batch.pop("paddings", None) + + preds = batch[DefaultDataKeys.PREDS] + + # preds: list of torch.Tensor(N, 5) as x1, y1, x2, y2, score + preds = [preds[preds[:, 5] == batch_idx, :5] for batch_idx in range(len(preds))] + preds = ff.utils.preprocess.adjust_results(preds, scales, paddings) + batch[DefaultDataKeys.PREDS] = preds + + return batch + + +class FaceDetectionData(ObjectDetectionData): + preprocess_cls = FaceDetectionPreprocess + postprocess_cls = FaceDetectionPostProcess diff --git a/flash/image/face_detection/model.py b/flash/image/face_detection/model.py new file mode 100644 index 0000000000..5d7c6e0445 --- /dev/null +++ b/flash/image/face_detection/model.py @@ -0,0 +1,187 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union + +import pytorch_lightning as pl +import torch +from torch import nn +from torch.optim import Optimizer + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.process import Preprocess, Serializer +from flash.core.finetuning import FlashBaseFinetuning +from flash.core.model import Task +from flash.core.utilities.imports import _FASTFACE_AVAILABLE +from flash.image.face_detection.backbones import FACE_DETECTION_BACKBONES +from flash.image.face_detection.data import FaceDetectionPreprocess + +if _FASTFACE_AVAILABLE: + import fastface as ff + + +class FaceDetectionFineTuning(FlashBaseFinetuning): + def __init__(self, train_bn: bool = True) -> None: + super().__init__(train_bn=train_bn) + + def freeze_before_training(self, pl_module: pl.LightningModule) -> None: + self.freeze(modules=pl_module.model.backbone, train_bn=self.train_bn) + + +class DetectionLabels(Serializer): + """A :class:`.Serializer` which extracts predictions from sample dict.""" + + def serialize(self, sample: Any) -> Dict[str, Any]: + return sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample + + +class FaceDetector(Task): + """The ``FaceDetector`` is a :class:`~flash.Task` for detecting faces in images. + + For more details, see + :ref:`face_detection`. + Args: + model: a string of :attr`_models`. Defaults to 'lffd_slim'. + pretrained: Whether the model from fastface should be loaded with it's pretrained weights. + loss: the function(s) to update the model with. Has no effect for fastface models. + metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. + Changing this argument currently has no effect. + optimizer: The optimizer to use for training. Can either be the actual class or the class name. + learning_rate: The learning rate to use for training + """ + + required_extras: str = "image" + + def __init__( + self, + model: str = "lffd_slim", + pretrained: bool = True, + loss=None, + metrics: Union[Callable, nn.Module, Mapping, Sequence, None] = None, + optimizer: Type[Optimizer] = torch.optim.AdamW, + learning_rate: float = 1e-4, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + preprocess: Optional[Preprocess] = None, + **kwargs: Any, + ): + self.save_hyperparameters() + + if model in ff.list_pretrained_models(): + model = FaceDetector.get_model(model, pretrained, **kwargs) + else: + ValueError(model + f" is not supported yet, please select one from {ff.list_pretrained_models()}") + + super().__init__( + model=model, + loss_fn=loss, + metrics=metrics or {"AP": ff.metric.AveragePrecision()}, # TODO: replace with torch metrics MAP + learning_rate=learning_rate, + optimizer=optimizer, + serializer=serializer or DetectionLabels(), + preprocess=preprocess or FaceDetectionPreprocess(), + ) + + @staticmethod + def get_model( + model_name: str, + pretrained: bool, + **kwargs, + ): + model, pl_model = FACE_DETECTION_BACKBONES.get(model_name)(pretrained=pretrained, **kwargs) + + # following steps are required since `get_model` needs to return `torch.nn.Module` + # moving some required parameters from `fastface.FaceDetector` to `torch.nn.Module` + # set preprocess params + model.register_buffer("normalizer", getattr(pl_model, "normalizer")) + model.register_buffer("mean", getattr(pl_model, "mean")) + model.register_buffer("std", getattr(pl_model, "std")) + + # copy pasting `_postprocess` function from `fastface.FaceDetector` to `torch.nn.Module` + # set postprocess function + # this is called from FaceDetector lightning module form fastface itself + # https://github.com/borhanMorphy/fastface/blob/master/fastface/module.py#L200 + setattr(model, "_postprocess", getattr(pl_model, "_postprocess")) + + return model + + def forward(self, x: List[torch.Tensor]) -> Any: + images = self._prepare_batch(x) + logits = self.model(images) + + # preds: torch.Tensor(B, N, 5) + # preds: torch.Tensor(N, 6) as x1,y1,x2,y2,score,batch_idx + preds = self.model.logits_to_preds(logits) + preds = self.model._postprocess(preds) + + return preds + + def _prepare_batch(self, batch): + batch = (((batch * 255) / self.model.normalizer) - self.model.mean) / self.model.std + return batch + + def _compute_metrics(self, logits, targets): + # preds: torch.Tensor(B, N, 5) + preds = self.model.logits_to_preds(logits) + + # preds: torch.Tensor(N, 6) as x1,y1,x2,y2,score,batch_idx + preds = self.model._postprocess(preds) + + target_boxes = [target["target_boxes"] for target in targets] + pred_boxes = [preds[preds[:, 5] == batch_idx, :5] for batch_idx in range(len(targets))] + + for metric in self.val_metrics.values(): + metric.update(pred_boxes, target_boxes) + + def __shared_step(self, batch, train=False) -> Any: + images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] + images = self._prepare_batch(images) + logits = self.model(images) + loss = self.model.compute_loss(logits, targets) + + self._compute_metrics(logits, targets) + + return loss + + def training_step(self, batch, batch_idx) -> Any: + loss = self.__shared_step(batch) + + self.log_dict({f"train_{k}": v for k, v in loss.items()}, on_step=True, on_epoch=True, prog_bar=True) + return loss + + def validation_step(self, batch, batch_idx): + loss = self.__shared_step(batch) + + self.log_dict({f"val_{k}": v for k, v in loss.items()}, on_step=True, on_epoch=True, prog_bar=True) + return loss + + def validation_epoch_end(self, outputs) -> None: + metric_results = {name: metric.compute() for name, metric in self.val_metrics.items()} + self.log_dict({f"val_{k}": v for k, v in metric_results.items()}, on_epoch=True) + + def test_step(self, batch, batch_idx): + loss = self.__shared_step(batch) + + self.log_dict({f"test_{k}": v for k, v in loss.items()}, on_step=True, on_epoch=True, prog_bar=True) + return loss + + def test_epoch_end(self, outputs) -> None: + metric_results = {name: metric.compute() for name, metric in self.val_metrics.items()} + self.log_dict({f"test_{k}": v for k, v in metric_results.items()}, on_epoch=True) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + images = batch[DefaultDataKeys.INPUT] + batch[DefaultDataKeys.PREDS] = self(images) + return batch + + def configure_finetune_callback(self): + return [FaceDetectionFineTuning()] diff --git a/flash_examples/face_detection.py b/flash_examples/face_detection.py new file mode 100644 index 0000000000..9762cadb01 --- /dev/null +++ b/flash_examples/face_detection.py @@ -0,0 +1,47 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import flash +from flash.core.utilities.imports import example_requires +from flash.image import FaceDetectionData, FaceDetector + +example_requires("fastface") +import fastface as ff # noqa: E402 + +# # 1. Create the DataModule +train_dataset = ff.dataset.FDDBDataset(source_dir="data/", phase="train") +val_dataset = ff.dataset.FDDBDataset(source_dir="data/", phase="val") + +datamodule = FaceDetectionData.from_datasets(train_dataset=train_dataset, val_dataset=val_dataset, batch_size=2) + +# # 2. Build the task +model = FaceDetector(model="lffd_slim") + +# # 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) +trainer.finetune(model, datamodule=datamodule, strategy="freeze") + +# 4. Detect faces in a few images! +predictions = model.predict( + [ + "data/2002/07/19/big/img_18.jpg", + "data/2002/07/19/big/img_65.jpg", + "data/2002/07/19/big/img_255.jpg", + ] +) +print(predictions) + +# # 5. Save the model! +trainer.save_checkpoint("face_detection_model.pt") diff --git a/requirements/datatype_image_extras.txt b/requirements/datatype_image_extras.txt index 7704ebb7a2..8e976d2ade 100644 --- a/requirements/datatype_image_extras.txt +++ b/requirements/datatype_image_extras.txt @@ -8,3 +8,4 @@ effdet albumentations learn2learn baal +fastface diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py index 49d24bf7ab..61666b70ca 100644 --- a/tests/core/test_utils.py +++ b/tests/core/test_utils.py @@ -13,6 +13,8 @@ # limitations under the License. import os +import pytest + from flash.core.data.utils import download_data from flash.core.utilities.apply_func import get_callable_dict, get_callable_name @@ -50,7 +52,10 @@ def test_get_callable_dict(): assert d["two"] == b -def test_download_data(tmpdir): +@pytest.mark.parametrize("file", ["titanic.zip", "titanic.tar.gz", "titanic.tar.bz2"]) +def test_download_data(tmpdir, file): + download_path = "https://pl-flash-data.s3.amazonaws.com/" path = os.path.join(tmpdir, "data") - download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", path) - assert set(os.listdir(path)) == {"titanic", "titanic.zip"} + download_data(download_path + file, path) + assert "titanic" in set(os.listdir(path)) + assert file in set(os.listdir(path)) diff --git a/tests/image/face_detection/__init__.py b/tests/image/face_detection/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/image/face_detection/test_model.py b/tests/image/face_detection/test_model.py new file mode 100644 index 0000000000..05228c6586 --- /dev/null +++ b/tests/image/face_detection/test_model.py @@ -0,0 +1,60 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +import flash +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _FASTFACE_AVAILABLE +from flash.image import FaceDetectionData, FaceDetector + +if _FASTFACE_AVAILABLE: + import fastface as ff + from fastface.arch.lffd import LFFD + + from flash.image.face_detection.backbones import FACE_DETECTION_BACKBONES +else: + FACE_DETECTION_BACKBONES = FlashRegistry("face_detection_backbones") + LFFD = object + + +@pytest.mark.skipif(not _FASTFACE_AVAILABLE, reason="fastface not installed.") +def test_fastface_training(): + dataset = ff.dataset.FDDBDataset(source_dir="data/", phase="val") + datamodule = FaceDetectionData.from_datasets(train_dataset=dataset, batch_size=2) + + model = FaceDetector(model="lffd_slim") + + # test fit + trainer = flash.Trainer(max_steps=2, num_sanity_val_steps=0) + trainer.finetune(model, datamodule=datamodule, strategy="freeze") + + +@pytest.mark.skipif(not _FASTFACE_AVAILABLE, reason="fastface not installed.") +def test_fastface_forward(): + model = FaceDetector(model="lffd_slim") + mock_batch = torch.randn(2, 3, 256, 256) + + # test model forward (tests: _prepare_batch, logits_to_preds, _postprocess from ff) + model(mock_batch) + + +@pytest.mark.skipif(not _FASTFACE_AVAILABLE, reason="fastface not installed.") +def test_fastface_backbones_registry(): + backbones = FACE_DETECTION_BACKBONES.available_keys() + assert "lffd_slim" in backbones + assert "lffd_original" in backbones + + backbone, _ = FACE_DETECTION_BACKBONES.get("lffd_original")(pretrained=False) + assert isinstance(backbone, LFFD)