Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Face Detection Task (task-a-thon) #606

Merged
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
56bce9c
.
ananyahjha93 Sep 23, 2021
2b22d87
merging taskathon PR code
ananyahjha93 Sep 23, 2021
6c624ad
working
ananyahjha93 Sep 24, 2021
13f0c27
pep8
ananyahjha93 Sep 24, 2021
845ec6f
Merge branch 'master' into feature/face_detection
ananyahjha93 Sep 24, 2021
29384d0
imports
ananyahjha93 Sep 24, 2021
39cdf6c
erge branch 'feature/face_detection' of https://github.com/borhanMorp…
ananyahjha93 Sep 24, 2021
e99adad
backbones registry
ananyahjha93 Sep 24, 2021
fc7ff44
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2021
159bf03
tests
ananyahjha93 Sep 24, 2021
1175b41
tests
ananyahjha93 Sep 24, 2021
ebfafcd
Merge branch 'feature/face_detection' of https://github.com/borhanMor…
ananyahjha93 Sep 24, 2021
021b12e
more coverage
ananyahjha93 Sep 24, 2021
fdff774
final
ananyahjha93 Sep 24, 2021
9779567
.
ananyahjha93 Sep 27, 2021
b4dbb96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2021
b7af234
.
ananyahjha93 Sep 27, 2021
5fdc2a7
.
ananyahjha93 Sep 27, 2021
70514f0
.
ananyahjha93 Sep 27, 2021
fceeb2d
.
ananyahjha93 Sep 27, 2021
24fcdbd
.
ananyahjha93 Sep 27, 2021
3e2eeb0
Update flash/image/face_detection/model.py
ananyahjha93 Sep 27, 2021
f1bff2b
comments
ananyahjha93 Sep 27, 2021
6dd8b10
Merge branch 'feature/face_detection' of https://github.com/borhanMor…
ananyahjha93 Sep 27, 2021
5ef6812
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2021
087a317
.
ananyahjha93 Sep 27, 2021
7048430
Merge branch 'feature/face_detection' of https://github.com/borhanMor…
ananyahjha93 Sep 27, 2021
d570cb3
.
ananyahjha93 Sep 27, 2021
bbda294
.
ananyahjha93 Sep 27, 2021
2eb53d6
Merge branch 'master' into feature/face_detection
ananyahjha93 Sep 28, 2021
dc2d102
Merge branch 'master' into feature/face_detection
tchaton Sep 28, 2021
7fba71d
Merge branch 'feature/face_detection' of https://github.com/PyTorchLi…
borhanMorphy Sep 28, 2021
5db8aca
added comments to clearfy some steps in the face detection task
borhanMorphy Sep 28, 2021
8117900
conflicts are resolved for face detection task
borhanMorphy Sep 28, 2021
e2ca442
imports
ananyahjha93 Sep 28, 2021
4deaf9c
.
ananyahjha93 Sep 29, 2021
2b7c326
.
ananyahjha93 Sep 29, 2021
144e063
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2021
2085764
.
ananyahjha93 Sep 29, 2021
1d53299
.
ananyahjha93 Sep 29, 2021
0e4c862
.
ananyahjha93 Sep 29, 2021
b94415c
Merge branch 'master' into feature/face_detection
ananyahjha93 Sep 29, 2021
6a9a9a1
tests
ananyahjha93 Sep 30, 2021
934d44b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2021
164de7c
.
ananyahjha93 Sep 30, 2021
895863d
Merge branch 'feature/face_detection' of https://github.com/borhanMor…
ananyahjha93 Sep 30, 2021
5384003
.
ananyahjha93 Sep 30, 2021
a94b107
.
ananyahjha93 Sep 30, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for `from_data_frame` to `TextClassificationData` ([#785](https://github.com/PyTorchLightning/lightning-flash/pull/785))

- Added `FastFace` integration ([#606](https://github.com/PyTorchLightning/lightning-flash/pull/606))

- Added support for `from_lists` to `TextClassificationData` ([#805](https://github.com/PyTorchLightning/lightning-flash/pull/805))

### Changed
Expand Down
14 changes: 14 additions & 0 deletions flash/core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os.path
import tarfile
import zipfile
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Set, Type

Expand Down Expand Up @@ -148,10 +149,23 @@ def download_data(url: str, path: str = "data/", verbose: bool = False) -> None:
):
fp.write(chunk) # type: ignore

def extract_tarfile(file_path: str, extract_path: str, mode: str):
if os.path.exists(file_path):
with tarfile.open(file_path, mode=mode) as tar_ref:
for member in tar_ref.getmembers():
try:
tar_ref.extract(member, path=extract_path, set_attrs=False)
except PermissionError:
pass
ananyahjha93 marked this conversation as resolved.
Show resolved Hide resolved

if ".zip" in local_filename:
if os.path.exists(local_filename):
ananyahjha93 marked this conversation as resolved.
Show resolved Hide resolved
with zipfile.ZipFile(local_filename, "r") as zip_ref:
zip_ref.extractall(path)
elif local_filename.endswith(".tar.gz") or local_filename.endswith(".tgz"):
extract_tarfile(local_filename, path, "r:gz")
elif local_filename.endswith(".tar.bz2") or local_filename.endswith(".tbz"):
extract_tarfile(local_filename, path, "r:bz2")


def _contains_any_tensor(value: Any, dtype: Type = Tensor) -> bool:
Expand Down
1 change: 1 addition & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def _compare_version(package: str, op, version) -> bool:
_PIL_AVAILABLE = _module_available("PIL")
_OPEN3D_AVAILABLE = _module_available("open3d")
_SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch")
_FASTFACE_AVAILABLE = _module_available("fastface")
_LIBROSA_AVAILABLE = _module_available("librosa")
_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
Expand Down
1 change: 1 addition & 0 deletions flash/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES # noqa: F401
from flash.image.detection import ObjectDetectionData, ObjectDetector # noqa: F401
from flash.image.embedding import ImageEmbedder # noqa: F401
from flash.image.face_detection import FaceDetectionData, FaceDetector # noqa: F401
from flash.image.instance_segmentation import InstanceSegmentation, InstanceSegmentationData # noqa: F401
from flash.image.keypoint_detection import KeypointDetectionData, KeypointDetector # noqa: F401
from flash.image.segmentation import ( # noqa: F401
Expand Down
2 changes: 2 additions & 0 deletions flash/image/face_detection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from flash.image.face_detection.data import FaceDetectionData # noqa: F401
from flash.image.face_detection.model import FaceDetector # noqa: F401
5 changes: 5 additions & 0 deletions flash/image/face_detection/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from flash.core.registry import FlashRegistry # noqa: F401
from flash.image.face_detection.backbones.fastface_backbones import register_ff_backbones # noqa: F401

FACE_DETECTION_BACKBONES = FlashRegistry("face_detection_backbones")
register_ff_backbones(FACE_DETECTION_BACKBONES)
44 changes: 44 additions & 0 deletions flash/image/face_detection/backbones/fastface_backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FASTFACE_AVAILABLE

if _FASTFACE_AVAILABLE:
import fastface as ff

_MODEL_NAMES = ff.list_pretrained_models()
else:
_MODEL_NAMES = []


def fastface_backbone(model_name, pretrained, **kwargs):
ananyahjha93 marked this conversation as resolved.
Show resolved Hide resolved
if pretrained:
pl_model = ff.FaceDetector.from_pretrained(model_name, **kwargs)
else:
arch, config = model_name.split("_")
pl_model = ff.FaceDetector.build(arch, config, **kwargs)

backbone = getattr(pl_model, "arch")

return backbone, pl_model


def register_ff_backbones(register: FlashRegistry):
if _FASTFACE_AVAILABLE:
backbones = [partial(fastface_backbone, model_name=name) for name in _MODEL_NAMES]

for idx, backbone in enumerate(backbones):
register(backbone, name=_MODEL_NAMES[idx])
172 changes: 172 additions & 0 deletions flash/image/face_detection/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple

import torch
import torch.nn as nn
from torch.utils.data import Dataset

from flash.core.data.data_source import DatasetDataSource, DefaultDataKeys, DefaultDataSources
from flash.core.data.process import Postprocess, Preprocess
from flash.core.data.transforms import ApplyToKeys
from flash.core.utilities.imports import _FASTFACE_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.image.data import ImagePathsDataSource
from flash.image.detection import ObjectDetectionData

if _TORCHVISION_AVAILABLE:
import torchvision
from torchvision.datasets.folder import default_loader

if _FASTFACE_AVAILABLE:
import fastface as ff


def fastface_collate_fn(samples: Sequence[Dict[str, Any]]) -> Dict[str, Sequence[Any]]:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""Collate function from fastface.

Organizes individual elements in a batch, calls prepare_batch from fastface and prepares the targets.
"""
samples = {key: [sample[key] for sample in samples] for key in samples[0]}

images, scales, paddings = ff.utils.preprocess.prepare_batch(
samples[DefaultDataKeys.INPUT], None, adaptive_batch=True
)

samples["scales"] = scales
samples["paddings"] = paddings

if DefaultDataKeys.TARGET in samples.keys():
targets = samples[DefaultDataKeys.TARGET]
targets = [{"target_boxes": target["boxes"]} for target in targets]

for i, (target, scale, padding) in enumerate(zip(targets, scales, paddings)):
target["target_boxes"] *= scale
target["target_boxes"][:, [0, 2]] += padding[0]
target["target_boxes"][:, [1, 3]] += padding[1]
targets[i]["target_boxes"] = target["target_boxes"]

samples[DefaultDataKeys.TARGET] = targets
samples[DefaultDataKeys.INPUT] = images

return samples


class FastFaceDataSource(DatasetDataSource):
"""Logic for loading from FDDBDataset."""
ananyahjha93 marked this conversation as resolved.
Show resolved Hide resolved

def load_data(self, data: Dataset, dataset: Any = None) -> Dataset:
new_data = []
for img_file_path, targets in zip(data.ids, data.targets):
new_data.append(
super().load_sample(
(
img_file_path,
dict(
boxes=targets["target_boxes"],
# label `1` indicates positive sample
labels=[1 for _ in range(targets["target_boxes"].shape[0])],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

noob question. Why are 1 hardcoded there ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 indicates that this is a positive sample because the task Is similar to binary classification at this point

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add this as a comment please?

),
)
)
)

return new_data

def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str, Any]:
filepath = sample[DefaultDataKeys.INPUT]
img = default_loader(filepath)
sample[DefaultDataKeys.INPUT] = img

w, h = img.size # WxH
sample[DefaultDataKeys.METADATA] = {
"filepath": filepath,
"size": (h, w),
}

return sample


class FaceDetectionPreprocess(Preprocess):
ananyahjha93 marked this conversation as resolved.
Show resolved Hide resolved
"""Applies default transform and collate_fn for fastface on FastFaceDataSource."""

def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
image_size: Tuple[int, int] = (128, 128),
):
self.image_size = image_size

super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.FILES: ImagePathsDataSource(),
DefaultDataSources.FOLDERS: ImagePathsDataSource(),
DefaultDataSources.DATASETS: FastFaceDataSource(),
},
default_data_source=DefaultDataSources.FILES,
)

def get_state_dict(self) -> Dict[str, Any]:
return {**self.transforms}

@classmethod
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
return cls(**state_dict)

def default_transforms(self) -> Dict[str, Callable]:
return {
"to_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
ApplyToKeys(
DefaultDataKeys.TARGET,
nn.Sequential(
ApplyToKeys("boxes", torch.as_tensor),
ApplyToKeys("labels", torch.as_tensor),
),
),
),
"collate": fastface_collate_fn,
}


class FaceDetectionPostProcess(Postprocess):
"""Generates preds from model output."""

@staticmethod
def per_batch_transform(batch: Any) -> Any:
ananyahjha93 marked this conversation as resolved.
Show resolved Hide resolved
scales = batch["scales"]
paddings = batch["paddings"]

batch.pop("scales", None)
batch.pop("paddings", None)

preds = batch[DefaultDataKeys.PREDS]

# preds: list of torch.Tensor(N, 5) as x1, y1, x2, y2, score
preds = [preds[preds[:, 5] == batch_idx, :5] for batch_idx in range(len(preds))]
preds = ff.utils.preprocess.adjust_results(preds, scales, paddings)
batch[DefaultDataKeys.PREDS] = preds

return batch


class FaceDetectionData(ObjectDetectionData):
preprocess_cls = FaceDetectionPreprocess
postprocess_cls = FaceDetectionPostProcess
Loading