Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
pep8
Browse files Browse the repository at this point in the history
  • Loading branch information
ananyahjha93 committed Sep 24, 2021
1 parent 6c624ad commit 13f0c27
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 47 deletions.
2 changes: 1 addition & 1 deletion flash/core/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def extract_tarfile(file_path: str, extract_path: str, mode: str):
zip_ref.extractall(path)
elif local_filename.endswith(".tar.gz") or local_filename.endswith(".tgz"):
extract_tarfile(local_filename, path, "r:gz")
elif local_filename.endswith('.tar.bz2') or local_filename.endswith('.tbz'):
elif local_filename.endswith(".tar.bz2") or local_filename.endswith(".tbz"):
extract_tarfile(local_filename, path, "r:bz2")


Expand Down
2 changes: 1 addition & 1 deletion flash/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES # noqa: F401
from flash.image.detection import ObjectDetectionData, ObjectDetector # noqa: F401
from flash.image.embedding import ImageEmbedder # noqa: F401
from flash.image.face_detection import FaceDetector, FaceDetectionData # noqa: F401
from flash.image.face_detection import FaceDetectionData, FaceDetector # noqa: F401
from flash.image.instance_segmentation import InstanceSegmentation, InstanceSegmentationData # noqa: F401
from flash.image.keypoint_detection import KeypointDetectionData, KeypointDetector # noqa: F401
from flash.image.segmentation import ( # noqa: F401
Expand Down
2 changes: 1 addition & 1 deletion flash/image/face_detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from flash.image.face_detection.model import FaceDetector # noqa: F401
from flash.image.face_detection.data import FaceDetectionData # noqa: F401
from flash.image.face_detection.model import FaceDetector # noqa: F401
35 changes: 13 additions & 22 deletions flash/image/face_detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Mapping
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple

import torch
import torchvision
import torch.nn as nn

import torchvision
from torch.utils.data import Dataset

from flash.core.data.transforms import ApplyToKeys
from flash.core.data.data_source import DatasetDataSource, DefaultDataKeys, DefaultDataSources
from flash.core.data.process import Preprocess, Postprocess
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE, _ICEVISION_AVAILABLE, _FASTFACE_AVAILABLE
from flash.core.data.process import Postprocess, Preprocess
from flash.core.data.transforms import ApplyToKeys
from flash.core.utilities.imports import _FASTFACE_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.image.data import ImagePathsDataSource
from flash.core.integrations.icevision.data import IceVisionParserDataSource
from flash.core.integrations.icevision.transforms import default_transforms
from flash.image.detection import ObjectDetectionData

if _TORCHVISION_AVAILABLE:
from torchvision.datasets.folder import default_loader

if _ICEVISION_AVAILABLE:
from icevision.parsers import COCOBBoxParser
else:
COCOBBoxParser = object

if _FASTFACE_AVAILABLE:
import fastface as ff

Expand Down Expand Up @@ -77,7 +69,7 @@ def load_data(self, data: Dataset, dataset: Any = None) -> Dataset:
dict(
boxes=targets["target_boxes"],
labels=[1 for _ in range(targets["target_boxes"].shape[0])],
)
),
)
)
)
Expand All @@ -99,7 +91,6 @@ def load_sample(self, sample: Any, dataset: Optional[Any] = None) -> Mapping[str


class FaceDetectionPreprocess(Preprocess):

def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
Expand Down Expand Up @@ -137,9 +128,9 @@ def default_transforms(self) -> Dict[str, Callable]:
ApplyToKeys(
DefaultDataKeys.TARGET,
nn.Sequential(
ApplyToKeys('boxes', torch.as_tensor),
ApplyToKeys('labels', torch.as_tensor),
)
ApplyToKeys("boxes", torch.as_tensor),
ApplyToKeys("labels", torch.as_tensor),
),
),
),
"collate": fastface_collate_fn,
Expand All @@ -149,11 +140,11 @@ def default_transforms(self) -> Dict[str, Callable]:
class FaceDetectionPostProcess(Postprocess):
@staticmethod
def per_batch_transform(batch: Any) -> Any:
scales = batch['scales']
paddings = batch['paddings']
scales = batch["scales"]
paddings = batch["paddings"]

batch.pop('scales', None)
batch.pop('paddings', None)
batch.pop("scales", None)
batch.pop("paddings", None)

preds = batch[DefaultDataKeys.PREDS]

Expand Down
16 changes: 8 additions & 8 deletions flash/image/face_detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union, Dict
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union

import torch
import pytorch_lightning as pl

import torch
from torch import nn
from torch.optim import Optimizer

from flash.core.model import Task
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.process import Postprocess
from flash.core.data.process import Preprocess, Serializer
from flash.core.utilities.imports import _FASTFACE_AVAILABLE
from flash.core.finetuning import FlashBaseFinetuning
from flash.image.face_detection.data import FaceDetectionPreprocess, FaceDetectionPostProcess
from flash.core.model import Task
from flash.core.utilities.imports import _FASTFACE_AVAILABLE
from flash.image.face_detection.data import FaceDetectionPreprocess

if _FASTFACE_AVAILABLE:
import fastface as ff
Expand All @@ -48,7 +46,9 @@ def serialize(self, sample: Any) -> Dict[str, Any]:


class FaceDetector(Task):
"""The ``FaceDetector`` is a :class:`~flash.Task` for detecting faces in images. For more details, see
"""The ``FaceDetector`` is a :class:`~flash.Task` for detecting faces in images.
For more details, see
:ref:`face_detection`.
Args:
model: a string of :attr`_models`. Defaults to 'lffd_slim'.
Expand Down
23 changes: 9 additions & 14 deletions flash_examples/face_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import flash

from flash.core.data.utils import download_data
from flash.core.data.data_module import DataModule
from flash.core.utilities.imports import _FASTFACE_AVAILABLE
from flash.image import FaceDetector
from flash.image.face_detection.data import FaceDetectionPreprocess, FaceDetectionPostProcess
from flash.image import FaceDetectionData
from flash.image import FaceDetectionData, FaceDetector

if _FASTFACE_AVAILABLE:
import fastface as ff
Expand All @@ -29,9 +24,7 @@
train_dataset = ff.dataset.FDDBDataset(source_dir="data/", phase="train")
val_dataset = ff.dataset.FDDBDataset(source_dir="data/", phase="val")

datamodule = FaceDetectionData.from_datasets(
train_dataset=train_dataset, val_dataset=val_dataset, batch_size=2
)
datamodule = FaceDetectionData.from_datasets(train_dataset=train_dataset, val_dataset=val_dataset, batch_size=2)

# # 2. Build the task
model = FaceDetector(model="lffd_slim")
Expand All @@ -41,11 +34,13 @@
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Detect faces in a few images!
predictions = model.predict([
"data/2002/07/19/big/img_18.jpg",
"data/2002/07/19/big/img_65.jpg",
"data/2002/07/19/big/img_255.jpg",
])
predictions = model.predict(
[
"data/2002/07/19/big/img_18.jpg",
"data/2002/07/19/big/img_65.jpg",
"data/2002/07/19/big/img_255.jpg",
]
)
print(predictions)

# # 5. Save the model!
Expand Down

0 comments on commit 13f0c27

Please sign in to comment.