This repository was archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 210
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
56bce9c
commit 2b22d87
Showing
5 changed files
with
374 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from flash.image.face_detection.model import FaceDetector # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple | ||
|
||
from torch.utils.data import Dataset | ||
|
||
from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources | ||
from flash.core.data.process import Preprocess | ||
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE | ||
from flash.image.data import ImagePathsDataSource | ||
from flash.image.detection.transforms import default_transforms | ||
|
||
if _TORCHVISION_AVAILABLE: | ||
from torchvision.datasets.folder import default_loader | ||
|
||
|
||
class FastFaceDataSource(DataSource[Tuple[str, str]]): | ||
|
||
def load_data(self, data: Dataset, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]: | ||
new_data = [] | ||
for img_file_path, targets in zip(data.ids, data.targets): | ||
new_data.append( | ||
dict( | ||
input=img_file_path, | ||
target=dict( | ||
boxes=targets["target_boxes"], | ||
labels=[1 for _ in range(targets["target_boxes"].shape[0])], | ||
) | ||
) | ||
) | ||
return new_data | ||
|
||
def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: | ||
filepath = sample[DefaultDataKeys.INPUT] | ||
img = default_loader(filepath) | ||
sample[DefaultDataKeys.INPUT] = img | ||
w, h = img.size # WxH | ||
sample[DefaultDataKeys.METADATA] = { | ||
"filepath": filepath, | ||
"size": (h, w), | ||
} | ||
return sample | ||
|
||
|
||
class FaceDetectionPreprocess(Preprocess): | ||
|
||
def __init__( | ||
self, | ||
train_transform: Optional[Dict[str, Callable]] = None, | ||
val_transform: Optional[Dict[str, Callable]] = None, | ||
test_transform: Optional[Dict[str, Callable]] = None, | ||
predict_transform: Optional[Dict[str, Callable]] = None | ||
): | ||
super().__init__( | ||
train_transform=train_transform, | ||
val_transform=val_transform, | ||
test_transform=test_transform, | ||
predict_transform=predict_transform, | ||
data_sources={ | ||
DefaultDataSources.FILES: ImagePathsDataSource(), | ||
DefaultDataSources.FOLDERS: ImagePathsDataSource(), | ||
"fastface": FastFaceDataSource() | ||
}, | ||
default_data_source=DefaultDataSources.FILES, | ||
) | ||
|
||
def get_state_dict(self) -> Dict[str, Any]: | ||
return {**self.transforms} | ||
|
||
@classmethod | ||
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): | ||
return cls(**state_dict) | ||
|
||
def default_transforms(self) -> Optional[Dict[str, Callable]]: | ||
return default_transforms() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union | ||
|
||
import torch | ||
from torch import nn | ||
from torch.optim import Optimizer | ||
|
||
from flash.core.data.data_source import DefaultDataKeys | ||
from flash.core.data.process import Preprocess, Serializer | ||
from flash.core.model import Task | ||
from flash.core.utilities.imports import _FASTFACE_AVAILABLE | ||
from flash.image.detection.finetuning import ObjectDetectionFineTuning | ||
from flash.image.detection.serialization import DetectionLabels | ||
from flash.image.face_detection.data import FaceDetectionPreprocess | ||
|
||
if _FASTFACE_AVAILABLE: | ||
import fastface as ff | ||
|
||
|
||
class FaceDetector(Task): | ||
"""The ``FaceDetector`` is a :class:`~flash.Task` for detecting faces in images. For more details, see | ||
:ref:`face_detection`. | ||
Args: | ||
model: a string of :attr`_models`. Defaults to 'lffd_slim'. | ||
pretrained: Whether the model from fastface should be loaded with it's pretrained weights. | ||
loss: the function(s) to update the model with. Has no effect for fastface models. | ||
metrics: The provided metrics. All metrics here will be logged to progress bar and the respective logger. | ||
Changing this argument currently has no effect. | ||
optimizer: The optimizer to use for training. Can either be the actual class or the class name. | ||
learning_rate: The learning rate to use for training | ||
""" | ||
|
||
required_extras: str = "image" | ||
|
||
def __init__( | ||
self, | ||
model: str = "lffd_slim", | ||
pretrained: bool = True, | ||
loss=None, | ||
metrics: Union[Callable, nn.Module, Mapping, Sequence, None] = None, | ||
optimizer: Type[Optimizer] = torch.optim.AdamW, | ||
learning_rate: float = 1e-4, | ||
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, | ||
preprocess: Optional[Preprocess] = None, | ||
**kwargs: Any, | ||
): | ||
self.save_hyperparameters() | ||
|
||
if model in ff.list_pretrained_models(): | ||
model = FaceDetector.get_model(model, pretrained, **kwargs) | ||
else: | ||
ValueError(f"{model} is not supported yet.") | ||
|
||
super().__init__( | ||
model=model, | ||
loss_fn=loss, | ||
metrics=metrics or {"AP": ff.metric.AveragePrecision()}, | ||
learning_rate=learning_rate, | ||
optimizer=optimizer, | ||
serializer=serializer or DetectionLabels(), | ||
preprocess=preprocess or FaceDetectionPreprocess(), | ||
) | ||
|
||
@staticmethod | ||
def get_model( | ||
model_name, | ||
pretrained, | ||
**kwargs, | ||
): | ||
|
||
if pretrained: | ||
pl_model = ff.FaceDetector.from_pretrained(model_name, **kwargs) | ||
else: | ||
arch, config = model_name.split("_") | ||
pl_model = ff.FaceDetector.build(arch, config, **kwargs) | ||
|
||
# get torch.nn.Module | ||
model = getattr(pl_model, "arch") | ||
|
||
# set preprocess params | ||
model.register_buffer("normalizer", getattr(pl_model, "normalizer")) | ||
model.register_buffer("mean", getattr(pl_model, "mean")) | ||
model.register_buffer("std", getattr(pl_model, "std")) | ||
|
||
# set postprocess function | ||
setattr(model, "_postprocess", getattr(pl_model, "_postprocess")) | ||
|
||
return model | ||
|
||
def forward(self, x: List[torch.Tensor]) -> Any: | ||
|
||
batch, scales, paddings = ff.utils.preprocess.prepare_batch(x, None, adaptive_batch=True) | ||
# batch: torch.Tensor(B,C,T,T) | ||
# scales: torch.Tensor(B,) | ||
# paddings: torch.Tensor(B,4) as pad (left, top, right, bottom) | ||
|
||
# apply preprocess | ||
batch = (((batch * 255) / self.model.normalizer) - self.model.mean) / self.model.std | ||
|
||
# get logits | ||
logits = self.model(batch) | ||
# logits, any | ||
|
||
preds = self.model.logits_to_preds(logits) | ||
# preds: torch.Tensor(B, N, 5) | ||
|
||
preds = self.model._postprocess(preds) | ||
# preds: torch.Tensor(N, 6) as x1,y1,x2,y2,score,batch_idx | ||
|
||
preds = [preds[preds[:, 5] == batch_idx, :5] for batch_idx in range(batch.size(0))] | ||
# preds: list of torch.Tensor(N, 5) as x1,y1,x2,y2,score | ||
|
||
preds = ff.utils.preprocess.adjust_results(preds, scales, paddings) | ||
# preds: list of torch.Tensor(N, 5) as x1,y1,x2,y2,score | ||
|
||
return preds | ||
|
||
def _prepare_batch(self, batch): | ||
images, targets = batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET] | ||
|
||
targets = [{"target_boxes": target["boxes"]} for target in targets] | ||
|
||
batch, scales, paddings = ff.utils.preprocess.prepare_batch(images, None, adaptive_batch=True) | ||
# batch: torch.Tensor(B,C,T,T) | ||
# scales: torch.Tensor(B,) | ||
# paddings: torch.Tensor(B,4) as pad (left, top, right, bottom) | ||
|
||
# apply preprocess | ||
batch = (((batch * 255) / self.model.normalizer) - self.model.mean) / self.model.std | ||
|
||
# adjust targets | ||
for i, (target, scale, padding) in enumerate(zip(targets, scales, paddings)): | ||
target["target_boxes"] *= scale | ||
target["target_boxes"][:, [0, 2]] += padding[0] | ||
target["target_boxes"][:, [1, 3]] += padding[1] | ||
targets[i]["target_boxes"] = target["target_boxes"] | ||
|
||
return batch, targets | ||
|
||
def _compute_metrics(self, logits, targets): | ||
preds = self.model.logits_to_preds(logits) | ||
# preds: torch.Tensor(B, N, 5) | ||
|
||
preds = self.model._postprocess(preds) | ||
# preds: torch.Tensor(N, 6) as x1,y1,x2,y2,score,batch_idx | ||
|
||
target_boxes = [target["target_boxes"] for target in targets] | ||
pred_boxes = [preds[preds[:, 5] == batch_idx, :5] for batch_idx in range(len(targets))] | ||
|
||
for metric in self.val_metrics.values(): | ||
metric.update(pred_boxes, target_boxes) | ||
|
||
def training_step(self, batch, batch_idx) -> Any: | ||
"""The training step. Overrides ``Task.training_step`` | ||
""" | ||
|
||
batch, targets = self._prepare_batch(batch) | ||
|
||
# get logits | ||
logits = self.model(batch) | ||
# logits, any | ||
|
||
# compute loss | ||
loss = self.model.compute_loss(logits, targets) | ||
# loss: dict of losses or loss | ||
|
||
self.log_dict({f"train_{k}": v for k, v in loss.items()}, on_step=True, on_epoch=True, prog_bar=True) | ||
return loss | ||
|
||
def on_validation_epoch_start(self) -> None: | ||
for metric in self.val_metrics.values(): | ||
metric.reset() | ||
|
||
def validation_step(self, batch, batch_idx): | ||
batch, targets = self._prepare_batch(batch) | ||
|
||
# get logits | ||
logits = self.model(batch) | ||
# logits, any | ||
|
||
# compute loss | ||
loss = self.model.compute_loss(logits, targets) | ||
# loss: dict of losses or loss | ||
|
||
self._compute_metrics(logits, targets) | ||
|
||
self.log_dict({f"val_{k}": v for k, v in loss.items()}, on_step=True, on_epoch=True, prog_bar=True) | ||
return loss | ||
|
||
def validation_epoch_end(self, outputs) -> None: | ||
metric_results = {name: metric.compute() for name, metric in self.val_metrics.items()} | ||
self.log_dict({f"val_{k}": v for k, v in metric_results.items()}, on_epoch=True) | ||
|
||
def on_test_epoch_start(self) -> None: | ||
for metric in self.val_metrics.values(): | ||
metric.reset() | ||
|
||
def test_step(self, batch, batch_idx): | ||
batch, targets = self._prepare_batch(batch) | ||
|
||
# get logits | ||
logits = self.model(batch) | ||
# logits, any | ||
|
||
# compute loss | ||
loss = self.model.compute_loss(logits, targets) | ||
# loss: dict of losses or loss | ||
|
||
self._compute_metrics(logits, targets) | ||
|
||
self.log_dict({f"test_{k}": v for k, v in loss.items()}, on_step=True, on_epoch=True, prog_bar=True) | ||
return loss | ||
|
||
def test_epoch_end(self, outputs) -> None: | ||
metric_results = {name: metric.compute() for name, metric in self.val_metrics.items()} | ||
self.log_dict({f"test_{k}": v for k, v in metric_results.items()}, on_epoch=True) | ||
|
||
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: | ||
images = batch[DefaultDataKeys.INPUT] | ||
batch[DefaultDataKeys.PREDS] = self(images) | ||
return batch | ||
|
||
def configure_finetune_callback(self): | ||
return [ObjectDetectionFineTuning(train_bn=True)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import flash | ||
from flash.core.data.data_module import DataModule | ||
from flash.core.utilities.imports import _FASTFACE_AVAILABLE | ||
from flash.image import FaceDetector | ||
from flash.image.face_detection.data import FaceDetectionPreprocess | ||
|
||
if _FASTFACE_AVAILABLE: | ||
import fastface as ff | ||
else: | ||
raise ModuleNotFoundError("Please, pip install -e '.[image]'") | ||
|
||
# 1. Create the DataModule | ||
train_dataset = ff.dataset.FDDBDataset(source_dir="data/", phase="train") | ||
val_dataset = ff.dataset.FDDBDataset(source_dir="data/", phase="val") | ||
|
||
datamodule = DataModule.from_data_source( | ||
"fastface", train_data=train_dataset, val_data=val_dataset, preprocess=FaceDetectionPreprocess() | ||
) | ||
|
||
# 2. Build the task | ||
model = FaceDetector(model="lffd_slim") | ||
|
||
# 3. Create the trainer and finetune the model | ||
trainer = flash.Trainer(max_epochs=3, limit_train_batches=0.1, limit_val_batches=0.1) | ||
|
||
trainer.finetune(model, datamodule=datamodule, strategy="freeze") | ||
|
||
# 4. Detect faces in a few images! | ||
predictions = model.predict([ | ||
"data/2002/07/19/big/img_18.jpg", | ||
"data/2002/07/19/big/img_65.jpg", | ||
"data/2002/07/19/big/img_255.jpg", | ||
]) | ||
print(predictions) | ||
|
||
# 5. Save the model! | ||
trainer.save_checkpoint("face_detection_model.pt") |