\w+?)_multi_layer_encoder$")
- STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones")
-
for mle_fn in dir(enc):
match = MLE_FN_PATTERN.match(mle_fn)
if not match:
@@ -37,5 +36,5 @@
fn=lambda: (getattr(enc, mle_fn)(), None),
name=match.group("name"),
namespace="image/style_transfer",
- package="pystiche",
+ providers=_PYSTICHE,
)
diff --git a/flash/image/style_transfer/cli.py b/flash/image/style_transfer/cli.py
new file mode 100644
index 0000000000..0fec347021
--- /dev/null
+++ b/flash/image/style_transfer/cli.py
@@ -0,0 +1,57 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from typing import Optional
+
+import flash
+from flash.core.data.utils import download_data
+from flash.core.utilities.flash_cli import FlashCLI
+from flash.image import StyleTransfer, StyleTransferData
+
+__all__ = ["style_transfer"]
+
+
+def from_coco_128(
+ batch_size: int = 4,
+ num_workers: Optional[int] = None,
+ **preprocess_kwargs,
+) -> StyleTransferData:
+ """Downloads and loads the COCO 128 data set."""
+ download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/")
+ return StyleTransferData.from_folders(
+ train_folder="data/coco128/images/train2017/",
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **preprocess_kwargs,
+ )
+
+
+def style_transfer():
+ """Image style transfer."""
+ cli = FlashCLI(
+ StyleTransfer,
+ StyleTransferData,
+ default_datamodule_builder=from_coco_128,
+ default_arguments={
+ "trainer.max_epochs": 3,
+ "model.style_image": os.path.join(flash.ASSETS_ROOT, "starry_night.jpg"),
+ },
+ finetune=False,
+ )
+
+ cli.trainer.save_checkpoint("style_transfer_model.pt")
+
+
+if __name__ == "__main__":
+ style_transfer()
diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py
index 75ab6f9e7a..f9f63c5905 100644
--- a/flash/image/style_transfer/data.py
+++ b/flash/image/style_transfer/data.py
@@ -17,6 +17,7 @@
from torch import nn
+from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources
from flash.core.data.process import Preprocess
from flash.core.data.transforms import ApplyToKeys
@@ -31,9 +32,9 @@
__all__ = ["StyleTransferPreprocess", "StyleTransferData"]
-def _apply_to_input(default_transforms_fn, keys: Union[Sequence[DefaultDataKeys],
- DefaultDataKeys]) -> Callable[..., Dict[str, ApplyToKeys]]:
-
+def _apply_to_input(
+ default_transforms_fn, keys: Union[Sequence[DefaultDataKeys], DefaultDataKeys]
+) -> Callable[..., Dict[str, ApplyToKeys]]:
@functools.wraps(default_transforms_fn)
def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]:
default_transforms = default_transforms_fn(*args, **kwargs)
@@ -46,7 +47,6 @@ def wrapper(*args: Any, **kwargs: Any) -> Optional[Dict[str, ApplyToKeys]]:
class StyleTransferPreprocess(Preprocess):
-
def __init__(
self,
train_transform: Optional[Union[Dict[str, Callable]]] = None,
@@ -118,12 +118,12 @@ def from_folders(
predict_transform: Optional[Union[str, Dict]] = None,
preprocess: Optional[Preprocess] = None,
**kwargs: Any,
- ) -> "StyleTransferData":
+ ) -> "DataModule":
- if any(param in kwargs for param in ("val_folder", "val_transform")):
+ if any(param in kwargs and kwargs[param] is not None for param in ("val_folder", "val_transform")):
raise_not_supported("validation")
- if any(param in kwargs for param in ("test_folder", "test_transform")):
+ if any(param in kwargs and kwargs[param] is not None for param in ("test_folder", "test_transform")):
raise_not_supported("test")
preprocess = preprocess or cls.preprocess_cls(
diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py
index 1573a10612..86a6b723e5 100644
--- a/flash/image/style_transfer/model.py
+++ b/flash/image/style_transfer/model.py
@@ -26,7 +26,7 @@
if _IMAGE_AVAILABLE:
import pystiche.demo
- from pystiche import enc, loss, ops
+ from pystiche import enc, loss
from pystiche.image import read_image
else:
@@ -34,12 +34,9 @@ class enc:
Encoder = None
MultiLayerEncoder = None
- class ops:
- EncodingComparisonOperator = None
- FeatureReconstructionOperator = None
- MultiLayerEncodingOperator = None
-
class loss:
+ class GramLoss:
+ pass
class PerceptualLoss:
pass
@@ -80,7 +77,7 @@ def __init__(
backbone: str = "vgg16",
content_layer: str = "relu2_2",
content_weight: float = 1e5,
- style_layers: Union[Sequence[str], str] = ("relu1_2", "relu2_2", "relu3_3", "relu4_3"),
+ style_layers: Union[Sequence[str], str] = ["relu1_2", "relu2_2", "relu3_3", "relu4_3"],
style_weight: float = 1e10,
optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
@@ -100,7 +97,7 @@ def __init__(
model = pystiche.demo.transformer()
if not isinstance(style_layers, (List, Tuple)):
- style_layers = (style_layers, )
+ style_layers = (style_layers,)
perceptual_loss = self._get_perceptual_loss(
backbone=backbone,
@@ -129,12 +126,11 @@ def default_style_image() -> torch.Tensor:
return pystiche.demo.images()["paint"].read(size=256)
@staticmethod
- def _modified_gram_loss(encoder: enc.Encoder, *, score_weight: float) -> ops.EncodingComparisonOperator:
+ def _modified_gram_loss(encoder: enc.Encoder, *, score_weight: float) -> loss.GramLoss:
# The official PyTorch examples as well as the reference implementation of the original author contain an
# oversight: they normalize the representation twice by the number of channels. To be compatible with them, we
# do the same here.
- class GramOperator(ops.GramOperator):
-
+ class GramOperator(loss.GramLoss):
def enc_to_repr(self, enc: torch.Tensor) -> torch.Tensor:
repr = super().enc_to_repr(enc)
num_channels = repr.size()[1]
@@ -152,10 +148,8 @@ def _get_perceptual_loss(
style_weight: float,
) -> loss.PerceptualLoss:
mle, _ = cast(enc.MultiLayerEncoder, self.backbones.get(backbone)())
- content_loss = ops.FeatureReconstructionOperator(
- mle.extract_encoder(content_layer), score_weight=content_weight
- )
- style_loss = ops.MultiLayerEncodingOperator(
+ content_loss = loss.FeatureReconstructionLoss(mle.extract_encoder(content_layer), score_weight=content_weight)
+ style_loss = loss.MultiLayerEncodingLoss(
mle,
style_layers,
lambda encoder, layer_weight: self._modified_gram_loss(encoder, score_weight=layer_weight),
diff --git a/flash/pointcloud/__init__.py b/flash/pointcloud/__init__.py
new file mode 100644
index 0000000000..766f2f2e89
--- /dev/null
+++ b/flash/pointcloud/__init__.py
@@ -0,0 +1,2 @@
+from flash.pointcloud.detection import PointCloudObjectDetector, PointCloudObjectDetectorData # noqa: F401
+from flash.pointcloud.segmentation import PointCloudSegmentation, PointCloudSegmentationData # noqa: F401
diff --git a/flash/pointcloud/detection/__init__.py b/flash/pointcloud/detection/__init__.py
new file mode 100644
index 0000000000..cfe4c690f0
--- /dev/null
+++ b/flash/pointcloud/detection/__init__.py
@@ -0,0 +1,3 @@
+from flash.pointcloud.detection.data import PointCloudObjectDetectorData # noqa: F401
+from flash.pointcloud.detection.model import PointCloudObjectDetector # noqa: F401
+from flash.pointcloud.detection.open3d_ml.app import launch_app # noqa: F401
diff --git a/flash/pointcloud/detection/backbones.py b/flash/pointcloud/detection/backbones.py
new file mode 100644
index 0000000000..88268dd036
--- /dev/null
+++ b/flash/pointcloud/detection/backbones.py
@@ -0,0 +1,19 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from flash.core.registry import FlashRegistry
+from flash.pointcloud.detection.open3d_ml.backbones import register_open_3d_ml
+
+POINTCLOUD_OBJECT_DETECTION_BACKBONES = FlashRegistry("backbones")
+
+register_open_3d_ml(POINTCLOUD_OBJECT_DETECTION_BACKBONES)
diff --git a/flash/pointcloud/detection/cli.py b/flash/pointcloud/detection/cli.py
new file mode 100644
index 0000000000..01a4c329ce
--- /dev/null
+++ b/flash/pointcloud/detection/cli.py
@@ -0,0 +1,55 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+from flash.core.data.utils import download_data
+from flash.core.utilities.flash_cli import FlashCLI
+from flash.pointcloud import PointCloudObjectDetector, PointCloudObjectDetectorData
+
+__all__ = ["pointcloud_detection"]
+
+
+def from_kitti(
+ batch_size: int = 4,
+ num_workers: Optional[int] = None,
+ **preprocess_kwargs,
+) -> PointCloudObjectDetectorData:
+ """Downloads and loads the KITTI data set."""
+ download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_tiny.zip", "data/")
+ return PointCloudObjectDetectorData.from_folders(
+ train_folder="data/KITTI_Tiny/Kitti/train",
+ val_folder="data/KITTI_Tiny/Kitti/val",
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **preprocess_kwargs,
+ )
+
+
+def pointcloud_detection():
+ """Detect objects in point clouds."""
+ cli = FlashCLI(
+ PointCloudObjectDetector,
+ PointCloudObjectDetectorData,
+ default_datamodule_builder=from_kitti,
+ default_arguments={
+ "trainer.max_epochs": 3,
+ },
+ finetune=False,
+ )
+
+ cli.trainer.save_checkpoint("pointcloud_detection_model.pt")
+
+
+if __name__ == "__main__":
+ pointcloud_detection()
diff --git a/flash/pointcloud/detection/data.py b/flash/pointcloud/detection/data.py
new file mode 100644
index 0000000000..40349b8653
--- /dev/null
+++ b/flash/pointcloud/detection/data.py
@@ -0,0 +1,176 @@
+from typing import Any, Callable, Dict, Optional, Type
+
+from torch.utils.data import Sampler
+
+from flash.core.data.base_viz import BaseDataFetcher
+from flash.core.data.data_module import DataModule
+from flash.core.data.data_source import BaseDataFormat, DataSource, DefaultDataKeys, DefaultDataSources
+from flash.core.data.process import Deserializer, Preprocess
+from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE, requires_extras
+
+if _POINTCLOUD_AVAILABLE:
+ from flash.pointcloud.detection.open3d_ml.data_sources import (
+ PointCloudObjectDetectionDataFormat,
+ PointCloudObjectDetectorFoldersDataSource,
+ )
+else:
+ PointCloudObjectDetectorFoldersDataSource = object
+
+ class PointCloudObjectDetectionDataFormat:
+ KITTI = None
+
+
+class PointCloudObjectDetectorDatasetDataSource(DataSource):
+ def __init__(self, **kwargs):
+ super().__init__()
+
+ def load_data(
+ self,
+ data: Any,
+ dataset: Optional[Any] = None,
+ ) -> Any:
+
+ dataset.dataset = data
+
+ return range(len(data))
+
+ def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any:
+ sample = dataset.dataset[index]
+
+ return {
+ DefaultDataKeys.INPUT: sample["data"],
+ DefaultDataKeys.METADATA: sample["attr"],
+ }
+
+
+class PointCloudObjectDetectorPreprocess(Preprocess):
+ @requires_extras("pointcloud")
+ def __init__(
+ self,
+ train_transform: Optional[Dict[str, Callable]] = None,
+ val_transform: Optional[Dict[str, Callable]] = None,
+ test_transform: Optional[Dict[str, Callable]] = None,
+ predict_transform: Optional[Dict[str, Callable]] = None,
+ deserializer: Optional[Deserializer] = None,
+ **data_source_kwargs,
+ ):
+
+ super().__init__(
+ train_transform=train_transform,
+ val_transform=val_transform,
+ test_transform=test_transform,
+ predict_transform=predict_transform,
+ data_sources={
+ DefaultDataSources.DATASETS: PointCloudObjectDetectorDatasetDataSource(**data_source_kwargs),
+ DefaultDataSources.FOLDERS: PointCloudObjectDetectorFoldersDataSource(**data_source_kwargs),
+ },
+ deserializer=deserializer,
+ default_data_source=DefaultDataSources.FOLDERS,
+ )
+
+ def get_state_dict(self):
+ return {}
+
+ def state_dict(self):
+ return {}
+
+ @classmethod
+ def load_state_dict(cls, state_dict, strict: bool = False):
+ pass
+
+
+class PointCloudObjectDetectorData(DataModule):
+
+ preprocess_cls = PointCloudObjectDetectorPreprocess
+
+ @classmethod
+ def from_folders(
+ cls,
+ train_folder: Optional[str] = None,
+ val_folder: Optional[str] = None,
+ test_folder: Optional[str] = None,
+ predict_folder: Optional[str] = None,
+ train_transform: Optional[Dict[str, Callable]] = None,
+ val_transform: Optional[Dict[str, Callable]] = None,
+ test_transform: Optional[Dict[str, Callable]] = None,
+ predict_transform: Optional[Dict[str, Callable]] = None,
+ data_fetcher: Optional[BaseDataFetcher] = None,
+ preprocess: Optional[Preprocess] = None,
+ val_split: Optional[float] = None,
+ batch_size: int = 4,
+ num_workers: Optional[int] = None,
+ sampler: Optional[Type[Sampler]] = None,
+ scans_folder_name: Optional[str] = "scans",
+ labels_folder_name: Optional[str] = "labels",
+ calibrations_folder_name: Optional[str] = "calibs",
+ data_format: Optional[BaseDataFormat] = PointCloudObjectDetectionDataFormat.KITTI,
+ **preprocess_kwargs: Any,
+ ) -> "DataModule":
+ """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given folders using the
+ :class:`~flash.core.data.data_source.DataSource` of name
+ :attr:`~flash.core.data.data_source.DefaultDataSources.FOLDERS`
+ from the passed or constructed :class:`~flash.core.data.process.Preprocess`.
+
+ Args:
+ train_folder: The folder containing the train data.
+ val_folder: The folder containing the validation data.
+ test_folder: The folder containing the test data.
+ predict_folder: The folder containing the predict data.
+ train_transform: The dictionary of transforms to use during training which maps
+ :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
+ val_transform: The dictionary of transforms to use during validation which maps
+ :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
+ test_transform: The dictionary of transforms to use during testing which maps
+ :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
+ predict_transform: The dictionary of transforms to use during predicting which maps
+ :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
+ data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
+ :class:`~flash.core.data.data_module.DataModule`.
+ preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
+ :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls``
+ will be constructed and used.
+ val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
+ batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
+ num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
+ sampler: The ``sampler`` to use for the ``train_dataloader``.
+ preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
+ if ``preprocess = None``.
+ scans_folder_name: The name of the pointcloud scan folder
+ labels_folder_name: The name of the pointcloud scan labels folder
+ calibrations_folder_name: The name of the pointcloud scan calibration folder
+ data_format: Format in which the data are stored.
+
+ Returns:
+ The constructed data module.
+
+ Examples::
+
+ data_module = DataModule.from_folders(
+ train_folder="train_folder",
+ train_transform={
+ "to_tensor_transform": torch.as_tensor,
+ },
+ )
+ """
+ return cls.from_data_source(
+ DefaultDataSources.FOLDERS,
+ train_folder,
+ val_folder,
+ test_folder,
+ predict_folder,
+ train_transform=train_transform,
+ val_transform=val_transform,
+ test_transform=test_transform,
+ predict_transform=predict_transform,
+ data_fetcher=data_fetcher,
+ preprocess=preprocess,
+ val_split=val_split,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ sampler=sampler,
+ scans_folder_name=scans_folder_name,
+ labels_folder_name=labels_folder_name,
+ calibrations_folder_name=calibrations_folder_name,
+ data_format=data_format,
+ **preprocess_kwargs,
+ )
diff --git a/flash/pointcloud/detection/datasets.py b/flash/pointcloud/detection/datasets.py
new file mode 100644
index 0000000000..335f699757
--- /dev/null
+++ b/flash/pointcloud/detection/datasets.py
@@ -0,0 +1,41 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+from flash.core.registry import FlashRegistry
+from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
+from flash.pointcloud.segmentation.datasets import executor
+
+if _POINTCLOUD_AVAILABLE:
+ from open3d.ml.datasets import KITTI
+
+_OBJECT_DETECTION_DATASET = FlashRegistry("dataset")
+
+
+@_OBJECT_DETECTION_DATASET
+def kitti(dataset_path, download, **kwargs):
+ name = "KITTI"
+ download_path = os.path.join(dataset_path, name, "Kitti")
+ if not os.path.exists(download_path):
+ executor(
+ "https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/scripts/download_datasets/download_kitti.sh", # noqa E501
+ None,
+ dataset_path,
+ name,
+ )
+ return KITTI(download_path, **kwargs)
+
+
+def KITTIDataset(dataset_path, download: bool = True, **kwargs):
+ return _OBJECT_DETECTION_DATASET.get("kitti")(dataset_path, download, **kwargs)
diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py
new file mode 100644
index 0000000000..155126d785
--- /dev/null
+++ b/flash/pointcloud/detection/model.py
@@ -0,0 +1,183 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union
+
+import torch
+import torchmetrics
+from torch import nn
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.utils.data import DataLoader, Sampler
+
+from flash.core.data.auto_dataset import BaseAutoDataset
+from flash.core.data.data_source import DefaultDataKeys
+from flash.core.data.process import Serializer
+from flash.core.data.states import CollateFn
+from flash.core.model import Task
+from flash.core.registry import FlashRegistry
+from flash.core.utilities.apply_func import get_callable_dict
+from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
+from flash.pointcloud.detection.backbones import POINTCLOUD_OBJECT_DETECTION_BACKBONES
+
+__FILE_EXAMPLE__ = "pointcloud_detection"
+
+
+class PointCloudObjectDetectorSerializer(Serializer):
+ pass
+
+
+class PointCloudObjectDetector(Task):
+ """The ``PointCloudObjectDetector`` is a :class:`~flash.core.classification.ClassificationTask` that classifies
+ pointcloud data.
+
+ Args:
+ num_features: The number of features (elements) in the input data.
+ num_classes: The number of classes (outputs) for this :class:`~flash.core.model.Task`.
+ backbone: The backbone name (or a tuple of ``nn.Module``, output size) to use.
+ backbone_kwargs: Any additional kwargs to pass to the backbone constructor.
+ loss_fn: The loss function to use. If ``None``, a default will be selected by the
+ :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument.
+ optimizer: The optimizer or optimizer class to use.
+ optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance).
+ scheduler: The scheduler or scheduler class to use.
+ scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance).
+ metrics: Any metrics to use with this :class:`~flash.core.model.Task`. If ``None``, a default will be selected
+ by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument.
+ learning_rate: The learning rate for the optimizer.
+ multi_label: If ``True``, this will be treated as a multi-label classification problem.
+ serializer: The :class:`~flash.core.data.process.Serializer` to use for prediction outputs.
+ lambda_loss_cls: The value to scale the loss classification.
+ lambda_loss_bbox: The value to scale the bounding boxes loss.
+ lambda_loss_dir: The value to scale the bounding boxes direction loss.
+ """
+
+ backbones: FlashRegistry = POINTCLOUD_OBJECT_DETECTION_BACKBONES
+ required_extras: str = "pointcloud"
+
+ def __init__(
+ self,
+ num_classes: int,
+ backbone: Union[str, Tuple[nn.Module, int]] = "pointpillars_kitti",
+ backbone_kwargs: Optional[Dict] = None,
+ head: Optional[nn.Module] = None,
+ loss_fn: Optional[Callable] = None,
+ optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam,
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
+ scheduler_kwargs: Optional[Dict[str, Any]] = None,
+ metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
+ learning_rate: float = 1e-2,
+ serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = PointCloudObjectDetectorSerializer(),
+ lambda_loss_cls: float = 1.0,
+ lambda_loss_bbox: float = 1.0,
+ lambda_loss_dir: float = 1.0,
+ ):
+
+ super().__init__(
+ model=None,
+ loss_fn=loss_fn,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
+ scheduler=scheduler,
+ scheduler_kwargs=scheduler_kwargs,
+ metrics=metrics,
+ learning_rate=learning_rate,
+ serializer=serializer,
+ )
+
+ self.save_hyperparameters()
+
+ if backbone_kwargs is None:
+ backbone_kwargs = {}
+
+ if isinstance(backbone, tuple):
+ self.backbone, out_features = backbone
+ else:
+ self.model, out_features, collate_fn = self.backbones.get(backbone)(**backbone_kwargs)
+ self.backbone = self.model.backbone
+ self.neck = self.model.neck
+ self.set_state(CollateFn(collate_fn))
+ self.set_state(CollateFn(collate_fn))
+ self.set_state(CollateFn(collate_fn))
+ self.loss_fn = get_callable_dict(self.model.loss)
+
+ if __FILE_EXAMPLE__ not in sys.argv[0]:
+ self.model.bbox_head.conv_cls = self.head = nn.Conv2d(
+ out_features, num_classes, kernel_size=(1, 1), stride=(1, 1)
+ )
+
+ def compute_loss(self, losses: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
+ losses = losses["loss"]
+ return (
+ self.hparams.lambda_loss_cls * losses["loss_cls"]
+ + self.hparams.lambda_loss_bbox * losses["loss_bbox"]
+ + self.hparams.lambda_loss_dir * losses["loss_dir"]
+ )
+
+ def compute_logs(self, logs: Dict[str, Any], losses: Dict[str, torch.Tensor]):
+ logs.update({"loss": self.compute_loss(losses)})
+ return logs
+
+ def training_step(self, batch: Any, batch_idx: int) -> Any:
+ return super().training_step((batch, batch), batch_idx)
+
+ def validation_step(self, batch: Any, batch_idx: int) -> Any:
+ super().validation_step((batch, batch), batch_idx)
+
+ def test_step(self, batch: Any, batch_idx: int) -> Any:
+ super().validation_step((batch, batch), batch_idx)
+
+ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
+ results = self.model(batch)
+ boxes = self.model.inference_end(results, batch)
+ return {
+ DefaultDataKeys.INPUT: getattr(batch, "point", None),
+ DefaultDataKeys.PREDS: boxes,
+ DefaultDataKeys.METADATA: [a["name"] for a in batch.attr],
+ }
+
+ def forward(self, x) -> torch.Tensor:
+ """First call the backbone, then the model head."""
+ # hack to enable backbone to work properly.
+ self.model.device = self.device
+ return self.model(x)
+
+ def _process_dataset(
+ self,
+ dataset: BaseAutoDataset,
+ batch_size: int,
+ num_workers: int,
+ pin_memory: bool,
+ collate_fn: Callable,
+ shuffle: bool = False,
+ drop_last: bool = True,
+ sampler: Optional[Sampler] = None,
+ ) -> DataLoader:
+
+ if not _POINTCLOUD_AVAILABLE:
+ raise ModuleNotFoundError("Please, run `pip install flash[pointcloud]`.")
+
+ dataset.preprocess_fn = self.model.preprocess
+ dataset.transform_fn = self.model.transform
+
+ return DataLoader(
+ dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ collate_fn=collate_fn,
+ shuffle=shuffle,
+ drop_last=drop_last,
+ sampler=sampler,
+ )
diff --git a/flash/pointcloud/detection/open3d_ml/app.py b/flash/pointcloud/detection/open3d_ml/app.py
new file mode 100644
index 0000000000..bddcfe7e41
--- /dev/null
+++ b/flash/pointcloud/detection/open3d_ml/app.py
@@ -0,0 +1,169 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import torch
+from torch.utils.data.dataset import Dataset
+
+import flash
+from flash.core.data.data_module import DataModule
+from flash.core.data.data_source import DefaultDataKeys
+from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
+
+if _POINTCLOUD_AVAILABLE:
+
+ from open3d._ml3d.vis.visualizer import LabelLUT, Visualizer
+ from open3d.visualization import gui
+
+ class Visualizer(Visualizer):
+ def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768):
+ """Visualize a dataset.
+
+ Example:
+ Minimal example for visualizing a dataset::
+ import open3d.ml.torch as ml3d # or open3d.ml.tf as ml3d
+
+ dataset = ml3d.datasets.SemanticKITTI(dataset_path='/path/to/SemanticKITTI/')
+ vis = ml3d.vis.Visualizer()
+ vis.visualize_dataset(dataset, 'all', indices=range(100))
+
+ Args:
+ dataset: The dataset to use for visualization.
+ split: The dataset split to be used, such as 'training'
+ indices: An iterable with a subset of the data points to visualize, such as [0,2,3,4].
+ width: The width of the visualization window.
+ height: The height of the visualization window.
+ """
+ # Setup the labels
+ lut = LabelLUT()
+ for id, color in dataset.color_map.items():
+ lut.add_label(id, id, color=color)
+ self.set_lut("label", lut)
+
+ self._consolidate_bounding_boxes = True
+ self._init_dataset(dataset, split, indices)
+
+ self._visualize("Open3D - " + dataset.name, width, height)
+
+ def _visualize(self, title, width, height):
+ gui.Application.instance.initialize()
+ self._init_user_interface(title, width, height)
+
+ # override just to set background color to back :)
+ bgcolor = gui.ColorEdit()
+ bgcolor.color_value = gui.Color(0, 0, 0)
+ self._on_bgcolor_changed(bgcolor.color_value)
+
+ self._3d.scene.downsample_threshold = 400000
+
+ # Turn all the objects off except the first one
+ for name, node in self._name2treenode.items():
+ node.checkbox.checked = False
+ self._3d.scene.show_geometry(name, False)
+ for name in [self._objects.data_names[0]]:
+ self._name2treenode[name].checkbox.checked = True
+ self._3d.scene.show_geometry(name, True)
+
+ def on_done_ui():
+ # Add bounding boxes here: bounding boxes belonging to the dataset
+ # will not be loaded until now.
+ self._update_bounding_boxes()
+
+ self._update_datasource_combobox()
+ self._update_shaders_combobox()
+
+ # Display "colors" by default if available, "points" if not
+ available_attrs = self._get_available_attrs()
+ self._set_shader(self.SOLID_NAME, force_update=True)
+ if "colors" in available_attrs:
+ self._datasource_combobox.selected_text = "colors"
+ elif "points" in available_attrs:
+ self._datasource_combobox.selected_text = "points"
+
+ self._dont_update_geometry = True
+ self._on_datasource_changed(
+ self._datasource_combobox.selected_text, self._datasource_combobox.selected_index
+ )
+ self._update_geometry_colors()
+ self._dont_update_geometry = False
+ # _datasource_combobox was empty, now isn't, re-layout.
+ self.window.set_needs_layout()
+
+ self._update_geometry()
+ self.setup_camera()
+
+ self._load_geometries(self._objects.data_names, on_done_ui)
+ gui.Application.instance.run()
+
+ class VizDataset(Dataset):
+
+ name = "VizDataset"
+
+ def __init__(self, dataset):
+ self.dataset = dataset
+ self.label_to_names = getattr(dataset, "label_to_names", {})
+ self.path_list = getattr(dataset, "path_list", [])
+ self.color_map = getattr(dataset, "color_map", {})
+
+ def get_data(self, index):
+ data = self.dataset[index]["data"]
+ data["bounding_boxes"] = data["bbox_objs"]
+ data["color"] = np.ones_like(data["point"])
+ return data
+
+ def get_attr(self, index):
+ return self.dataset[index]["attr"]
+
+ def get_split(self, *_) -> "VizDataset":
+ return self
+
+ def __len__(self) -> int:
+ return len(self.dataset)
+
+ class App:
+ def __init__(self, datamodule: DataModule):
+ self.datamodule = datamodule
+ self._enabled = not flash._IS_TESTING
+
+ def get_dataset(self, stage: str = "train"):
+ dataloader = getattr(self.datamodule, f"{stage}_dataloader")()
+ return VizDataset(dataloader.dataset)
+
+ def show_train_dataset(self, indices=None):
+ if self._enabled:
+ dataset = self.get_dataset("train")
+ viz = Visualizer()
+ viz.visualize_dataset(dataset, "all", indices=indices)
+
+ def show_predictions(self, predictions):
+ if self._enabled:
+ dataset = self.get_dataset("train")
+
+ viz = Visualizer()
+ lut = LabelLUT()
+ for id, color in dataset.color_map.items():
+ lut.add_label(id, id, color=color)
+ viz.set_lut("label", lut)
+
+ for pred in predictions:
+ data = {
+ "points": torch.stack(pred[DefaultDataKeys.INPUT])[:, :3],
+ "name": pred[DefaultDataKeys.METADATA],
+ }
+ bounding_box = pred[DefaultDataKeys.PREDS]
+
+ viz.visualize([data], bounding_boxes=bounding_box)
+
+
+def launch_app(datamodule: DataModule) -> "App":
+ return App(datamodule)
diff --git a/flash/pointcloud/detection/open3d_ml/backbones.py b/flash/pointcloud/detection/open3d_ml/backbones.py
new file mode 100644
index 0000000000..759b6bdb43
--- /dev/null
+++ b/flash/pointcloud/detection/open3d_ml/backbones.py
@@ -0,0 +1,83 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from abc import ABC
+from typing import Callable
+
+import torch
+from pytorch_lightning.utilities.cloud_io import load as pl_load
+
+from flash.core.registry import FlashRegistry
+from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
+from flash.core.utilities.providers import _OPEN3D_ML
+
+ROOT_URL = "https://storage.googleapis.com/open3d-releases/model-zoo/"
+
+if _POINTCLOUD_AVAILABLE:
+ import open3d
+ import open3d.ml as _ml3d
+ from open3d._ml3d.torch.dataloaders.concat_batcher import ConcatBatcher, ObjectDetectBatch
+ from open3d._ml3d.torch.models.point_pillars import PointPillars
+ from open3d.ml.torch.dataloaders import DefaultBatcher
+else:
+ ObjectDetectBatch = ABC
+ PointPillars = ABC
+
+
+class ObjectDetectBatchCollator(ObjectDetectBatch):
+ def __init__(self, batches):
+ self.num_batches = len(batches)
+ super().__init__(batches)
+
+ def to(self, device):
+ super().to(device)
+ return self
+
+ def __len__(self):
+ return self.num_batches
+
+
+def register_open_3d_ml(register: FlashRegistry):
+
+ if _POINTCLOUD_AVAILABLE:
+
+ CONFIG_PATH = os.path.join(os.path.dirname(open3d.__file__), "_ml3d/configs")
+
+ def get_collate_fn(model) -> Callable:
+ batcher_name = model.cfg.batcher
+ if batcher_name == "DefaultBatcher":
+ batcher = DefaultBatcher()
+ elif batcher_name == "ConcatBatcher":
+ batcher = ConcatBatcher(torch, model.__class__.__name__)
+ elif batcher_name == "ObjectDetectBatchCollator":
+ return ObjectDetectBatchCollator
+ return batcher.collate_fn
+
+ @register(parameters=PointPillars.__init__, providers=_OPEN3D_ML)
+ def pointpillars_kitti(*args, **kwargs) -> PointPillars:
+ cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "pointpillars_kitti.yml"))
+ cfg.model.device = "cpu"
+ model = PointPillars(**cfg.model)
+ weight_url = os.path.join(ROOT_URL, "pointpillars_kitti_202012221652utc.pth")
+ model.load_state_dict(
+ pl_load(weight_url, map_location="cpu")["model_state_dict"],
+ )
+ model.cfg.batcher = "ObjectDetectBatchCollator"
+ return model, 384, get_collate_fn(model)
+
+ @register(parameters=PointPillars.__init__, providers=_OPEN3D_ML)
+ def pointpillars(*args, **kwargs) -> PointPillars:
+ model = PointPillars(*args, **kwargs)
+ model.cfg.batcher = "ObjectDetectBatch"
+ return model, get_collate_fn(model)
diff --git a/flash/pointcloud/detection/open3d_ml/data_sources.py b/flash/pointcloud/detection/open3d_ml/data_sources.py
new file mode 100644
index 0000000000..0c4872c3b3
--- /dev/null
+++ b/flash/pointcloud/detection/open3d_ml/data_sources.py
@@ -0,0 +1,241 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from os.path import basename, dirname, exists, isdir, isfile, join
+from typing import Any, Dict, List, Optional, Union
+
+import yaml
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+
+from flash.core.data.auto_dataset import BaseAutoDataset
+from flash.core.data.data_source import BaseDataFormat, DataSource
+from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
+
+if _POINTCLOUD_AVAILABLE:
+ from open3d._ml3d.datasets.kitti import DataProcessing, KITTI
+
+
+class PointCloudObjectDetectionDataFormat(BaseDataFormat):
+ KITTI = "kitti"
+
+
+class BasePointCloudObjectDetectorLoader:
+
+ pass
+
+
+class KITTIPointCloudObjectDetectorLoader(BasePointCloudObjectDetectorLoader):
+ def __init__(
+ self,
+ image_size: tuple = (375, 1242),
+ scans_folder_name: Optional[str] = "scans",
+ labels_folder_name: Optional[str] = "labels",
+ calibrations_folder_name: Optional[str] = "calibs",
+ **kwargs,
+ ):
+
+ self.image_size = image_size
+ self.scans_folder_name = scans_folder_name
+ self.labels_folder_name = labels_folder_name
+ self.calibrations_folder_name = calibrations_folder_name
+
+ def load_meta(self, root_dir, dataset: Optional[BaseAutoDataset]):
+ meta_file = join(root_dir, "meta.yaml")
+ if not exists(meta_file):
+ raise MisconfigurationException(f"The {root_dir} should contain a `meta.yaml` file about the classes.")
+
+ with open(meta_file) as f:
+ self.meta = yaml.safe_load(f)
+
+ if "label_to_names" not in self.meta:
+ raise MisconfigurationException(
+ f"The {root_dir} should contain a `meta.yaml` file about the classes with the field `label_to_names`."
+ )
+
+ dataset.num_classes = len(self.meta["label_to_names"])
+ dataset.label_to_names = self.meta["label_to_names"]
+ dataset.color_map = self.meta["color_map"]
+
+ def load_data(self, folder: str, dataset: Optional[BaseAutoDataset]):
+ sub_directories = os.listdir(folder)
+ if len(sub_directories) != 3:
+ raise MisconfigurationException(
+ f"Using KITTI Format, the {folder} should contains 3 directories "
+ "for ``calibrations``, ``labels`` and ``scans``."
+ )
+
+ assert self.scans_folder_name in sub_directories
+ assert self.labels_folder_name in sub_directories
+ assert self.calibrations_folder_name in sub_directories
+
+ scans_dir = join(folder, self.scans_folder_name)
+ labels_dir = join(folder, self.labels_folder_name)
+ calibrations_dir = join(folder, self.calibrations_folder_name)
+
+ scan_paths = [join(scans_dir, f) for f in os.listdir(scans_dir)]
+ label_paths = [join(labels_dir, f) for f in os.listdir(labels_dir)]
+ calibration_paths = [join(calibrations_dir, f) for f in os.listdir(calibrations_dir)]
+
+ assert len(scan_paths) == len(label_paths) == len(calibration_paths)
+
+ self.load_meta(dirname(folder), dataset)
+
+ dataset.path_list = scan_paths
+
+ return [
+ {"scan_path": scan_path, "label_path": label_path, "calibration_path": calibration_path}
+ for scan_path, label_path, calibration_path, in zip(scan_paths, label_paths, calibration_paths)
+ ]
+
+ def load_sample(
+ self, sample: Dict[str, str], dataset: Optional[BaseAutoDataset] = None, has_label: bool = True
+ ) -> Any:
+ pc = KITTI.read_lidar(sample["scan_path"])
+ calib = KITTI.read_calib(sample["calibration_path"])
+ label = None
+ if has_label:
+ label = KITTI.read_label(sample["label_path"], calib)
+
+ reduced_pc = DataProcessing.remove_outside_points(pc, calib["world_cam"], calib["cam_img"], self.image_size)
+
+ attr = {
+ "name": basename(sample["scan_path"]),
+ "path": sample["scan_path"],
+ "calibration_path": sample["calibration_path"],
+ "label_path": sample["label_path"] if has_label else None,
+ "split": "val",
+ }
+
+ data = {
+ "point": reduced_pc,
+ "full_point": pc,
+ "feat": None,
+ "calib": calib,
+ "bounding_boxes": label if has_label else None,
+ "attr": attr,
+ }
+ return data, attr
+
+ def load_files(self, scan_paths: Union[str, List[str]], dataset: Optional[BaseAutoDataset] = None):
+ if isinstance(scan_paths, str):
+ scan_paths = [scan_paths]
+
+ def clean_fn(path: str) -> str:
+ return path.replace(self.scans_folder_name, self.calibrations_folder_name).replace(".bin", ".txt")
+
+ dataset.path_list = scan_paths
+
+ return [{"scan_path": scan_path, "calibration_path": clean_fn(scan_path)} for scan_path in scan_paths]
+
+ def predict_load_data(self, data, dataset: Optional[BaseAutoDataset] = None):
+ if (isinstance(data, str) and isfile(data)) or (isinstance(data, list) and all(isfile(p) for p in data)):
+ return self.load_files(data, dataset)
+ elif isinstance(data, str) and isdir(data):
+ raise NotImplementedError
+
+ def predict_load_sample(self, data, dataset: Optional[BaseAutoDataset] = None):
+ data, attr = self.load_sample(data, dataset, has_label=False)
+ # hack to prevent manipulation of labels
+ attr["split"] = "test"
+ return data, attr
+
+
+class PointCloudObjectDetectorFoldersDataSource(DataSource):
+ def __init__(
+ self,
+ data_format: Optional[BaseDataFormat] = None,
+ image_size: tuple = (375, 1242),
+ **loader_kwargs,
+ ):
+ super().__init__()
+
+ self.loaders = {
+ PointCloudObjectDetectionDataFormat.KITTI: KITTIPointCloudObjectDetectorLoader(
+ **loader_kwargs, image_size=image_size
+ )
+ }
+
+ self.data_format = data_format or PointCloudObjectDetectionDataFormat.KITTI
+ self.loader = self.loaders[self.data_format]
+
+ def _validate_data(self, folder: str) -> None:
+ msg = f"The provided dataset for stage {self._running_stage} should be a folder. Found {folder}."
+ if not isinstance(folder, str):
+ raise MisconfigurationException(msg)
+
+ if isinstance(folder, str) and not isdir(folder):
+ raise MisconfigurationException(msg)
+
+ def load_data(
+ self,
+ data: Any,
+ dataset: Optional[BaseAutoDataset] = None,
+ ) -> Any:
+
+ self._validate_data(data)
+
+ return self.loader.load_data(data, dataset)
+
+ def load_sample(self, metadata: Dict[str, str], dataset: Optional[BaseAutoDataset] = None) -> Any:
+
+ data, metadata = self.loader.load_sample(metadata, dataset)
+
+ preprocess_fn = getattr(dataset, "preprocess_fn", None)
+ if preprocess_fn:
+ data = preprocess_fn(data, metadata)
+
+ transform_fn = getattr(dataset, "transform_fn", None)
+ if transform_fn:
+ data = transform_fn(data, metadata)
+
+ return {"data": data, "attr": metadata}
+
+ def _validate_predict_data(self, data: Union[str, List[str]]) -> None:
+ msg = f"The provided predict data should be a either a folder or a single/list of scan path(s). Found {data}."
+ if not isinstance(data, str) and not isinstance(data, list):
+ raise MisconfigurationException(msg)
+
+ if isinstance(data, str) and (not isfile(data) or not isdir(data)):
+ raise MisconfigurationException(msg)
+
+ if isinstance(data, list) and not all(isfile(p) for p in data):
+ raise MisconfigurationException(msg)
+
+ def predict_load_data(
+ self,
+ data: Any,
+ dataset: Optional[BaseAutoDataset] = None,
+ ) -> Any:
+
+ self._validate_predict_data(data)
+
+ return self.loader.predict_load_data(data, dataset)
+
+ def predict_load_sample(
+ self,
+ metadata: Any,
+ dataset: Optional[BaseAutoDataset] = None,
+ ) -> Any:
+
+ data, metadata = self.loader.predict_load_sample(metadata, dataset)
+
+ preprocess_fn = getattr(dataset, "preprocess_fn", None)
+ if preprocess_fn:
+ data = preprocess_fn(data, metadata)
+
+ transform_fn = getattr(dataset, "transform_fn", None)
+ if transform_fn:
+ data = transform_fn(data, metadata)
+
+ return {"data": data, "attr": metadata}
diff --git a/flash/pointcloud/segmentation/__init__.py b/flash/pointcloud/segmentation/__init__.py
new file mode 100644
index 0000000000..5d10606f79
--- /dev/null
+++ b/flash/pointcloud/segmentation/__init__.py
@@ -0,0 +1,3 @@
+from flash.pointcloud.segmentation.data import PointCloudSegmentationData # noqa: F401
+from flash.pointcloud.segmentation.model import PointCloudSegmentation # noqa: F401
+from flash.pointcloud.segmentation.open3d_ml.app import launch_app # noqa: F401
diff --git a/flash/pointcloud/segmentation/backbones.py b/flash/pointcloud/segmentation/backbones.py
new file mode 100644
index 0000000000..023daa9ac0
--- /dev/null
+++ b/flash/pointcloud/segmentation/backbones.py
@@ -0,0 +1,19 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from flash.core.registry import FlashRegistry
+from flash.pointcloud.segmentation.open3d_ml.backbones import register_open_3d_ml
+
+POINTCLOUD_SEGMENTATION_BACKBONES = FlashRegistry("backbones")
+
+register_open_3d_ml(POINTCLOUD_SEGMENTATION_BACKBONES)
diff --git a/flash/pointcloud/segmentation/cli.py b/flash/pointcloud/segmentation/cli.py
new file mode 100644
index 0000000000..57d1125f9b
--- /dev/null
+++ b/flash/pointcloud/segmentation/cli.py
@@ -0,0 +1,56 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+from flash.core.data.utils import download_data
+from flash.core.utilities.flash_cli import FlashCLI
+from flash.pointcloud import PointCloudSegmentation, PointCloudSegmentationData
+
+__all__ = ["pointcloud_segmentation"]
+
+
+def from_kitti(
+ batch_size: int = 4,
+ num_workers: Optional[int] = None,
+ **preprocess_kwargs,
+) -> PointCloudSegmentationData:
+ """Downloads and loads the semantic KITTI data set."""
+ download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiTiny.zip", "data/")
+ return PointCloudSegmentationData.from_folders(
+ train_folder="data/SemanticKittiTiny/train",
+ val_folder="data/SemanticKittiTiny/val",
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **preprocess_kwargs,
+ )
+
+
+def pointcloud_segmentation():
+ """Segment objects in point clouds."""
+ cli = FlashCLI(
+ PointCloudSegmentation,
+ PointCloudSegmentationData,
+ default_datamodule_builder=from_kitti,
+ default_arguments={
+ "trainer.max_epochs": 3,
+ "model.backbone": "randlanet_semantic_kitti",
+ },
+ finetune=False,
+ )
+
+ cli.trainer.save_checkpoint("pointcloud_segmentation_model.pt")
+
+
+if __name__ == "__main__":
+ pointcloud_segmentation()
diff --git a/flash/pointcloud/segmentation/data.py b/flash/pointcloud/segmentation/data.py
new file mode 100644
index 0000000000..92cd2cdbc2
--- /dev/null
+++ b/flash/pointcloud/segmentation/data.py
@@ -0,0 +1,93 @@
+from typing import Any, Callable, Dict, Optional, Tuple
+
+from flash.core.data.data_module import DataModule
+from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources
+from flash.core.data.process import Deserializer, Preprocess
+from flash.core.utilities.imports import requires_extras
+from flash.pointcloud.segmentation.open3d_ml.sequences_dataset import SequencesDataset
+
+
+class PointCloudSegmentationDatasetDataSource(DataSource):
+ def load_data(
+ self,
+ data: Any,
+ dataset: Optional[Any] = None,
+ ) -> Any:
+ if self.training:
+ dataset.num_classes = len(data.dataset.label_to_names)
+
+ dataset.dataset = data
+
+ return range(len(data))
+
+ def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any:
+ sample = dataset.dataset[index]
+
+ return {
+ DefaultDataKeys.INPUT: sample["data"],
+ DefaultDataKeys.METADATA: sample["attr"],
+ }
+
+
+class PointCloudSegmentationFoldersDataSource(DataSource):
+ @requires_extras("pointcloud")
+ def load_data(
+ self,
+ folder: Any,
+ dataset: Optional[Any] = None,
+ ) -> Any:
+ sequence_dataset = SequencesDataset(folder, use_cache=True, predicting=self.predicting)
+ dataset.dataset = sequence_dataset
+ if self.training:
+ dataset.num_classes = sequence_dataset.num_classes
+
+ return range(len(sequence_dataset))
+
+ def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any:
+ sample = dataset.dataset[index]
+
+ return {
+ DefaultDataKeys.INPUT: sample["data"],
+ DefaultDataKeys.METADATA: sample["attr"],
+ }
+
+
+class PointCloudSegmentationPreprocess(Preprocess):
+ def __init__(
+ self,
+ train_transform: Optional[Dict[str, Callable]] = None,
+ val_transform: Optional[Dict[str, Callable]] = None,
+ test_transform: Optional[Dict[str, Callable]] = None,
+ predict_transform: Optional[Dict[str, Callable]] = None,
+ image_size: Tuple[int, int] = (196, 196),
+ deserializer: Optional[Deserializer] = None,
+ ):
+ self.image_size = image_size
+
+ super().__init__(
+ train_transform=train_transform,
+ val_transform=val_transform,
+ test_transform=test_transform,
+ predict_transform=predict_transform,
+ data_sources={
+ DefaultDataSources.DATASETS: PointCloudSegmentationDatasetDataSource(),
+ DefaultDataSources.FOLDERS: PointCloudSegmentationFoldersDataSource(),
+ },
+ deserializer=deserializer,
+ default_data_source=DefaultDataSources.FOLDERS,
+ )
+
+ def get_state_dict(self):
+ return {}
+
+ def state_dict(self):
+ return {}
+
+ @classmethod
+ def load_state_dict(cls, state_dict, strict: bool = False):
+ pass
+
+
+class PointCloudSegmentationData(DataModule):
+
+ preprocess_cls = PointCloudSegmentationPreprocess
diff --git a/flash/pointcloud/segmentation/datasets.py b/flash/pointcloud/segmentation/datasets.py
new file mode 100644
index 0000000000..ff792282a4
--- /dev/null
+++ b/flash/pointcloud/segmentation/datasets.py
@@ -0,0 +1,62 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+from flash.core.registry import FlashRegistry
+from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
+
+if _POINTCLOUD_AVAILABLE:
+ from open3d.ml.datasets import Lyft, SemanticKITTI
+
+_SEGMENTATION_DATASET = FlashRegistry("dataset")
+
+
+def executor(download_script, preprocess_script, dataset_path, name):
+ if not os.path.exists(os.path.join(dataset_path, name)):
+ os.system(f'bash -c "bash <(curl -s {download_script}) {dataset_path}"')
+ if preprocess_script:
+ os.system(f'bash -c "bash <(curl -s {preprocess_script}) {dataset_path}"')
+
+
+@_SEGMENTATION_DATASET
+def lyft(dataset_path):
+ name = "Lyft"
+ executor(
+ "https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/scripts/download_datasets/download_lyft.sh",
+ "https://github.com/intel-isl/Open3D-ML/blob/master/scripts/preprocess_lyft.py",
+ dataset_path,
+ name,
+ )
+ return Lyft(os.path.join(dataset_path, name))
+
+
+def LyftDataset(dataset_path):
+ return _SEGMENTATION_DATASET.get("lyft")(dataset_path)
+
+
+@_SEGMENTATION_DATASET
+def semantickitti(dataset_path, download, **kwargs):
+ name = "SemanticKitti"
+ if download:
+ executor(
+ "https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/scripts/download_datasets/download_semantickitti.sh", # noqa E501
+ None,
+ dataset_path,
+ name,
+ )
+ return SemanticKITTI(os.path.join(dataset_path, name), **kwargs)
+
+
+def SemanticKITTIDataset(dataset_path, download: bool = True, **kwargs):
+ return _SEGMENTATION_DATASET.get("semantickitti")(dataset_path, download, **kwargs)
diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py
new file mode 100644
index 0000000000..9342a61758
--- /dev/null
+++ b/flash/pointcloud/segmentation/model.py
@@ -0,0 +1,221 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
+
+import torch
+import torchmetrics
+from pytorch_lightning import Callback, LightningModule
+from torch import nn
+from torch.nn import functional as F
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.utils.data import DataLoader, Sampler
+from torchmetrics import IoU
+
+from flash.core.classification import ClassificationTask
+from flash.core.data.auto_dataset import BaseAutoDataset
+from flash.core.data.data_source import DefaultDataKeys
+from flash.core.data.process import Serializer
+from flash.core.data.states import CollateFn
+from flash.core.finetuning import BaseFinetuning
+from flash.core.registry import FlashRegistry
+from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
+from flash.pointcloud.segmentation.backbones import POINTCLOUD_SEGMENTATION_BACKBONES
+
+if _POINTCLOUD_AVAILABLE:
+ from open3d._ml3d.torch.modules.losses.semseg_loss import filter_valid_label
+ from open3d.ml.torch.dataloaders import TorchDataloader
+
+
+class PointCloudSegmentationFinetuning(BaseFinetuning):
+ def __init__(self, num_layers: int = 5, train_bn: bool = True, unfreeze_epoch: int = 1):
+ super().__init__()
+ self.num_layers = num_layers
+ self.train_bn = train_bn
+ self.unfreeze_epoch = unfreeze_epoch
+
+ def freeze_before_training(self, pl_module: LightningModule) -> None:
+ self.freeze(modules=list(pl_module.backbone.children())[: -self.num_layers], train_bn=self.train_bn)
+
+ def finetune_function(
+ self,
+ pl_module: LightningModule,
+ epoch: int,
+ optimizer: Optimizer,
+ opt_idx: int,
+ ) -> None:
+ if epoch != self.unfreeze_epoch:
+ return
+ self.unfreeze_and_add_param_group(
+ modules=list(pl_module.backbone.children())[-self.num_layers :],
+ optimizer=optimizer,
+ train_bn=self.train_bn,
+ )
+
+
+class PointCloudSegmentationSerializer(Serializer):
+ pass
+
+
+class PointCloudSegmentation(ClassificationTask):
+ """The ``PointCloudClassifier`` is a :class:`~flash.core.classification.ClassificationTask` that classifies
+ pointcloud data.
+
+ Args:
+ num_features: The number of features (elements) in the input data.
+ num_classes: The number of classes (outputs) for this :class:`~flash.core.model.Task`.
+ backbone: The backbone name (or a tuple of ``nn.Module``, output size) to use.
+ backbone_kwargs: Any additional kwargs to pass to the backbone constructor.
+ loss_fn: The loss function to use. If ``None``, a default will be selected by the
+ :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument.
+ optimizer: The optimizer or optimizer class to use.
+ optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance).
+ scheduler: The scheduler or scheduler class to use.
+ scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance).
+ metrics: Any metrics to use with this :class:`~flash.core.model.Task`. If ``None``, a default will be selected
+ by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument.
+ learning_rate: The learning rate for the optimizer.
+ multi_label: If ``True``, this will be treated as a multi-label classification problem.
+ serializer: The :class:`~flash.core.data.process.Serializer` to use for prediction outputs.
+ """
+
+ backbones: FlashRegistry = POINTCLOUD_SEGMENTATION_BACKBONES
+
+ required_extras: str = "pointcloud"
+
+ def __init__(
+ self,
+ num_classes: int,
+ backbone: Union[str, Tuple[nn.Module, int]] = "RandLANet",
+ backbone_kwargs: Optional[Dict] = None,
+ head: Optional[nn.Module] = None,
+ loss_fn: Optional[Callable] = torch.nn.functional.cross_entropy,
+ optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam,
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
+ scheduler_kwargs: Optional[Dict[str, Any]] = None,
+ metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
+ learning_rate: float = 1e-2,
+ multi_label: bool = False,
+ serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = PointCloudSegmentationSerializer(),
+ ):
+ import flash
+
+ if metrics is None:
+ metrics = IoU(num_classes=num_classes)
+
+ super().__init__(
+ model=None,
+ loss_fn=loss_fn,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
+ scheduler=scheduler,
+ scheduler_kwargs=scheduler_kwargs,
+ metrics=metrics,
+ learning_rate=learning_rate,
+ multi_label=multi_label,
+ serializer=serializer,
+ )
+
+ self.save_hyperparameters()
+
+ if not backbone_kwargs:
+ backbone_kwargs = {"num_classes": num_classes}
+
+ if isinstance(backbone, tuple):
+ self.backbone, out_features = backbone
+ else:
+ self.backbone, out_features, collate_fn = self.backbones.get(backbone)(**backbone_kwargs)
+ # replace latest layer
+ if not flash._IS_TESTING:
+ self.backbone.fc = nn.Identity()
+ self.set_state(CollateFn(collate_fn))
+
+ self.head = nn.Identity() if flash._IS_TESTING else (head or nn.Linear(out_features, num_classes))
+
+ def apply_filtering(self, labels, scores):
+ scores, labels = filter_valid_label(scores, labels, self.hparams.num_classes, [0], self.device)
+ return labels, scores
+
+ def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
+ return F.softmax(self.to_loss_format(x), dim=-1)
+
+ def to_loss_format(self, x: torch.Tensor) -> torch.Tensor:
+ return x.reshape(-1, x.shape[-1])
+
+ def training_step(self, batch: Any, batch_idx: int) -> Any:
+ batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.INPUT]["labels"].view(-1))
+ return super().training_step(batch, batch_idx)
+
+ def validation_step(self, batch: Any, batch_idx: int) -> Any:
+ batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.INPUT]["labels"].view(-1))
+ return super().validation_step(batch, batch_idx)
+
+ def test_step(self, batch: Any, batch_idx: int) -> Any:
+ batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.INPUT]["labels"].view(-1))
+ return super().test_step(batch, batch_idx)
+
+ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
+ batch[DefaultDataKeys.PREDS] = self(batch[DefaultDataKeys.INPUT])
+ batch[DefaultDataKeys.TARGET] = batch[DefaultDataKeys.INPUT]["labels"]
+ # drop sub-sampled pointclouds
+ batch[DefaultDataKeys.INPUT] = batch[DefaultDataKeys.INPUT]["xyz"][0]
+ return batch
+
+ def forward(self, x) -> torch.Tensor:
+ """First call the backbone, then the model head."""
+ # hack to enable backbone to work properly.
+ self.backbone.device = self.device
+ x = self.backbone(x)
+ if self.head is not None:
+ x = self.head(x)
+ return x
+
+ def _process_dataset(
+ self,
+ dataset: BaseAutoDataset,
+ batch_size: int,
+ num_workers: int,
+ pin_memory: bool,
+ collate_fn: Callable,
+ shuffle: bool = False,
+ drop_last: bool = True,
+ sampler: Optional[Sampler] = None,
+ ) -> DataLoader:
+
+ if not _POINTCLOUD_AVAILABLE:
+ raise ModuleNotFoundError("Please, run `pip install flash[pointcloud]`.")
+
+ if not isinstance(dataset.dataset, TorchDataloader):
+
+ dataset.dataset = TorchDataloader(
+ dataset.dataset,
+ preprocess=self.backbone.preprocess,
+ transform=self.backbone.transform,
+ use_cache=False,
+ )
+
+ return DataLoader(
+ dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ collate_fn=collate_fn,
+ shuffle=shuffle,
+ drop_last=drop_last,
+ sampler=sampler,
+ )
+
+ def configure_finetune_callback(self) -> List[Callback]:
+ return [PointCloudSegmentationFinetuning()]
diff --git a/flash/pointcloud/segmentation/open3d_ml/__init__.py b/flash/pointcloud/segmentation/open3d_ml/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/flash/pointcloud/segmentation/open3d_ml/app.py b/flash/pointcloud/segmentation/open3d_ml/app.py
new file mode 100644
index 0000000000..b1145c53b5
--- /dev/null
+++ b/flash/pointcloud/segmentation/open3d_ml/app.py
@@ -0,0 +1,107 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+from flash.core.data.data_module import DataModule
+from flash.core.data.data_source import DefaultDataKeys
+from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
+
+if _POINTCLOUD_AVAILABLE:
+
+ from open3d._ml3d.torch.dataloaders import TorchDataloader
+ from open3d._ml3d.vis.visualizer import LabelLUT
+ from open3d._ml3d.vis.visualizer import Visualizer as Open3dVisualizer
+
+else:
+
+ Open3dVisualizer = object
+
+
+class Visualizer(Open3dVisualizer):
+ def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768):
+ """Visualize a dataset.
+
+ Example:
+ Minimal example for visualizing a dataset::
+ import open3d.ml.torch as ml3d # or open3d.ml.tf as ml3d
+
+ dataset = ml3d.datasets.SemanticKITTI(dataset_path='/path/to/SemanticKITTI/')
+ vis = ml3d.vis.Visualizer()
+ vis.visualize_dataset(dataset, 'all', indices=range(100))
+
+ Args:
+ dataset: The dataset to use for visualization.
+ split: The dataset split to be used, such as 'training'
+ indices: An iterable with a subset of the data points to visualize, such as [0,2,3,4].
+ width: The width of the visualization window.
+ height: The height of the visualization window.
+ """
+ # Setup the labels
+ lut = LabelLUT()
+ color_map = dataset.color_map
+ for id, val in dataset.label_to_names.items():
+ lut.add_label(val, id, color=color_map[id])
+ self.set_lut("labels", lut)
+
+ self._consolidate_bounding_boxes = True
+ self._init_dataset(dataset, split, indices)
+ self._visualize("Open3D - " + dataset.name, width, height)
+
+
+class App:
+ def __init__(self, datamodule: DataModule):
+ self.datamodule = datamodule
+ self._enabled = True # not flash._IS_TESTING
+
+ def get_dataset(self, stage: str = "train"):
+ dataloader = getattr(self.datamodule, f"{stage}_dataloader")()
+ dataset = dataloader.dataset.dataset
+ if isinstance(dataset, TorchDataloader):
+ return dataset.dataset
+ return dataset
+
+ def show_train_dataset(self, indices=None):
+ if self._enabled:
+ dataset = self.get_dataset("train")
+ viz = Visualizer()
+ viz.visualize_dataset(dataset, "all", indices=indices)
+
+ def show_predictions(self, predictions):
+ if self._enabled:
+ dataset = self.get_dataset("train")
+ color_map = dataset.color_map
+
+ predictions_visualizations = []
+ for pred in predictions:
+ predictions_visualizations.append(
+ {
+ "points": torch.stack(pred[DefaultDataKeys.INPUT]),
+ "labels": torch.stack(pred[DefaultDataKeys.TARGET]),
+ "predictions": torch.argmax(torch.stack(pred[DefaultDataKeys.PREDS]), axis=-1) + 1,
+ "name": pred[DefaultDataKeys.METADATA]["name"],
+ }
+ )
+
+ viz = Visualizer()
+ lut = LabelLUT()
+ color_map = dataset.color_map
+ for id, val in dataset.label_to_names.items():
+ lut.add_label(val, id, color=color_map[id])
+ viz.set_lut("labels", lut)
+ viz.set_lut("predictions", lut)
+ viz.visualize(predictions_visualizations)
+
+
+def launch_app(datamodule: DataModule) -> "App":
+ return App(datamodule)
diff --git a/flash/pointcloud/segmentation/open3d_ml/backbones.py b/flash/pointcloud/segmentation/open3d_ml/backbones.py
new file mode 100644
index 0000000000..a326cbcdc5
--- /dev/null
+++ b/flash/pointcloud/segmentation/open3d_ml/backbones.py
@@ -0,0 +1,82 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from typing import Callable
+
+import torch
+from pytorch_lightning.utilities.cloud_io import load as pl_load
+
+from flash.core.registry import FlashRegistry
+from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
+from flash.core.utilities.providers import _OPEN3D_ML
+
+ROOT_URL = "https://storage.googleapis.com/open3d-releases/model-zoo/"
+
+
+def register_open_3d_ml(register: FlashRegistry):
+ if _POINTCLOUD_AVAILABLE:
+ import open3d
+ import open3d.ml as _ml3d
+ from open3d._ml3d.torch.dataloaders import ConcatBatcher, DefaultBatcher
+ from open3d._ml3d.torch.models import RandLANet
+
+ CONFIG_PATH = os.path.join(os.path.dirname(open3d.__file__), "_ml3d/configs")
+
+ def get_collate_fn(model) -> Callable:
+ batcher_name = model.cfg.batcher
+ if batcher_name == "DefaultBatcher":
+ batcher = DefaultBatcher()
+ elif batcher_name == "ConcatBatcher":
+ batcher = ConcatBatcher(torch, model.__class__.__name__)
+ else:
+ batcher = None
+ return batcher.collate_fn
+
+ @register(providers=_OPEN3D_ML)
+ def randlanet_s3dis(*args, use_fold_5: bool = True, **kwargs) -> RandLANet:
+ cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_s3dis.yml"))
+ model = RandLANet(**cfg.model)
+ if use_fold_5:
+ weight_url = os.path.join(ROOT_URL, "randlanet_s3dis_area5_202010091333utc.pth")
+ else:
+ weight_url = os.path.join(ROOT_URL, "randlanet_s3dis_202010091238.pth")
+ model.load_state_dict(pl_load(weight_url, map_location="cpu")["model_state_dict"])
+ return model, 32, get_collate_fn(model)
+
+ @register(providers=_OPEN3D_ML)
+ def randlanet_toronto3d(*args, **kwargs) -> RandLANet:
+ cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_toronto3d.yml"))
+ model = RandLANet(**cfg.model)
+ model.load_state_dict(
+ pl_load(os.path.join(ROOT_URL, "randlanet_toronto3d_202010091306utc.pth"), map_location="cpu")[
+ "model_state_dict"
+ ],
+ )
+ return model, 32, get_collate_fn(model)
+
+ @register(providers=_OPEN3D_ML)
+ def randlanet_semantic_kitti(*args, **kwargs) -> RandLANet:
+ cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_semantickitti.yml"))
+ model = RandLANet(**cfg.model)
+ model.load_state_dict(
+ pl_load(os.path.join(ROOT_URL, "randlanet_semantickitti_202009090354utc.pth"), map_location="cpu")[
+ "model_state_dict"
+ ],
+ )
+ return model, 32, get_collate_fn(model)
+
+ @register(providers=_OPEN3D_ML)
+ def randlanet(*args, **kwargs) -> RandLANet:
+ model = RandLANet(*args, **kwargs)
+ return model, 32, get_collate_fn(model)
diff --git a/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py b/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py
new file mode 100644
index 0000000000..966b224c78
--- /dev/null
+++ b/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py
@@ -0,0 +1,182 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from os.path import basename, dirname, exists, isdir, isfile, join, split
+
+import numpy as np
+import yaml
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from torch.utils.data import Dataset
+
+from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
+
+if _POINTCLOUD_AVAILABLE:
+
+ from open3d._ml3d.datasets.utils import DataProcessing
+ from open3d._ml3d.utils.config import Config
+
+
+class SequencesDataset(Dataset):
+ def __init__(
+ self,
+ data,
+ cache_dir="./logs/cache",
+ use_cache=False,
+ num_points=65536,
+ ignored_label_inds=[0],
+ predicting=False,
+ **kwargs,
+ ):
+
+ super().__init__()
+
+ self.name = "Dataset"
+ self.ignored_label_inds = ignored_label_inds
+
+ kwargs["cache_dir"] = cache_dir
+ kwargs["use_cache"] = use_cache
+ kwargs["num_points"] = num_points
+ kwargs["ignored_label_inds"] = ignored_label_inds
+
+ self.cfg = Config(kwargs)
+ self.predicting = predicting
+
+ if not predicting:
+ self.on_fit(data)
+ else:
+ self.on_predict(data)
+
+ @property
+ def color_map(self):
+ return self.meta["color_map"]
+
+ def on_fit(self, dataset_path):
+ self.split = basename(dataset_path)
+
+ self.load_meta(dirname(dataset_path))
+ self.dataset_path = dataset_path
+ self.label_to_names = self.get_label_to_names()
+ self.num_classes = len(self.label_to_names) - len(self.ignored_label_inds)
+ self.make_datasets()
+
+ def load_meta(self, root_dir):
+ meta_file = join(root_dir, "meta.yaml")
+ if not exists(meta_file):
+ raise MisconfigurationException(
+ f"The {root_dir} should contain a `meta.yaml` file about the pointcloud sequences."
+ )
+
+ with open(meta_file) as f:
+ self.meta = yaml.safe_load(f)
+
+ self.label_to_names = self.get_label_to_names()
+ self.num_classes = len(self.label_to_names)
+
+ with open(meta_file) as f:
+ self.meta = yaml.safe_load(f)
+
+ remap_dict_val = self.meta["learning_map"]
+ max_key = max(remap_dict_val.keys())
+ remap_lut_val = np.zeros((max_key + 100), dtype=np.int32)
+ remap_lut_val[list(remap_dict_val.keys())] = list(remap_dict_val.values())
+
+ self.remap_lut_val = remap_lut_val
+
+ def make_datasets(self):
+ self.path_list = []
+ for seq in os.listdir(self.dataset_path):
+ sequence_path = join(self.dataset_path, seq)
+ directories = [f for f in os.listdir(sequence_path) if isdir(join(sequence_path, f)) and f != "labels"]
+ assert len(directories) == 1
+ scan_dir = join(sequence_path, directories[0])
+ for scan_name in os.listdir(scan_dir):
+ self.path_list.append(join(scan_dir, scan_name))
+
+ def on_predict(self, data):
+ if isinstance(data, list):
+ if not all(isfile(p) for p in data):
+ raise MisconfigurationException("The predict input data takes only a list of paths or a directory.")
+ root_dir = split(data[0])[0]
+ elif isinstance(data, str):
+ if not isdir(data) and not isfile(data):
+ raise MisconfigurationException("The predict input data takes only a list of paths or a directory.")
+ if isdir(data):
+ root_dir = data
+ data = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if ".bin" in f]
+ elif isfile(data):
+ root_dir = dirname(data)
+ data = [data]
+ else:
+ raise MisconfigurationException("The predict input data takes only a list of paths or a directory.")
+ else:
+ raise MisconfigurationException("The predict input data takes only a list of paths or a directory.")
+
+ self.path_list = data
+ self.split = "predict"
+ self.load_meta(root_dir)
+
+ def get_label_to_names(self):
+ """Returns a label to names dictonary object.
+
+ Returns:
+ A dict where keys are label numbers and
+ values are the corresponding names.
+ """
+ return self.meta["label_to_names"]
+
+ def __getitem__(self, index):
+ data = self.get_data(index)
+ data["attr"] = self.get_attr(index)
+ return data
+
+ def get_data(self, idx):
+ pc_path = self.path_list[idx]
+ points = DataProcessing.load_pc_kitti(pc_path)
+
+ dir, file = split(pc_path)
+ if self.predicting:
+ label_path = join(dir, file[:-4] + ".label")
+ else:
+ label_path = join(dir, "../labels", file[:-4] + ".label")
+ if not exists(label_path):
+ labels = np.zeros(np.shape(points)[0], dtype=np.int32)
+ if self.split not in ["test", "all"]:
+ raise FileNotFoundError(f" Label file {label_path} not found")
+
+ else:
+ labels = DataProcessing.load_label_kitti(label_path, self.remap_lut_val).astype(np.int32)
+
+ data = {
+ "point": points[:, 0:3],
+ "feat": None,
+ "label": labels,
+ }
+
+ return data
+
+ def get_attr(self, idx):
+ pc_path = self.path_list[idx]
+ dir, file = split(pc_path)
+ _, seq = split(split(dir)[0])
+ name = f"{seq}_{file[:-4]}"
+
+ pc_path = str(pc_path)
+ attr = {"idx": idx, "name": name, "path": pc_path, "split": self.split}
+ return attr
+
+ def __len__(self):
+ return len(self.path_list)
+
+ def get_split(self, *_):
+ return self
diff --git a/flash/setup_tools.py b/flash/setup_tools.py
index 8e27bf2c1c..a7376eb940 100644
--- a/flash/setup_tools.py
+++ b/flash/setup_tools.py
@@ -19,17 +19,17 @@
_PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))
-def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comment_chars: str = '#@') -> List[str]:
- with open(os.path.join(path_dir, file_name), 'r') as file:
+def _load_requirements(path_dir: str, file_name: str = "requirements.txt", comment_chars: str = "#@") -> List[str]:
+ with open(os.path.join(path_dir, file_name)) as file:
lines = [ln.strip() for ln in file.readlines()]
reqs = []
for ln in lines:
# filer all comments
found = [ln.index(ch) for ch in comment_chars if ch in ln]
if found:
- ln = ln[:min(found)].strip()
+ ln = ln[: min(found)].strip()
# skip directly installed dependencies
- if ln.startswith('http') or ln.startswith('git'):
+ if ln.startswith("http") or ln.startswith("git"):
continue
if ln: # if requirement is not empty
reqs.append(ln)
@@ -37,7 +37,7 @@ def _load_requirements(path_dir: str, file_name: str = 'requirements.txt', comme
def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str:
- """Load readme as decribtion
+ """Load readme as decribtion.
>>> _load_readme_description(_PROJECT_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
'...'
@@ -46,7 +46,7 @@ def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str:
text = open(path_readme, encoding="utf-8").read()
# drop images from readme
- text = text.replace('![PT to PL](docs/source/_images/general/pl_quick_start_full_compressed.gif)', '')
+ text = text.replace("![PT to PL](docs/source/_images/general/pl_quick_start_full_compressed.gif)", "")
# https://github.com/PyTorchLightning/pytorch-lightning/raw/master/docs/source/_images/lightning_module/pt_to_pl.png
github_source_url = os.path.join(homepage, "raw", ver)
@@ -55,17 +55,17 @@ def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str:
text = text.replace("docs/source/_static/", f"{os.path.join(github_source_url, 'docs/source/_static/')}")
# readthedocs badge
- text = text.replace('badge/?version=stable', f'badge/?version={ver}')
- text = text.replace('pytorch-lightning.readthedocs.io/en/stable/', f'pytorch-lightning.readthedocs.io/en/{ver}')
+ text = text.replace("badge/?version=stable", f"badge/?version={ver}")
+ text = text.replace("pytorch-lightning.readthedocs.io/en/stable/", f"pytorch-lightning.readthedocs.io/en/{ver}")
# codecov badge
- text = text.replace('/branch/master/graph/badge.svg', f'/release/{ver}/graph/badge.svg')
+ text = text.replace("/branch/master/graph/badge.svg", f"/release/{ver}/graph/badge.svg")
# replace github badges for release ones
- text = text.replace('badge.svg?branch=master&event=push', f'badge.svg?tag={ver}')
+ text = text.replace("badge.svg?branch=master&event=push", f"badge.svg?tag={ver}")
- skip_begin = r''
- skip_end = r''
+ skip_begin = r""
+ skip_end = r""
# todo: wrap content as commented description
- text = re.sub(rf"{skip_begin}.+?{skip_end}", '', text, flags=re.IGNORECASE + re.DOTALL)
+ text = re.sub(rf"{skip_begin}.+?{skip_end}", "", text, flags=re.IGNORECASE + re.DOTALL)
# # https://github.com/Borda/pytorch-lightning/releases/download/1.1.0a6/codecov_badge.png
# github_release_url = os.path.join(homepage, "releases", "download", ver)
diff --git a/flash/tabular/__init__.py b/flash/tabular/__init__.py
index a3b8e2ca2d..22698efc99 100644
--- a/flash/tabular/__init__.py
+++ b/flash/tabular/__init__.py
@@ -1 +1,3 @@
-from flash.tabular.classification import TabularClassifier, TabularData # noqa: F401
+from flash.tabular.classification import TabularClassificationData, TabularClassifier # noqa: F401
+from flash.tabular.data import TabularData # noqa: F401
+from flash.tabular.regression import TabularRegressionData # noqa: F401
diff --git a/flash/tabular/classification/__init__.py b/flash/tabular/classification/__init__.py
index 45724db27b..6134277abf 100644
--- a/flash/tabular/classification/__init__.py
+++ b/flash/tabular/classification/__init__.py
@@ -1,2 +1,2 @@
-from flash.tabular.classification.data import TabularData # noqa: F401
+from flash.tabular.classification.data import TabularClassificationData # noqa: F401
from flash.tabular.classification.model import TabularClassifier # noqa: F401
diff --git a/flash/tabular/classification/cli.py b/flash/tabular/classification/cli.py
new file mode 100644
index 0000000000..63eff2458f
--- /dev/null
+++ b/flash/tabular/classification/cli.py
@@ -0,0 +1,59 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+from flash.core.data.utils import download_data
+from flash.core.utilities.flash_cli import FlashCLI
+from flash.tabular import TabularClassificationData, TabularClassifier
+
+__all__ = ["tabular_classification"]
+
+
+def from_titanic(
+ batch_size: int = 4,
+ num_workers: Optional[int] = None,
+ **preprocess_kwargs,
+) -> TabularClassificationData:
+ """Downloads and loads the Titanic data set."""
+ download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data")
+ return TabularClassificationData.from_csv(
+ ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
+ "Fare",
+ target_fields="Survived",
+ train_file="data/titanic/titanic.csv",
+ val_split=0.1,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **preprocess_kwargs,
+ )
+
+
+def tabular_classification():
+ """Classify tabular data."""
+ cli = FlashCLI(
+ TabularClassifier,
+ TabularClassificationData,
+ default_datamodule_builder=from_titanic,
+ default_arguments={
+ "trainer.max_epochs": 3,
+ },
+ finetune=False,
+ datamodule_attributes={"num_features", "num_classes", "embedding_sizes"},
+ )
+
+ cli.trainer.save_checkpoint("tabular_classification_model.pt")
+
+
+if __name__ == "__main__":
+ tabular_classification()
diff --git a/flash/tabular/classification/data.py b/flash/tabular/classification/data.py
index c2a60e24da..63cdda9ea2 100644
--- a/flash/tabular/classification/data.py
+++ b/flash/tabular/classification/data.py
@@ -11,505 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from io import StringIO
-from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from flash.tabular.data import TabularData
-import numpy as np
-from pytorch_lightning.utilities.exceptions import MisconfigurationException
-from flash.core.classification import LabelsState
-from flash.core.data.callback import BaseDataFetcher
-from flash.core.data.data_module import DataModule
-from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources
-from flash.core.data.process import Deserializer, Postprocess, Preprocess
-from flash.core.utilities.imports import _PANDAS_AVAILABLE
-from flash.tabular.classification.utils import (
- _compute_normalization,
- _generate_codes,
- _pre_transform,
- _to_cat_vars_numpy,
- _to_num_vars_numpy,
-)
-
-if _PANDAS_AVAILABLE:
- import pandas as pd
- from pandas.core.frame import DataFrame
-else:
- DataFrame = object
-
-
-class TabularDataFrameDataSource(DataSource[DataFrame]):
-
- def __init__(
- self,
- cat_cols: Optional[List[str]] = None,
- num_cols: Optional[List[str]] = None,
- target_col: Optional[str] = None,
- mean: Optional[DataFrame] = None,
- std: Optional[DataFrame] = None,
- codes: Optional[Dict[str, Any]] = None,
- target_codes: Optional[Dict[str, Any]] = None,
- classes: Optional[List[str]] = None,
- is_regression: bool = True,
- ):
- super().__init__()
-
- self.cat_cols = cat_cols
- self.num_cols = num_cols
- self.target_col = target_col
- self.mean = mean
- self.std = std
- self.codes = codes
- self.target_codes = target_codes
- self.is_regression = is_regression
-
- self.set_state(LabelsState(classes))
- self.num_classes = len(classes)
-
- def common_load_data(
- self,
- df: DataFrame,
- dataset: Optional[Any] = None,
- ):
- # impute_data
- # compute train dataset stats
- dfs = _pre_transform([df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col,
- self.target_codes)
-
- df = dfs[0]
-
- if dataset is not None:
- dataset.num_samples = len(df)
-
- cat_vars = _to_cat_vars_numpy(df, self.cat_cols)
- num_vars = _to_num_vars_numpy(df, self.num_cols)
-
- cat_vars = np.stack(cat_vars, 1) # if len(cat_vars) else np.zeros((len(self), 0))
- num_vars = np.stack(num_vars, 1) # if len(num_vars) else np.zeros((len(self), 0))
- return df, cat_vars, num_vars
-
- def load_data(self, data: DataFrame, dataset: Optional[Any] = None):
- df, cat_vars, num_vars = self.common_load_data(data, dataset=dataset)
- target = df[self.target_col].to_numpy().astype(np.float32 if self.is_regression else np.int64)
- return [{
- DefaultDataKeys.INPUT: (c, n),
- DefaultDataKeys.TARGET: t
- } for c, n, t in zip(cat_vars, num_vars, target)]
-
- def predict_load_data(self, data: DataFrame, dataset: Optional[Any] = None):
- _, cat_vars, num_vars = self.common_load_data(data, dataset=dataset)
- return [{DefaultDataKeys.INPUT: (c, n)} for c, n in zip(cat_vars, num_vars)]
-
-
-class TabularCSVDataSource(TabularDataFrameDataSource):
-
- def load_data(self, data: str, dataset: Optional[Any] = None):
- return super().load_data(pd.read_csv(data), dataset=dataset)
-
- def predict_load_data(self, data: str, dataset: Optional[Any] = None):
- return super().predict_load_data(pd.read_csv(data), dataset=dataset)
-
-
-class TabularDeserializer(Deserializer):
-
- def __init__(
- self,
- cat_cols: Optional[List[str]] = None,
- num_cols: Optional[List[str]] = None,
- target_col: Optional[str] = None,
- mean: Optional[DataFrame] = None,
- std: Optional[DataFrame] = None,
- codes: Optional[Dict[str, Any]] = None,
- target_codes: Optional[Dict[str, Any]] = None,
- classes: Optional[List[str]] = None,
- is_regression: bool = True
- ):
- super().__init__()
- self.cat_cols = cat_cols
- self.num_cols = num_cols
- self.target_col = target_col
- self.mean = mean
- self.std = std
- self.codes = codes
- self.target_codes = target_codes
- self.classes = classes
- self.is_regression = is_regression
-
- def deserialize(self, data: str) -> Any:
- df = pd.read_csv(StringIO(data))
- df = _pre_transform([df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col,
- self.target_codes)[0]
-
- cat_vars = _to_cat_vars_numpy(df, self.cat_cols)
- num_vars = _to_num_vars_numpy(df, self.num_cols)
-
- cat_vars = np.stack(cat_vars, 1)
- num_vars = np.stack(num_vars, 1)
-
- return [{DefaultDataKeys.INPUT: [c, n]} for c, n in zip(cat_vars, num_vars)]
-
- @property
- def example_input(self) -> str:
- row = {}
- for cat_col in self.cat_cols:
- row[cat_col] = ["test"]
- for num_col in self.num_cols:
- row[num_col] = [0]
- return str(DataFrame.from_dict(row).to_csv())
-
-
-class TabularPreprocess(Preprocess):
-
- def __init__(
- self,
- train_transform: Optional[Dict[str, Callable]] = None,
- val_transform: Optional[Dict[str, Callable]] = None,
- test_transform: Optional[Dict[str, Callable]] = None,
- predict_transform: Optional[Dict[str, Callable]] = None,
- cat_cols: Optional[List[str]] = None,
- num_cols: Optional[List[str]] = None,
- target_col: Optional[str] = None,
- mean: Optional[DataFrame] = None,
- std: Optional[DataFrame] = None,
- codes: Optional[Dict[str, Any]] = None,
- target_codes: Optional[Dict[str, Any]] = None,
- classes: Optional[List[str]] = None,
- is_regression: bool = True,
- deserializer: Optional[Deserializer] = None
- ):
- self.cat_cols = cat_cols
- self.num_cols = num_cols
- self.target_col = target_col
- self.mean = mean
- self.std = std
- self.codes = codes
- self.target_codes = target_codes
- self.classes = classes
- self.is_regression = is_regression
-
- super().__init__(
- train_transform=train_transform,
- val_transform=val_transform,
- test_transform=test_transform,
- predict_transform=predict_transform,
- data_sources={
- DefaultDataSources.CSV: TabularCSVDataSource(
- cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression
- ),
- "data_frame": TabularDataFrameDataSource(
- cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression
- ),
- },
- default_data_source=DefaultDataSources.CSV,
- deserializer=deserializer or TabularDeserializer(
- cat_cols=cat_cols,
- num_cols=num_cols,
- target_col=target_col,
- mean=mean,
- std=std,
- codes=codes,
- target_codes=target_codes,
- classes=classes,
- is_regression=is_regression
- )
- )
-
- def get_state_dict(self, strict: bool = False) -> Dict[str, Any]:
- return {
- **self.transforms,
- "cat_cols": self.cat_cols,
- "num_cols": self.num_cols,
- "target_col": self.target_col,
- "mean": self.mean,
- "std": self.std,
- "codes": self.codes,
- "target_codes": self.target_codes,
- "classes": self.classes,
- "is_regression": self.is_regression,
- }
-
- @classmethod
- def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> 'Preprocess':
- return cls(**state_dict)
-
-
-class TabularPostprocess(Postprocess):
-
- def uncollate(self, batch: Any) -> Any:
- return batch
-
-
-class TabularData(DataModule):
- """Data module for tabular tasks"""
-
- preprocess_cls = TabularPreprocess
- postprocess_cls = TabularPostprocess
-
- @property
- def codes(self) -> Dict[str, str]:
- return self._data_source.codes
-
- @property
- def num_classes(self) -> int:
- return self._data_source.num_classes
-
- @property
- def cat_cols(self) -> Optional[List[str]]:
- return self._data_source.cat_cols
-
- @property
- def num_cols(self) -> Optional[List[str]]:
- return self._data_source.num_cols
-
- @property
- def num_features(self) -> int:
- return len(self.cat_cols) + len(self.num_cols)
-
- @property
- def emb_sizes(self) -> list:
- """Recommended embedding sizes."""
-
- # https://developers.googleblog.com/2017/11/introducing-tensorflow-feature-columns.html
- # The following "formula" provides a general rule of thumb about the number of embedding dimensions:
- # embedding_dimensions = number_of_categories**0.25
- num_classes = [len(self.codes[cat]) for cat in self.cat_cols]
- emb_dims = [max(int(n**0.25), 16) for n in num_classes]
- return list(zip(num_classes, emb_dims))
-
- @staticmethod
- def _sanetize_cols(cat_cols: Optional[Union[str, List[str]]], num_cols: Optional[Union[str, List[str]]]):
- if cat_cols is None and num_cols is None:
- raise RuntimeError('Both `cat_cols` and `num_cols` are None!')
-
- return cat_cols or [], num_cols or []
-
- @classmethod
- def compute_state(
- cls,
- train_data_frame: DataFrame,
- val_data_frame: Optional[DataFrame],
- test_data_frame: Optional[DataFrame],
- predict_data_frame: Optional[DataFrame],
- target_fields: str,
- numerical_fields: List[str],
- categorical_fields: List[str],
- ) -> Tuple[float, float, List[str], Dict[str, Any], Dict[str, Any]]:
-
- if train_data_frame is None:
- raise MisconfigurationException(
- "train_data_frame is required to instantiate the TabularDataFrameDataSource"
- )
-
- data_frames = [train_data_frame]
-
- if val_data_frame is not None:
- data_frames += [val_data_frame]
-
- if test_data_frame is not None:
- data_frames += [test_data_frame]
-
- if predict_data_frame is not None:
- data_frames += [predict_data_frame]
-
- mean, std = _compute_normalization(data_frames[0], numerical_fields)
-
- classes = list(data_frames[0][target_fields].unique())
-
- if data_frames[0][target_fields].dtype == object:
- # if the target_fields is a category, not an int
- target_codes = _generate_codes(data_frames, [target_fields])
- else:
- target_codes = None
- codes = _generate_codes(data_frames, categorical_fields)
-
- return mean, std, classes, codes, target_codes
-
- @classmethod
- def from_data_frame(
- cls,
- categorical_fields: Optional[Union[str, List[str]]],
- numerical_fields: Optional[Union[str, List[str]]],
- target_fields: Optional[str] = None,
- train_data_frame: Optional[DataFrame] = None,
- val_data_frame: Optional[DataFrame] = None,
- test_data_frame: Optional[DataFrame] = None,
- predict_data_frame: Optional[DataFrame] = None,
- train_transform: Optional[Dict[str, Callable]] = None,
- val_transform: Optional[Dict[str, Callable]] = None,
- test_transform: Optional[Dict[str, Callable]] = None,
- predict_transform: Optional[Dict[str, Callable]] = None,
- data_fetcher: Optional[BaseDataFetcher] = None,
- preprocess: Optional[Preprocess] = None,
- val_split: Optional[float] = None,
- batch_size: int = 4,
- num_workers: Optional[int] = None,
- is_regression: bool = False,
- **preprocess_kwargs: Any,
- ):
- """Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames.
-
- Args:
- categorical_fields: The field or fields (columns) in the CSV file containing categorical inputs.
- numerical_fields: The field or fields (columns) in the CSV file containing numerical inputs.
- target_fields: The field or fields (columns) in the CSV file to use for the target.
- train_data_frame: The pandas ``DataFrame`` containing the training data.
- val_data_frame: The pandas ``DataFrame`` containing the validation data.
- test_data_frame: The pandas ``DataFrame`` containing the testing data.
- predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting.
- train_transform: The dictionary of transforms to use during training which maps
- :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
- val_transform: The dictionary of transforms to use during validation which maps
- :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
- test_transform: The dictionary of transforms to use during testing which maps
- :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
- predict_transform: The dictionary of transforms to use during predicting which maps
- :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
- data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
- :class:`~flash.core.data.data_module.DataModule`.
- preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
- :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls``
- will be constructed and used.
- val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
- batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
- num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
- is_regression: If ``True``, targets will be formatted as floating point. If ``False``, targets will be
- formatted as integers.
- preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
- if ``preprocess = None``.
-
- Returns:
- The constructed data module.
-
- Examples::
-
- data_module = TabularData.from_data_frame(
- "categorical_input",
- "numerical_input",
- "target",
- train_data_frame=train_data,
- )
- """
- categorical_fields, numerical_fields = cls._sanetize_cols(categorical_fields, numerical_fields)
-
- if not isinstance(categorical_fields, list):
- categorical_fields = [categorical_fields]
-
- if not isinstance(numerical_fields, list):
- numerical_fields = [numerical_fields]
-
- mean, std, classes, codes, target_codes = cls.compute_state(
- train_data_frame=train_data_frame,
- val_data_frame=val_data_frame,
- test_data_frame=test_data_frame,
- predict_data_frame=predict_data_frame,
- target_fields=target_fields,
- numerical_fields=numerical_fields,
- categorical_fields=categorical_fields,
- )
-
- return cls.from_data_source(
- "data_frame",
- train_data_frame,
- val_data_frame,
- test_data_frame,
- predict_data_frame,
- train_transform=train_transform,
- val_transform=val_transform,
- test_transform=test_transform,
- predict_transform=predict_transform,
- data_fetcher=data_fetcher,
- preprocess=preprocess,
- val_split=val_split,
- batch_size=batch_size,
- num_workers=num_workers,
- cat_cols=categorical_fields,
- num_cols=numerical_fields,
- target_col=target_fields,
- mean=mean,
- std=std,
- codes=codes,
- target_codes=target_codes,
- classes=classes,
- is_regression=is_regression,
- **preprocess_kwargs,
- )
-
- @classmethod
- def from_csv(
- cls,
- categorical_fields: Optional[Union[str, List[str]]],
- numerical_fields: Optional[Union[str, List[str]]],
- target_fields: Optional[str] = None,
- train_file: Optional[str] = None,
- val_file: Optional[str] = None,
- test_file: Optional[str] = None,
- predict_file: Optional[str] = None,
- train_transform: Optional[Dict[str, Callable]] = None,
- val_transform: Optional[Dict[str, Callable]] = None,
- test_transform: Optional[Dict[str, Callable]] = None,
- predict_transform: Optional[Dict[str, Callable]] = None,
- data_fetcher: Optional[BaseDataFetcher] = None,
- preprocess: Optional[Preprocess] = None,
- val_split: Optional[float] = None,
- batch_size: int = 4,
- num_workers: Optional[int] = None,
- is_regression: bool = False,
- **preprocess_kwargs: Any,
- ) -> 'DataModule':
- """Creates a :class:`~flash.tabular.data.TabularData` object from the given CSV files.
-
- Args:
- categorical_fields: The field or fields (columns) in the CSV file containing categorical inputs.
- numerical_fields: The field or fields (columns) in the CSV file containing numerical inputs.
- target_fields: The field or fields (columns) in the CSV file to use for the target.
- train_file: The CSV file containing the training data.
- val_file: The CSV file containing the validation data.
- test_file: The CSV file containing the testing data.
- predict_file: The CSV file containing the data to use when predicting.
- train_transform: The dictionary of transforms to use during training which maps
- :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
- val_transform: The dictionary of transforms to use during validation which maps
- :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
- test_transform: The dictionary of transforms to use during testing which maps
- :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
- predict_transform: The dictionary of transforms to use during predicting which maps
- :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
- data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
- :class:`~flash.core.data.data_module.DataModule`.
- preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
- :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls``
- will be constructed and used.
- val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
- batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
- num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
- is_regression: If ``True``, targets will be formatted as floating point. If ``False``, targets will be
- formatted as integers.
- preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
- if ``preprocess = None``.
-
- Returns:
- The constructed data module.
-
- Examples::
-
- data_module = TabularData.from_csv(
- "categorical_input",
- "numerical_input",
- "target",
- train_file="train_data.csv",
- )
- """
- return cls.from_data_frame(
- categorical_fields=categorical_fields,
- numerical_fields=numerical_fields,
- target_fields=target_fields,
- train_data_frame=pd.read_csv(train_file) if train_file is not None else None,
- val_data_frame=pd.read_csv(val_file) if val_file is not None else None,
- test_data_frame=pd.read_csv(test_file) if test_file is not None else None,
- predict_data_frame=pd.read_csv(predict_file) if predict_file is not None else None,
- is_regression=is_regression,
- preprocess=preprocess,
- val_split=val_split,
- batch_size=batch_size,
- num_workers=num_workers,
- )
+class TabularClassificationData(TabularData):
+ is_regression = False
diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py
index 3106bd57c9..b01e99e4f6 100644
--- a/flash/tabular/classification/model.py
+++ b/flash/tabular/classification/model.py
@@ -53,7 +53,7 @@ def __init__(
self,
num_features: int,
num_classes: int,
- embedding_sizes: List[Tuple] = None,
+ embedding_sizes: List[Tuple[int, int]] = None,
loss_fn: Callable = F.cross_entropy,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[Metric, Callable, Mapping, Sequence, None] = None,
@@ -71,7 +71,7 @@ def __init__(
cat_idxs=list(range(len(embedding_sizes))),
cat_dims=list(cat_dims),
cat_emb_dim=list(cat_emb_dim),
- **tabnet_kwargs
+ **tabnet_kwargs,
)
super().__init__(
@@ -108,17 +108,15 @@ def test_step(self, batch: Any, batch_idx: int) -> Any:
return super().test_step(batch, batch_idx)
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
- batch = (batch[DefaultDataKeys.INPUT])
+ batch = batch[DefaultDataKeys.INPUT]
return self(batch)
@classmethod
- def from_data(cls, datamodule, **kwargs) -> 'TabularClassifier':
- model = cls(datamodule.num_features, datamodule.num_classes, datamodule.emb_sizes, **kwargs)
+ def from_data(cls, datamodule, **kwargs) -> "TabularClassifier":
+ model = cls(datamodule.num_features, datamodule.num_classes, datamodule.embedding_sizes, **kwargs)
return model
@staticmethod
def _ci_benchmark_fn(history: List[Dict[str, Any]]):
- """
- This function is used only for debugging usage with CI
- """
- assert history[-1]["val_accuracy"] > 0.65
+ """This function is used only for debugging usage with CI."""
+ assert history[-1]["val_accuracy"] > 0.6, history[-1]["val_accuracy"]
diff --git a/flash/tabular/data.py b/flash/tabular/data.py
new file mode 100644
index 0000000000..da36d726ce
--- /dev/null
+++ b/flash/tabular/data.py
@@ -0,0 +1,509 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from io import StringIO
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+
+from flash.core.classification import LabelsState
+from flash.core.data.callback import BaseDataFetcher
+from flash.core.data.data_module import DataModule
+from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources
+from flash.core.data.process import Deserializer, Postprocess, Preprocess
+from flash.core.utilities.imports import _PANDAS_AVAILABLE
+from flash.tabular.classification.utils import (
+ _compute_normalization,
+ _generate_codes,
+ _pre_transform,
+ _to_cat_vars_numpy,
+ _to_num_vars_numpy,
+)
+
+if _PANDAS_AVAILABLE:
+ import pandas as pd
+ from pandas.core.frame import DataFrame
+else:
+ DataFrame = object
+
+
+class TabularDataFrameDataSource(DataSource[DataFrame]):
+ def __init__(
+ self,
+ cat_cols: Optional[List[str]] = None,
+ num_cols: Optional[List[str]] = None,
+ target_col: Optional[str] = None,
+ mean: Optional[DataFrame] = None,
+ std: Optional[DataFrame] = None,
+ codes: Optional[Dict[str, Any]] = None,
+ target_codes: Optional[Dict[str, Any]] = None,
+ classes: Optional[List[str]] = None,
+ is_regression: bool = True,
+ ):
+ super().__init__()
+
+ self.cat_cols = cat_cols
+ self.num_cols = num_cols
+ self.target_col = target_col
+ self.mean = mean
+ self.std = std
+ self.codes = codes
+ self.target_codes = target_codes
+ self.is_regression = is_regression
+
+ self.set_state(LabelsState(classes))
+ self.num_classes = len(classes)
+
+ def common_load_data(
+ self,
+ df: DataFrame,
+ dataset: Optional[Any] = None,
+ ):
+ # impute_data
+ # compute train dataset stats
+ dfs = _pre_transform(
+ [df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, self.target_codes
+ )
+
+ df = dfs[0]
+
+ if dataset is not None:
+ dataset.num_samples = len(df)
+
+ cat_vars = _to_cat_vars_numpy(df, self.cat_cols)
+ num_vars = _to_num_vars_numpy(df, self.num_cols)
+
+ cat_vars = np.stack(cat_vars, 1) # if len(cat_vars) else np.zeros((len(self), 0))
+ num_vars = np.stack(num_vars, 1) # if len(num_vars) else np.zeros((len(self), 0))
+ return df, cat_vars, num_vars
+
+ def load_data(self, data: DataFrame, dataset: Optional[Any] = None):
+ df, cat_vars, num_vars = self.common_load_data(data, dataset=dataset)
+ target = df[self.target_col].to_numpy().astype(np.float32 if self.is_regression else np.int64)
+ return [
+ {DefaultDataKeys.INPUT: (c, n), DefaultDataKeys.TARGET: t} for c, n, t in zip(cat_vars, num_vars, target)
+ ]
+
+ def predict_load_data(self, data: DataFrame, dataset: Optional[Any] = None):
+ _, cat_vars, num_vars = self.common_load_data(data, dataset=dataset)
+ return [{DefaultDataKeys.INPUT: (c, n)} for c, n in zip(cat_vars, num_vars)]
+
+
+class TabularCSVDataSource(TabularDataFrameDataSource):
+ def load_data(self, data: str, dataset: Optional[Any] = None):
+ return super().load_data(pd.read_csv(data), dataset=dataset)
+
+ def predict_load_data(self, data: str, dataset: Optional[Any] = None):
+ return super().predict_load_data(pd.read_csv(data), dataset=dataset)
+
+
+class TabularDeserializer(Deserializer):
+ def __init__(
+ self,
+ cat_cols: Optional[List[str]] = None,
+ num_cols: Optional[List[str]] = None,
+ target_col: Optional[str] = None,
+ mean: Optional[DataFrame] = None,
+ std: Optional[DataFrame] = None,
+ codes: Optional[Dict[str, Any]] = None,
+ target_codes: Optional[Dict[str, Any]] = None,
+ classes: Optional[List[str]] = None,
+ is_regression: bool = True,
+ ):
+ super().__init__()
+ self.cat_cols = cat_cols
+ self.num_cols = num_cols
+ self.target_col = target_col
+ self.mean = mean
+ self.std = std
+ self.codes = codes
+ self.target_codes = target_codes
+ self.classes = classes
+ self.is_regression = is_regression
+
+ def deserialize(self, data: str) -> Any:
+ df = pd.read_csv(StringIO(data))
+ df = _pre_transform(
+ [df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, self.target_codes
+ )[0]
+
+ cat_vars = _to_cat_vars_numpy(df, self.cat_cols)
+ num_vars = _to_num_vars_numpy(df, self.num_cols)
+
+ cat_vars = np.stack(cat_vars, 1)
+ num_vars = np.stack(num_vars, 1)
+
+ return [{DefaultDataKeys.INPUT: [c, n]} for c, n in zip(cat_vars, num_vars)]
+
+ @property
+ def example_input(self) -> str:
+ row = {}
+ for cat_col in self.cat_cols:
+ row[cat_col] = ["test"]
+ for num_col in self.num_cols:
+ row[num_col] = [0]
+ return str(DataFrame.from_dict(row).to_csv())
+
+
+class TabularPreprocess(Preprocess):
+ def __init__(
+ self,
+ train_transform: Optional[Dict[str, Callable]] = None,
+ val_transform: Optional[Dict[str, Callable]] = None,
+ test_transform: Optional[Dict[str, Callable]] = None,
+ predict_transform: Optional[Dict[str, Callable]] = None,
+ cat_cols: Optional[List[str]] = None,
+ num_cols: Optional[List[str]] = None,
+ target_col: Optional[str] = None,
+ mean: Optional[DataFrame] = None,
+ std: Optional[DataFrame] = None,
+ codes: Optional[Dict[str, Any]] = None,
+ target_codes: Optional[Dict[str, Any]] = None,
+ classes: Optional[List[str]] = None,
+ is_regression: bool = True,
+ deserializer: Optional[Deserializer] = None,
+ ):
+ classes = classes or []
+
+ self.cat_cols = cat_cols
+ self.num_cols = num_cols
+ self.target_col = target_col
+ self.mean = mean
+ self.std = std
+ self.codes = codes
+ self.target_codes = target_codes
+ self.classes = classes
+ self.is_regression = is_regression
+
+ super().__init__(
+ train_transform=train_transform,
+ val_transform=val_transform,
+ test_transform=test_transform,
+ predict_transform=predict_transform,
+ data_sources={
+ DefaultDataSources.CSV: TabularCSVDataSource(
+ cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression
+ ),
+ "data_frame": TabularDataFrameDataSource(
+ cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression
+ ),
+ },
+ default_data_source=DefaultDataSources.CSV,
+ deserializer=deserializer
+ or TabularDeserializer(
+ cat_cols=cat_cols,
+ num_cols=num_cols,
+ target_col=target_col,
+ mean=mean,
+ std=std,
+ codes=codes,
+ target_codes=target_codes,
+ classes=classes,
+ is_regression=is_regression,
+ ),
+ )
+
+ def get_state_dict(self, strict: bool = False) -> Dict[str, Any]:
+ return {
+ **self.transforms,
+ "cat_cols": self.cat_cols,
+ "num_cols": self.num_cols,
+ "target_col": self.target_col,
+ "mean": self.mean,
+ "std": self.std,
+ "codes": self.codes,
+ "target_codes": self.target_codes,
+ "classes": self.classes,
+ "is_regression": self.is_regression,
+ }
+
+ @classmethod
+ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> "Preprocess":
+ return cls(**state_dict)
+
+
+class TabularPostprocess(Postprocess):
+ def uncollate(self, batch: Any) -> Any:
+ return batch
+
+
+class TabularData(DataModule):
+ """Data module for tabular tasks."""
+
+ preprocess_cls = TabularPreprocess
+ postprocess_cls = TabularPostprocess
+
+ is_regression: bool = False
+
+ @property
+ def codes(self) -> Dict[str, str]:
+ return self._data_source.codes
+
+ @property
+ def num_classes(self) -> int:
+ return self._data_source.num_classes
+
+ @property
+ def cat_cols(self) -> Optional[List[str]]:
+ return self._data_source.cat_cols
+
+ @property
+ def num_cols(self) -> Optional[List[str]]:
+ return self._data_source.num_cols
+
+ @property
+ def num_features(self) -> int:
+ return len(self.cat_cols) + len(self.num_cols)
+
+ @property
+ def embedding_sizes(self) -> list:
+ """Recommended embedding sizes."""
+
+ # https://developers.googleblog.com/2017/11/introducing-tensorflow-feature-columns.html
+ # The following "formula" provides a general rule of thumb about the number of embedding dimensions:
+ # embedding_dimensions = number_of_categories**0.25
+ num_classes = [len(self.codes[cat]) for cat in self.cat_cols]
+ emb_dims = [max(int(n ** 0.25), 16) for n in num_classes]
+ return list(zip(num_classes, emb_dims))
+
+ @staticmethod
+ def _sanetize_cols(cat_cols: Optional[Union[str, List[str]]], num_cols: Optional[Union[str, List[str]]]):
+ if cat_cols is None and num_cols is None:
+ raise RuntimeError("Both `cat_cols` and `num_cols` are None!")
+
+ return cat_cols or [], num_cols or []
+
+ @classmethod
+ def compute_state(
+ cls,
+ train_data_frame: DataFrame,
+ val_data_frame: Optional[DataFrame],
+ test_data_frame: Optional[DataFrame],
+ predict_data_frame: Optional[DataFrame],
+ target_fields: str,
+ numerical_fields: List[str],
+ categorical_fields: List[str],
+ ) -> Tuple[float, float, List[str], Dict[str, Any], Dict[str, Any]]:
+
+ if train_data_frame is None:
+ raise MisconfigurationException(
+ "train_data_frame is required to instantiate the TabularDataFrameDataSource"
+ )
+
+ data_frames = [train_data_frame]
+
+ if val_data_frame is not None:
+ data_frames += [val_data_frame]
+
+ if test_data_frame is not None:
+ data_frames += [test_data_frame]
+
+ if predict_data_frame is not None:
+ data_frames += [predict_data_frame]
+
+ mean, std = _compute_normalization(data_frames[0], numerical_fields)
+
+ classes = list(data_frames[0][target_fields].unique())
+
+ if data_frames[0][target_fields].dtype == object:
+ # if the target_fields is a category, not an int
+ target_codes = _generate_codes(data_frames, [target_fields])
+ else:
+ target_codes = None
+ codes = _generate_codes(data_frames, categorical_fields)
+
+ return mean, std, classes, codes, target_codes
+
+ @classmethod
+ def from_data_frame(
+ cls,
+ categorical_fields: Optional[Union[str, List[str]]],
+ numerical_fields: Optional[Union[str, List[str]]],
+ target_fields: Optional[str] = None,
+ train_data_frame: Optional[DataFrame] = None,
+ val_data_frame: Optional[DataFrame] = None,
+ test_data_frame: Optional[DataFrame] = None,
+ predict_data_frame: Optional[DataFrame] = None,
+ train_transform: Optional[Dict[str, Callable]] = None,
+ val_transform: Optional[Dict[str, Callable]] = None,
+ test_transform: Optional[Dict[str, Callable]] = None,
+ predict_transform: Optional[Dict[str, Callable]] = None,
+ data_fetcher: Optional[BaseDataFetcher] = None,
+ preprocess: Optional[Preprocess] = None,
+ val_split: Optional[float] = None,
+ batch_size: int = 4,
+ num_workers: Optional[int] = None,
+ **preprocess_kwargs: Any,
+ ):
+ """Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames.
+
+ Args:
+ categorical_fields: The field or fields (columns) in the CSV file containing categorical inputs.
+ numerical_fields: The field or fields (columns) in the CSV file containing numerical inputs.
+ target_fields: The field or fields (columns) in the CSV file to use for the target.
+ train_data_frame: The pandas ``DataFrame`` containing the training data.
+ val_data_frame: The pandas ``DataFrame`` containing the validation data.
+ test_data_frame: The pandas ``DataFrame`` containing the testing data.
+ predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting.
+ train_transform: The dictionary of transforms to use during training which maps
+ :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
+ val_transform: The dictionary of transforms to use during validation which maps
+ :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
+ test_transform: The dictionary of transforms to use during testing which maps
+ :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
+ predict_transform: The dictionary of transforms to use during predicting which maps
+ :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
+ data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
+ :class:`~flash.core.data.data_module.DataModule`.
+ preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
+ :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls``
+ will be constructed and used.
+ val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
+ batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
+ num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
+ preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
+ if ``preprocess = None``.
+
+ Returns:
+ The constructed data module.
+
+ Examples::
+
+ data_module = TabularData.from_data_frame(
+ "categorical_input",
+ "numerical_input",
+ "target",
+ train_data_frame=train_data,
+ )
+ """
+ categorical_fields, numerical_fields = cls._sanetize_cols(categorical_fields, numerical_fields)
+
+ if not isinstance(categorical_fields, list):
+ categorical_fields = [categorical_fields]
+
+ if not isinstance(numerical_fields, list):
+ numerical_fields = [numerical_fields]
+
+ mean, std, classes, codes, target_codes = cls.compute_state(
+ train_data_frame=train_data_frame,
+ val_data_frame=val_data_frame,
+ test_data_frame=test_data_frame,
+ predict_data_frame=predict_data_frame,
+ target_fields=target_fields,
+ numerical_fields=numerical_fields,
+ categorical_fields=categorical_fields,
+ )
+
+ return cls.from_data_source(
+ "data_frame",
+ train_data_frame,
+ val_data_frame,
+ test_data_frame,
+ predict_data_frame,
+ train_transform=train_transform,
+ val_transform=val_transform,
+ test_transform=test_transform,
+ predict_transform=predict_transform,
+ data_fetcher=data_fetcher,
+ preprocess=preprocess,
+ val_split=val_split,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ cat_cols=categorical_fields,
+ num_cols=numerical_fields,
+ target_col=target_fields,
+ mean=mean,
+ std=std,
+ codes=codes,
+ target_codes=target_codes,
+ classes=classes,
+ is_regression=cls.is_regression,
+ **preprocess_kwargs,
+ )
+
+ @classmethod
+ def from_csv(
+ cls,
+ categorical_fields: Optional[Union[str, List[str]]],
+ numerical_fields: Optional[Union[str, List[str]]],
+ target_fields: Optional[str] = None,
+ train_file: Optional[str] = None,
+ val_file: Optional[str] = None,
+ test_file: Optional[str] = None,
+ predict_file: Optional[str] = None,
+ train_transform: Optional[Dict[str, Callable]] = None,
+ val_transform: Optional[Dict[str, Callable]] = None,
+ test_transform: Optional[Dict[str, Callable]] = None,
+ predict_transform: Optional[Dict[str, Callable]] = None,
+ data_fetcher: Optional[BaseDataFetcher] = None,
+ preprocess: Optional[Preprocess] = None,
+ val_split: Optional[float] = None,
+ batch_size: int = 4,
+ num_workers: Optional[int] = None,
+ **preprocess_kwargs: Any,
+ ) -> "DataModule":
+ """Creates a :class:`~flash.tabular.data.TabularData` object from the given CSV files.
+
+ Args:
+ categorical_fields: The field or fields (columns) in the CSV file containing categorical inputs.
+ numerical_fields: The field or fields (columns) in the CSV file containing numerical inputs.
+ target_fields: The field or fields (columns) in the CSV file to use for the target.
+ train_file: The CSV file containing the training data.
+ val_file: The CSV file containing the validation data.
+ test_file: The CSV file containing the testing data.
+ predict_file: The CSV file containing the data to use when predicting.
+ train_transform: The dictionary of transforms to use during training which maps
+ :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
+ val_transform: The dictionary of transforms to use during validation which maps
+ :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
+ test_transform: The dictionary of transforms to use during testing which maps
+ :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
+ predict_transform: The dictionary of transforms to use during predicting which maps
+ :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
+ data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
+ :class:`~flash.core.data.data_module.DataModule`.
+ preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
+ :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls``
+ will be constructed and used.
+ val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
+ batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
+ num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
+ preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
+ if ``preprocess = None``.
+
+ Returns:
+ The constructed data module.
+
+ Examples::
+
+ data_module = TabularData.from_csv(
+ "categorical_input",
+ "numerical_input",
+ "target",
+ train_file="train_data.csv",
+ )
+ """
+ return cls.from_data_frame(
+ categorical_fields=categorical_fields,
+ numerical_fields=numerical_fields,
+ target_fields=target_fields,
+ train_data_frame=pd.read_csv(train_file) if train_file is not None else None,
+ val_data_frame=pd.read_csv(val_file) if val_file is not None else None,
+ test_data_frame=pd.read_csv(test_file) if test_file is not None else None,
+ predict_data_frame=pd.read_csv(predict_file) if predict_file is not None else None,
+ preprocess=preprocess,
+ val_split=val_split,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ )
diff --git a/flash/tabular/regression/__init__.py b/flash/tabular/regression/__init__.py
new file mode 100644
index 0000000000..a93e599ff0
--- /dev/null
+++ b/flash/tabular/regression/__init__.py
@@ -0,0 +1 @@
+from flash.tabular.regression.data import TabularRegressionData # noqa: F401
diff --git a/flash/tabular/regression/data.py b/flash/tabular/regression/data.py
new file mode 100644
index 0000000000..52bd44cd77
--- /dev/null
+++ b/flash/tabular/regression/data.py
@@ -0,0 +1,18 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from flash.tabular.data import TabularData
+
+
+class TabularRegressionData(TabularData):
+ is_regression = True
diff --git a/flash/template/classification/backbones.py b/flash/template/classification/backbones.py
index b36f6a398e..7ea8413003 100644
--- a/flash/template/classification/backbones.py
+++ b/flash/template/classification/backbones.py
@@ -21,21 +21,27 @@
@TEMPLATE_BACKBONES(name="mlp-128", namespace="template/classification")
def load_mlp_128(num_features, **_):
"""A simple MLP backbone with 128 hidden units."""
- return nn.Sequential(
- nn.Linear(num_features, 128),
- nn.ReLU(True),
- nn.BatchNorm1d(128),
- ), 128
+ return (
+ nn.Sequential(
+ nn.Linear(num_features, 128),
+ nn.ReLU(True),
+ nn.BatchNorm1d(128),
+ ),
+ 128,
+ )
@TEMPLATE_BACKBONES(name="mlp-128-256", namespace="template/classification")
def load_mlp_128_256(num_features, **_):
"""An two layer MLP backbone with 128 and 256 hidden units respectively."""
- return nn.Sequential(
- nn.Linear(num_features, 128),
- nn.ReLU(True),
- nn.BatchNorm1d(128),
- nn.Linear(128, 256),
- nn.ReLU(True),
- nn.BatchNorm1d(256),
- ), 256
+ return (
+ nn.Sequential(
+ nn.Linear(num_features, 128),
+ nn.ReLU(True),
+ nn.BatchNorm1d(128),
+ nn.Linear(128, 256),
+ nn.ReLU(True),
+ nn.BatchNorm1d(256),
+ ),
+ 256,
+ )
diff --git a/flash/template/classification/data.py b/flash/template/classification/data.py
index 2624f1c9f3..f81111bc3c 100644
--- a/flash/template/classification/data.py
+++ b/flash/template/classification/data.py
@@ -33,8 +33,11 @@
class TemplateNumpyDataSource(NumpyDataSource):
- """An example data source that records ``num_features`` on the dataset. We extend
- :class:`~flash.core.data.data_source.NumpyDataSource` so that we can use ``super().load_data``."""
+ """An example data source that records ``num_features`` on the dataset.
+
+ We extend
+ :class:`~flash.core.data.data_source.NumpyDataSource` so that we can use ``super().load_data``.
+ """
def load_data(self, data: Tuple[np.ndarray, Sequence[Any]], dataset: Any) -> Sequence[Mapping[str, Any]]:
"""Sets the ``num_features`` attribute and calls ``super().load_data``.
@@ -109,16 +112,18 @@ def __init__(
)
def get_state_dict(self) -> Dict[str, Any]:
- """For serialization, you have control over what to save with the ``get_state_dict`` method. It's usually a good
- idea to save the transforms. So we just return them here. If you had any other attributes you wanted to save,
- this is where you would return them.
+ """For serialization, you have control over what to save with the ``get_state_dict`` method.
+
+ It's usually a good idea to save the transforms. So we just return them here. If you had any other attributes
+ you wanted to save, this is where you would return them.
"""
return self.transforms
@classmethod
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
- """This methods gets whatever we returned from ``get_state_dict`` as an input. Now we re-create the class with
- the transforms we saved.
+ """This methods gets whatever we returned from ``get_state_dict`` as an input.
+
+ Now we re-create the class with the transforms we saved.
"""
return cls(**state_dict)
@@ -147,8 +152,10 @@ def default_transforms(self) -> Optional[Dict[str, Callable]]:
class TemplateData(DataModule):
"""Creating our :class:`~flash.core.data.data_module.DataModule` is as easy as setting the ``preprocess_cls``
- attribute. We get the ``from_numpy`` method for free as we've configured a ``DefaultDataSources.NUMPY`` data source.
- We'll also add a ``from_sklearn`` method so that we can use our ``TemplateSKLearnDataSource. Finally, we define the
+ attribute.
+
+ We get the ``from_numpy`` method for free as we've configured a ``DefaultDataSources.NUMPY`` data source. We'll also
+ add a ``from_sklearn`` method so that we can use our ``TemplateSKLearnDataSource. Finally, we define the
``num_features`` property for convenience.
"""
@@ -232,13 +239,17 @@ def num_features(self) -> Optional[int]:
@staticmethod
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
- """We can, *optionally*, provide a data visualization callback using the ``configure_data_fetcher`` method."""
+ """We can, *optionally*, provide a data visualization callback using the ``configure_data_fetcher``
+ method."""
return TemplateVisualization(*args, **kwargs)
class TemplateVisualization(BaseVisualization):
- """The ``TemplateVisualization`` class is a :class:`~flash.core.data.callbacks.BaseVisualization` that just prints
- the data. If you want to provide a visualization with your task, you can override these hooks."""
+ """The ``TemplateVisualization`` class is a :class:`~flash.core.data.callbacks.BaseVisualization` that just
+ prints the data.
+
+ If you want to provide a visualization with your task, you can override these hooks.
+ """
def show_load_sample(self, samples: List[Any], running_stage: RunningStage):
print(samples)
diff --git a/flash/template/classification/model.py b/flash/template/classification/model.py
index e52faf1274..e330fafdc8 100644
--- a/flash/template/classification/model.py
+++ b/flash/template/classification/model.py
@@ -26,8 +26,8 @@
class TemplateSKLearnClassifier(ClassificationTask):
- """The ``TemplateSKLearnClassifier`` is a :class:`~flash.core.classification.ClassificationTask` that classifies
- tabular data from scikit-learn.
+ """The ``TemplateSKLearnClassifier`` is a :class:`~flash.core.classification.ClassificationTask` that
+ classifies tabular data from scikit-learn.
Args:
num_features: The number of features (elements) in the input data.
@@ -112,9 +112,9 @@ def test_step(self, batch: Any, batch_idx: int) -> Any:
return super().test_step(batch, batch_idx)
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
- """For the predict step, we just extract the :attr:`~flash.core.data.data_source.DefaultDataKeys.INPUT` key from
- the input and forward it to the :meth:`~flash.core.model.Task.predict_step`."""
- batch = (batch[DefaultDataKeys.INPUT])
+ """For the predict step, we just extract the :attr:`~flash.core.data.data_source.DefaultDataKeys.INPUT` key
+ from the input and forward it to the :meth:`~flash.core.model.Task.predict_step`."""
+ batch = batch[DefaultDataKeys.INPUT]
return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)
def forward(self, x) -> torch.Tensor:
diff --git a/flash/text/__init__.py b/flash/text/__init__.py
index 8ac71bdfb5..23786d11f3 100644
--- a/flash/text/__init__.py
+++ b/flash/text/__init__.py
@@ -1,5 +1,7 @@
from flash.text.classification import TextClassificationData, TextClassifier # noqa: F401
from flash.text.seq2seq import ( # noqa: F401
+ QuestionAnsweringData,
+ QuestionAnsweringTask,
Seq2SeqData,
Seq2SeqTask,
SummarizationData,
diff --git a/flash/text/classification/cli.py b/flash/text/classification/cli.py
new file mode 100644
index 0000000000..42499bb53f
--- /dev/null
+++ b/flash/text/classification/cli.py
@@ -0,0 +1,81 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+from flash.core.data.utils import download_data
+from flash.core.utilities.flash_cli import FlashCLI
+from flash.text import TextClassificationData, TextClassifier
+
+__all__ = ["text_classification"]
+
+
+def from_imdb(
+ backbone: str = "prajjwal1/bert-medium",
+ batch_size: int = 4,
+ num_workers: Optional[int] = None,
+ **preprocess_kwargs,
+) -> TextClassificationData:
+ """Downloads and loads the IMDB sentiment classification data set."""
+ download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")
+ return TextClassificationData.from_csv(
+ "review",
+ "sentiment",
+ train_file="data/imdb/train.csv",
+ val_file="data/imdb/valid.csv",
+ backbone=backbone,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **preprocess_kwargs,
+ )
+
+
+def from_toxic(
+ backbone: str = "unitary/toxic-bert",
+ val_split: float = 0.1,
+ batch_size: int = 4,
+ num_workers: Optional[int] = None,
+ **preprocess_kwargs,
+) -> TextClassificationData:
+ """Downloads and loads the Jigsaw toxic comments data set."""
+ download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "./data")
+ return TextClassificationData.from_csv(
+ "comment_text",
+ ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
+ train_file="data/jigsaw_toxic_comments/train.csv",
+ backbone=backbone,
+ val_split=val_split,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **preprocess_kwargs,
+ )
+
+
+def text_classification():
+ """Classify text."""
+ cli = FlashCLI(
+ TextClassifier,
+ TextClassificationData,
+ default_datamodule_builder=from_imdb,
+ additional_datamodule_builders=[from_toxic],
+ default_arguments={
+ "trainer.max_epochs": 3,
+ },
+ datamodule_attributes={"num_classes", "multi_label", "backbone"},
+ )
+
+ cli.trainer.save_checkpoint("text_classification_model.pt")
+
+
+if __name__ == "__main__":
+ text_classification()
diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py
index 826ab87e3d..c7b130543d 100644
--- a/flash/text/classification/data.py
+++ b/flash/text/classification/data.py
@@ -22,7 +22,7 @@
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DataSource, DefaultDataSources, LabelsState
from flash.core.data.process import Deserializer, Postprocess, Preprocess
-from flash.core.utilities.imports import _requires_extras, _TEXT_AVAILABLE
+from flash.core.utilities.imports import _TEXT_AVAILABLE, requires_extras
if _TEXT_AVAILABLE:
from datasets import DatasetDict, load_dataset
@@ -31,9 +31,9 @@
from flash.core.data.data_source import LabelStudioTextDataSource
-class TextDeserializer(Deserializer):
- @_requires_extras("text")
+class TextDeserializer(Deserializer):
+ @requires_extras("text")
def __init__(self, backbone: str, max_length: int, use_fast: bool = True):
super().__init__()
self.backbone = backbone
@@ -58,8 +58,7 @@ def __setstate__(self, state):
class TextDataSource(DataSource):
-
- @_requires_extras("text")
+ @requires_extras("text")
def __init__(self, backbone: str, max_length: int = 128):
super().__init__()
@@ -93,7 +92,6 @@ def __setstate__(self, state):
class TextFileDataSource(TextDataSource):
-
def __init__(self, filetype: str, backbone: str, max_length: int = 128):
super().__init__(backbone, max_length=max_length)
@@ -111,7 +109,10 @@ def load_data(
dataset: Optional[Any] = None,
columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"),
) -> Union[Sequence[Mapping[str, Any]]]:
- file, input, target = data
+ if self.filetype == "json":
+ file, input, target, field = data
+ else:
+ file, input, target = data
data_files = {}
@@ -121,21 +122,38 @@ def load_data(
# FLASH_TESTING is set in the CI to run faster.
if flash._IS_TESTING and not torch.cuda.is_available():
try:
- dataset_dict = DatasetDict({
- stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0]
- })
+ if self.filetype == "json" and field is not None:
+ dataset_dict = DatasetDict(
+ {
+ stage: load_dataset(
+ self.filetype, data_files=data_files, split=[f"{stage}[:20]"], field=field
+ )[0]
+ }
+ )
+ else:
+ dataset_dict = DatasetDict(
+ {stage: load_dataset(self.filetype, data_files=data_files, split=[f"{stage}[:20]"])[0]}
+ )
except Exception:
- dataset_dict = load_dataset(self.filetype, data_files=data_files)
+ if self.filetype == "json" and field is not None:
+ dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field)
+ else:
+ dataset_dict = load_dataset(self.filetype, data_files=data_files)
else:
- dataset_dict = load_dataset(self.filetype, data_files=data_files)
+ if self.filetype == "json" and field is not None:
+ dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field)
+ else:
+ dataset_dict = load_dataset(self.filetype, data_files=data_files)
if not self.predicting:
if isinstance(target, List):
# multi-target
+ dataset.multi_label = True
dataset_dict = dataset_dict.map(partial(self._multilabel_target, target))
dataset.num_classes = len(target)
self.set_state(LabelsState(target))
else:
+ dataset.multi_label = False
if self.training:
labels = list(sorted(list(set(dataset_dict[stage][target]))))
dataset.num_classes = len(labels)
@@ -172,7 +190,6 @@ def __setstate__(self, state):
class TextCSVDataSource(TextFileDataSource):
-
def __init__(self, backbone: str, max_length: int = 128):
super().__init__("csv", backbone, max_length=max_length)
@@ -187,7 +204,6 @@ def __setstate__(self, state):
class TextJSONDataSource(TextFileDataSource):
-
def __init__(self, backbone: str, max_length: int = 128):
super().__init__("json", backbone, max_length=max_length)
@@ -202,7 +218,6 @@ def __setstate__(self, state):
class TextSentencesDataSource(TextDataSource):
-
def __init__(self, backbone: str, max_length: int = 128):
super().__init__(backbone, max_length=max_length)
@@ -214,7 +229,12 @@ def load_data(
if isinstance(data, str):
data = [data]
- return [self._tokenize_fn(s, ) for s in data]
+ return [
+ self._tokenize_fn(
+ s,
+ )
+ for s in data
+ ]
def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
@@ -227,8 +247,7 @@ def __setstate__(self, state):
class TextClassificationPreprocess(Preprocess):
-
- @_requires_extras("text")
+ @requires_extras("text")
def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
@@ -250,7 +269,9 @@ def __init__(
DefaultDataSources.CSV: TextCSVDataSource(self.backbone, max_length=max_length),
DefaultDataSources.JSON: TextJSONDataSource(self.backbone, max_length=max_length),
"sentences": TextSentencesDataSource(self.backbone, max_length=max_length),
- DefaultDataSources.LABELSTUDIO: LabelStudioTextDataSource(backbone=self.backbone, max_length=max_length)
+ DefaultDataSources.LABELSTUDIO: LabelStudioTextDataSource(
+ backbone=self.backbone, max_length=max_length
+ ),
},
default_data_source="sentences",
deserializer=TextDeserializer(backbone, max_length),
@@ -275,14 +296,13 @@ def per_batch_transform(self, batch: Any) -> Any:
return batch
def collate(self, samples: Any) -> Tensor:
- """Override to convert a set of samples to a batch"""
+ """Override to convert a set of samples to a batch."""
if isinstance(samples, dict):
samples = [samples]
return default_data_collator(samples)
class TextClassificationPostprocess(Postprocess):
-
def per_batch_transform(self, batch: Any) -> Any:
if isinstance(batch, SequenceClassifierOutput):
batch = batch.logits
@@ -290,7 +310,11 @@ def per_batch_transform(self, batch: Any) -> Any:
class TextClassificationData(DataModule):
- """Data Module for text classification tasks"""
+ """Data Module for text classification tasks."""
preprocess_cls = TextClassificationPreprocess
postprocess_cls = TextClassificationPostprocess
+
+ @property
+ def backbone(self) -> Optional[str]:
+ return getattr(self.preprocess, "backbone", None)
diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py
index e1da47be55..cf339153a0 100644
--- a/flash/text/classification/model.py
+++ b/flash/text/classification/model.py
@@ -16,15 +16,17 @@
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union
import torch
-from torchmetrics import Accuracy, F1, Metric
+from pytorch_lightning import Callback
+from torchmetrics import Metric
from flash.core.classification import ClassificationTask, Labels
from flash.core.data.process import Serializer
from flash.core.utilities.imports import _TEXT_AVAILABLE
+from flash.text.ort_callback import ORTCallback
if _TEXT_AVAILABLE:
- from transformers import BertForSequenceClassification
- from transformers.modeling_outputs import SequenceClassifierOutput
+ from transformers import AutoModelForSequenceClassification
+ from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput, SequenceClassifierOutput
class TextClassifier(ClassificationTask):
@@ -43,6 +45,7 @@ class TextClassifier(ClassificationTask):
learning_rate: Learning rate to use for training, defaults to `1e-3`
multi_label: Whether the targets are multi-label or not.
serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
+ enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""
required_extras: str = "text"
@@ -57,6 +60,7 @@ def __init__(
learning_rate: float = 1e-2,
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
+ enable_ort: bool = False,
):
self.save_hyperparameters()
@@ -67,49 +71,53 @@ def __init__(
os.environ["PYTHONWARNINGS"] = "ignore"
super().__init__(
+ num_classes=num_classes,
model=None,
loss_fn=loss_fn,
optimizer=optimizer,
- metrics=metrics or (F1(num_classes) if multi_label else Accuracy()),
+ metrics=metrics,
learning_rate=learning_rate,
multi_label=multi_label,
serializer=serializer or Labels(multi_label=multi_label),
)
- self.model = BertForSequenceClassification.from_pretrained(backbone, num_labels=num_classes)
-
+ self.enable_ort = enable_ort
+ self.model = AutoModelForSequenceClassification.from_pretrained(backbone, num_labels=num_classes)
self.save_hyperparameters()
@property
def backbone(self):
- # see huggingface's BertForSequenceClassification
- return self.model.bert
+ return self.model.base_model
def forward(self, batch: Dict[str, torch.Tensor]):
return self.model(input_ids=batch.get("input_ids", None), attention_mask=batch.get("attention_mask", None))
def to_loss_format(self, x) -> torch.Tensor:
- if isinstance(x, SequenceClassifierOutput):
+ if isinstance(x, (SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput)):
x = x.logits
return super().to_loss_format(x)
def to_metrics_format(self, x) -> torch.Tensor:
- if isinstance(x, SequenceClassifierOutput):
+ if isinstance(x, (SequenceClassifierOutput, Seq2SeqSequenceClassifierOutput)):
x = x.logits
return super().to_metrics_format(x)
- def step(self, batch, batch_idx) -> dict:
+ def step(self, batch, batch_idx, metrics) -> dict:
target = batch.pop("labels")
batch = (batch, target)
- return super().step(batch, batch_idx)
+ return super().step(batch, batch_idx, metrics)
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self(batch)
def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
- """
- This function is used only for debugging usage with CI
- """
+ """This function is used only for debugging usage with CI."""
if self.hparams.multi_label:
- assert history[-1]["val_f1"] > 0.45
+ assert history[-1]["val_f1"] > 0.40, history[-1]["val_f1"]
else:
- assert history[-1]["val_accuracy"] > 0.73
+ assert history[-1]["val_accuracy"] > 0.70, history[-1]["val_accuracy"]
+
+ def configure_callbacks(self) -> List[Callback]:
+ callbacks = super().configure_callbacks() or []
+ if self.enable_ort:
+ callbacks.append(ORTCallback())
+ return callbacks
diff --git a/flash/text/ort_callback.py b/flash/text/ort_callback.py
new file mode 100644
index 0000000000..53b5bdf197
--- /dev/null
+++ b/flash/text/ort_callback.py
@@ -0,0 +1,51 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from pytorch_lightning import Callback, LightningModule, Trainer
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+
+from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE
+
+if _TORCH_ORT_AVAILABLE:
+ from torch_ort import ORTModule
+
+
+class ORTCallback(Callback):
+ """Enables Torch ORT: Accelerate PyTorch models with ONNX Runtime.
+
+ Wraps a model with the ORT wrapper, lazily converting your module into an ONNX export, to optimize for
+ training and inference.
+
+ Usage:
+
+ # via Transformer Tasks
+ model = TextClassifier(backbone="facebook/bart-large", num_classes=datamodule.num_classes, enable_ort=True)
+
+ # or via the trainer
+ trainer = flash.Trainer(callbacks=ORTCallback())
+ """
+
+ def __init__(self):
+ if not _TORCH_ORT_AVAILABLE:
+ raise MisconfigurationException(
+ "Torch ORT is required to use ORT. See here for installation: https://github.com/pytorch/ort"
+ )
+
+ def on_before_accelerator_backend_setup(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ if not hasattr(pl_module, "model"):
+ raise MisconfigurationException(
+ "Torch ORT requires to wrap a single model that defines a forward function "
+ "assigned as `model` inside the `LightningModule`."
+ )
+ if not isinstance(pl_module.model, ORTModule):
+ pl_module.model = ORTModule(pl_module.model)
diff --git a/flash/text/seq2seq/__init__.py b/flash/text/seq2seq/__init__.py
index 1c30bc9d85..88adc2ab65 100644
--- a/flash/text/seq2seq/__init__.py
+++ b/flash/text/seq2seq/__init__.py
@@ -1,3 +1,4 @@
from flash.text.seq2seq.core import Seq2SeqData, Seq2SeqFreezeEmbeddings, Seq2SeqTask # noqa: F401
+from flash.text.seq2seq.question_answering import QuestionAnsweringData, QuestionAnsweringTask # noqa: F401
from flash.text.seq2seq.summarization import SummarizationData, SummarizationTask # noqa: F401
from flash.text.seq2seq.translation import TranslationData, TranslationTask # noqa: F401
diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py
index 1b29d7e2c2..60404a5b66 100644
--- a/flash/text/seq2seq/core/data.py
+++ b/flash/text/seq2seq/core/data.py
@@ -23,7 +23,7 @@
from flash.core.data.data_source import DataSource, DefaultDataSources
from flash.core.data.process import Postprocess, Preprocess
from flash.core.data.properties import ProcessState
-from flash.core.utilities.imports import _requires_extras, _TEXT_AVAILABLE
+from flash.core.utilities.imports import _TEXT_AVAILABLE, requires_extras
from flash.text.classification.data import TextDeserializer
if _TEXT_AVAILABLE:
@@ -33,14 +33,13 @@
class Seq2SeqDataSource(DataSource):
-
- @_requires_extras("text")
+ @requires_extras("text")
def __init__(
self,
backbone: str,
max_source_length: int = 128,
max_target_length: int = 128,
- padding: Union[str, bool] = 'max_length'
+ padding: Union[str, bool] = "max_length",
):
super().__init__()
@@ -82,23 +81,25 @@ def __setstate__(self, state):
class Seq2SeqFileDataSource(Seq2SeqDataSource):
-
def __init__(
self,
filetype: str,
backbone: str,
max_source_length: int = 128,
max_target_length: int = 128,
- padding: Union[str, bool] = 'max_length',
+ padding: Union[str, bool] = "max_length",
):
super().__init__(backbone, max_source_length, max_target_length, padding)
self.filetype = filetype
- def load_data(self, data: Any, columns: List[str] = None) -> 'datasets.Dataset':
+ def load_data(self, data: Any, columns: List[str] = None) -> "datasets.Dataset":
if columns is None:
columns = ["input_ids", "attention_mask", "labels"]
- file, input, target = data
+ if self.filetype == "json":
+ file, input, target, field = data
+ else:
+ file, input, target = data
data_files = {}
stage = self._running_stage.value
data_files[stage] = str(file)
@@ -106,19 +107,34 @@ def load_data(self, data: Any, columns: List[str] = None) -> 'datasets.Dataset':
# FLASH_TESTING is set in the CI to run faster.
if flash._IS_TESTING:
try:
- dataset_dict = DatasetDict({
- stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0]
- })
+ if self.filetype == "json" and field is not None:
+ dataset_dict = DatasetDict(
+ {
+ stage: load_dataset(
+ self.filetype, data_files=data_files, split=[f"{stage}[:20]"], field=field
+ )[0]
+ }
+ )
+ else:
+ dataset_dict = DatasetDict(
+ {stage: load_dataset(self.filetype, data_files=data_files, split=[f"{stage}[:20]"])[0]}
+ )
except Exception:
- dataset_dict = load_dataset(self.filetype, data_files=data_files)
+ if self.filetype == "json" and field is not None:
+ dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field)
+ else:
+ dataset_dict = load_dataset(self.filetype, data_files=data_files)
else:
- dataset_dict = load_dataset(self.filetype, data_files=data_files)
+ if self.filetype == "json" and field is not None:
+ dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field)
+ else:
+ dataset_dict = load_dataset(self.filetype, data_files=data_files)
dataset_dict = dataset_dict.map(partial(self._tokenize_fn, input=input, target=target), batched=True)
dataset_dict.set_format(columns=columns)
return dataset_dict[stage]
- def predict_load_data(self, data: Any) -> Union['datasets.Dataset', List[Dict[str, torch.Tensor]]]:
+ def predict_load_data(self, data: Any) -> Union["datasets.Dataset", List[Dict[str, torch.Tensor]]]:
return self.load_data(data, columns=["input_ids", "attention_mask"])
def __getstate__(self): # TODO: Find out why this is being pickled
@@ -132,13 +148,12 @@ def __setstate__(self, state):
class Seq2SeqCSVDataSource(Seq2SeqFileDataSource):
-
def __init__(
self,
backbone: str,
max_source_length: int = 128,
max_target_length: int = 128,
- padding: Union[str, bool] = 'max_length',
+ padding: Union[str, bool] = "max_length",
):
super().__init__(
"csv",
@@ -159,13 +174,12 @@ def __setstate__(self, state):
class Seq2SeqJSONDataSource(Seq2SeqFileDataSource):
-
def __init__(
self,
backbone: str,
max_source_length: int = 128,
max_target_length: int = 128,
- padding: Union[str, bool] = 'max_length',
+ padding: Union[str, bool] = "max_length",
):
super().__init__(
"json",
@@ -186,7 +200,6 @@ def __setstate__(self, state):
class Seq2SeqSentencesDataSource(Seq2SeqDataSource):
-
def load_data(
self,
data: Union[str, List[str]],
@@ -217,8 +230,7 @@ class Seq2SeqBackboneState(ProcessState):
class Seq2SeqPreprocess(Preprocess):
-
- @_requires_extras("text")
+ @requires_extras("text")
def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
@@ -228,7 +240,7 @@ def __init__(
backbone: str = "sshleifer/tiny-mbart",
max_source_length: int = 128,
max_target_length: int = 128,
- padding: Union[str, bool] = 'max_length'
+ padding: Union[str, bool] = "max_length",
):
self.backbone = backbone
self.max_target_length = max_target_length
@@ -261,7 +273,7 @@ def __init__(
),
},
default_data_source="sentences",
- deserializer=TextDeserializer(backbone, max_source_length)
+ deserializer=TextDeserializer(backbone, max_source_length),
)
self.set_state(Seq2SeqBackboneState(self.backbone))
@@ -280,13 +292,12 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool):
return cls(**state_dict)
def collate(self, samples: Any) -> Tensor:
- """Override to convert a set of samples to a batch"""
+ """Override to convert a set of samples to a batch."""
return default_data_collator(samples)
class Seq2SeqPostprocess(Postprocess):
-
- @_requires_extras("text")
+ @requires_extras("text")
def __init__(self):
super().__init__()
diff --git a/flash/text/seq2seq/core/finetuning.py b/flash/text/seq2seq/core/finetuning.py
index 6d3ea3e512..f75ab65a54 100644
--- a/flash/text/seq2seq/core/finetuning.py
+++ b/flash/text/seq2seq/core/finetuning.py
@@ -17,9 +17,7 @@
class Seq2SeqFreezeEmbeddings(FlashBaseFinetuning):
- """
- Freezes the embedding layers during Seq2Seq training.
- """
+ """Freezes the embedding layers during Seq2Seq training."""
def __init__(self, model_type: str, train_bn: bool = True):
super().__init__("", train_bn)
diff --git a/flash/text/seq2seq/summarization/metric.py b/flash/text/seq2seq/core/metrics.py
similarity index 51%
rename from flash/text/seq2seq/summarization/metric.py
rename to flash/text/seq2seq/core/metrics.py
index 1e7e7dd3f0..a99c113122 100644
--- a/flash/text/seq2seq/summarization/metric.py
+++ b/flash/text/seq2seq/core/metrics.py
@@ -11,14 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+# referenced from
+# Library Name: torchtext
+# Authors: torchtext authors and @sluks
+# Date: 2020-07-18
+# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
+from collections import Counter
from typing import Dict, List, Tuple
import numpy as np
+import torch
from torch import tensor
from torchmetrics import Metric
-from flash.core.utilities.imports import _requires_extras, _TEXT_AVAILABLE
-from flash.text.seq2seq.summarization.utils import add_newline_to_end_of_each_sentence
+from flash.core.utilities.imports import _TEXT_AVAILABLE, requires_extras
+from flash.text.seq2seq.core.utils import add_newline_to_end_of_each_sentence
if _TEXT_AVAILABLE:
from rouge_score import rouge_scorer
@@ -27,9 +34,103 @@
AggregateScore, Score, BootstrapAggregator = None, None, object
-class RougeMetric(Metric):
+def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
+ """
+ Counting how many times each word appears in a given text with ngram
+ Args:
+ ngram_input_list: A list of translated text or reference texts
+ n_gram: gram value ranged 1 to 4
+
+ Return:
+ ngram_counter: a collections.Counter object of ngram
"""
- Metric used for automatic summarization. https://www.aclweb.org/anthology/W04-1013/
+
+ ngram_counter = Counter()
+
+ for i in range(1, n_gram + 1):
+ for j in range(len(ngram_input_list) - i + 1):
+ ngram_key = tuple(ngram_input_list[j : (i + j)])
+ ngram_counter[ngram_key] += 1
+
+ return ngram_counter
+
+
+class BLEUScore(Metric):
+ """Calculate BLEU score of machine translated text with one or more references.
+
+ Example:
+ >>> translate_corpus = ['the cat is on the mat'.split()]
+ >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
+ >>> metric = BLEUScore()
+ >>> metric(translate_corpus, reference_corpus)
+ tensor(0.7598)
+ """
+
+ def __init__(self, n_gram: int = 4, smooth: bool = False):
+ """
+ Args:
+ n_gram: Gram value ranged from 1 to 4 (Default 4)
+ smooth: Whether or not to apply smoothing – Lin et al. 2004
+ """
+ super().__init__()
+ self.n_gram = n_gram
+ self.smooth = smooth
+
+ self.add_state("c", tensor(0, dtype=torch.float), dist_reduce_fx="sum")
+ self.add_state("r", tensor(0, dtype=torch.float), dist_reduce_fx="sum")
+ self.add_state("numerator", torch.zeros(self.n_gram), dist_reduce_fx="sum")
+ self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum")
+
+ def compute(self):
+
+ trans_len = self.c.clone().detach()
+ ref_len = self.r.clone().detach()
+
+ if min(self.numerator) == 0.0:
+ return tensor(0.0, device=self.r.device)
+
+ if self.smooth:
+ precision_scores = (self.numerator + 1.0) / (self.denominator + 1.0)
+ else:
+ precision_scores = self.numerator / self.denominator
+
+ log_precision_scores = tensor([1.0 / self.n_gram] * self.n_gram, device=self.r.device) * torch.log(
+ precision_scores
+ )
+ geometric_mean = torch.exp(torch.sum(log_precision_scores))
+ brevity_penalty = tensor(1.0, device=self.r.device) if self.c > self.r else torch.exp(1 - (ref_len / trans_len))
+ bleu = brevity_penalty * geometric_mean
+ return bleu
+
+ def update(self, translate_corpus, reference_corpus) -> None:
+ """
+ Actual metric computation
+ Args:
+ translate_corpus: An iterable of machine translated corpus
+ reference_corpus: An iterable of iterables of reference corpus
+ """
+ for (translation, references) in zip(translate_corpus, reference_corpus):
+ self.c += len(translation)
+ ref_len_list = [len(ref) for ref in references]
+ ref_len_diff = [abs(len(translation) - x) for x in ref_len_list]
+ self.r += ref_len_list[ref_len_diff.index(min(ref_len_diff))]
+ translation_counter = _count_ngram(translation, self.n_gram)
+ reference_counter = Counter()
+
+ for ref in references:
+ reference_counter |= _count_ngram(ref, self.n_gram)
+
+ ngram_counter_clip = translation_counter & reference_counter
+
+ for counter_clip in ngram_counter_clip:
+ self.numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]
+
+ for counter in translation_counter:
+ self.denominator[len(counter) - 1] += translation_counter[counter]
+
+
+class RougeMetric(Metric):
+ """Metric used for automatic summarization. https://www.aclweb.org/anthology/W04-1013/
Example:
@@ -52,7 +153,7 @@ class RougeMetric(Metric):
'rougeLsum_recall': 0.25}
"""
- @_requires_extras("text")
+ @requires_extras("text")
def __init__(
self,
rouge_newline_sep: bool = False,
@@ -102,13 +203,11 @@ def __hash__(self):
class RougeBatchAggregator(BootstrapAggregator):
- """
- Aggregates rouge scores and provides confidence intervals.
- """
+ """Aggregates rouge scores and provides confidence intervals."""
def aggregate(self):
- """
- Override function to wrap the final results in `Score` objects.
+ """Override function to wrap the final results in `Score` objects.
+
This is due to the scores being replaced with a list of torch tensors.
"""
result = {}
@@ -118,7 +217,7 @@ def aggregate(self):
# Percentiles are returned as (interval, measure).
percentiles = self._bootstrap_resample(score_matrix)
# Extract the three intervals (low, mid, high).
- intervals = tuple((Score(*percentiles[j, :]) for j in range(3)))
+ intervals = tuple(Score(*percentiles[j, :]) for j in range(3))
result[score_type] = AggregateScore(low=intervals[0], mid=intervals[1], high=intervals[2])
return result
diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py
index d965c084ae..d79ca18a78 100644
--- a/flash/text/seq2seq/core/model.py
+++ b/flash/text/seq2seq/core/model.py
@@ -16,6 +16,7 @@
from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union
import torch
+from pytorch_lightning import Callback
from pytorch_lightning.utilities import rank_zero_info
from torch import Tensor
from torchmetrics import Metric
@@ -23,6 +24,7 @@
from flash.core.finetuning import FlashBaseFinetuning
from flash.core.model import Task
from flash.core.utilities.imports import _TEXT_AVAILABLE
+from flash.text.ort_callback import ORTCallback
from flash.text.seq2seq.core.finetuning import Seq2SeqFreezeEmbeddings
if _TEXT_AVAILABLE:
@@ -40,7 +42,7 @@ def _pad_tensors_to_max_len(model_cfg, tensor, max_length):
)
padded_tensor = pad_token_id * torch.ones((tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device)
- padded_tensor[:, :tensor.shape[-1]] = tensor
+ padded_tensor[:, : tensor.shape[-1]] = tensor
return padded_tensor
@@ -54,19 +56,21 @@ class Seq2SeqTask(Task):
learning_rate: Learning rate to use for training, defaults to `3e-4`
val_target_max_length: Maximum length of targets in validation. Defaults to `128`
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
+ enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""
required_extras: str = "text"
def __init__(
self,
- backbone: str = 't5-small',
+ backbone: str = "t5-small",
loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[Metric, Callable, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
val_target_max_length: Optional[int] = None,
num_beams: Optional[int] = None,
+ enable_ort: bool = False,
):
os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
# disable HF thousand warnings
@@ -75,6 +79,7 @@ def __init__(
os.environ["PYTHONWARNINGS"] = "ignore"
super().__init__(loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate)
self.model = AutoModelForSeq2SeqLM.from_pretrained(backbone)
+ self.enable_ort = enable_ort
self.val_target_max_length = val_target_max_length
self.num_beams = num_beams
self._initialize_model_specific_parameters()
@@ -83,7 +88,7 @@ def forward(self, x: Any) -> Any:
max_length = self.val_target_max_length if self.val_target_max_length else self.model.config.max_length
num_beams = self.num_beams if self.num_beams else self.model.config.num_beams
generated_tokens = self.model.generate(
- input_ids=x['input_ids'], attention_mask=x['attention_mask'], max_length=max_length, num_beams=num_beams
+ input_ids=x["input_ids"], attention_mask=x["attention_mask"], max_length=max_length, num_beams=num_beams
)
# in case the batch is shorter than max length, the output should be padded
if generated_tokens.shape[-1] < max_length:
@@ -113,9 +118,7 @@ def compute_metrics(self, generated_tokens, batch, prefix):
@property
def task(self) -> Optional[str]:
- """
- Override to define AutoConfig task specific parameters stored within the model.
- """
+ """Override to define AutoConfig task specific parameters stored within the model."""
return
def _initialize_model_specific_parameters(self):
@@ -127,7 +130,7 @@ def _initialize_model_specific_parameters(self):
self.model.config.update(pars)
@property
- def tokenizer(self) -> 'PreTrainedTokenizerBase':
+ def tokenizer(self) -> "PreTrainedTokenizerBase":
return self.data_pipeline.data_source.tokenizer
def tokenize_labels(self, labels: Tensor) -> List[str]:
@@ -136,3 +139,9 @@ def tokenize_labels(self, labels: Tensor) -> List[str]:
def configure_finetune_callback(self) -> List[FlashBaseFinetuning]:
return [Seq2SeqFreezeEmbeddings(self.model.config.model_type, train_bn=True)]
+
+ def configure_callbacks(self) -> List[Callback]:
+ callbacks = super().configure_callbacks() or []
+ if self.enable_ort:
+ callbacks.append(ORTCallback())
+ return callbacks
diff --git a/flash/text/seq2seq/summarization/utils.py b/flash/text/seq2seq/core/utils.py
similarity index 97%
rename from flash/text/seq2seq/summarization/utils.py
rename to flash/text/seq2seq/core/utils.py
index 02647f7264..e48248754c 100644
--- a/flash/text/seq2seq/summarization/utils.py
+++ b/flash/text/seq2seq/core/utils.py
@@ -16,8 +16,9 @@
from pytorch_lightning.utilities import _module_available
nltk = None
-if _module_available('nltk'):
+if _module_available("nltk"):
import nltk
+
nltk.download("punkt", quiet=True)
diff --git a/flash/text/seq2seq/question_answering/__init__.py b/flash/text/seq2seq/question_answering/__init__.py
new file mode 100644
index 0000000000..83330ccb4b
--- /dev/null
+++ b/flash/text/seq2seq/question_answering/__init__.py
@@ -0,0 +1,2 @@
+from flash.text.seq2seq.question_answering.data import QuestionAnsweringData # noqa: F401
+from flash.text.seq2seq.question_answering.model import QuestionAnsweringTask # noqa: F401
diff --git a/flash/text/seq2seq/question_answering/data.py b/flash/text/seq2seq/question_answering/data.py
new file mode 100644
index 0000000000..ad3f028f20
--- /dev/null
+++ b/flash/text/seq2seq/question_answering/data.py
@@ -0,0 +1,46 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Dict, Optional, Union
+
+from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPostprocess, Seq2SeqPreprocess
+
+
+class QuestionAnsweringPreprocess(Seq2SeqPreprocess):
+ def __init__(
+ self,
+ train_transform: Optional[Dict[str, Callable]] = None,
+ val_transform: Optional[Dict[str, Callable]] = None,
+ test_transform: Optional[Dict[str, Callable]] = None,
+ predict_transform: Optional[Dict[str, Callable]] = None,
+ backbone: str = "t5-small",
+ max_source_length: int = 128,
+ max_target_length: int = 128,
+ padding: Union[str, bool] = "max_length",
+ ):
+ super().__init__(
+ train_transform=train_transform,
+ val_transform=val_transform,
+ test_transform=test_transform,
+ predict_transform=predict_transform,
+ backbone=backbone,
+ max_source_length=max_source_length,
+ max_target_length=max_target_length,
+ padding=padding,
+ )
+
+
+class QuestionAnsweringData(Seq2SeqData):
+
+ preprocess_cls = QuestionAnsweringPreprocess
+ postprocess_cls = Seq2SeqPostprocess
diff --git a/flash/text/seq2seq/question_answering/model.py b/flash/text/seq2seq/question_answering/model.py
new file mode 100644
index 0000000000..0ebec8aed3
--- /dev/null
+++ b/flash/text/seq2seq/question_answering/model.py
@@ -0,0 +1,85 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union
+
+import torch
+from torchmetrics import Metric
+
+from flash.text.seq2seq.core.metrics import RougeMetric
+from flash.text.seq2seq.core.model import Seq2SeqTask
+
+
+class QuestionAnsweringTask(Seq2SeqTask):
+ """The ``QuestionAnsweringTask`` is a :class:`~flash.Task` for Seq2Seq text question answering. For more
+ details, see `question_answering`.
+
+ You can change the backbone to any question answering model from `HuggingFace/transformers
+ `_ using the ``backbone`` argument.
+
+ .. note:: When changing the backbone, make sure you pass in the same backbone to the :class:`~flash.Task` and the
+ :class:`~flash.core.data.data_module.DataModule` object! Since this is a Seq2Seq task, make sure you use a
+ Seq2Seq model.
+
+ Args:
+ backbone: backbone model to use for the task.
+ loss_fn: Loss function for training.
+ optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
+ metrics: Metrics to compute for training and evaluation. Defauls to calculating the ROUGE metric.
+ Changing this argument currently has no effect.
+ learning_rate: Learning rate to use for training, defaults to `3e-4`
+ val_target_max_length: Maximum length of targets in validation. Defaults to `128`
+ num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
+ use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching.
+ rouge_newline_sep: Add a new line at the beginning of each sentence in Rouge Metric calculation.
+ enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
+ """
+
+ def __init__(
+ self,
+ backbone: str = "t5-small",
+ loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
+ optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
+ metrics: Union[Metric, Callable, Mapping, Sequence, None] = None,
+ learning_rate: float = 1e-5,
+ val_target_max_length: Optional[int] = None,
+ num_beams: Optional[int] = 4,
+ use_stemmer: bool = True,
+ rouge_newline_sep: bool = True,
+ enable_ort: bool = False,
+ ):
+ self.save_hyperparameters()
+ super().__init__(
+ backbone=backbone,
+ loss_fn=loss_fn,
+ optimizer=optimizer,
+ metrics=metrics,
+ learning_rate=learning_rate,
+ val_target_max_length=val_target_max_length,
+ num_beams=num_beams,
+ enable_ort=enable_ort,
+ )
+ self.rouge = RougeMetric(
+ rouge_newline_sep=rouge_newline_sep,
+ use_stemmer=use_stemmer,
+ )
+
+ def compute_metrics(self, generated_tokens: torch.Tensor, batch: Dict, prefix: str) -> None:
+ tgt_lns = self.tokenize_labels(batch["labels"])
+ result = self.rouge(self._postprocess.uncollate(generated_tokens), tgt_lns)
+ self.log_dict(result, on_step=False, on_epoch=True, prog_bar=True)
+
+ @staticmethod
+ def _ci_benchmark_fn(history: List[Dict[str, Any]]):
+ """This function is used only for debugging usage with CI."""
+ assert history[-1]["rouge1_recall"] > 0.2
diff --git a/flash/text/seq2seq/summarization/cli.py b/flash/text/seq2seq/summarization/cli.py
new file mode 100644
index 0000000000..666dd87f40
--- /dev/null
+++ b/flash/text/seq2seq/summarization/cli.py
@@ -0,0 +1,59 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+from flash.core.data.utils import download_data
+from flash.core.utilities.flash_cli import FlashCLI
+from flash.text import SummarizationData, SummarizationTask
+
+__all__ = ["summarization"]
+
+
+def from_xsum(
+ backbone: str = "sshleifer/distilbart-xsum-1-1",
+ batch_size: int = 4,
+ num_workers: Optional[int] = None,
+ **preprocess_kwargs,
+) -> SummarizationData:
+ """Downloads and loads the XSum data set."""
+ download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", "./data/")
+ return SummarizationData.from_csv(
+ "input",
+ "target",
+ train_file="data/xsum/train.csv",
+ val_file="data/xsum/valid.csv",
+ backbone=backbone,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **preprocess_kwargs,
+ )
+
+
+def summarization():
+ """Summarize text."""
+ cli = FlashCLI(
+ SummarizationTask,
+ SummarizationData,
+ default_datamodule_builder=from_xsum,
+ default_arguments={
+ "trainer.max_epochs": 3,
+ "model.backbone": "sshleifer/distilbart-xsum-1-1",
+ },
+ )
+
+ cli.trainer.save_checkpoint("summarization_model_xsum.pt")
+
+
+if __name__ == "__main__":
+ summarization()
diff --git a/flash/text/seq2seq/summarization/data.py b/flash/text/seq2seq/summarization/data.py
index c2a29df52c..3797d97f92 100644
--- a/flash/text/seq2seq/summarization/data.py
+++ b/flash/text/seq2seq/summarization/data.py
@@ -17,7 +17,6 @@
class SummarizationPreprocess(Seq2SeqPreprocess):
-
def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
@@ -27,7 +26,7 @@ def __init__(
backbone: str = "sshleifer/distilbart-xsum-1-1",
max_source_length: int = 128,
max_target_length: int = 128,
- padding: Union[str, bool] = 'max_length'
+ padding: Union[str, bool] = "max_length",
):
super().__init__(
train_transform=train_transform,
diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py
index d547972f3f..19e812baf1 100644
--- a/flash/text/seq2seq/summarization/model.py
+++ b/flash/text/seq2seq/summarization/model.py
@@ -16,8 +16,8 @@
import torch
from torchmetrics import Metric
+from flash.text.seq2seq.core.metrics import RougeMetric
from flash.text.seq2seq.core.model import Seq2SeqTask
-from flash.text.seq2seq.summarization.metric import RougeMetric
class SummarizationTask(Seq2SeqTask):
@@ -42,6 +42,7 @@ class SummarizationTask(Seq2SeqTask):
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching.
rouge_newline_sep: Add a new line at the beginning of each sentence in Rouge Metric calculation.
+ enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""
def __init__(
@@ -54,7 +55,8 @@ def __init__(
val_target_max_length: Optional[int] = None,
num_beams: Optional[int] = 4,
use_stemmer: bool = True,
- rouge_newline_sep: bool = True
+ rouge_newline_sep: bool = True,
+ enable_ort: bool = False,
):
self.save_hyperparameters()
super().__init__(
@@ -64,7 +66,8 @@ def __init__(
metrics=metrics,
learning_rate=learning_rate,
val_target_max_length=val_target_max_length,
- num_beams=num_beams
+ num_beams=num_beams,
+ enable_ort=enable_ort,
)
self.rouge = RougeMetric(
rouge_newline_sep=rouge_newline_sep,
@@ -82,7 +85,5 @@ def compute_metrics(self, generated_tokens: torch.Tensor, batch: Dict, prefix: s
@staticmethod
def _ci_benchmark_fn(history: List[Dict[str, Any]]):
- """
- This function is used only for debugging usage with CI
- """
+ """This function is used only for debugging usage with CI."""
assert history[-1]["rouge1_recall"] > 0.2
diff --git a/flash/text/seq2seq/translation/cli.py b/flash/text/seq2seq/translation/cli.py
new file mode 100644
index 0000000000..1609cb4de0
--- /dev/null
+++ b/flash/text/seq2seq/translation/cli.py
@@ -0,0 +1,59 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+from flash.core.data.utils import download_data
+from flash.core.utilities.flash_cli import FlashCLI
+from flash.text import TranslationData, TranslationTask
+
+__all__ = ["translation"]
+
+
+def from_wmt_en_ro(
+ backbone: str = "Helsinki-NLP/opus-mt-en-ro",
+ batch_size: int = 4,
+ num_workers: Optional[int] = None,
+ **preprocess_kwargs,
+) -> TranslationData:
+ """Downloads and loads the WMT EN RO data set."""
+ download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "./data")
+ return TranslationData.from_csv(
+ "input",
+ "target",
+ train_file="data/wmt_en_ro/train.csv",
+ val_file="data/wmt_en_ro/valid.csv",
+ backbone=backbone,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **preprocess_kwargs,
+ )
+
+
+def translation():
+ """Translate text."""
+ cli = FlashCLI(
+ TranslationTask,
+ TranslationData,
+ default_datamodule_builder=from_wmt_en_ro,
+ default_arguments={
+ "trainer.max_epochs": 3,
+ "model.backbone": "Helsinki-NLP/opus-mt-en-ro",
+ },
+ )
+
+ cli.trainer.save_checkpoint("translation_model_en_ro.pt")
+
+
+if __name__ == "__main__":
+ translation()
diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py
index 0b9e7a3ce7..5485be1003 100644
--- a/flash/text/seq2seq/translation/data.py
+++ b/flash/text/seq2seq/translation/data.py
@@ -17,7 +17,6 @@
class TranslationPreprocess(Seq2SeqPreprocess):
-
def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
@@ -27,7 +26,7 @@ def __init__(
backbone: str = "t5-small",
max_source_length: int = 128,
max_target_length: int = 128,
- padding: Union[str, bool] = 'max_length'
+ padding: Union[str, bool] = "max_length",
):
super().__init__(
train_transform=train_transform,
diff --git a/flash/text/seq2seq/translation/metric.py b/flash/text/seq2seq/translation/metric.py
deleted file mode 100644
index bd3e4fe872..0000000000
--- a/flash/text/seq2seq/translation/metric.py
+++ /dev/null
@@ -1,121 +0,0 @@
-# Copyright The PyTorch Lightning team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# referenced from
-# Library Name: torchtext
-# Authors: torchtext authors and @sluks
-# Date: 2020-07-18
-# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score
-from collections import Counter
-from typing import List
-
-import torch
-from torch import tensor
-from torchmetrics import Metric
-
-
-def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter:
- """
- Counting how many times each word appears in a given text with ngram
- Args:
- ngram_input_list: A list of translated text or reference texts
- n_gram: gram value ranged 1 to 4
-
- Return:
- ngram_counter: a collections.Counter object of ngram
- """
-
- ngram_counter = Counter()
-
- for i in range(1, n_gram + 1):
- for j in range(len(ngram_input_list) - i + 1):
- ngram_key = tuple(ngram_input_list[j:(i + j)])
- ngram_counter[ngram_key] += 1
-
- return ngram_counter
-
-
-class BLEUScore(Metric):
- """
- Calculate BLEU score of machine translated text with one or more references.
-
- Example:
- >>> translate_corpus = ['the cat is on the mat'.split()]
- >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
- >>> metric = BLEUScore()
- >>> metric(translate_corpus, reference_corpus)
- tensor(0.7598)
- """
-
- def __init__(self, n_gram: int = 4, smooth: bool = False):
- """
- Args:
- n_gram: Gram value ranged from 1 to 4 (Default 4)
- smooth: Whether or not to apply smoothing – Lin et al. 2004
- """
- super().__init__()
- self.n_gram = n_gram
- self.smooth = smooth
-
- self.add_state("c", tensor(0, dtype=torch.float), dist_reduce_fx="sum")
- self.add_state("r", tensor(0, dtype=torch.float), dist_reduce_fx="sum")
- self.add_state("numerator", torch.zeros(self.n_gram), dist_reduce_fx="sum")
- self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum")
-
- def compute(self):
-
- trans_len = self.c.clone().detach()
- ref_len = self.r.clone().detach()
-
- if min(self.numerator) == 0.0:
- return tensor(0.0, device=self.r.device)
-
- if self.smooth:
- precision_scores = (self.numerator + 1.0) / (self.denominator + 1.0)
- else:
- precision_scores = self.numerator / self.denominator
-
- log_precision_scores = tensor([1.0 / self.n_gram] * self.n_gram,
- device=self.r.device) * torch.log(precision_scores)
- geometric_mean = torch.exp(torch.sum(log_precision_scores))
- brevity_penalty = (
- tensor(1.0, device=self.r.device) if self.c > self.r else torch.exp(1 - (ref_len / trans_len))
- )
- bleu = brevity_penalty * geometric_mean
- return bleu
-
- def update(self, translate_corpus, reference_corpus) -> None:
- """
- Actual metric computation
- Args:
- translate_corpus: An iterable of machine translated corpus
- reference_corpus: An iterable of iterables of reference corpus
- """
- for (translation, references) in zip(translate_corpus, reference_corpus):
- self.c += len(translation)
- ref_len_list = [len(ref) for ref in references]
- ref_len_diff = [abs(len(translation) - x) for x in ref_len_list]
- self.r += ref_len_list[ref_len_diff.index(min(ref_len_diff))]
- translation_counter = _count_ngram(translation, self.n_gram)
- reference_counter = Counter()
-
- for ref in references:
- reference_counter |= _count_ngram(ref, self.n_gram)
-
- ngram_counter_clip = translation_counter & reference_counter
-
- for counter_clip in ngram_counter_clip:
- self.numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip]
-
- for counter in translation_counter:
- self.denominator[len(counter) - 1] += translation_counter[counter]
diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py
index a9ac0a6a31..c70089e8d6 100644
--- a/flash/text/seq2seq/translation/model.py
+++ b/flash/text/seq2seq/translation/model.py
@@ -16,8 +16,8 @@
import torch
from torchmetrics import Metric
+from flash.text.seq2seq.core.metrics import BLEUScore
from flash.text.seq2seq.core.model import Seq2SeqTask
-from flash.text.seq2seq.translation.metric import BLEUScore
class TranslationTask(Seq2SeqTask):
@@ -42,6 +42,7 @@ class TranslationTask(Seq2SeqTask):
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
n_gram: Maximum n_grams to use in metric calculation. Defaults to `4`
smooth: Apply smoothing in BLEU calculation. Defaults to `True`
+ enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
"""
def __init__(
@@ -55,6 +56,7 @@ def __init__(
num_beams: Optional[int] = 4,
n_gram: bool = 4,
smooth: bool = True,
+ enable_ort: bool = False,
):
self.save_hyperparameters()
super().__init__(
@@ -65,6 +67,7 @@ def __init__(
learning_rate=learning_rate,
val_target_max_length=val_target_max_length,
num_beams=num_beams,
+ enable_ort=enable_ort,
)
self.bleu = BLEUScore(
n_gram=n_gram,
@@ -84,7 +87,5 @@ def compute_metrics(self, generated_tokens, batch, prefix):
@staticmethod
def _ci_benchmark_fn(history: List[Dict[str, Any]]):
- """
- This function is used only for debugging usage with CI
- """
+ """This function is used only for debugging usage with CI."""
assert history[-1]["val_bleu_score"] > 0.6
diff --git a/flash/video/classification/cli.py b/flash/video/classification/cli.py
new file mode 100644
index 0000000000..840386506b
--- /dev/null
+++ b/flash/video/classification/cli.py
@@ -0,0 +1,61 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from typing import Optional
+
+from flash.core.data.utils import download_data
+from flash.core.utilities.flash_cli import FlashCLI
+from flash.video import VideoClassificationData, VideoClassifier
+
+__all__ = ["video_classification"]
+
+
+def from_kinetics(
+ clip_sampler: str = "uniform",
+ clip_duration: int = 1,
+ decode_audio: bool = False,
+ batch_size: int = 4,
+ num_workers: Optional[int] = None,
+ **preprocess_kwargs,
+) -> VideoClassificationData:
+ """Downloads and loads the Kinetics data set."""
+ download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip", "./data")
+ return VideoClassificationData.from_folders(
+ train_folder=os.path.join(os.getcwd(), "data/kinetics/train"),
+ val_folder=os.path.join(os.getcwd(), "data/kinetics/val"),
+ clip_sampler=clip_sampler,
+ clip_duration=clip_duration,
+ decode_audio=decode_audio,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **preprocess_kwargs,
+ )
+
+
+def video_classification():
+ """Classify videos."""
+ cli = FlashCLI(
+ VideoClassifier,
+ VideoClassificationData,
+ default_datamodule_builder=from_kinetics,
+ default_arguments={
+ "trainer.max_epochs": 3,
+ },
+ )
+
+ cli.trainer.save_checkpoint("video_classification.pt")
+
+
+if __name__ == "__main__":
+ video_classification()
diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py
index db9cddc6f8..a57de670d7 100644
--- a/flash/video/classification/data.py
+++ b/flash/video/classification/data.py
@@ -25,8 +25,8 @@
DefaultDataSources,
FiftyOneDataSource,
LabelsState,
+ LabelStudioVideoDataSource,
PathsDataSource,
- LabelStudioVideoDataSource
)
from flash.core.data.process import Preprocess
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _KORNIA_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, lazy_import
@@ -45,21 +45,20 @@
if _PYTORCHVIDEO_AVAILABLE:
from pytorchvideo.data.clip_sampling import ClipSampler, make_clip_sampler
from pytorchvideo.data.encoded_video import EncodedVideo
- from pytorchvideo.data.encoded_video_dataset import EncodedVideoDataset, labeled_encoded_video_dataset
+ from pytorchvideo.data.labeled_video_dataset import labeled_video_dataset, LabeledVideoDataset
from pytorchvideo.data.labeled_video_paths import LabeledVideoPaths
from pytorchvideo.transforms import ApplyTransformToKey, UniformTemporalSubsample
from torchvision.transforms import CenterCrop, Compose, RandomCrop, RandomHorizontalFlip
else:
- ClipSampler, EncodedVideoDataset, EncodedVideo, ApplyTransformToKey = None, None, None, None
+ ClipSampler, LabeledVideoDataset, EncodedVideo, ApplyTransformToKey = None, None, None, None
_PYTORCHVIDEO_DATA = Dict[str, Union[str, torch.Tensor, int, float, List]]
-class BaseVideoClassification(object):
-
+class BaseVideoClassification:
def __init__(
self,
- clip_sampler: 'ClipSampler',
+ clip_sampler: "ClipSampler",
video_sampler: Type[Sampler] = torch.utils.data.RandomSampler,
decode_audio: bool = True,
decoder: str = "pyav",
@@ -69,28 +68,31 @@ def __init__(
self.decode_audio = decode_audio
self.decoder = decoder
- def load_data(self, data: str, dataset: Optional[Any] = None) -> 'EncodedVideoDataset':
+ def load_data(self, data: str, dataset: Optional[Any] = None) -> "LabeledVideoDataset":
ds = self._make_encoded_video_dataset(data)
if self.training:
label_to_class_mapping = {p[1]: p[0].split("/")[-2] for p in ds._labeled_videos._paths_and_labels}
self.set_state(LabelsState(label_to_class_mapping))
- dataset.num_classes = len(np.unique([s[1]['label'] for s in ds._labeled_videos]))
+ dataset.num_classes = len(np.unique([s[1]["label"] for s in ds._labeled_videos]))
return ds
+ def load_sample(self, sample):
+ return sample
+
def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
video_path = sample[DefaultDataKeys.INPUT]
sample.update(self._encoded_video_to_dict(EncodedVideo.from_path(video_path)))
sample[DefaultDataKeys.METADATA] = {"filepath": video_path}
return sample
- def _encoded_video_to_dict(self, video) -> Dict[str, Any]:
+ def _encoded_video_to_dict(self, video, annotation: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
(
clip_start,
clip_end,
clip_index,
aug_index,
is_last_clip,
- ) = self.clip_sampler(0.0, video.duration)
+ ) = self.clip_sampler(0.0, video.duration, annotation)
loaded_clip = video.get_clip(clip_start, clip_end)
@@ -111,20 +113,17 @@ def _encoded_video_to_dict(self, video) -> Dict[str, Any]:
"video_index": 0,
"clip_index": clip_index,
"aug_index": aug_index,
- **({
- "audio": audio_samples
- } if audio_samples is not None else {}),
+ **({"audio": audio_samples} if audio_samples is not None else {}),
}
- def _make_encoded_video_dataset(self, data) -> 'EncodedVideoDataset':
+ def _make_encoded_video_dataset(self, data) -> "LabeledVideoDataset":
raise NotImplementedError("Subclass must implement _make_encoded_video_dataset()")
class VideoClassificationPathsDataSource(BaseVideoClassification, PathsDataSource):
-
def __init__(
self,
- clip_sampler: 'ClipSampler',
+ clip_sampler: "ClipSampler",
video_sampler: Type[Sampler] = torch.utils.data.RandomSampler,
decode_audio: bool = True,
decoder: str = "pyav",
@@ -140,8 +139,8 @@ def __init__(
extensions=("mp4", "avi"),
)
- def _make_encoded_video_dataset(self, data) -> 'EncodedVideoDataset':
- ds: EncodedVideoDataset = labeled_encoded_video_dataset(
+ def _make_encoded_video_dataset(self, data) -> "LabeledVideoDataset":
+ ds: LabeledVideoDataset = labeled_video_dataset(
pathlib.Path(data),
self.clip_sampler,
video_sampler=self.video_sampler,
@@ -155,10 +154,9 @@ class VideoClassificationFiftyOneDataSource(
BaseVideoClassification,
FiftyOneDataSource,
):
-
def __init__(
self,
- clip_sampler: 'ClipSampler',
+ clip_sampler: "ClipSampler",
video_sampler: Type[Sampler] = torch.utils.data.RandomSampler,
decode_audio: bool = True,
decoder: str = "pyav",
@@ -179,7 +177,7 @@ def __init__(
def label_cls(self):
return fol.Classification
- def _make_encoded_video_dataset(self, data: SampleCollection) -> 'EncodedVideoDataset':
+ def _make_encoded_video_dataset(self, data: SampleCollection) -> "LabeledVideoDataset":
classes = self._get_classes(data)
label_to_class_mapping = dict(enumerate(classes))
class_to_label_mapping = {c: lab for lab, c in label_to_class_mapping.items()}
@@ -189,7 +187,7 @@ def _make_encoded_video_dataset(self, data: SampleCollection) -> 'EncodedVideoDa
targets = [class_to_label_mapping[lab] for lab in labels]
labeled_video_paths = LabeledVideoPaths(list(zip(filepaths, targets)))
- ds: EncodedVideoDataset = EncodedVideoDataset(
+ ds: LabeledVideoDataset = LabeledVideoDataset(
labeled_video_paths,
self.clip_sampler,
video_sampler=self.video_sampler,
@@ -200,14 +198,13 @@ def _make_encoded_video_dataset(self, data: SampleCollection) -> 'EncodedVideoDa
class VideoClassificationPreprocess(Preprocess):
-
def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
- clip_sampler: Union[str, 'ClipSampler'] = "random",
+ clip_sampler: Union[str, "ClipSampler"] = "random",
clip_duration: float = 2,
clip_sampler_kwargs: Dict[str, Any] = None,
video_sampler: Type[Sampler] = torch.utils.data.RandomSampler,
@@ -266,7 +263,7 @@ def __init__(
decode_audio=decode_audio,
decoder=decoder,
**data_source_kwargs,
- )
+ ),
},
default_data_source=DefaultDataSources.FILES,
)
@@ -283,7 +280,7 @@ def get_state_dict(self) -> Dict[str, Any]:
}
@classmethod
- def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool) -> 'VideoClassificationPreprocess':
+ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool) -> "VideoClassificationPreprocess":
return cls(**state_dict)
def default_transforms(self) -> Dict[str, Callable]:
@@ -298,22 +295,26 @@ def default_transforms(self) -> Dict[str, Callable]:
]
return {
- "post_tensor_transform": Compose([
- ApplyTransformToKey(
- key="video",
- transform=Compose([UniformTemporalSubsample(8)] + post_tensor_transform),
- ),
- ]),
- "per_batch_transform_on_device": Compose([
- ApplyTransformToKey(
- key="video",
- transform=K.VideoSequential(
- K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])),
- data_format="BCTHW",
- same_on_frame=False
- )
- ),
- ]),
+ "post_tensor_transform": Compose(
+ [
+ ApplyTransformToKey(
+ key="video",
+ transform=Compose([UniformTemporalSubsample(8)] + post_tensor_transform),
+ ),
+ ]
+ ),
+ "per_batch_transform_on_device": Compose(
+ [
+ ApplyTransformToKey(
+ key="video",
+ transform=K.VideoSequential(
+ K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])),
+ data_format="BCTHW",
+ same_on_frame=False,
+ ),
+ ),
+ ]
+ ),
}
diff --git a/flash/video/classification/model.py b/flash/video/classification/model.py
index 8e05069a2b..9345b7b19b 100644
--- a/flash/video/classification/model.py
+++ b/flash/video/classification/model.py
@@ -31,20 +31,21 @@
from flash.core.data.process import Serializer
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE
+from flash.core.utilities.providers import _PYTORCHVIDEO
_VIDEO_CLASSIFIER_BACKBONES = FlashRegistry("backbones")
if _PYTORCHVIDEO_AVAILABLE:
from pytorchvideo.models import hub
+
for fn_name in dir(hub):
if "__" not in fn_name:
fn = getattr(hub, fn_name)
if isinstance(fn, FunctionType):
- _VIDEO_CLASSIFIER_BACKBONES(fn=fn)
+ _VIDEO_CLASSIFIER_BACKBONES(fn=fn, providers=_PYTORCHVIDEO)
class VideoClassifierFinetuning(BaseFinetuning):
-
def __init__(self, num_layers: int = 5, train_bn: bool = True, unfreeze_epoch: int = 1):
super().__init__()
self.num_layers = num_layers
@@ -52,7 +53,7 @@ def __init__(self, num_layers: int = 5, train_bn: bool = True, unfreeze_epoch: i
self.unfreeze_epoch = unfreeze_epoch
def freeze_before_training(self, pl_module: LightningModule) -> None:
- self.freeze(modules=list(pl_module.backbone.children())[:-self.num_layers], train_bn=self.train_bn)
+ self.freeze(modules=list(pl_module.backbone.children())[: -self.num_layers], train_bn=self.train_bn)
def finetune_function(
self,
@@ -64,7 +65,7 @@ def finetune_function(
if epoch != self.unfreeze_epoch:
return
self.unfreeze_and_add_param_group(
- modules=list(pl_module.backbone.children())[-self.num_layers:],
+ modules=list(pl_module.backbone.children())[-self.num_layers :],
optimizer=optimizer,
train_bn=self.train_bn,
)
@@ -94,7 +95,7 @@ class VideoClassifier(ClassificationTask):
def __init__(
self,
num_classes: int,
- backbone: Union[str, nn.Module] = "slow_r50",
+ backbone: Union[str, nn.Module] = "x3d_xs",
backbone_kwargs: Optional[Dict] = None,
pretrained: bool = True,
loss_fn: Callable = F.cross_entropy,
@@ -110,7 +111,7 @@ def __init__(
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
- serializer=serializer or Labels()
+ serializer=serializer or Labels(),
)
self.save_hyperparameters()
@@ -146,8 +147,8 @@ def on_train_epoch_start(self) -> None:
encoded_dataset._video_sampler.set_epoch(self.trainer.current_epoch)
super().on_train_epoch_start()
- def step(self, batch: Any, batch_idx: int) -> Any:
- return super().step((batch["video"], batch["label"]), batch_idx)
+ def step(self, batch: Any, batch_idx: int, metrics) -> Any:
+ return super().step((batch["video"], batch["label"]), batch_idx, metrics)
def forward(self, x: Any) -> Any:
x = self.backbone(x)
@@ -165,7 +166,5 @@ def configure_finetune_callback(self) -> List[Callback]:
@staticmethod
def _ci_benchmark_fn(history: List[Dict[str, Any]]):
- """
- This function is used only for debugging usage with CI
- """
- assert history[-1]["val_accuracy"] > 0.80
+ """This function is used only for debugging usage with CI."""
+ assert history[-1]["val_accuracy"] > 0.70
diff --git a/flash_examples/audio_classification.py b/flash_examples/audio_classification.py
new file mode 100644
index 0000000000..6dea056c18
--- /dev/null
+++ b/flash_examples/audio_classification.py
@@ -0,0 +1,49 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+import flash
+from flash.audio import AudioClassificationData
+from flash.core.data.utils import download_data
+from flash.core.finetuning import FreezeUnfreeze
+from flash.image import ImageClassifier
+
+# 1. Create the DataModule
+download_data("https://pl-flash-data.s3.amazonaws.com/urban8k_images.zip", "./data")
+
+datamodule = AudioClassificationData.from_folders(
+ train_folder="data/urban8k_images/train",
+ val_folder="data/urban8k_images/val",
+ spectrogram_size=(64, 64),
+)
+
+# 2. Build the model.
+model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
+
+# 3. Create the trainer and finetune the model
+trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
+trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))
+
+# 4. Predict what's on few images! air_conditioner, children_playing, siren e.t.c
+predictions = model.predict(
+ [
+ "data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg",
+ "data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg",
+ "data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg",
+ ]
+)
+print(predictions)
+
+# 5. Save the model!
+trainer.save_checkpoint("audio_classification_model.pt")
diff --git a/flash_examples/custom_task.py b/flash_examples/custom_task.py
index 2ab29f6526..15cc3b9fc7 100644
--- a/flash_examples/custom_task.py
+++ b/flash_examples/custom_task.py
@@ -35,7 +35,6 @@
class RegressionTask(flash.Task):
-
def __init__(self, num_inputs, learning_rate=0.2, metrics=None):
# what kind of model do we want?
model = nn.Linear(num_inputs, 1)
@@ -85,7 +84,6 @@ def forward(self, x):
class NumpyDataSource(DataSource[Tuple[ND, ND]]):
-
def load_data(self, data: Tuple[ND, ND], dataset: Optional[Any] = None) -> List[Dict[str, Any]]:
if self.training:
dataset.num_inputs = data[0].shape[1]
@@ -97,7 +95,6 @@ def predict_load_data(data: ND) -> List[Dict[str, Any]]:
class NumpyPreprocess(Preprocess):
-
def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
@@ -160,16 +157,20 @@ class NumpyDataModule(flash.DataModule):
datamodule = NumpyDataModule.from_numpy(x, y)
model = RegressionTask(num_inputs=datamodule.train_dataset.num_inputs)
-trainer = flash.Trainer(max_epochs=20, progress_bar_refresh_rate=20, checkpoint_callback=False)
+trainer = flash.Trainer(
+ max_epochs=20, progress_bar_refresh_rate=20, checkpoint_callback=False, gpus=torch.cuda.device_count()
+)
trainer.fit(model, datamodule=datamodule)
-predict_data = np.array([
- [0.0199, 0.0507, 0.1048, 0.0701, -0.0360, -0.0267, -0.0250, -0.0026, 0.0037, 0.0403],
- [-0.0128, -0.0446, 0.0606, 0.0529, 0.0480, 0.0294, -0.0176, 0.0343, 0.0702, 0.0072],
- [0.0381, 0.0507, 0.0089, 0.0425, -0.0428, -0.0210, -0.0397, -0.0026, -0.0181, 0.0072],
- [-0.0128, -0.0446, -0.0235, -0.0401, -0.0167, 0.0046, -0.0176, -0.0026, -0.0385, -0.0384],
- [-0.0237, -0.0446, 0.0455, 0.0907, -0.0181, -0.0354, 0.0707, -0.0395, -0.0345, -0.0094],
-])
+predict_data = np.array(
+ [
+ [0.0199, 0.0507, 0.1048, 0.0701, -0.0360, -0.0267, -0.0250, -0.0026, 0.0037, 0.0403],
+ [-0.0128, -0.0446, 0.0606, 0.0529, 0.0480, 0.0294, -0.0176, 0.0343, 0.0702, 0.0072],
+ [0.0381, 0.0507, 0.0089, 0.0425, -0.0428, -0.0210, -0.0397, -0.0026, -0.0181, 0.0072],
+ [-0.0128, -0.0446, -0.0235, -0.0401, -0.0167, 0.0046, -0.0176, -0.0026, -0.0385, -0.0384],
+ [-0.0237, -0.0446, 0.0455, 0.0907, -0.0181, -0.0354, 0.0707, -0.0395, -0.0345, -0.0094],
+ ]
+)
predictions = model.predict(predict_data)
print(predictions)
diff --git a/flash_examples/graph_classification.py b/flash_examples/graph_classification.py
new file mode 100644
index 0000000000..4519f70c33
--- /dev/null
+++ b/flash_examples/graph_classification.py
@@ -0,0 +1,44 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+import flash
+from flash.core.utilities.imports import example_requires
+from flash.graph import GraphClassificationData, GraphClassifier
+
+example_requires("graph")
+
+from torch_geometric.datasets import TUDataset # noqa: E402
+
+# 1. Create the DataModule
+dataset = TUDataset(root="data", name="KKI")
+
+datamodule = GraphClassificationData.from_datasets(
+ train_dataset=dataset,
+ val_split=0.1,
+)
+
+# 2. Build the task
+model = GraphClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)
+
+# 3. Create the trainer and fit the model
+trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
+trainer.fit(model, datamodule=datamodule)
+
+# 4. Classify some graphs!
+predictions = model.predict(dataset[:3])
+print(predictions)
+
+# 5. Save the model!
+trainer.save_checkpoint("graph_classification.pt")
diff --git a/flash_examples/image_classification.py b/flash_examples/image_classification.py
index a675938c57..3b9413a629 100644
--- a/flash_examples/image_classification.py
+++ b/flash_examples/image_classification.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import torch
+
import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
@@ -27,15 +29,17 @@
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
# 3. Create the trainer and finetune the model
-trainer = flash.Trainer(max_epochs=3)
+trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Predict what's on a few images! ants or bees?
-predictions = model.predict([
- "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
- "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
- "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
-])
+predictions = model.predict(
+ [
+ "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
+ "data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
+ "data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
+ ]
+)
print(predictions)
# 5. Save the model!
diff --git a/flash_examples/image_classification_multi_label.py b/flash_examples/image_classification_multi_label.py
index 00e86d7f0b..947446a9c0 100644
--- a/flash_examples/image_classification_multi_label.py
+++ b/flash_examples/image_classification_multi_label.py
@@ -11,10 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import os.path as osp
-from typing import List, Tuple
-
-import pandas as pd
+import torch
import flash
from flash.core.data.utils import download_data
@@ -24,37 +21,31 @@
# Data set from the paper “Movie Genre Classification based on Poster Images with Deep Neural Networks”.
# More info here: https://www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/
download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip")
-genres = ["Action", "Romance", "Crime", "Thriller", "Adventure"]
-
-
-def load_data(data: str, root: str = 'data/movie_posters') -> Tuple[List[str], List[List[int]]]:
- metadata = pd.read_csv(osp.join(root, data, "metadata.csv"))
- return ([osp.join(root, data, row['Id'] + ".jpg") for _, row in metadata.iterrows()],
- [[int(row[genre]) for genre in genres] for _, row in metadata.iterrows()])
-
-train_files, train_targets = load_data('train')
-datamodule = ImageClassificationData.from_files(
- train_files=train_files,
- train_targets=train_targets,
- val_split=0.1,
+datamodule = ImageClassificationData.from_csv(
+ "Id",
+ ["Action", "Romance", "Crime", "Thriller", "Adventure"],
+ train_file="data/movie_posters/train/metadata.csv",
+ val_file="data/movie_posters/val/metadata.csv",
image_size=(128, 128),
)
# 2. Build the task
-model = ImageClassifier(backbone="resnet18", num_classes=len(genres), multi_label=True)
+model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, multi_label=True)
# 3. Create the trainer and finetune the model
-trainer = flash.Trainer(max_epochs=3)
+trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Predict the genre of a few movies!
-predictions = model.predict([
- "data/movie_posters/predict/tt0085318.jpg",
- "data/movie_posters/predict/tt0089461.jpg",
- "data/movie_posters/predict/tt0097179.jpg",
-])
+predictions = model.predict(
+ [
+ "data/movie_posters/predict/tt0085318.jpg",
+ "data/movie_posters/predict/tt0089461.jpg",
+ "data/movie_posters/predict/tt0097179.jpg",
+ ]
+)
print(predictions)
-# 7. Save the model!
+# 5. Save the model!
trainer.save_checkpoint("image_classification_multi_label_model.pt")
diff --git a/flash_examples/image_embedder.py b/flash_examples/image_embedder.py
index cd786472c3..5a4de94fcf 100644
--- a/flash_examples/image_embedder.py
+++ b/flash_examples/image_embedder.py
@@ -22,3 +22,4 @@
# 3. Generate an embedding from an image path.
embeddings = embedder.predict(["data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg"])
+print(embeddings)
diff --git a/flash_examples/instance_segmentation.py b/flash_examples/instance_segmentation.py
new file mode 100644
index 0000000000..3fdc4e8a4b
--- /dev/null
+++ b/flash_examples/instance_segmentation.py
@@ -0,0 +1,55 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from functools import partial
+
+import flash
+from flash.core.utilities.imports import example_requires
+from flash.image import InstanceSegmentation, InstanceSegmentationData
+
+example_requires("image")
+
+import icedata # noqa: E402
+
+# 1. Create the DataModule
+data_dir = icedata.pets.load_data()
+
+datamodule = InstanceSegmentationData.from_folders(
+ train_folder=data_dir,
+ val_split=0.1,
+ parser=partial(icedata.pets.parser, mask=True),
+)
+
+# 2. Build the task
+model = InstanceSegmentation(
+ head="mask_rcnn",
+ backbone="resnet18_fpn",
+ num_classes=datamodule.num_classes,
+)
+
+# 3. Create the trainer and finetune the model
+trainer = flash.Trainer(max_epochs=1)
+trainer.finetune(model, datamodule=datamodule, strategy="freeze")
+
+# 4. Detect objects in a few images!
+predictions = model.predict(
+ [
+ str(data_dir / "images/yorkshire_terrier_9.jpg"),
+ str(data_dir / "images/english_cocker_spaniel_1.jpg"),
+ str(data_dir / "images/scottish_terrier_1.jpg"),
+ ]
+)
+print(predictions)
+
+# 5. Save the model!
+trainer.save_checkpoint("instance_segmentation_model.pt")
diff --git a/flash_examples/integrations/fiftyone/image_classification.py b/flash_examples/integrations/fiftyone/image_classification.py
index ebf40df56c..b1f5fb56cf 100644
--- a/flash_examples/integrations/fiftyone/image_classification.py
+++ b/flash_examples/integrations/fiftyone/image_classification.py
@@ -13,6 +13,8 @@
# limitations under the License.
from itertools import chain
+import torch
+
import flash
from flash.core.classification import FiftyOneLabels, Labels
from flash.core.data.utils import download_data
@@ -39,6 +41,7 @@
)
trainer = flash.Trainer(
max_epochs=1,
+ gpus=torch.cuda.device_count(),
limit_train_batches=1,
limit_val_batches=1,
)
diff --git a/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py b/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py
index 5ec81bdf6f..9ef31609d5 100644
--- a/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py
+++ b/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py
@@ -14,6 +14,7 @@
from itertools import chain
import fiftyone as fo
+import torch
import flash
from flash.core.classification import FiftyOneLabels, Labels
@@ -53,6 +54,7 @@
)
trainer = flash.Trainer(
max_epochs=1,
+ gpus=torch.cuda.device_count(),
limit_train_batches=1,
limit_val_batches=1,
)
diff --git a/flash_examples/integrations/fiftyone/image_embedding.py b/flash_examples/integrations/fiftyone/image_embedding.py
index b9d1651ceb..019bd9cffe 100644
--- a/flash_examples/integrations/fiftyone/image_embedding.py
+++ b/flash_examples/integrations/fiftyone/image_embedding.py
@@ -28,7 +28,7 @@
)
# 3 Load model
-embedder = ImageEmbedder(backbone="resnet101", embedding_dim=128)
+embedder = ImageEmbedder(backbone="resnet101")
# 4 Generate embeddings
filepaths = dataset.values("filepath")
diff --git a/flash_examples/integrations/labelstudio/image_classification.py b/flash_examples/integrations/labelstudio/image_classification.py
index 1ef31c6b2e..12d9df7952 100644
--- a/flash_examples/integrations/labelstudio/image_classification.py
+++ b/flash_examples/integrations/labelstudio/image_classification.py
@@ -9,8 +9,8 @@
# 1. Load export data
datamodule = ImageClassificationData.from_labelstudio(
- export_json='data/project.json',
- data_folder='data/upload/',
+ export_json="data/project.json",
+ data_folder="data/upload/",
val_split=0.8,
)
diff --git a/flash_examples/integrations/labelstudio/text_classification.py b/flash_examples/integrations/labelstudio/text_classification.py
index 4d4f260991..88a315d535 100644
--- a/flash_examples/integrations/labelstudio/text_classification.py
+++ b/flash_examples/integrations/labelstudio/text_classification.py
@@ -8,7 +8,7 @@
backbone = "prajjwal1/bert-medium"
datamodule = TextClassificationData.from_labelstudio(
- export_json='data/project.json',
+ export_json="data/project.json",
val_split=0.8,
backbone=backbone,
)
diff --git a/flash_examples/integrations/labelstudio/video_classification.py b/flash_examples/integrations/labelstudio/video_classification.py
index 26315fe4a9..af4c206590 100644
--- a/flash_examples/integrations/labelstudio/video_classification.py
+++ b/flash_examples/integrations/labelstudio/video_classification.py
@@ -9,8 +9,8 @@
# 1. Load export data
datamodule = VideoClassificationData.from_labelstudio(
- export_json='data/project.json',
- data_folder='data/upload/',
+ export_json="data/project.json",
+ data_folder="data/upload/",
val_split=0.8,
clip_sampler="uniform",
clip_duration=1,
diff --git a/flash_examples/keypoint_detection.py b/flash_examples/keypoint_detection.py
new file mode 100644
index 0000000000..b1fa29cc02
--- /dev/null
+++ b/flash_examples/keypoint_detection.py
@@ -0,0 +1,54 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import flash
+from flash.core.utilities.imports import example_requires
+from flash.image import KeypointDetectionData, KeypointDetector
+
+example_requires("image")
+
+import icedata # noqa: E402
+
+# 1. Create the DataModule
+data_dir = icedata.biwi.load_data()
+
+datamodule = KeypointDetectionData.from_folders(
+ train_folder=data_dir,
+ val_split=0.1,
+ parser=icedata.biwi.parser,
+)
+
+# 2. Build the task
+model = KeypointDetector(
+ head="keypoint_rcnn",
+ backbone="resnet18_fpn",
+ num_keypoints=1,
+ num_classes=datamodule.num_classes,
+)
+
+# 3. Create the trainer and finetune the model
+trainer = flash.Trainer(max_epochs=1)
+trainer.finetune(model, datamodule=datamodule, strategy="freeze")
+
+# 4. Detect objects in a few images!
+predictions = model.predict(
+ [
+ str(data_dir / "biwi_sample/images/0.jpg"),
+ str(data_dir / "biwi_sample/images/1.jpg"),
+ str(data_dir / "biwi_sample/images/10.jpg"),
+ ]
+)
+print(predictions)
+
+# 5. Save the model!
+trainer.save_checkpoint("keypoint_detection_model.pt")
diff --git a/flash_examples/object_detection.py b/flash_examples/object_detection.py
index 4f488e1e11..1a5dddbce9 100644
--- a/flash_examples/object_detection.py
+++ b/flash_examples/object_detection.py
@@ -17,27 +17,30 @@
# 1. Create the DataModule
# Dataset Credit: https://www.kaggle.com/ultralytics/coco128
-download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "finetuning/data/")
+download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/")
datamodule = ObjectDetectionData.from_coco(
train_folder="data/coco128/images/train2017/",
- train_ann_file="finetuning/data/coco128/annotations/instances_train2017.json",
+ train_ann_file="data/coco128/annotations/instances_train2017.json",
val_split=0.1,
+ image_size=128,
)
# 2. Build the task
-model = ObjectDetector(model="retinanet", num_classes=datamodule.num_classes)
+model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=datamodule.num_classes, image_size=128)
# 3. Create the trainer and finetune the model
-trainer = flash.Trainer(max_epochs=3)
-trainer.finetune(model, datamodule=datamodule)
+trainer = flash.Trainer(max_epochs=1)
+trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Detect objects in a few images!
-predictions = model.predict([
- "data/coco128/images/train2017/000000000625.jpg",
- "data/coco128/images/train2017/000000000626.jpg",
- "data/coco128/images/train2017/000000000629.jpg",
-])
+predictions = model.predict(
+ [
+ "data/coco128/images/train2017/000000000625.jpg",
+ "data/coco128/images/train2017/000000000626.jpg",
+ "data/coco128/images/train2017/000000000629.jpg",
+ ]
+)
print(predictions)
# 5. Save the model!
diff --git a/flash_examples/pointcloud_detection.py b/flash_examples/pointcloud_detection.py
new file mode 100644
index 0000000000..ff29265355
--- /dev/null
+++ b/flash_examples/pointcloud_detection.py
@@ -0,0 +1,47 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+import flash
+from flash.core.data.utils import download_data
+from flash.pointcloud import PointCloudObjectDetector, PointCloudObjectDetectorData
+
+# 1. Create the DataModule
+# Dataset Credit: http://www.semantic-kitti.org/
+download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_tiny.zip", "data/")
+
+datamodule = PointCloudObjectDetectorData.from_folders(
+ train_folder="data/KITTI_Tiny/Kitti/train",
+ val_folder="data/KITTI_Tiny/Kitti/val",
+)
+
+# 2. Build the task
+model = PointCloudObjectDetector(backbone="pointpillars_kitti", num_classes=datamodule.num_classes)
+
+# 3. Create the trainer and finetune the model
+trainer = flash.Trainer(
+ max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, gpus=torch.cuda.device_count()
+)
+trainer.fit(model, datamodule)
+
+# 4. Predict what's within a few PointClouds?
+predictions = model.predict(
+ [
+ "data/KITTI_Tiny/Kitti/predict/scans/000000.bin",
+ "data/KITTI_Tiny/Kitti/predict/scans/000001.bin",
+ ]
+)
+
+# 5. Save the model!
+trainer.save_checkpoint("pointcloud_detection_model.pt")
diff --git a/flash_examples/pointcloud_segmentation.py b/flash_examples/pointcloud_segmentation.py
new file mode 100644
index 0000000000..7d1a0eb538
--- /dev/null
+++ b/flash_examples/pointcloud_segmentation.py
@@ -0,0 +1,47 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+import flash
+from flash.core.data.utils import download_data
+from flash.pointcloud import PointCloudSegmentation, PointCloudSegmentationData
+
+# 1. Create the DataModule
+# Dataset Credit: http://www.semantic-kitti.org/
+download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiTiny.zip", "data/")
+
+datamodule = PointCloudSegmentationData.from_folders(
+ train_folder="data/SemanticKittiTiny/train",
+ val_folder="data/SemanticKittiTiny/val",
+)
+
+# 2. Build the task
+model = PointCloudSegmentation(backbone="randlanet_semantic_kitti", num_classes=datamodule.num_classes)
+
+# 3. Create the trainer and finetune the model
+trainer = flash.Trainer(
+ max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, gpus=torch.cuda.device_count()
+)
+trainer.fit(model, datamodule)
+
+# 4. Predict what's within a few PointClouds?
+predictions = model.predict(
+ [
+ "data/SemanticKittiTiny/predict/000000.bin",
+ "data/SemanticKittiTiny/predict/000001.bin",
+ ]
+)
+
+# 5. Save the model!
+trainer.save_checkpoint("pointcloud_segmentation_model.pt")
diff --git a/flash_examples/semantic_segmentation.py b/flash_examples/semantic_segmentation.py
index 83aa617c62..a3800f2508 100644
--- a/flash_examples/semantic_segmentation.py
+++ b/flash_examples/semantic_segmentation.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import torch
+
import flash
from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData
@@ -20,34 +22,36 @@
# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge
download_data(
"https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
- "./data"
+ "./data",
)
datamodule = SemanticSegmentationData.from_folders(
train_folder="data/CameraRGB",
train_target_folder="data/CameraSeg",
val_split=0.1,
- image_size=(200, 200),
+ image_size=(256, 256),
num_classes=21,
)
# 2. Build the task
model = SemanticSegmentation(
- backbone="mobilenet_v3_large",
- head="fcn",
+ backbone="mobilenetv3_large_100",
+ head="fpn",
num_classes=datamodule.num_classes,
)
# 3. Create the trainer and finetune the model
-trainer = flash.Trainer(max_epochs=3)
+trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Segment a few images!
-predictions = model.predict([
- "data/CameraRGB/F61-1.png",
- "data/CameraRGB/F62-1.png",
- "data/CameraRGB/F63-1.png",
-])
+predictions = model.predict(
+ [
+ "data/CameraRGB/F61-1.png",
+ "data/CameraRGB/F62-1.png",
+ "data/CameraRGB/F63-1.png",
+ ]
+)
print(predictions)
# 5. Save the model!
diff --git a/flash_examples/serve/generic/boston_prediction/inference_server.py b/flash_examples/serve/generic/boston_prediction/inference_server.py
index 1e1d958e9f..acd1735ae9 100644
--- a/flash_examples/serve/generic/boston_prediction/inference_server.py
+++ b/flash_examples/serve/generic/boston_prediction/inference_server.py
@@ -35,7 +35,6 @@
class PricePrediction(ModelComponent):
-
def __init__(self, model): # skipcq: PYL-W0621
self.model = model
diff --git a/flash_examples/serve/generic/detection/inference.py b/flash_examples/serve/generic/detection/inference.py
index 0971fb380c..813359a6dc 100644
--- a/flash_examples/serve/generic/detection/inference.py
+++ b/flash_examples/serve/generic/detection/inference.py
@@ -18,16 +18,12 @@
class ObjectDetection(ModelComponent):
-
def __init__(self, model):
self.model = model
@expose(
inputs={"img": Image()},
- outputs={
- "boxes": Repeated(BBox()),
- "labels": Repeated(Label("classes.txt"))
- },
+ outputs={"boxes": Repeated(BBox()), "labels": Repeated(Label("classes.txt"))},
)
def detect(self, img):
img = img.permute(0, 3, 2, 1).float() / 255
diff --git a/flash/image/detection/finetuning.py b/flash_examples/serve/speech_recognition/client.py
similarity index 55%
rename from flash/image/detection/finetuning.py
rename to flash_examples/serve/speech_recognition/client.py
index c1ca20072d..c855a37204 100644
--- a/flash/image/detection/finetuning.py
+++ b/flash_examples/serve/speech_recognition/client.py
@@ -11,19 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import pytorch_lightning as pl
+import base64
+from pathlib import Path
-from flash.core.finetuning import FlashBaseFinetuning
+import requests
+import flash
-class ObjectDetectionFineTuning(FlashBaseFinetuning):
- """
- Freezes the backbone during Detector training.
- """
+with (Path(flash.ASSETS_ROOT) / "example.wav").open("rb") as f:
+ audio_str = base64.b64encode(f.read()).decode("UTF-8")
- def __init__(self, train_bn: bool = True) -> None:
- super().__init__(train_bn=train_bn)
+body = {"session": "UUID", "payload": {"inputs": {"data": audio_str}}}
+resp = requests.post("http://127.0.0.1:8000/predict", json=body)
- def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
- model = pl_module.model
- self.freeze(modules=model.backbone, train_bn=self.train_bn)
+print(resp.json())
diff --git a/flash_examples/serve/speech_recognition/inference_server.py b/flash_examples/serve/speech_recognition/inference_server.py
new file mode 100644
index 0000000000..bbc4479624
--- /dev/null
+++ b/flash_examples/serve/speech_recognition/inference_server.py
@@ -0,0 +1,17 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from flash.audio import SpeechRecognition
+
+model = SpeechRecognition.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/speech_recognition_model.pt")
+model.serve()
diff --git a/flash_examples/serve/tabular_classification/inference_server.py b/flash_examples/serve/tabular_classification/inference_server.py
index f6aac866e2..4b58b8f691 100644
--- a/flash_examples/serve/tabular_classification/inference_server.py
+++ b/flash_examples/serve/tabular_classification/inference_server.py
@@ -15,5 +15,5 @@
from flash.tabular import TabularClassifier
model = TabularClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/tabular_classification_model.pt")
-model.serializer = Labels(['Did not survive', 'Survived'])
+model.serializer = Labels(["Did not survive", "Survived"])
model.serve()
diff --git a/flash_examples/speech_recognition.py b/flash_examples/speech_recognition.py
new file mode 100644
index 0000000000..1672dbe1fe
--- /dev/null
+++ b/flash_examples/speech_recognition.py
@@ -0,0 +1,42 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+import flash
+from flash.audio import SpeechRecognition, SpeechRecognitionData
+from flash.core.data.utils import download_data
+
+# # 1. Create the DataModule
+download_data("https://pl-flash-data.s3.amazonaws.com/timit_data.zip", "./data")
+
+datamodule = SpeechRecognitionData.from_json(
+ input_fields="file",
+ target_fields="text",
+ train_file="data/timit/train.json",
+ test_file="data/timit/test.json",
+)
+
+# 2. Build the task
+model = SpeechRecognition(backbone="facebook/wav2vec2-base-960h")
+
+# 3. Create the trainer and finetune the model
+trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
+trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")
+
+# 4. Predict on audio files!
+predictions = model.predict(["data/timit/example.wav"])
+print(predictions)
+
+# 5. Save the model!
+trainer.save_checkpoint("speech_recognition_model.pt")
diff --git a/flash_examples/style_transfer.py b/flash_examples/style_transfer.py
index 37500e9358..607f5ad0f6 100644
--- a/flash_examples/style_transfer.py
+++ b/flash_examples/style_transfer.py
@@ -13,6 +13,8 @@
# limitations under the License.
import os
+import torch
+
import flash
from flash.core.data.utils import download_data
from flash.image.style_transfer import StyleTransfer, StyleTransferData
@@ -26,15 +28,17 @@
model = StyleTransfer(os.path.join(flash.ASSETS_ROOT, "starry_night.jpg"))
# 3. Create the trainer and train the model
-trainer = flash.Trainer(max_epochs=3)
+trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)
# 4. Apply style transfer to a few images!
-predictions = model.predict([
- "data/coco128/images/train2017/000000000625.jpg",
- "data/coco128/images/train2017/000000000626.jpg",
- "data/coco128/images/train2017/000000000629.jpg",
-])
+predictions = model.predict(
+ [
+ "data/coco128/images/train2017/000000000625.jpg",
+ "data/coco128/images/train2017/000000000626.jpg",
+ "data/coco128/images/train2017/000000000629.jpg",
+ ]
+)
print(predictions)
# 5. Save the model!
diff --git a/flash_examples/tabular_classification.py b/flash_examples/tabular_classification.py
index fa3a2cc23e..ef80723afa 100644
--- a/flash_examples/tabular_classification.py
+++ b/flash_examples/tabular_classification.py
@@ -11,14 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import torch
+
import flash
from flash.core.data.utils import download_data
-from flash.tabular import TabularClassifier, TabularData
+from flash.tabular import TabularClassificationData, TabularClassifier
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data")
-datamodule = TabularData.from_csv(
+datamodule = TabularClassificationData.from_csv(
["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
"Fare",
target_fields="Survived",
@@ -30,7 +32,7 @@
model = TabularClassifier.from_data(datamodule)
# 3. Create the trainer and train the model
-trainer = flash.Trainer(max_epochs=3)
+trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)
# 4. Generate predictions from a CSV
diff --git a/flash_examples/template.py b/flash_examples/template.py
index 66ce579a83..0d8c7016ed 100644
--- a/flash_examples/template.py
+++ b/flash_examples/template.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
+import torch
from sklearn import datasets
import flash
@@ -27,15 +28,17 @@
model = TemplateSKLearnClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes)
# 3. Create the trainer and train the model
-trainer = flash.Trainer(max_epochs=3)
+trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)
# 4. Classify a few examples
-predictions = model.predict([
- np.array([4.9, 3.0, 1.4, 0.2]),
- np.array([6.9, 3.2, 5.7, 2.3]),
- np.array([7.2, 3.0, 5.8, 1.6]),
-])
+predictions = model.predict(
+ [
+ np.array([4.9, 3.0, 1.4, 0.2]),
+ np.array([6.9, 3.2, 5.7, 2.3]),
+ np.array([7.2, 3.0, 5.8, 1.6]),
+ ]
+)
print(predictions)
# 5. Save the model!
diff --git a/flash_examples/text_classification.py b/flash_examples/text_classification.py
index 1924d408de..3d62dbb0dc 100644
--- a/flash_examples/text_classification.py
+++ b/flash_examples/text_classification.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import torch
+
import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier
@@ -30,15 +32,17 @@
model = TextClassifier(backbone="prajjwal1/bert-medium", num_classes=datamodule.num_classes)
# 3. Create the trainer and finetune the model
-trainer = flash.Trainer(max_epochs=3)
+trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Classify a few sentences! How was the movie?
-predictions = model.predict([
- "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
- "The worst movie in the history of cinema.",
- "I come from Bulgaria where it 's almost impossible to have a tornado.",
-])
+predictions = model.predict(
+ [
+ "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
+ "The worst movie in the history of cinema.",
+ "I come from Bulgaria where it 's almost impossible to have a tornado.",
+ ]
+)
print(predictions)
# 5. Save the model!
diff --git a/flash_examples/text_classification_multi_label.py b/flash_examples/text_classification_multi_label.py
index 57222bf560..72f87b7c81 100644
--- a/flash_examples/text_classification_multi_label.py
+++ b/flash_examples/text_classification_multi_label.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import torch
+
import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier
@@ -36,15 +38,17 @@
)
# 3. Create the trainer and finetune the model
-trainer = flash.Trainer(max_epochs=3)
+trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Generate predictions for a few comments!
-predictions = model.predict([
- "No, he is an arrogant, self serving, immature idiot. Get it right.",
- "U SUCK HANNAH MONTANA",
- "Would you care to vote? Thx.",
-])
+predictions = model.predict(
+ [
+ "No, he is an arrogant, self serving, immature idiot. Get it right.",
+ "U SUCK HANNAH MONTANA",
+ "Would you care to vote? Thx.",
+ ]
+)
print(predictions)
# 5. Save the model!
diff --git a/flash_examples/translation.py b/flash_examples/translation.py
index 2a0d7889f2..fc82bb767a 100644
--- a/flash_examples/translation.py
+++ b/flash_examples/translation.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import torch
+
import flash
from flash.core.data.utils import download_data
from flash.text import TranslationData, TranslationTask
@@ -30,15 +32,17 @@
model = TranslationTask(backbone="Helsinki-NLP/opus-mt-en-ro")
# 3. Create the trainer and finetune the model
-trainer = flash.Trainer(max_epochs=3)
+trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule)
# 4. Translate something!
-predictions = model.predict([
- "BBC News went to meet one of the project's first graduates.",
- "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.",
- "Of course, it's still early in the election cycle.",
-])
+predictions = model.predict(
+ [
+ "BBC News went to meet one of the project's first graduates.",
+ "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.",
+ "Of course, it's still early in the election cycle.",
+ ]
+)
print(predictions)
# 5. Save the model!
diff --git a/flash_examples/video_classification.py b/flash_examples/video_classification.py
index 1ecfd25959..99c7422dcd 100644
--- a/flash_examples/video_classification.py
+++ b/flash_examples/video_classification.py
@@ -13,6 +13,8 @@
# limitations under the License.
import os
+import torch
+
import flash
from flash.core.data.utils import download_data
from flash.video import VideoClassificationData, VideoClassifier
@@ -33,7 +35,7 @@
model = VideoClassifier(backbone="x3d_xs", num_classes=datamodule.num_classes, pretrained=False)
# 3. Create the trainer and finetune the model
-trainer = flash.Trainer(max_epochs=3)
+trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Make a prediction
diff --git a/flash_examples/visualizations/pointcloud_detection.py b/flash_examples/visualizations/pointcloud_detection.py
new file mode 100644
index 0000000000..899e30a3aa
--- /dev/null
+++ b/flash_examples/visualizations/pointcloud_detection.py
@@ -0,0 +1,47 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+import flash
+from flash.core.data.utils import download_data
+from flash.pointcloud.detection import launch_app, PointCloudObjectDetector, PointCloudObjectDetectorData
+
+# 1. Create the DataModule
+# Dataset Credit: http://www.semantic-kitti.org/
+download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_tiny.zip", "data/")
+
+datamodule = PointCloudObjectDetectorData.from_folders(
+ train_folder="data/KITTI_Tiny/Kitti/train",
+ val_folder="data/KITTI_Tiny/Kitti/val",
+)
+
+# 2. Build the task
+model = PointCloudObjectDetector(backbone="pointpillars_kitti", num_classes=datamodule.num_classes)
+
+# 3. Create the trainer and finetune the model
+trainer = flash.Trainer(
+ max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0, gpus=torch.cuda.device_count()
+)
+trainer.fit(model, datamodule)
+
+# 4. Predict what's within a few PointClouds?
+predictions = model.predict(["data/KITTI_Tiny/Kitti/predict/scans/000000.bin"])
+
+# 5. Save the model!
+trainer.save_checkpoint("pointcloud_segmentation_model.pt")
+
+# 6. Optional Visualize
+app = launch_app(datamodule)
+# app.show_train_dataset()
+app.show_predictions(predictions)
diff --git a/flash_examples/visualizations/pointcloud_segmentation.py b/flash_examples/visualizations/pointcloud_segmentation.py
new file mode 100644
index 0000000000..c50ea7b958
--- /dev/null
+++ b/flash_examples/visualizations/pointcloud_segmentation.py
@@ -0,0 +1,52 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+import flash
+from flash.core.data.utils import download_data
+from flash.pointcloud.segmentation import launch_app, PointCloudSegmentation, PointCloudSegmentationData
+
+# 1. Create the DataModule
+# Dataset Credit: http://www.semantic-kitti.org/
+download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiTiny.zip", "data/")
+
+datamodule = PointCloudSegmentationData.from_folders(
+ train_folder="data/SemanticKittiTiny/train",
+ val_folder="data/SemanticKittiTiny/val",
+)
+
+# 2. Build the task
+model = PointCloudSegmentation(backbone="randlanet_semantic_kitti", num_classes=datamodule.num_classes)
+
+# 3. Create the trainer and finetune the model
+trainer = flash.Trainer(
+ max_epochs=1, limit_train_batches=0, limit_val_batches=0, num_sanity_val_steps=0, gpus=torch.cuda.device_count()
+)
+trainer.fit(model, datamodule)
+
+# 4. Predict what's within a few PointClouds?
+predictions = model.predict(
+ [
+ "data/SemanticKittiTiny/predict/000000.bin",
+ "data/SemanticKittiTiny/predict/000001.bin",
+ ]
+)
+
+# 5. Save the model!
+trainer.save_checkpoint("pointcloud_segmentation_model.pt")
+
+# 6. Optional Visualize
+app = launch_app(datamodule)
+# app.show_train_dataset()
+app.show_predictions(predictions)
diff --git a/pyproject.toml b/pyproject.toml
index cbfacb0aeb..e18a6fbac5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,2 +1,6 @@
[tool.autopep8]
ignore = ["E731"]
+
+
+[tool.black]
+line-length = 120
diff --git a/requirements.txt b/requirements.txt
index 01330917d4..e367ff1793 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,9 +1,8 @@
-torch>=1.8
+packaging
+torch
torchmetrics
-pytorch-lightning>=1.3.1
+pytorch-lightning>=1.4.0
pyDeprecate
-PyYAML>=5.1
-numpy
pandas<1.3.0
-packaging
-tqdm
+jsonargparse[signatures]>=3.17.0
+click>=7.1.2
diff --git a/requirements/datatype_audio.txt b/requirements/datatype_audio.txt
new file mode 100644
index 0000000000..4c198da250
--- /dev/null
+++ b/requirements/datatype_audio.txt
@@ -0,0 +1,4 @@
+torchaudio
+soundfile>=0.10.2
+transformers>=4.5
+datasets>=1.8
diff --git a/requirements/datatype_graph.txt b/requirements/datatype_graph.txt
new file mode 100644
index 0000000000..9109e2167f
--- /dev/null
+++ b/requirements/datatype_graph.txt
@@ -0,0 +1,3 @@
+torch-scatter
+torch-sparse
+torch-geometric
diff --git a/requirements/datatype_image.txt b/requirements/datatype_image.txt
index ab91d28d57..aa9fe14c15 100644
--- a/requirements/datatype_image.txt
+++ b/requirements/datatype_image.txt
@@ -3,7 +3,8 @@ timm>=0.4.5
lightning-bolts>=0.3.3
Pillow>=7.2
kornia>=0.5.1,<0.5.4
-matplotlib
-pycocotools>=2.0.2 ; python_version >= "3.7"
-fiftyone
-pystiche>=0.7.2
+pystiche==1.*
+segmentation-models-pytorch
+icevision>=0.8
+icedata
+effdet
diff --git a/requirements/datatype_image_extras.txt b/requirements/datatype_image_extras.txt
new file mode 100644
index 0000000000..f61e3f9c25
--- /dev/null
+++ b/requirements/datatype_image_extras.txt
@@ -0,0 +1,2 @@
+matplotlib
+fiftyone
diff --git a/requirements/datatype_pointcloud.txt b/requirements/datatype_pointcloud.txt
new file mode 100644
index 0000000000..cc6437f44c
--- /dev/null
+++ b/requirements/datatype_pointcloud.txt
@@ -0,0 +1,4 @@
+open3d==0.13
+torch==1.7.1
+torchvision
+tensorboard
diff --git a/requirements/datatype_video.txt b/requirements/datatype_video.txt
index 85bc82a5df..28279e2293 100644
--- a/requirements/datatype_video.txt
+++ b/requirements/datatype_video.txt
@@ -1,5 +1,4 @@
torchvision
Pillow>=7.2
kornia>=0.5.1,<0.5.4
-pytorchvideo==0.1.0
-fiftyone
+pytorchvideo==0.1.2
diff --git a/requirements/datatype_video_extras.txt b/requirements/datatype_video_extras.txt
new file mode 100644
index 0000000000..00de5ca1d2
--- /dev/null
+++ b/requirements/datatype_video_extras.txt
@@ -0,0 +1 @@
+fiftyone
diff --git a/requirements/docs.txt b/requirements/docs.txt
index a126cd5db3..5a6057f8e8 100644
--- a/requirements/docs.txt
+++ b/requirements/docs.txt
@@ -1,4 +1,4 @@
-sphinx>=4.0
+sphinx>=4.0,<4.1
recommonmark # fails with badges
m2r # fails with multi-line text
nbsphinx>=0.8
diff --git a/requirements/test.txt b/requirements/test.txt
index 6a4674f7d9..3fecfe24d9 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -11,7 +11,6 @@ twine==3.2
# formatting
pre-commit
isort
-yapf
#mypy
scikit-learn
pytest_mock
diff --git a/setup.cfg b/setup.cfg
index 73aff69cad..8ed86d15f0 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -72,18 +72,6 @@ ignore =
.circleci
-[yapf]
-based_on_style = pep8
-spaces_before_comment = 2
-split_before_logical_operator = true
-COLUMN_LIMIT = 120
-COALESCE_BRACKETS = true
-DEDENT_CLOSING_BRACKETS = true
-ALLOW_SPLIT_BEFORE_DICT_VALUE = false
-BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
-NO_SPACES_AROUND_SELECTED_BINARY_OPERATORS = false
-
-
[mypy]
# Typing tests is low priority, but enabling type checking on the
# untyped test functions (using `--check-untyped-defs`) is still
diff --git a/setup.py b/setup.py
index d581ce9275..96fb1a6164 100644
--- a/setup.py
+++ b/setup.py
@@ -33,8 +33,8 @@ def _load_py_module(fname, pkg="flash"):
return py
-about = _load_py_module('__about__.py')
-setup_tools = _load_py_module('setup_tools.py')
+about = _load_py_module("__about__.py")
+setup_tools = _load_py_module("setup_tools.py")
long_description = setup_tools._load_readme_description(
_PATH_ROOT,
@@ -49,13 +49,19 @@ def _load_py_module(fname, pkg="flash"):
"text": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_text.txt"),
"tabular": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_tabular.txt"),
"image": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_image.txt"),
+ "image_extras": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_image_extras.txt"),
"video": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_video.txt"),
+ "pointcloud": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_pointcloud.txt"),
+ "video_extras": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_video_extras.txt"),
"serve": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="serve.txt"),
+ "audio": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_audio.txt"),
+ "graph": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_graph.txt"),
}
-# remove possible duplicate.
extras["vision"] = list(set(extras["image"] + extras["video"]))
-extras["all"] = list(set(extras["vision"] + extras["tabular"] + extras["text"]))
+extras["all"] = list(
+ set(extras["vision"] + extras["tabular"] + extras["text"])
+) # + extras["pointcloud"] dependencies conflicts
extras["dev"] = list(set(extras["all"] + extras["test"] + extras["docs"]))
# https://packaging.python.org/discussions/install-requires-vs-requirements /
@@ -77,10 +83,13 @@ def _load_py_module(fname, pkg="flash"):
long_description_content_type="text/markdown",
include_package_data=True,
extras_require=extras,
+ entry_points={
+ "console_scripts": ["flash=flash.__main__:main"],
+ },
zip_safe=False,
keywords=["deep learning", "pytorch", "AI"],
python_requires=">=3.6",
- install_requires=setup_tools._load_requirements(_PATH_ROOT, file_name='requirements.txt'),
+ install_requires=setup_tools._load_requirements(_PATH_ROOT, file_name="requirements.txt"),
project_urls={
"Bug Tracker": "https://github.com/PyTorchLightning/lightning-flash/issues",
"Documentation": "https://lightning-flash.rtfd.io/en/latest/",
diff --git a/tests/__init__.py b/tests/__init__.py
index c64310c910..2be74bcdc7 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -2,5 +2,5 @@
# TorchVision hotfix https://github.com/pytorch/vision/issues/1938
opener = urllib.request.build_opener()
-opener.addheaders = [('User-agent', 'Mozilla/5.0')]
+opener.addheaders = [("User-agent", "Mozilla/5.0")]
urllib.request.install_opener(opener)
diff --git a/tests/audio/__init__.py b/tests/audio/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/audio/classification/__init__.py b/tests/audio/classification/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py
new file mode 100644
index 0000000000..626ca12b93
--- /dev/null
+++ b/tests/audio/classification/test_data.py
@@ -0,0 +1,340 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from pathlib import Path
+from typing import Any, List, Tuple
+
+import numpy as np
+import pytest
+import torch
+import torch.nn as nn
+
+from flash.audio import AudioClassificationData
+from flash.core.data.data_source import DefaultDataKeys
+from flash.core.data.transforms import ApplyToKeys
+from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
+from tests.helpers.utils import _AUDIO_TESTING
+
+if _TORCHVISION_AVAILABLE:
+ import torchvision
+
+if _PIL_AVAILABLE:
+ from PIL import Image
+
+
+def _rand_image(size: Tuple[int, int] = None):
+ if size is None:
+ _size = np.random.choice([196, 244])
+ size = (_size, _size)
+ return Image.fromarray(np.random.randint(0, 255, (*size, 3), dtype="uint8"))
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_from_filepaths_smoke(tmpdir):
+ tmpdir = Path(tmpdir)
+
+ (tmpdir / "a").mkdir()
+ (tmpdir / "b").mkdir()
+ _rand_image().save(tmpdir / "a_1.png")
+ _rand_image().save(tmpdir / "b_1.png")
+
+ train_images = [
+ str(tmpdir / "a_1.png"),
+ str(tmpdir / "b_1.png"),
+ ]
+
+ spectrograms_data = AudioClassificationData.from_files(
+ train_files=train_images,
+ train_targets=[1, 2],
+ batch_size=2,
+ num_workers=0,
+ )
+ assert spectrograms_data.train_dataloader() is not None
+ assert spectrograms_data.val_dataloader() is None
+ assert spectrograms_data.test_dataloader() is None
+
+ data = next(iter(spectrograms_data.train_dataloader()))
+ imgs, labels = data["input"], data["target"]
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2,)
+ assert sorted(list(labels.numpy())) == [1, 2]
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_from_filepaths_list_image_paths(tmpdir):
+ tmpdir = Path(tmpdir)
+
+ (tmpdir / "e").mkdir()
+ _rand_image().save(tmpdir / "e_1.png")
+
+ train_images = [
+ str(tmpdir / "e_1.png"),
+ str(tmpdir / "e_1.png"),
+ str(tmpdir / "e_1.png"),
+ ]
+
+ spectrograms_data = AudioClassificationData.from_files(
+ train_files=train_images,
+ train_targets=[0, 3, 6],
+ val_files=train_images,
+ val_targets=[1, 4, 7],
+ test_files=train_images,
+ test_targets=[2, 5, 8],
+ batch_size=2,
+ num_workers=0,
+ )
+
+ # check training data
+ data = next(iter(spectrograms_data.train_dataloader()))
+ imgs, labels = data["input"], data["target"]
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2,)
+ assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here
+ assert labels.numpy()[1] in [0, 3, 6] # data comes shuffled here
+
+ # check validation data
+ data = next(iter(spectrograms_data.val_dataloader()))
+ imgs, labels = data["input"], data["target"]
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2,)
+ assert list(labels.numpy()) == [1, 4]
+
+ # check test data
+ data = next(iter(spectrograms_data.test_dataloader()))
+ imgs, labels = data["input"], data["target"]
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2,)
+ assert list(labels.numpy()) == [2, 5]
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.")
+def test_from_filepaths_visualise(tmpdir):
+ tmpdir = Path(tmpdir)
+
+ (tmpdir / "e").mkdir()
+ _rand_image().save(tmpdir / "e_1.png")
+
+ train_images = [
+ str(tmpdir / "e_1.png"),
+ str(tmpdir / "e_1.png"),
+ str(tmpdir / "e_1.png"),
+ ]
+
+ dm = AudioClassificationData.from_files(
+ train_files=train_images,
+ train_targets=[0, 3, 6],
+ val_files=train_images,
+ val_targets=[1, 4, 7],
+ test_files=train_images,
+ test_targets=[2, 5, 8],
+ batch_size=2,
+ num_workers=0,
+ )
+
+ # disable visualisation for testing
+ assert dm.data_fetcher.block_viz_window is True
+ dm.set_block_viz_window(False)
+ assert dm.data_fetcher.block_viz_window is False
+
+ # call show functions
+ # dm.show_train_batch()
+ dm.show_train_batch("pre_tensor_transform")
+ dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"])
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.")
+def test_from_filepaths_visualise_multilabel(tmpdir):
+ tmpdir = Path(tmpdir)
+
+ (tmpdir / "a").mkdir()
+ (tmpdir / "b").mkdir()
+
+ image_a = str(tmpdir / "a" / "a_1.png")
+ image_b = str(tmpdir / "b" / "b_1.png")
+
+ _rand_image().save(image_a)
+ _rand_image().save(image_b)
+
+ dm = AudioClassificationData.from_files(
+ train_files=[image_a, image_b],
+ train_targets=[[0, 1, 0], [0, 1, 1]],
+ val_files=[image_b, image_a],
+ val_targets=[[1, 1, 0], [0, 0, 1]],
+ test_files=[image_b, image_b],
+ test_targets=[[0, 0, 1], [1, 1, 0]],
+ batch_size=2,
+ spectrogram_size=(64, 64),
+ )
+ # disable visualisation for testing
+ assert dm.data_fetcher.block_viz_window is True
+ dm.set_block_viz_window(False)
+ assert dm.data_fetcher.block_viz_window is False
+
+ # call show functions
+ dm.show_train_batch()
+ dm.show_train_batch("pre_tensor_transform")
+ dm.show_train_batch("to_tensor_transform")
+ dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"])
+ dm.show_val_batch("per_batch_transform")
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_from_filepaths_splits(tmpdir):
+ tmpdir = Path(tmpdir)
+
+ B, _, H, W = 2, 3, 224, 224
+ img_size: Tuple[int, int] = (H, W)
+
+ (tmpdir / "splits").mkdir()
+ _rand_image(img_size).save(tmpdir / "s.png")
+
+ num_samples: int = 10
+ val_split: float = 0.3
+
+ train_filepaths: List[str] = [str(tmpdir / "s.png") for _ in range(num_samples)]
+
+ train_labels: List[int] = list(range(num_samples))
+
+ assert len(train_filepaths) == len(train_labels)
+
+ _to_tensor = {
+ "to_tensor_transform": nn.Sequential(
+ ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
+ ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
+ ),
+ }
+
+ def run(transform: Any = None):
+ dm = AudioClassificationData.from_files(
+ train_files=train_filepaths,
+ train_targets=train_labels,
+ train_transform=transform,
+ val_transform=transform,
+ batch_size=B,
+ num_workers=0,
+ val_split=val_split,
+ spectrogram_size=img_size,
+ )
+ data = next(iter(dm.train_dataloader()))
+ imgs, labels = data["input"], data["target"]
+ assert imgs.shape == (B, 3, H, W)
+ assert labels.shape == (B,)
+
+ run(_to_tensor)
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_from_folders_only_train(tmpdir):
+ train_dir = Path(tmpdir / "train")
+ train_dir.mkdir()
+
+ (train_dir / "a").mkdir()
+ _rand_image().save(train_dir / "a" / "1.png")
+ _rand_image().save(train_dir / "a" / "2.png")
+
+ (train_dir / "b").mkdir()
+ _rand_image().save(train_dir / "b" / "1.png")
+ _rand_image().save(train_dir / "b" / "2.png")
+
+ spectrograms_data = AudioClassificationData.from_folders(train_dir, train_transform=None, batch_size=1)
+
+ data = next(iter(spectrograms_data.train_dataloader()))
+ imgs, labels = data["input"], data["target"]
+ assert imgs.shape == (1, 3, 128, 128)
+ assert labels.shape == (1,)
+
+ assert spectrograms_data.val_dataloader() is None
+ assert spectrograms_data.test_dataloader() is None
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_from_folders_train_val(tmpdir):
+
+ train_dir = Path(tmpdir / "train")
+ train_dir.mkdir()
+
+ (train_dir / "a").mkdir()
+ _rand_image().save(train_dir / "a" / "1.png")
+ _rand_image().save(train_dir / "a" / "2.png")
+
+ (train_dir / "b").mkdir()
+ _rand_image().save(train_dir / "b" / "1.png")
+ _rand_image().save(train_dir / "b" / "2.png")
+ spectrograms_data = AudioClassificationData.from_folders(
+ train_dir,
+ val_folder=train_dir,
+ test_folder=train_dir,
+ batch_size=2,
+ num_workers=0,
+ )
+
+ data = next(iter(spectrograms_data.train_dataloader()))
+ imgs, labels = data["input"], data["target"]
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2,)
+
+ data = next(iter(spectrograms_data.val_dataloader()))
+ imgs, labels = data["input"], data["target"]
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2,)
+ assert list(labels.numpy()) == [0, 0]
+
+ data = next(iter(spectrograms_data.test_dataloader()))
+ imgs, labels = data["input"], data["target"]
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2,)
+ assert list(labels.numpy()) == [0, 0]
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_from_filepaths_multilabel(tmpdir):
+ tmpdir = Path(tmpdir)
+
+ (tmpdir / "a").mkdir()
+ _rand_image().save(tmpdir / "a1.png")
+ _rand_image().save(tmpdir / "a2.png")
+
+ train_images = [str(tmpdir / "a1.png"), str(tmpdir / "a2.png")]
+ train_labels = [[1, 0, 1, 0], [0, 0, 1, 1]]
+ valid_labels = [[1, 1, 1, 0], [1, 0, 0, 1]]
+ test_labels = [[1, 0, 1, 0], [1, 1, 0, 1]]
+
+ dm = AudioClassificationData.from_files(
+ train_files=train_images,
+ train_targets=train_labels,
+ val_files=train_images,
+ val_targets=valid_labels,
+ test_files=train_images,
+ test_targets=test_labels,
+ batch_size=2,
+ num_workers=0,
+ )
+
+ data = next(iter(dm.train_dataloader()))
+ imgs, labels = data["input"], data["target"]
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2, 4)
+
+ data = next(iter(dm.val_dataloader()))
+ imgs, labels = data["input"], data["target"]
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2, 4)
+ torch.testing.assert_allclose(labels, torch.tensor(valid_labels))
+
+ data = next(iter(dm.test_dataloader()))
+ imgs, labels = data["input"], data["target"]
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2, 4)
+ torch.testing.assert_allclose(labels, torch.tensor(test_labels))
diff --git a/tests/audio/classification/test_model.py b/tests/audio/classification/test_model.py
new file mode 100644
index 0000000000..0e5a4fa3fc
--- /dev/null
+++ b/tests/audio/classification/test_model.py
@@ -0,0 +1,31 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest import mock
+
+import pytest
+
+from flash.__main__ import main
+from flash.core.utilities.imports import _IMAGE_AVAILABLE
+from tests.helpers.utils import _AUDIO_TESTING
+
+
+@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_cli():
+ cli_args = ["flash", "audio_classification", "--trainer.fast_dev_run", "True"]
+ with mock.patch("sys.argv", cli_args):
+ try:
+ main()
+ except SystemExit:
+ pass
diff --git a/tests/audio/speech_recognition/__init__.py b/tests/audio/speech_recognition/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/audio/speech_recognition/test_data.py b/tests/audio/speech_recognition/test_data.py
new file mode 100644
index 0000000000..6205da309d
--- /dev/null
+++ b/tests/audio/speech_recognition/test_data.py
@@ -0,0 +1,89 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import os
+from pathlib import Path
+
+import pytest
+
+import flash
+from flash.audio import SpeechRecognitionData
+from flash.core.data.data_source import DefaultDataKeys
+from tests.helpers.utils import _AUDIO_TESTING
+
+path = str(Path(flash.ASSETS_ROOT) / "example.wav")
+sample = {"file": path, "text": "example input."}
+
+TEST_CSV_DATA = f"""file,text
+{path},example input.
+{path},example input.
+{path},example input.
+{path},example input.
+{path},example input.
+"""
+
+
+def csv_data(tmpdir):
+ path = Path(tmpdir) / "data.csv"
+ path.write_text(TEST_CSV_DATA)
+ return path
+
+
+def json_data(tmpdir, n_samples=5):
+ path = Path(tmpdir) / "data.json"
+ with path.open("w") as f:
+ f.write("\n".join([json.dumps(sample) for x in range(n_samples)]))
+ return path
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="speech libraries aren't installed.")
+def test_from_csv(tmpdir):
+ csv_path = csv_data(tmpdir)
+ dm = SpeechRecognitionData.from_csv("file", "text", train_file=csv_path, batch_size=1, num_workers=0)
+ batch = next(iter(dm.train_dataloader()))
+ assert DefaultDataKeys.INPUT in batch
+ assert DefaultDataKeys.TARGET in batch
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="speech libraries aren't installed.")
+def test_stage_test_and_valid(tmpdir):
+ csv_path = csv_data(tmpdir)
+ dm = SpeechRecognitionData.from_csv(
+ "file", "text", train_file=csv_path, val_file=csv_path, test_file=csv_path, batch_size=1, num_workers=0
+ )
+ batch = next(iter(dm.val_dataloader()))
+ assert DefaultDataKeys.INPUT in batch
+ assert DefaultDataKeys.TARGET in batch
+
+ batch = next(iter(dm.test_dataloader()))
+ assert DefaultDataKeys.INPUT in batch
+ assert DefaultDataKeys.TARGET in batch
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="speech libraries aren't installed.")
+def test_from_json(tmpdir):
+ json_path = json_data(tmpdir)
+ dm = SpeechRecognitionData.from_json("file", "text", train_file=json_path, batch_size=1, num_workers=0)
+ batch = next(iter(dm.train_dataloader()))
+ assert DefaultDataKeys.INPUT in batch
+ assert DefaultDataKeys.TARGET in batch
+
+
+@pytest.mark.skipif(_AUDIO_TESTING, reason="audio libraries are installed.")
+def test_audio_module_not_found_error():
+ with pytest.raises(ModuleNotFoundError, match="[audio]"):
+ SpeechRecognitionData.from_json("file", "text", train_file="", batch_size=1, num_workers=0)
diff --git a/tests/audio/speech_recognition/test_data_model_integration.py b/tests/audio/speech_recognition/test_data_model_integration.py
new file mode 100644
index 0000000000..eda3ac86b3
--- /dev/null
+++ b/tests/audio/speech_recognition/test_data_model_integration.py
@@ -0,0 +1,83 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import os
+from pathlib import Path
+
+import pytest
+from pytorch_lightning import Trainer
+
+import flash
+from flash.audio import SpeechRecognition, SpeechRecognitionData
+from tests.helpers.utils import _AUDIO_TESTING
+
+TEST_BACKBONE = "patrickvonplaten/wav2vec2_tiny_random_robust" # super small model for testing
+
+path = str(Path(flash.ASSETS_ROOT) / "example.wav")
+sample = {"file": path, "text": "example input."}
+
+TEST_CSV_DATA = f"""file,text
+{path},example input.
+{path},example input.
+{path},example input.
+{path},example input.
+{path},example input.
+"""
+
+
+def csv_data(tmpdir):
+ path = Path(tmpdir) / "data.csv"
+ path.write_text(TEST_CSV_DATA)
+ return path
+
+
+def json_data(tmpdir, n_samples=5):
+ path = Path(tmpdir) / "data.json"
+ with path.open("w") as f:
+ f.write("\n".join([json.dumps(sample) for x in range(n_samples)]))
+ return path
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_classification_csv(tmpdir):
+ csv_path = csv_data(tmpdir)
+
+ data = SpeechRecognitionData.from_csv(
+ "file",
+ "text",
+ train_file=csv_path,
+ num_workers=0,
+ batch_size=2,
+ )
+ model = SpeechRecognition(backbone=TEST_BACKBONE)
+ trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
+ trainer.fit(model, datamodule=data)
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_classification_json(tmpdir):
+ json_path = json_data(tmpdir)
+
+ data = SpeechRecognitionData.from_json(
+ "file",
+ "text",
+ train_file=json_path,
+ num_workers=0,
+ batch_size=2,
+ )
+ model = SpeechRecognition(backbone=TEST_BACKBONE)
+ trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
+ trainer.fit(model, datamodule=data)
diff --git a/tests/audio/speech_recognition/test_model.py b/tests/audio/speech_recognition/test_model.py
new file mode 100644
index 0000000000..5ce932cd4d
--- /dev/null
+++ b/tests/audio/speech_recognition/test_model.py
@@ -0,0 +1,102 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import re
+from unittest import mock
+
+import numpy as np
+import pytest
+import torch
+
+from flash import Trainer
+from flash.__main__ import main
+from flash.audio import SpeechRecognition
+from flash.audio.speech_recognition.data import SpeechRecognitionPostprocess, SpeechRecognitionPreprocess
+from flash.core.data.data_source import DefaultDataKeys
+from tests.helpers.utils import _AUDIO_TESTING, _SERVE_TESTING
+
+# ======== Mock functions ========
+
+
+class DummyDataset(torch.utils.data.Dataset):
+ def __getitem__(self, index):
+ return {
+ DefaultDataKeys.INPUT: np.random.randn(86631),
+ DefaultDataKeys.TARGET: "some target text",
+ DefaultDataKeys.METADATA: {"sampling_rate": 16000},
+ }
+
+ def __len__(self) -> int:
+ return 100
+
+
+# ==============================
+
+TEST_BACKBONE = "patrickvonplaten/wav2vec2_tiny_random_robust" # super small model for testing
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_init_train(tmpdir):
+ model = SpeechRecognition(backbone=TEST_BACKBONE)
+ train_dl = torch.utils.data.DataLoader(DummyDataset())
+ trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
+ trainer.fit(model, train_dl)
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_jit(tmpdir):
+ sample_input = {"input_values": torch.randn(size=torch.Size([1, 86631])).float()}
+ path = os.path.join(tmpdir, "test.pt")
+
+ model = SpeechRecognition(backbone=TEST_BACKBONE)
+ model.eval()
+
+ # Huggingface model only supports `torch.jit.trace` with `strict=False`
+ model = torch.jit.trace(model, sample_input, strict=False)
+
+ torch.jit.save(model, path)
+ model = torch.jit.load(path)
+
+ out = model(sample_input)["logits"]
+ assert isinstance(out, torch.Tensor)
+ assert out.shape == torch.Size([1, 95, 12])
+
+
+@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.")
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+@mock.patch("flash._IS_TESTING", True)
+def test_serve():
+ model = SpeechRecognition(backbone=TEST_BACKBONE)
+ # TODO: Currently only servable once a preprocess and postprocess have been attached
+ model._preprocess = SpeechRecognitionPreprocess()
+ model._postprocess = SpeechRecognitionPostprocess()
+ model.eval()
+ model.serve()
+
+
+@pytest.mark.skipif(_AUDIO_TESTING, reason="audio libraries are installed.")
+def test_load_from_checkpoint_dependency_error():
+ with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[audio]'")):
+ SpeechRecognition.load_from_checkpoint("not_a_real_checkpoint.pt")
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_cli():
+ cli_args = ["flash", "speech_recognition", "--trainer.fast_dev_run", "True"]
+ with mock.patch("sys.argv", cli_args):
+ try:
+ main()
+ except SystemExit:
+ pass
diff --git a/tests/conftest.py b/tests/conftest.py
index f2f67cc829..43fd8dc824 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -15,7 +15,7 @@
class UUID_String(str):
- """Class to replace UUID object with str instance and hex attribute"""
+ """Class to replace UUID object with str instance and hex attribute."""
@property
def hex(self):
@@ -80,7 +80,7 @@ def lightning_squeezenet1_1_obj():
def squeezenet_servable(squeezenet1_1_model, session_global_datadir):
from flash.core.serve import Servable
- trace = torch.jit.trace(squeezenet1_1_model.eval(), (torch.rand(1, 3, 224, 224), ))
+ trace = torch.jit.trace(squeezenet1_1_model.eval(), (torch.rand(1, 3, 224, 224),))
fpth = str(session_global_datadir / "squeezenet_jit_trace.pt")
torch.jit.save(trace, fpth)
diff --git a/tests/core/data/test_auto_dataset.py b/tests/core/data/test_auto_dataset.py
index 7acbffe671..8571363a0a 100644
--- a/tests/core/data/test_auto_dataset.py
+++ b/tests/core/data/test_auto_dataset.py
@@ -22,7 +22,6 @@
class _AutoDatasetTestDataSource(DataSource):
-
def __init__(self, with_dset: bool):
self._callbacks: List[FlashCallback] = []
self.load_data_count = 0
diff --git a/tests/core/data/test_base_viz.py b/tests/core/data/test_base_viz.py
index 20d2084b9b..9af754eb1c 100644
--- a/tests/core/data/test_base_viz.py
+++ b/tests/core/data/test_base_viz.py
@@ -37,7 +37,6 @@ def _rand_image():
class CustomBaseVisualization(BaseVisualization):
-
def __init__(self):
super().__init__()
@@ -77,7 +76,6 @@ def check_reset(self):
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
class TestBaseViz:
-
def test_base_viz(self, tmpdir):
seed_everything(42)
@@ -89,7 +87,6 @@ def test_base_viz(self, tmpdir):
_rand_image().save(train_images[1])
class CustomImageClassificationData(ImageClassificationData):
-
@staticmethod
def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization:
return CustomBaseVisualization(*args, **kwargs)
@@ -154,7 +151,7 @@ def _get_result(function_name: str):
if not is_predict:
res = _get_result("per_batch_transform")
- assert res[0][DefaultDataKeys.TARGET].shape == (B, )
+ assert res[0][DefaultDataKeys.TARGET].shape == (B,)
assert dm.data_fetcher.show_load_sample_called
assert dm.data_fetcher.show_pre_tensor_transform_called
@@ -165,12 +162,13 @@ def _get_result(function_name: str):
dm.data_fetcher.check_reset()
@pytest.mark.parametrize(
- "func_names, valid", [
+ "func_names, valid",
+ [
(["load_sample"], True),
(["not_a_hook"], False),
(["load_sample", "pre_tensor_transform"], True),
(["load_sample", "not_a_hook"], True),
- ]
+ ],
)
def test_show(self, func_names, valid):
base_viz = CustomBaseVisualization()
diff --git a/tests/core/data/test_batch.py b/tests/core/data/test_batch.py
index caba5cf4a0..a03457ed77 100644
--- a/tests/core/data/test_batch.py
+++ b/tests/core/data/test_batch.py
@@ -102,9 +102,9 @@ def test_tensor_batch():
def test_sequence(self):
batch = {
- 'a': torch.rand(self.BATCH_SIZE, 4),
- 'b': torch.rand(self.BATCH_SIZE, 2),
- 'c': torch.rand(self.BATCH_SIZE)
+ "a": torch.rand(self.BATCH_SIZE, 4),
+ "b": torch.rand(self.BATCH_SIZE, 2),
+ "c": torch.rand(self.BATCH_SIZE),
}
output = default_uncollate(batch)
@@ -112,13 +112,13 @@ def test_sequence(self):
assert len(batch) == self.BATCH_SIZE
for sample in output:
- assert list(sample.keys()) == ['a', 'b', 'c']
- assert isinstance(sample['a'], list)
- assert len(sample['a']) == 4
- assert isinstance(sample['b'], list)
- assert len(sample['b']) == 2
- assert isinstance(sample['c'], torch.Tensor)
- assert len(sample['c'].shape) == 0
+ assert list(sample.keys()) == ["a", "b", "c"]
+ assert isinstance(sample["a"], list)
+ assert len(sample["a"]) == 4
+ assert isinstance(sample["b"], list)
+ assert len(sample["b"]) == 2
+ assert isinstance(sample["c"], torch.Tensor)
+ assert len(sample["c"].shape) == 0
def test_named_tuple(self):
Batch = namedtuple("Batch", ["x", "y"])
diff --git a/tests/core/data/test_callback.py b/tests/core/data/test_callback.py
index e11591f33a..5db55dee08 100644
--- a/tests/core/data/test_callback.py
+++ b/tests/core/data/test_callback.py
@@ -23,8 +23,9 @@
from flash.core.trainer import Trainer
+@mock.patch("pickle.dumps") # need to mock pickle or we get pickle error
@mock.patch("torch.save") # need to mock torch.save or we get pickle error
-def test_flash_callback(_, tmpdir):
+def test_flash_callback(_, __, tmpdir):
"""Test the callback hook system for fit."""
callback_mock = MagicMock()
@@ -47,7 +48,6 @@ def test_flash_callback(_, tmpdir):
]
class CustomModel(Task):
-
def __init__(self):
super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss())
@@ -91,5 +91,5 @@ def __init__(self):
call.on_post_tensor_transform(ANY, RunningStage.VALIDATING),
call.on_collate(ANY, RunningStage.VALIDATING),
call.on_per_batch_transform(ANY, RunningStage.VALIDATING),
- call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING)
+ call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING),
]
diff --git a/tests/core/data/test_callbacks.py b/tests/core/data/test_callbacks.py
index 284de09b02..07e89fec16 100644
--- a/tests/core/data/test_callbacks.py
+++ b/tests/core/data/test_callbacks.py
@@ -23,9 +23,7 @@
def test_base_data_fetcher(tmpdir):
-
class CheckData(BaseDataFetcher):
-
def check(self):
assert self.batches["val"]["load_sample"] == [0, 1, 2, 3, 4]
assert self.batches["val"]["pre_tensor_transform"] == [0, 1, 2, 3, 4]
@@ -38,7 +36,6 @@ def check(self):
assert self.batches["predict"] == {}
class CustomDataModule(DataModule):
-
@staticmethod
def configure_data_fetcher():
return CheckData()
@@ -70,13 +67,11 @@ def from_inputs(cls, train_data: Any, val_data: Any, test_data: Any, predict_dat
data_fetcher.check()
data_fetcher.reset()
- assert data_fetcher.batches == {'train': {}, 'test': {}, 'val': {}, 'predict': {}}
+ assert data_fetcher.batches == {"train": {}, "test": {}, "val": {}, "predict": {}}
def test_data_loaders_num_workers_to_0(tmpdir):
- """
- num_workers should be set to `0` internally for visualization and not for training.
- """
+ """num_workers should be set to `0` internally for visualization and not for training."""
datamodule = DataModule(train_dataset=range(10), num_workers=3)
iterator = datamodule._reset_iterator(RunningStage.TRAINING)
diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py
index 2b593cdd9e..7124675f30 100644
--- a/tests/core/data/test_data_pipeline.py
+++ b/tests/core/data/test_data_pipeline.py
@@ -44,7 +44,6 @@
class DummyDataset(torch.utils.data.Dataset):
-
def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
return torch.rand(1), torch.rand(1)
@@ -53,7 +52,6 @@ def __len__(self) -> int:
class TestDataPipelineState:
-
@staticmethod
def test_str():
state = DataPipelineState()
@@ -95,9 +93,7 @@ def test_data_pipeline_str():
@pytest.mark.parametrize("use_preprocess", [False, True])
@pytest.mark.parametrize("use_postprocess", [False, True])
def test_data_pipeline_init_and_assignement(use_preprocess, use_postprocess, tmpdir):
-
class CustomModel(Task):
-
def __init__(self, postprocess: Optional[Postprocess] = None):
super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss())
self._postprocess = postprocess
@@ -135,9 +131,7 @@ class SubPostprocess(Postprocess):
def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir):
-
class CustomPreprocess(DefaultPreprocess):
-
def val_pre_tensor_transform(self, *_, **__):
pass
@@ -258,7 +252,6 @@ def test_per_batch_transform_on_device(self, *_, **__):
class CustomPreprocess(DefaultPreprocess):
-
def train_per_sample_transform(self, *_, **__):
pass
@@ -307,9 +300,7 @@ def test_data_pipeline_predict_worker_preprocessor_and_device_preprocessor():
def test_detach_preprocessing_from_model(tmpdir):
-
class CustomModel(Task):
-
def __init__(self, postprocess: Optional[Postprocess] = None):
super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss())
self._postprocess = postprocess
@@ -333,7 +324,6 @@ def train_dataloader(self) -> Any:
class TestPreprocess(DefaultPreprocess):
-
def train_per_sample_transform(self, *_, **__):
pass
@@ -363,7 +353,6 @@ def predict_per_batch_transform_on_device(self, *_, **__):
def test_attaching_datapipeline_to_model(tmpdir):
-
class SubPreprocess(DefaultPreprocess):
pass
@@ -371,7 +360,6 @@ class SubPreprocess(DefaultPreprocess):
data_pipeline = DataPipeline(preprocess=preprocess)
class CustomModel(Task):
-
def __init__(self):
super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss())
self._postprocess = Postprocess()
@@ -513,8 +501,7 @@ def test_stage_orchestrator_state_attach_detach(tmpdir):
_original_predict_step = model.predict_step
class CustomDataPipeline(DataPipeline):
-
- def _attach_postprocess_to_model(self, model: 'Task', _postprocesssor: _Postprocessor) -> 'Task':
+ def _attach_postprocess_to_model(self, model: "Task", _postprocesssor: _Postprocessor) -> "Task":
model.predict_step = self._model_predict_step_wrapper(model.predict_step, _postprocesssor, model)
return model
@@ -528,7 +515,6 @@ def _attach_postprocess_to_model(self, model: 'Task', _postprocesssor: _Postproc
class LamdaDummyDataset(torch.utils.data.Dataset):
-
def __init__(self, fx: Callable):
self.fx = fx
@@ -540,7 +526,6 @@ def __len__(self) -> int:
class TestPreprocessTransformationsDataSource(DataSource):
-
def __init__(self):
super().__init__()
@@ -589,7 +574,7 @@ def test_load_data(self, sample) -> LamdaDummyDataset:
@staticmethod
def fn_predict_load_data() -> List[str]:
- return (["a", "b"])
+ return ["a", "b"]
def predict_load_data(self, sample) -> LamdaDummyDataset:
assert self.predicting
@@ -599,7 +584,6 @@ def predict_load_data(self, sample) -> LamdaDummyDataset:
class TestPreprocessTransformations(DefaultPreprocess):
-
def __init__(self):
super().__init__(data_sources={"default": TestPreprocessTransformationsDataSource()})
@@ -616,7 +600,7 @@ def train_pre_tensor_transform(self, sample: Any) -> Any:
assert self.training
assert self.current_fn == "pre_tensor_transform"
self.train_pre_tensor_transform_called = True
- return sample + (5, )
+ return sample + (5,)
def train_collate(self, samples) -> Tensor:
assert self.training
@@ -640,9 +624,9 @@ def val_collate(self, samples) -> Dict[str, Tensor]:
assert self.validating
assert self.current_fn == "collate"
self.val_collate_called = True
- _count = samples[0]['a']
- assert samples == [{'a': _count, 'b': _count + 1}, {'a': _count + 1, 'b': _count + 2}]
- return {'a': tensor([0, 1]), 'b': tensor([1, 2])}
+ _count = samples[0]["a"]
+ assert samples == [{"a": _count, "b": _count + 1}, {"a": _count + 1, "b": _count + 2}]
+ return {"a": tensor([0, 1]), "b": tensor([1, 2])}
def val_per_batch_transform_on_device(self, batch: Any) -> Any:
assert self.validating
@@ -668,14 +652,12 @@ def test_post_tensor_transform(self, sample: Tensor) -> Tensor:
class TestPreprocessTransformations2(TestPreprocessTransformations):
-
def val_to_tensor_transform(self, sample: Any) -> Tensor:
self.val_to_tensor_transform_called = True
return {"a": tensor(sample["a"]), "b": tensor(sample["b"])}
class CustomModel(Task):
-
def __init__(self):
super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss())
@@ -691,11 +673,11 @@ def test_step(self, batch, batch_idx):
assert len(batch) == 2
assert batch[0].shape == torch.Size([2, 1])
- def predict_step(self, batch, batch_idx, dataloader_idx):
- assert batch[0][0] == 'a'
- assert batch[0][1] == 'a'
- assert batch[1][0] == 'b'
- assert batch[1][1] == 'b'
+ def predict_step(self, batch, batch_idx, dataloader_idx=None):
+ assert batch[0][0] == "a"
+ assert batch[0][1] == "a"
+ assert batch[1][0] == "b"
+ assert batch[1][1] == "b"
return tensor([0, 0, 0])
@@ -709,8 +691,8 @@ def test_datapipeline_transformations(tmpdir):
batch = next(iter(datamodule.train_dataloader()))
assert torch.equal(batch, tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]]))
- assert datamodule.val_dataloader().dataset[0] == {'a': 0, 'b': 1}
- assert datamodule.val_dataloader().dataset[1] == {'a': 1, 'b': 2}
+ assert datamodule.val_dataloader().dataset[0] == {"a": 0, "b": 1}
+ assert datamodule.val_dataloader().dataset[1] == {"a": 1, "b": 2}
with pytest.raises(MisconfigurationException, match="When ``to_tensor_transform``"):
batch = next(iter(datamodule.val_dataloader()))
@@ -728,7 +710,7 @@ def test_datapipeline_transformations(tmpdir):
limit_val_batches=1,
limit_test_batches=2,
limit_predict_batches=2,
- num_sanity_val_steps=1
+ num_sanity_val_steps=1,
)
trainer.fit(model, datamodule=datamodule)
trainer.test(model)
@@ -752,9 +734,7 @@ def test_datapipeline_transformations(tmpdir):
def test_is_overriden_recursive(tmpdir):
-
class TestPreprocess(DefaultPreprocess):
-
def collate(self, *_):
pass
@@ -775,9 +755,7 @@ def val_collate(self, *_):
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
@patch("torch.save") # need to mock torch.save or we get pickle error
def test_dummy_example(tmpdir):
-
class ImageDataSource(DataSource):
-
def load_data(self, folder: str):
# from folder -> return files paths
return ["a.jpg", "b.jpg"]
@@ -788,7 +766,6 @@ def load_sample(self, path: str) -> Image.Image:
return Image.fromarray(img8Bit)
class ImageClassificationPreprocess(DefaultPreprocess):
-
def __init__(
self,
train_transform=None,
@@ -817,7 +794,6 @@ def train_per_sample_transform_on_device(self, sample: Any) -> Any:
return self._train_per_sample_transform_on_device(sample)
class CustomModel(Task):
-
def __init__(self):
super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss())
@@ -856,17 +832,15 @@ class CustomDataModule(DataModule):
limit_val_batches=1,
limit_test_batches=2,
limit_predict_batches=2,
- num_sanity_val_steps=1
+ num_sanity_val_steps=1,
)
trainer.fit(model, datamodule=datamodule)
trainer.test(model)
def test_preprocess_transforms(tmpdir):
- """
- This test makes sure that when a preprocess is being provided transforms as dictionaries,
- checking is done properly, and collate_in_worker_from_transform is properly extracted.
- """
+ """This test makes sure that when a preprocess is being provided transforms as dictionaries, checking is done
+ properly, and collate_in_worker_from_transform is properly extracted."""
with pytest.raises(MisconfigurationException, match="Transform should be a dict."):
DefaultPreprocess(train_transform="choco")
@@ -885,13 +859,13 @@ def test_preprocess_transforms(tmpdir):
preprocess = DefaultPreprocess(
train_transform={
"per_batch_transform": torch.nn.Linear(1, 1),
- "per_sample_transform_on_device": torch.nn.Linear(1, 1)
+ "per_sample_transform_on_device": torch.nn.Linear(1, 1),
}
)
preprocess = DefaultPreprocess(
train_transform={"per_batch_transform": torch.nn.Linear(1, 1)},
- predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)}
+ predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)},
)
# keep is None
assert preprocess._train_collate_in_worker_from_transform is True
@@ -910,7 +884,6 @@ def test_preprocess_transforms(tmpdir):
assert predict_preprocessor.collate_fn.func == DataPipeline._identity
class CustomPreprocess(DefaultPreprocess):
-
def per_sample_transform_on_device(self, sample: Any) -> Any:
return super().per_sample_transform_on_device(sample)
@@ -919,7 +892,7 @@ def per_batch_transform(self, batch: Any) -> Any:
preprocess = CustomPreprocess(
train_transform={"per_batch_transform": torch.nn.Linear(1, 1)},
- predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)}
+ predict_transform={"per_sample_transform_on_device": torch.nn.Linear(1, 1)},
)
# keep is None
assert preprocess._train_collate_in_worker_from_transform is True
@@ -941,9 +914,7 @@ def per_batch_transform(self, batch: Any) -> Any:
def test_iterable_auto_dataset(tmpdir):
-
class CustomDataSource(DataSource):
-
def load_sample(self, index: int) -> Dict[str, int]:
return {"index": index}
@@ -954,7 +925,6 @@ def load_sample(self, index: int) -> Dict[str, int]:
class CustomPreprocessHyperparameters(DefaultPreprocess):
-
def __init__(self, token: str, *args, **kwargs):
self.token = token
super().__init__(*args, **kwargs)
diff --git a/tests/core/data/test_data_source.py b/tests/core/data/test_data_source.py
index 77dbb173be..24a0b875fc 100644
--- a/tests/core/data/test_data_source.py
+++ b/tests/core/data/test_data_source.py
@@ -17,7 +17,7 @@
def test_dataset_data_source():
data_source = DatasetDataSource()
- input, target = 'test', 3
+ input, target = "test", 3
assert data_source.load_sample((input, target)) == {DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target}
assert data_source.load_sample(input) == {DefaultDataKeys.INPUT: input}
diff --git a/tests/core/data/test_process.py b/tests/core/data/test_process.py
index 2e834fd666..509bbce3f8 100644
--- a/tests/core/data/test_process.py
+++ b/tests/core/data/test_process.py
@@ -33,41 +33,43 @@ def test_serializer():
my_serializer = Serializer()
- assert my_serializer.serialize('test') == 'test'
+ assert my_serializer.serialize("test") == "test"
my_serializer.serialize = Mock()
my_serializer.disable()
- assert my_serializer('test') == 'test'
+ assert my_serializer("test") == "test"
my_serializer.serialize.assert_not_called()
my_serializer.enable()
- my_serializer('test')
+ my_serializer("test")
my_serializer.serialize.assert_called_once()
def test_serializer_mapping():
- """Tests that ``SerializerMapping`` correctly passes its inputs to the underlying serializers. Also checks that
- state is retrieved / loaded correctly."""
+ """Tests that ``SerializerMapping`` correctly passes its inputs to the underlying serializers.
+
+ Also checks that state is retrieved / loaded correctly.
+ """
serializer1 = Serializer()
- serializer1.serialize = Mock(return_value='test1')
+ serializer1.serialize = Mock(return_value="test1")
class Serializer1State(ProcessState):
pass
serializer2 = Serializer()
- serializer2.serialize = Mock(return_value='test2')
+ serializer2.serialize = Mock(return_value="test2")
class Serializer2State(ProcessState):
pass
- serializer_mapping = SerializerMapping({'key1': serializer1, 'key2': serializer2})
- assert serializer_mapping({'key1': 'serializer1', 'key2': 'serializer2'}) == {'key1': 'test1', 'key2': 'test2'}
- serializer1.serialize.assert_called_once_with('serializer1')
- serializer2.serialize.assert_called_once_with('serializer2')
+ serializer_mapping = SerializerMapping({"key1": serializer1, "key2": serializer2})
+ assert serializer_mapping({"key1": "serializer1", "key2": "serializer2"}) == {"key1": "test1", "key2": "test2"}
+ serializer1.serialize.assert_called_once_with("serializer1")
+ serializer2.serialize.assert_called_once_with("serializer2")
- with pytest.raises(ValueError, match='output must be a mapping'):
- serializer_mapping('not a mapping')
+ with pytest.raises(ValueError, match="output must be a mapping"):
+ serializer_mapping("not a mapping")
serializer1_state = Serializer1State()
serializer2_state = Serializer2State()
@@ -87,10 +89,9 @@ class Serializer2State(ProcessState):
def test_saving_with_serializers(tmpdir):
- checkpoint_file = os.path.join(tmpdir, 'tmp.ckpt')
+ checkpoint_file = os.path.join(tmpdir, "tmp.ckpt")
class CustomModel(Task):
-
def __init__(self):
super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss())
@@ -110,7 +111,6 @@ def __init__(self):
class CustomPreprocess(DefaultPreprocess):
-
def __init__(self):
super().__init__(
data_sources={
diff --git a/tests/core/data/test_sampler.py b/tests/core/data/test_sampler.py
index 9ee9ace3a1..fd114d64f2 100644
--- a/tests/core/data/test_sampler.py
+++ b/tests/core/data/test_sampler.py
@@ -19,14 +19,14 @@
@mock.patch("flash.core.data.data_module.DataLoader")
def test_dataloaders_with_sampler(mock_dataloader):
- train_ds = val_ds = test_ds = 'dataset'
- mock_sampler = 'sampler'
+ train_ds = val_ds = test_ds = "dataset"
+ mock_sampler = mock.MagicMock()
dm = DataModule(train_ds, val_ds, test_ds, num_workers=0, sampler=mock_sampler)
assert dm.sampler is mock_sampler
dl = dm.train_dataloader()
kwargs = mock_dataloader.call_args[1]
- assert 'sampler' in kwargs
- assert kwargs['sampler'] is mock_sampler
+ assert "sampler" in kwargs
+ assert kwargs["sampler"] is mock_sampler.return_value
for dl in [dm.val_dataloader(), dm.test_dataloader()]:
kwargs = mock_dataloader.call_args[1]
- assert 'sampler' not in kwargs
+ assert "sampler" not in kwargs
diff --git a/tests/core/data/test_serialization.py b/tests/core/data/test_serialization.py
index 5c368bb0b9..948f6bee13 100644
--- a/tests/core/data/test_serialization.py
+++ b/tests/core/data/test_serialization.py
@@ -25,13 +25,11 @@
class CustomModel(Task):
-
def __init__(self):
super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss())
class CustomPreprocess(DefaultPreprocess):
-
@classmethod
def load_data(cls, data):
return data
@@ -40,8 +38,8 @@ def load_data(cls, data):
def test_serialization_data_pipeline(tmpdir):
model = CustomModel()
- checkpoint_file = os.path.join(tmpdir, 'tmp.ckpt')
- checkpoint = ModelCheckpoint(tmpdir, 'test.ckpt')
+ checkpoint_file = os.path.join(tmpdir, "tmp.ckpt")
+ checkpoint = ModelCheckpoint(tmpdir, "test.ckpt")
trainer = Trainer(callbacks=[checkpoint], max_epochs=1)
dummy_data = DataLoader(list(zip(torch.arange(10, dtype=torch.float), torch.arange(10, dtype=torch.float))))
trainer.fit(model, dummy_data)
@@ -69,5 +67,5 @@ def fn(*args, **kwargs):
assert loaded_model.data_pipeline
assert isinstance(loaded_model.preprocess, CustomPreprocess)
for file in os.listdir(tmpdir):
- if file.endswith('.ckpt'):
+ if file.endswith(".ckpt"):
os.remove(os.path.join(tmpdir, file))
diff --git a/tests/core/data/test_splits.py b/tests/core/data/test_splits.py
index 14e7f12993..0d58ed2228 100644
--- a/tests/core/data/test_splits.py
+++ b/tests/core/data/test_splits.py
@@ -28,7 +28,6 @@ def test_split_dataset():
assert len(np.unique(train_ds.indices)) == len(train_ds.indices)
class Dataset:
-
def __init__(self):
self.data = [0, 1, 2]
self.name = "something"
diff --git a/tests/core/data/test_transforms.py b/tests/core/data/test_transforms.py
index f9239aa654..b66bd41cc8 100644
--- a/tests/core/data/test_transforms.py
+++ b/tests/core/data/test_transforms.py
@@ -23,40 +23,21 @@
class TestApplyToKeys:
-
@pytest.mark.parametrize(
- "sample, keys, expected", [
- ({
- DefaultDataKeys.INPUT: "test"
- }, DefaultDataKeys.INPUT, "test"),
+ "sample, keys, expected",
+ [
+ ({DefaultDataKeys.INPUT: "test"}, DefaultDataKeys.INPUT, "test"),
(
- {
- DefaultDataKeys.INPUT: "test_a",
- DefaultDataKeys.TARGET: "test_b"
- },
+ {DefaultDataKeys.INPUT: "test_a", DefaultDataKeys.TARGET: "test_b"},
[DefaultDataKeys.INPUT, DefaultDataKeys.TARGET],
["test_a", "test_b"],
),
- ({
- "input": "test"
- }, "input", "test"),
- ({
- "input": "test_a",
- "target": "test_b"
- }, ["input", "target"], ["test_a", "test_b"]),
- ({
- "input": "test_a",
- "target": "test_b",
- "extra": "..."
- }, ["input", "target"], ["test_a", "test_b"]),
- ({
- "input": "test_a",
- "target": "test_b"
- }, ["input", "target", "extra"], ["test_a", "test_b"]),
- ({
- "target": "..."
- }, "input", None),
- ]
+ ({"input": "test"}, "input", "test"),
+ ({"input": "test_a", "target": "test_b"}, ["input", "target"], ["test_a", "test_b"]),
+ ({"input": "test_a", "target": "test_b", "extra": "..."}, ["input", "target"], ["test_a", "test_b"]),
+ ({"input": "test_a", "target": "test_b"}, ["input", "target", "extra"], ["test_a", "test_b"]),
+ ({"target": "..."}, "input", None),
+ ],
)
def test_forward(self, sample, keys, expected):
transform = Mock(return_value=["out"] * len(keys))
@@ -67,7 +48,8 @@ def test_forward(self, sample, keys, expected):
transform.assert_not_called()
@pytest.mark.parametrize(
- "transform, expected", [
+ "transform, expected",
+ [
(
ApplyToKeys(DefaultDataKeys.INPUT, torch.nn.ReLU()),
"ApplyToKeys(keys=, transform=ReLU())",
@@ -82,7 +64,7 @@ def test_forward(self, sample, keys, expected):
ApplyToKeys(["input", "target"], torch.nn.ReLU()),
"ApplyToKeys(keys=['input', 'target'], transform=ReLU())",
),
- ]
+ ],
)
def test_repr(self, transform, expected):
assert repr(transform) == expected
@@ -118,18 +100,9 @@ def test_kornia_parallel_transforms(with_params):
def test_kornia_collate():
samples = [
- {
- DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10),
- DefaultDataKeys.TARGET: 1
- },
- {
- DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10),
- DefaultDataKeys.TARGET: 2
- },
- {
- DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10),
- DefaultDataKeys.TARGET: 3
- },
+ {DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), DefaultDataKeys.TARGET: 1},
+ {DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), DefaultDataKeys.TARGET: 2},
+ {DefaultDataKeys.INPUT: torch.zeros(1, 3, 10, 10), DefaultDataKeys.TARGET: 3},
]
result = kornia_collate(samples)
@@ -145,24 +118,13 @@ def test_kornia_collate():
"base_transforms, additional_transforms, expected_result",
[
(
- {
- "to_tensor_transform": _MOCK_TRANSFORM
- },
- {
- "post_tensor_transform": _MOCK_TRANSFORM
- },
- {
- "to_tensor_transform": _MOCK_TRANSFORM,
- "post_tensor_transform": _MOCK_TRANSFORM
- },
+ {"to_tensor_transform": _MOCK_TRANSFORM},
+ {"post_tensor_transform": _MOCK_TRANSFORM},
+ {"to_tensor_transform": _MOCK_TRANSFORM, "post_tensor_transform": _MOCK_TRANSFORM},
),
(
- {
- "to_tensor_transform": _MOCK_TRANSFORM
- },
- {
- "to_tensor_transform": _MOCK_TRANSFORM
- },
+ {"to_tensor_transform": _MOCK_TRANSFORM},
+ {"to_tensor_transform": _MOCK_TRANSFORM},
{
"to_tensor_transform": nn.Sequential(
convert_to_modules(_MOCK_TRANSFORM), convert_to_modules(_MOCK_TRANSFORM)
@@ -170,33 +132,23 @@ def test_kornia_collate():
},
),
(
- {
- "to_tensor_transform": _MOCK_TRANSFORM
- },
- {
- "to_tensor_transform": _MOCK_TRANSFORM,
- "post_tensor_transform": _MOCK_TRANSFORM
- },
+ {"to_tensor_transform": _MOCK_TRANSFORM},
+ {"to_tensor_transform": _MOCK_TRANSFORM, "post_tensor_transform": _MOCK_TRANSFORM},
{
"to_tensor_transform": nn.Sequential(
convert_to_modules(_MOCK_TRANSFORM), convert_to_modules(_MOCK_TRANSFORM)
),
- "post_tensor_transform": _MOCK_TRANSFORM
+ "post_tensor_transform": _MOCK_TRANSFORM,
},
),
(
- {
- "to_tensor_transform": _MOCK_TRANSFORM,
- "post_tensor_transform": _MOCK_TRANSFORM
- },
- {
- "to_tensor_transform": _MOCK_TRANSFORM
- },
+ {"to_tensor_transform": _MOCK_TRANSFORM, "post_tensor_transform": _MOCK_TRANSFORM},
+ {"to_tensor_transform": _MOCK_TRANSFORM},
{
"to_tensor_transform": nn.Sequential(
convert_to_modules(_MOCK_TRANSFORM), convert_to_modules(_MOCK_TRANSFORM)
),
- "post_tensor_transform": _MOCK_TRANSFORM
+ "post_tensor_transform": _MOCK_TRANSFORM,
},
),
],
diff --git a/tests/core/optimizers/__init__.py b/tests/core/optimizers/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/core/optimizers/test_lr_shceduler.py b/tests/core/optimizers/test_lr_shceduler.py
new file mode 100644
index 0000000000..922978b014
--- /dev/null
+++ b/tests/core/optimizers/test_lr_shceduler.py
@@ -0,0 +1,64 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+
+import pytest
+from torch import nn
+from torch.optim import Adam
+
+from flash.core.optimizers import LinearWarmupCosineAnnealingLR
+
+
+@pytest.mark.parametrize(
+ "lr, warmup_epochs, max_epochs, warmup_start_lr, eta_min",
+ [
+ (1, 10, 3200, 0.001, 0.0),
+ (1e-4, 40, 300, 1e-6, 1e-5),
+ (0.01, 1, 10, 0.0, 0.0),
+ (0.01, 0, 10, 0.0, 0.0), # only cosine decay
+ (0.01, 10, 10, 0.0, 0.0), # only linear warmup
+ ],
+)
+def test_linear_warmup_cosine_annealing_lr(tmpdir, lr, warmup_epochs, max_epochs, warmup_start_lr, eta_min):
+ layer1 = nn.Linear(10, 1)
+ layer2 = nn.Linear(10, 1)
+ optimizer1 = Adam(layer1.parameters(), lr=lr)
+ optimizer2 = Adam(layer2.parameters(), lr=lr)
+
+ scheduler1 = LinearWarmupCosineAnnealingLR(
+ optimizer1,
+ warmup_epochs=warmup_epochs,
+ max_epochs=max_epochs,
+ warmup_start_lr=warmup_start_lr,
+ eta_min=eta_min,
+ )
+
+ scheduler2 = LinearWarmupCosineAnnealingLR(
+ optimizer2,
+ warmup_epochs=warmup_epochs,
+ max_epochs=max_epochs,
+ warmup_start_lr=warmup_start_lr,
+ eta_min=eta_min,
+ )
+
+ # compares closed form lr values against values of get_lr function
+ for epoch in range(max_epochs):
+ scheduler1.step(epoch)
+ expected_lr = scheduler1.get_last_lr()[0]
+ current_lr = scheduler2.get_last_lr()[0]
+
+ assert math.isclose(expected_lr, current_lr, rel_tol=1e-12)
+ optimizer1.step()
+ optimizer2.step()
+ scheduler2.step()
diff --git a/tests/core/optimizers/test_optimizers.py b/tests/core/optimizers/test_optimizers.py
new file mode 100644
index 0000000000..1413b762bc
--- /dev/null
+++ b/tests/core/optimizers/test_optimizers.py
@@ -0,0 +1,57 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+import torch
+from torch import nn
+
+from flash.core.optimizers import LAMB, LARS, LinearWarmupCosineAnnealingLR
+
+
+@pytest.mark.parametrize(
+ "optim_fn, lr, kwargs",
+ [
+ (LARS, 0.1, {}),
+ (LARS, 0.1, {"weight_decay": 0.001}),
+ (LARS, 0.1, {"momentum": 0.9}),
+ (LAMB, 1e-3, {}),
+ (LAMB, 1e-3, {"amsgrad": True}),
+ (LAMB, 1e-3, {"weight_decay": 0.001}),
+ ],
+)
+def test_optim_call(tmpdir, optim_fn, lr, kwargs):
+ layer = nn.Linear(10, 1)
+ optimizer = optim_fn(layer.parameters(), lr=lr, **kwargs)
+
+ for _ in range(10):
+ dummy_input = torch.rand(1, 10)
+ dummy_input.requires_grad = True
+ result = layer(dummy_input)
+ result.backward()
+ optimizer.step()
+
+
+@pytest.mark.parametrize("optim_fn, lr", [(LARS, 0.1), (LAMB, 1e-3)])
+def test_optim_with_scheduler(tmpdir, optim_fn, lr):
+ max_epochs = 10
+ layer = nn.Linear(10, 1)
+ optimizer = optim_fn(layer.parameters(), lr=lr)
+ scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=2, max_epochs=max_epochs)
+
+ for _ in range(max_epochs):
+ dummy_input = torch.rand(1, 10)
+ dummy_input.requires_grad = True
+ result = layer(dummy_input)
+ result.backward()
+ optimizer.step()
+ scheduler.step()
diff --git a/tests/core/serve/models.py b/tests/core/serve/models.py
index 63f99327f7..9e0e914c41 100644
--- a/tests/core/serve/models.py
+++ b/tests/core/serve/models.py
@@ -14,7 +14,6 @@
class LightningSqueezenet(pl.LightningModule):
-
def __init__(self):
super().__init__()
self.model = squeezenet1_1(pretrained=True).eval()
@@ -24,7 +23,6 @@ def forward(self, x):
class LightningSqueezenetServable(pl.LightningModule):
-
def __init__(self, model):
super().__init__()
self.model = model
@@ -38,7 +36,6 @@ def _func_from_exposed(arg):
class ClassificationInference(ModelComponent):
-
def __init__(self, model): # skipcq: PYL-W0621
self.model = model
@@ -73,7 +70,6 @@ def method_from_exposed(arg):
try:
class ClassificationInferenceRepeated(ModelComponent):
-
def __init__(self, model):
self.model = model
@@ -92,13 +88,14 @@ def classify(self, img):
img = img.permute(0, 3, 2, 1)
out = self.model(img)
return ([out.argmax(), out.argmax()], torch.Tensor([21]))
+
+
except TypeError:
ClassificationInferenceRepeated = None
try:
class ClassificationInferenceModelSequence(ModelComponent):
-
def __init__(self, model):
self.model1 = model[0]
self.model2 = model[1]
@@ -117,13 +114,14 @@ def classify(self, img):
out2 = self.model2(img)
assert out.argmax() == out2.argmax()
return out.argmax()
+
+
except TypeError:
ClassificationInferenceRepeated = None
try:
class ClassificationInferenceModelMapping(ModelComponent):
-
def __init__(self, model):
self.model1 = model["model_one"]
self.model2 = model["model_two"]
@@ -142,13 +140,14 @@ def classify(self, img):
out2 = self.model2(img)
assert out.argmax() == out2.argmax()
return out.argmax()
+
+
except TypeError:
ClassificationInferenceModelMapping = None
try:
class ClassificationInferenceComposable(ModelComponent):
-
def __init__(self, model):
self.model = model
@@ -171,13 +170,14 @@ def classify(self, img, tag):
out = self.model(img_new)
return out.argmax(), img
+
+
except TypeError:
ClassificationInferenceComposable = None
try:
class SeatClassifier(ModelComponent):
-
def __init__(self, model, config):
self.sport = config["sport"]
@@ -197,5 +197,7 @@ def predict(self, section, isle, row, stadium):
seat_num = section.item() * isle.item() * row.item() * stadium * len(self.sport)
stadium_idx = torch.tensor(1000)
return torch.Tensor([seat_num]), stadium_idx
+
+
except TypeError:
SeatClassifier = None
diff --git a/tests/core/serve/test_compat/test_cached_property.py b/tests/core/serve/test_compat/test_cached_property.py
index c6c909bdf8..b708fa8189 100644
--- a/tests/core/serve/test_compat/test_cached_property.py
+++ b/tests/core/serve/test_compat/test_cached_property.py
@@ -79,7 +79,6 @@ def cost(self):
# noinspection PyStatementEffect
@pytest.mark.skipif(sys.version_info >= (3, 8), reason="Python 3.8+ uses standard library implementation.")
class TestCachedProperty:
-
@staticmethod
def test_cached():
item = CachedCostItem()
@@ -125,7 +124,6 @@ def test_object_with_slots():
@staticmethod
def test_immutable_dict():
-
class MyMeta(type):
"""Test metaclass."""
@@ -214,7 +212,6 @@ def test_doc():
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Validate, that python 3.8 uses standard implementation")
class TestPy38Plus:
-
@staticmethod
def test_is():
import functools
diff --git a/tests/core/serve/test_components.py b/tests/core/serve/test_components.py
index a32773726f..f31f89c84a 100644
--- a/tests/core/serve/test_components.py
+++ b/tests/core/serve/test_components.py
@@ -21,12 +21,14 @@ def test_model_compute_dependencies(lightning_squeezenet1_1_obj):
comp2 = ClassificationInferenceComposable(lightning_squeezenet1_1_obj)
comp1.inputs.tag << comp2.outputs.predicted_tag
- res = [{
- "source_component": "callnum_2",
- "source_key": "predicted_tag",
- "target_component": "callnum_1",
- "target_key": "tag",
- }]
+ res = [
+ {
+ "source_component": "callnum_2",
+ "source_key": "predicted_tag",
+ "target_component": "callnum_1",
+ "target_key": "tag",
+ }
+ ]
assert list(map(lambda x: x._asdict(), comp1._flashserve_meta_.connections)) == res
assert list(comp2._flashserve_meta_.connections) == []
@@ -38,12 +40,14 @@ def test_inverse_model_compute_component_dependencies(lightning_squeezenet1_1_ob
comp2.outputs.predicted_tag >> comp1.inputs.tag
- res = [{
- "source_component": "callnum_2",
- "source_key": "predicted_tag",
- "target_component": "callnum_1",
- "target_key": "tag",
- }]
+ res = [
+ {
+ "source_component": "callnum_2",
+ "source_key": "predicted_tag",
+ "target_component": "callnum_1",
+ "target_key": "tag",
+ }
+ ]
assert list(map(lambda x: x._asdict(), comp2._flashserve_meta_.connections)) == res
assert list(comp1._flashserve_meta_.connections) == []
@@ -74,7 +78,6 @@ def test_two_component_invalid_dependencies_fail(lightning_squeezenet1_1_obj):
comp2.outputs.predicted_tag >> comp1.outputs.predicted_tag
class Foo:
-
def __init__(self):
pass
@@ -128,7 +131,6 @@ def test_invalid_expose_inputs():
with pytest.raises(SyntaxError, match="must be valid python attribute"):
class ComposeClassInvalidExposeNameKeyword(ModelComponent):
-
def __init__(self, model):
pass
@@ -142,7 +144,6 @@ def predict(param):
with pytest.raises(AttributeError, match="object has no attribute"):
class ComposeClassInvalidExposeNameType(ModelComponent):
-
def __init__(self, model):
pass
@@ -156,7 +157,6 @@ def predict(param):
with pytest.raises(TypeError, match="`expose` values must be"):
class ComposeClassInvalidExposeInputsType(ModelComponent):
-
def __init__(self, model):
pass
@@ -170,7 +170,6 @@ def predict(param):
with pytest.raises(ValueError, match="cannot set dict of length < 1"):
class ComposeClassEmptyExposeInputsType(ModelComponent):
-
def __init__(self, model):
pass
@@ -206,7 +205,6 @@ def test_invalid_name(lightning_squeezenet1_1_obj):
with pytest.raises(SyntaxError):
class FailedExposedOutputsKeyworkName(ModelComponent):
-
def __init__(self, model):
self.model = model
@@ -222,7 +220,6 @@ def test_invalid_config_args(lightning_squeezenet1_1_obj):
from flash.core.serve.types import Number
class SomeComponent(ModelComponent):
-
def __init__(self, model, config=None):
self.model = model
self.config = config
@@ -250,7 +247,6 @@ def test_invalid_model_args(lightning_squeezenet1_1_obj):
from flash.core.serve.types import Number
class SomeComponent(ModelComponent):
-
def __init__(self, model):
self.model = model
diff --git a/tests/core/serve/test_composition.py b/tests/core/serve/test_composition.py
index 5679859ee2..c354e64f2f 100644
--- a/tests/core/serve/test_composition.py
+++ b/tests/core/serve/test_composition.py
@@ -23,10 +23,7 @@ def test_composit_endpoint_data(lightning_squeezenet1_1_obj):
actual_endpoints = {k: asdict(v) for k, v in composit.endpoints.items()}
assert actual_endpoints == {
"classify_ENDPOINT": {
- "inputs": {
- "img": "callnum_1.inputs.img",
- "tag": "callnum_1.inputs.tag"
- },
+ "inputs": {"img": "callnum_1.inputs.img", "tag": "callnum_1.inputs.tag"},
"outputs": {
"cropped_img": "callnum_1.outputs.cropped_img",
"predicted_tag": "callnum_1.outputs.predicted_tag",
@@ -50,10 +47,7 @@ def test_composit_endpoint_data(lightning_squeezenet1_1_obj):
actual_endpoints = {k: asdict(v) for k, v in composit.endpoints.items()}
assert actual_endpoints == {
"predict_ep": {
- "inputs": {
- "label_1": "callnum_1.inputs.img",
- "tag_1": "callnum_1.inputs.tag"
- },
+ "inputs": {"label_1": "callnum_1.inputs.img", "tag_1": "callnum_1.inputs.tag"},
"outputs": {
"cropped": "callnum_1.outputs.cropped_img",
"prediction": "callnum_1.outputs.predicted_tag",
@@ -381,21 +375,13 @@ def test_start_server_from_composition(tmp_path, squeezenet_servable, session_gl
data = {
"session": "session_uuid",
"payload": {
- "img_1": {
- "data": cat_imgstr
- },
- "img_2": {
- "data": fish_imgstr
- },
- "tag_1": {
- "label": "stingray"
- },
+ "img_1": {"data": cat_imgstr},
+ "img_2": {"data": fish_imgstr},
+ "tag_1": {"label": "stingray"},
},
}
expected_response = {
- "result": {
- "prediction": "goldfish, Carassius auratus"
- },
+ "result": {"prediction": "goldfish, Carassius auratus"},
"session": "session_uuid",
}
diff --git a/tests/core/serve/test_dag/test_optimization.py b/tests/core/serve/test_dag/test_optimization.py
index 238adcfa3c..673dce8106 100644
--- a/tests/core/serve/test_dag/test_optimization.py
+++ b/tests/core/serve/test_dag/test_optimization.py
@@ -36,7 +36,7 @@ def test_cull():
def fuse2(*args, **kwargs):
- """Run both ``fuse`` and ``fuse_linear`` and compare results"""
+ """Run both ``fuse`` and ``fuse_linear`` and compare results."""
rv1 = fuse_linear(*args, **kwargs)
if kwargs.get("rename_keys") is not False:
return rv1
@@ -60,12 +60,14 @@ def test_fuse():
"b": 2,
}
assert fuse(d, rename_keys=False) == with_deps({"w": (inc, (inc, (inc, (add, "a", "b")))), "a": 1, "b": 2})
- assert fuse(d, rename_keys=True) == with_deps({
- "z-y-x-w": (inc, (inc, (inc, (add, "a", "b")))),
- "a": 1,
- "b": 2,
- "w": "z-y-x-w",
- })
+ assert fuse(d, rename_keys=True) == with_deps(
+ {
+ "z-y-x-w": (inc, (inc, (inc, (add, "a", "b")))),
+ "a": 1,
+ "b": 2,
+ "w": "z-y-x-w",
+ }
+ )
d = {
"NEW": (inc, "y"),
@@ -76,22 +78,26 @@ def test_fuse():
"a": 1,
"b": 2,
}
- assert fuse(d, rename_keys=False) == with_deps({
- "NEW": (inc, "y"),
- "w": (inc, (inc, "y")),
- "y": (inc, (add, "a", "b")),
- "a": 1,
- "b": 2,
- })
- assert fuse(d, rename_keys=True) == with_deps({
- "NEW": (inc, "z-y"),
- "x-w": (inc, (inc, "z-y")),
- "z-y": (inc, (add, "a", "b")),
- "a": 1,
- "b": 2,
- "w": "x-w",
- "y": "z-y",
- })
+ assert fuse(d, rename_keys=False) == with_deps(
+ {
+ "NEW": (inc, "y"),
+ "w": (inc, (inc, "y")),
+ "y": (inc, (add, "a", "b")),
+ "a": 1,
+ "b": 2,
+ }
+ )
+ assert fuse(d, rename_keys=True) == with_deps(
+ {
+ "NEW": (inc, "z-y"),
+ "x-w": (inc, (inc, "z-y")),
+ "z-y": (inc, (add, "a", "b")),
+ "a": 1,
+ "b": 2,
+ "w": "x-w",
+ "y": "z-y",
+ }
+ )
d = {
"v": (inc, "y"),
@@ -105,24 +111,28 @@ def test_fuse():
"c": 1,
"d": 2,
}
- assert fuse(d, rename_keys=False) == with_deps({
- "u": (inc, (inc, (inc, "y"))),
- "v": (inc, "y"),
- "y": (inc, (add, "a", "b")),
- "a": (inc, 1),
- "b": (inc, 2),
- })
- assert fuse(d, rename_keys=True) == with_deps({
- "x-w-u": (inc, (inc, (inc, "z-y"))),
- "v": (inc, "z-y"),
- "z-y": (inc, (add, "c-a", "d-b")),
- "c-a": (inc, 1),
- "d-b": (inc, 2),
- "a": "c-a",
- "b": "d-b",
- "u": "x-w-u",
- "y": "z-y",
- })
+ assert fuse(d, rename_keys=False) == with_deps(
+ {
+ "u": (inc, (inc, (inc, "y"))),
+ "v": (inc, "y"),
+ "y": (inc, (add, "a", "b")),
+ "a": (inc, 1),
+ "b": (inc, 2),
+ }
+ )
+ assert fuse(d, rename_keys=True) == with_deps(
+ {
+ "x-w-u": (inc, (inc, (inc, "z-y"))),
+ "v": (inc, "z-y"),
+ "z-y": (inc, (add, "c-a", "d-b")),
+ "c-a": (inc, 1),
+ "d-b": (inc, 2),
+ "a": "c-a",
+ "b": "d-b",
+ "u": "x-w-u",
+ "y": "z-y",
+ }
+ )
d = {
"a": (inc, "x"),
@@ -132,20 +142,19 @@ def test_fuse():
"x": (inc, "y"),
"y": 0,
}
- assert fuse(d, rename_keys=False) == with_deps({
- "a": (inc, "x"),
- "b": (inc, "x"),
- "d": (inc, (inc, "x")),
- "x": (inc, 0)
- })
- assert fuse(d, rename_keys=True) == with_deps({
- "a": (inc, "y-x"),
- "b": (inc, "y-x"),
- "c-d": (inc, (inc, "y-x")),
- "y-x": (inc, 0),
- "d": "c-d",
- "x": "y-x",
- })
+ assert fuse(d, rename_keys=False) == with_deps(
+ {"a": (inc, "x"), "b": (inc, "x"), "d": (inc, (inc, "x")), "x": (inc, 0)}
+ )
+ assert fuse(d, rename_keys=True) == with_deps(
+ {
+ "a": (inc, "y-x"),
+ "b": (inc, "y-x"),
+ "c-d": (inc, (inc, "y-x")),
+ "y-x": (inc, 0),
+ "d": "c-d",
+ "x": "y-x",
+ }
+ )
d = {"a": 1, "b": (inc, "a"), "c": (add, "b", "b")}
assert fuse(d, rename_keys=False) == with_deps({"b": (inc, 1), "c": (add, "b", "b")})
@@ -168,21 +177,19 @@ def test_fuse_keys():
"b": 2,
}
keys = ["x", "z"]
- assert fuse(d, keys, rename_keys=False) == with_deps({
- "w": (inc, "x"),
- "x": (inc, (inc, "z")),
- "z": (add, "a", "b"),
- "a": 1,
- "b": 2
- })
- assert fuse(d, keys, rename_keys=True) == with_deps({
- "w": (inc, "y-x"),
- "y-x": (inc, (inc, "z")),
- "z": (add, "a", "b"),
- "a": 1,
- "b": 2,
- "x": "y-x",
- })
+ assert fuse(d, keys, rename_keys=False) == with_deps(
+ {"w": (inc, "x"), "x": (inc, (inc, "z")), "z": (add, "a", "b"), "a": 1, "b": 2}
+ )
+ assert fuse(d, keys, rename_keys=True) == with_deps(
+ {
+ "w": (inc, "y-x"),
+ "y-x": (inc, (inc, "z")),
+ "z": (add, "a", "b"),
+ "a": 1,
+ "b": 2,
+ "x": "y-x",
+ }
+ )
def test_inline():
@@ -238,9 +245,7 @@ def test_inline_ignores_curries_and_partials():
def test_inline_functions_non_hashable():
-
class NonHashableCallable:
-
def __call__(self, a):
return a + 1
@@ -277,7 +282,6 @@ def test_inline_functions_protects_output_keys():
def test_functions_of():
-
def a(x):
return x
@@ -290,7 +294,7 @@ def b(x):
assert functions_of((a, [[[(b, 1)]]])) == {a, b}
assert functions_of(1) == set()
assert functions_of(a) == set()
- assert functions_of((a, )) == {a}
+ assert functions_of((a,)) == {a}
def test_inline_cull_dependencies():
@@ -301,7 +305,6 @@ def test_inline_cull_dependencies():
def test_fuse_reductions_single_input():
-
def f(*args):
return args
@@ -309,11 +312,9 @@ def f(*args):
assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d)
assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d)
assert fuse(d, ave_width=2, rename_keys=False) == with_deps({"a": 1, "c": (f, (f, "a"), (f, "a", "a"))})
- assert fuse(d, ave_width=2, rename_keys=True) == with_deps({
- "a": 1,
- "b1-b2-c": (f, (f, "a"), (f, "a", "a")),
- "c": "b1-b2-c"
- })
+ assert fuse(d, ave_width=2, rename_keys=True) == with_deps(
+ {"a": 1, "b1-b2-c": (f, (f, "a"), (f, "a", "a")), "c": "b1-b2-c"}
+ )
d = {
"a": 1,
@@ -324,25 +325,24 @@ def f(*args):
}
assert fuse(d, ave_width=2.9, rename_keys=False) == with_deps(d)
assert fuse(d, ave_width=2.9, rename_keys=True) == with_deps(d)
- assert fuse(d, ave_width=3, rename_keys=False) == with_deps({
- "a": 1,
- "c": (f, (f, "a"), (f, "a", "a"), (f, "a", "a", "a"))
- })
- assert fuse(d, ave_width=3, rename_keys=True) == with_deps({
- "a": 1,
- "b1-b2-b3-c": (f, (f, "a"), (f, "a", "a"), (f, "a", "a", "a")),
- "c": "b1-b2-b3-c",
- })
+ assert fuse(d, ave_width=3, rename_keys=False) == with_deps(
+ {"a": 1, "c": (f, (f, "a"), (f, "a", "a"), (f, "a", "a", "a"))}
+ )
+ assert fuse(d, ave_width=3, rename_keys=True) == with_deps(
+ {
+ "a": 1,
+ "b1-b2-b3-c": (f, (f, "a"), (f, "a", "a"), (f, "a", "a", "a")),
+ "c": "b1-b2-b3-c",
+ }
+ )
d = {"a": 1, "b1": (f, "a"), "b2": (f, "a"), "c": (f, "a", "b1", "b2")}
assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d)
assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d)
assert fuse(d, ave_width=2, rename_keys=False) == with_deps({"a": 1, "c": (f, "a", (f, "a"), (f, "a"))})
- assert fuse(d, ave_width=2, rename_keys=True) == with_deps({
- "a": 1,
- "b1-b2-c": (f, "a", (f, "a"), (f, "a")),
- "c": "b1-b2-c"
- })
+ assert fuse(d, ave_width=2, rename_keys=True) == with_deps(
+ {"a": 1, "b1-b2-c": (f, "a", (f, "a"), (f, "a")), "c": "b1-b2-c"}
+ )
d = {
"a": 1,
@@ -355,18 +355,18 @@ def f(*args):
}
assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d)
assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d)
- assert fuse(d, ave_width=2, rename_keys=False) == with_deps({
- "a": 1,
- "c": (f, (f, "a"), (f, "a")),
- "e": (f, (f, "c"), (f, "c"))
- })
- assert fuse(d, ave_width=2, rename_keys=True) == with_deps({
- "a": 1,
- "b1-b2-c": (f, (f, "a"), (f, "a")),
- "d1-d2-e": (f, (f, "c"), (f, "c")),
- "c": "b1-b2-c",
- "e": "d1-d2-e",
- })
+ assert fuse(d, ave_width=2, rename_keys=False) == with_deps(
+ {"a": 1, "c": (f, (f, "a"), (f, "a")), "e": (f, (f, "c"), (f, "c"))}
+ )
+ assert fuse(d, ave_width=2, rename_keys=True) == with_deps(
+ {
+ "a": 1,
+ "b1-b2-c": (f, (f, "a"), (f, "a")),
+ "d1-d2-e": (f, (f, "c"), (f, "c")),
+ "c": "b1-b2-c",
+ "e": "d1-d2-e",
+ }
+ )
d = {
"a": 1,
@@ -380,37 +380,42 @@ def f(*args):
}
assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d)
assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d)
- expected = with_deps({
- "a": 1,
- "c1": (f, (f, "a"), (f, "a")),
- "c2": (f, (f, "a"), (f, "a")),
- "d": (f, "c1", "c2"),
- })
+ expected = with_deps(
+ {
+ "a": 1,
+ "c1": (f, (f, "a"), (f, "a")),
+ "c2": (f, (f, "a"), (f, "a")),
+ "d": (f, "c1", "c2"),
+ }
+ )
assert fuse(d, ave_width=2, rename_keys=False) == expected
assert fuse(d, ave_width=2.9, rename_keys=False) == expected
- expected = with_deps({
- "a": 1,
- "b1-b2-c1": (f, (f, "a"), (f, "a")),
- "b3-b4-c2": (f, (f, "a"), (f, "a")),
- "d": (f, "c1", "c2"),
- "c1": "b1-b2-c1",
- "c2": "b3-b4-c2",
- })
+ expected = with_deps(
+ {
+ "a": 1,
+ "b1-b2-c1": (f, (f, "a"), (f, "a")),
+ "b3-b4-c2": (f, (f, "a"), (f, "a")),
+ "d": (f, "c1", "c2"),
+ "c1": "b1-b2-c1",
+ "c2": "b3-b4-c2",
+ }
+ )
assert fuse(d, ave_width=2, rename_keys=True) == expected
assert fuse(d, ave_width=2.9, rename_keys=True) == expected
- assert fuse(d, ave_width=3, rename_keys=False) == with_deps({
- "a": 1,
- "d": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a")))
- })
- assert fuse(d, ave_width=3, rename_keys=True) == with_deps({
- "a": 1,
- "b1-b2-b3-b4-c1-c2-d": (
- f,
- (f, (f, "a"), (f, "a")),
- (f, (f, "a"), (f, "a")),
- ),
- "d": "b1-b2-b3-b4-c1-c2-d",
- })
+ assert fuse(d, ave_width=3, rename_keys=False) == with_deps(
+ {"a": 1, "d": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a")))}
+ )
+ assert fuse(d, ave_width=3, rename_keys=True) == with_deps(
+ {
+ "a": 1,
+ "b1-b2-b3-b4-c1-c2-d": (
+ f,
+ (f, (f, "a"), (f, "a")),
+ (f, (f, "a"), (f, "a")),
+ ),
+ "d": "b1-b2-b3-b4-c1-c2-d",
+ }
+ )
d = {
"a": 1,
@@ -432,77 +437,89 @@ def f(*args):
}
assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d)
assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d)
- expected = with_deps({
- "a": 1,
- "c1": (f, (f, "a"), (f, "a")),
- "c2": (f, (f, "a"), (f, "a")),
- "c3": (f, (f, "a"), (f, "a")),
- "c4": (f, (f, "a"), (f, "a")),
- "d1": (f, "c1", "c2"),
- "d2": (f, "c3", "c4"),
- "e": (f, "d1", "d2"),
- })
+ expected = with_deps(
+ {
+ "a": 1,
+ "c1": (f, (f, "a"), (f, "a")),
+ "c2": (f, (f, "a"), (f, "a")),
+ "c3": (f, (f, "a"), (f, "a")),
+ "c4": (f, (f, "a"), (f, "a")),
+ "d1": (f, "c1", "c2"),
+ "d2": (f, "c3", "c4"),
+ "e": (f, "d1", "d2"),
+ }
+ )
assert fuse(d, ave_width=2, rename_keys=False) == expected
assert fuse(d, ave_width=2.9, rename_keys=False) == expected
- expected = with_deps({
- "a": 1,
- "b1-b2-c1": (f, (f, "a"), (f, "a")),
- "b3-b4-c2": (f, (f, "a"), (f, "a")),
- "b5-b6-c3": (f, (f, "a"), (f, "a")),
- "b7-b8-c4": (f, (f, "a"), (f, "a")),
- "d1": (f, "c1", "c2"),
- "d2": (f, "c3", "c4"),
- "e": (f, "d1", "d2"),
- "c1": "b1-b2-c1",
- "c2": "b3-b4-c2",
- "c3": "b5-b6-c3",
- "c4": "b7-b8-c4",
- })
+ expected = with_deps(
+ {
+ "a": 1,
+ "b1-b2-c1": (f, (f, "a"), (f, "a")),
+ "b3-b4-c2": (f, (f, "a"), (f, "a")),
+ "b5-b6-c3": (f, (f, "a"), (f, "a")),
+ "b7-b8-c4": (f, (f, "a"), (f, "a")),
+ "d1": (f, "c1", "c2"),
+ "d2": (f, "c3", "c4"),
+ "e": (f, "d1", "d2"),
+ "c1": "b1-b2-c1",
+ "c2": "b3-b4-c2",
+ "c3": "b5-b6-c3",
+ "c4": "b7-b8-c4",
+ }
+ )
assert fuse(d, ave_width=2, rename_keys=True) == expected
assert fuse(d, ave_width=2.9, rename_keys=True) == expected
- expected = with_deps({
- "a": 1,
- "d1": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- "d2": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- "e": (f, "d1", "d2"),
- })
+ expected = with_deps(
+ {
+ "a": 1,
+ "d1": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ "d2": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ "e": (f, "d1", "d2"),
+ }
+ )
assert fuse(d, ave_width=3, rename_keys=False) == expected
assert fuse(d, ave_width=4.6, rename_keys=False) == expected
- expected = with_deps({
- "a": 1,
- "b1-b2-b3-b4-c1-c2-d1": (
- f,
- (f, (f, "a"), (f, "a")),
- (f, (f, "a"), (f, "a")),
- ),
- "b5-b6-b7-b8-c3-c4-d2": (
- f,
- (f, (f, "a"), (f, "a")),
- (f, (f, "a"), (f, "a")),
- ),
- "e": (f, "d1", "d2"),
- "d1": "b1-b2-b3-b4-c1-c2-d1",
- "d2": "b5-b6-b7-b8-c3-c4-d2",
- })
+ expected = with_deps(
+ {
+ "a": 1,
+ "b1-b2-b3-b4-c1-c2-d1": (
+ f,
+ (f, (f, "a"), (f, "a")),
+ (f, (f, "a"), (f, "a")),
+ ),
+ "b5-b6-b7-b8-c3-c4-d2": (
+ f,
+ (f, (f, "a"), (f, "a")),
+ (f, (f, "a"), (f, "a")),
+ ),
+ "e": (f, "d1", "d2"),
+ "d1": "b1-b2-b3-b4-c1-c2-d1",
+ "d2": "b5-b6-b7-b8-c3-c4-d2",
+ }
+ )
assert fuse(d, ave_width=3, rename_keys=True) == expected
assert fuse(d, ave_width=4.6, rename_keys=True) == expected
- assert fuse(d, ave_width=4.7, rename_keys=False) == with_deps({
- "a": 1,
- "e": (
- f,
- (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- ),
- })
- assert fuse(d, ave_width=4.7, rename_keys=True) == with_deps({
- "a": 1,
- "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e": (
- f,
- (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- ),
- "e": "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e",
- })
+ assert fuse(d, ave_width=4.7, rename_keys=False) == with_deps(
+ {
+ "a": 1,
+ "e": (
+ f,
+ (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ ),
+ }
+ )
+ assert fuse(d, ave_width=4.7, rename_keys=True) == with_deps(
+ {
+ "a": 1,
+ "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e": (
+ f,
+ (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ ),
+ "e": "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e",
+ }
+ )
d = {
"a": 1,
@@ -540,165 +557,181 @@ def f(*args):
}
assert fuse(d, ave_width=1.9, rename_keys=False) == with_deps(d)
assert fuse(d, ave_width=1.9, rename_keys=True) == with_deps(d)
- expected = with_deps({
- "a": 1,
- "c1": (f, (f, "a"), (f, "a")),
- "c2": (f, (f, "a"), (f, "a")),
- "c3": (f, (f, "a"), (f, "a")),
- "c4": (f, (f, "a"), (f, "a")),
- "c5": (f, (f, "a"), (f, "a")),
- "c6": (f, (f, "a"), (f, "a")),
- "c7": (f, (f, "a"), (f, "a")),
- "c8": (f, (f, "a"), (f, "a")),
- "d1": (f, "c1", "c2"),
- "d2": (f, "c3", "c4"),
- "d3": (f, "c5", "c6"),
- "d4": (f, "c7", "c8"),
- "e1": (f, "d1", "d2"),
- "e2": (f, "d3", "d4"),
- "f": (f, "e1", "e2"),
- })
+ expected = with_deps(
+ {
+ "a": 1,
+ "c1": (f, (f, "a"), (f, "a")),
+ "c2": (f, (f, "a"), (f, "a")),
+ "c3": (f, (f, "a"), (f, "a")),
+ "c4": (f, (f, "a"), (f, "a")),
+ "c5": (f, (f, "a"), (f, "a")),
+ "c6": (f, (f, "a"), (f, "a")),
+ "c7": (f, (f, "a"), (f, "a")),
+ "c8": (f, (f, "a"), (f, "a")),
+ "d1": (f, "c1", "c2"),
+ "d2": (f, "c3", "c4"),
+ "d3": (f, "c5", "c6"),
+ "d4": (f, "c7", "c8"),
+ "e1": (f, "d1", "d2"),
+ "e2": (f, "d3", "d4"),
+ "f": (f, "e1", "e2"),
+ }
+ )
assert fuse(d, ave_width=2, rename_keys=False) == expected
assert fuse(d, ave_width=2.9, rename_keys=False) == expected
- expected = with_deps({
- "a": 1,
- "b1-b2-c1": (f, (f, "a"), (f, "a")),
- "b3-b4-c2": (f, (f, "a"), (f, "a")),
- "b5-b6-c3": (f, (f, "a"), (f, "a")),
- "b7-b8-c4": (f, (f, "a"), (f, "a")),
- "b10-b9-c5": (f, (f, "a"), (f, "a")),
- "b11-b12-c6": (f, (f, "a"), (f, "a")),
- "b13-b14-c7": (f, (f, "a"), (f, "a")),
- "b15-b16-c8": (f, (f, "a"), (f, "a")),
- "d1": (f, "c1", "c2"),
- "d2": (f, "c3", "c4"),
- "d3": (f, "c5", "c6"),
- "d4": (f, "c7", "c8"),
- "e1": (f, "d1", "d2"),
- "e2": (f, "d3", "d4"),
- "f": (f, "e1", "e2"),
- "c1": "b1-b2-c1",
- "c2": "b3-b4-c2",
- "c3": "b5-b6-c3",
- "c4": "b7-b8-c4",
- "c5": "b10-b9-c5",
- "c6": "b11-b12-c6",
- "c7": "b13-b14-c7",
- "c8": "b15-b16-c8",
- })
+ expected = with_deps(
+ {
+ "a": 1,
+ "b1-b2-c1": (f, (f, "a"), (f, "a")),
+ "b3-b4-c2": (f, (f, "a"), (f, "a")),
+ "b5-b6-c3": (f, (f, "a"), (f, "a")),
+ "b7-b8-c4": (f, (f, "a"), (f, "a")),
+ "b10-b9-c5": (f, (f, "a"), (f, "a")),
+ "b11-b12-c6": (f, (f, "a"), (f, "a")),
+ "b13-b14-c7": (f, (f, "a"), (f, "a")),
+ "b15-b16-c8": (f, (f, "a"), (f, "a")),
+ "d1": (f, "c1", "c2"),
+ "d2": (f, "c3", "c4"),
+ "d3": (f, "c5", "c6"),
+ "d4": (f, "c7", "c8"),
+ "e1": (f, "d1", "d2"),
+ "e2": (f, "d3", "d4"),
+ "f": (f, "e1", "e2"),
+ "c1": "b1-b2-c1",
+ "c2": "b3-b4-c2",
+ "c3": "b5-b6-c3",
+ "c4": "b7-b8-c4",
+ "c5": "b10-b9-c5",
+ "c6": "b11-b12-c6",
+ "c7": "b13-b14-c7",
+ "c8": "b15-b16-c8",
+ }
+ )
assert fuse(d, ave_width=2, rename_keys=True) == expected
assert fuse(d, ave_width=2.9, rename_keys=True) == expected
- expected = with_deps({
- "a": 1,
- "d1": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- "d2": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- "d3": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- "d4": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- "e1": (f, "d1", "d2"),
- "e2": (f, "d3", "d4"),
- "f": (f, "e1", "e2"),
- })
+ expected = with_deps(
+ {
+ "a": 1,
+ "d1": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ "d2": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ "d3": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ "d4": (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ "e1": (f, "d1", "d2"),
+ "e2": (f, "d3", "d4"),
+ "f": (f, "e1", "e2"),
+ }
+ )
assert fuse(d, ave_width=3, rename_keys=False) == expected
assert fuse(d, ave_width=4.6, rename_keys=False) == expected
- expected = with_deps({
- "a": 1,
- "b1-b2-b3-b4-c1-c2-d1": (
- f,
- (f, (f, "a"), (f, "a")),
- (f, (f, "a"), (f, "a")),
- ),
- "b5-b6-b7-b8-c3-c4-d2": (
- f,
- (f, (f, "a"), (f, "a")),
- (f, (f, "a"), (f, "a")),
- ),
- "b10-b11-b12-b9-c5-c6-d3": (
- f,
- (f, (f, "a"), (f, "a")),
- (f, (f, "a"), (f, "a")),
- ),
- "b13-b14-b15-b16-c7-c8-d4": (
- f,
- (f, (f, "a"), (f, "a")),
- (f, (f, "a"), (f, "a")),
- ),
- "e1": (f, "d1", "d2"),
- "e2": (f, "d3", "d4"),
- "f": (f, "e1", "e2"),
- "d1": "b1-b2-b3-b4-c1-c2-d1",
- "d2": "b5-b6-b7-b8-c3-c4-d2",
- "d3": "b10-b11-b12-b9-c5-c6-d3",
- "d4": "b13-b14-b15-b16-c7-c8-d4",
- })
+ expected = with_deps(
+ {
+ "a": 1,
+ "b1-b2-b3-b4-c1-c2-d1": (
+ f,
+ (f, (f, "a"), (f, "a")),
+ (f, (f, "a"), (f, "a")),
+ ),
+ "b5-b6-b7-b8-c3-c4-d2": (
+ f,
+ (f, (f, "a"), (f, "a")),
+ (f, (f, "a"), (f, "a")),
+ ),
+ "b10-b11-b12-b9-c5-c6-d3": (
+ f,
+ (f, (f, "a"), (f, "a")),
+ (f, (f, "a"), (f, "a")),
+ ),
+ "b13-b14-b15-b16-c7-c8-d4": (
+ f,
+ (f, (f, "a"), (f, "a")),
+ (f, (f, "a"), (f, "a")),
+ ),
+ "e1": (f, "d1", "d2"),
+ "e2": (f, "d3", "d4"),
+ "f": (f, "e1", "e2"),
+ "d1": "b1-b2-b3-b4-c1-c2-d1",
+ "d2": "b5-b6-b7-b8-c3-c4-d2",
+ "d3": "b10-b11-b12-b9-c5-c6-d3",
+ "d4": "b13-b14-b15-b16-c7-c8-d4",
+ }
+ )
assert fuse(d, ave_width=3, rename_keys=True) == expected
assert fuse(d, ave_width=4.6, rename_keys=True) == expected
- expected = with_deps({
- "a": 1,
- "e1": (
- f,
- (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- ),
- "e2": (
- f,
- (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- ),
- "f": (f, "e1", "e2"),
- })
- assert fuse(d, ave_width=4.7, rename_keys=False) == expected
- assert fuse(d, ave_width=7.4, rename_keys=False) == expected
- expected = with_deps({
- "a": 1,
- "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e1": (
- f,
- (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- ),
- "b10-b11-b12-b13-b14-b15-b16-b9-c5-c6-c7-c8-d3-d4-e2": (
- f,
- (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
- ),
- "f": (f, "e1", "e2"),
- "e1": "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e1",
- "e2": "b10-b11-b12-b13-b14-b15-b16-b9-c5-c6-c7-c8-d3-d4-e2",
- })
- assert fuse(d, ave_width=4.7, rename_keys=True) == expected
- assert fuse(d, ave_width=7.4, rename_keys=True) == expected
- assert fuse(d, ave_width=7.5, rename_keys=False) == with_deps({
- "a": 1,
- "f": (
- f,
- (
+ expected = with_deps(
+ {
+ "a": 1,
+ "e1": (
f,
(f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
(f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
),
- (
+ "e2": (
f,
(f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
(f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
),
- ),
- })
- assert fuse(d, ave_width=7.5, rename_keys=True) == with_deps({
- "a": 1,
- "b1-b10-b11-b12-b13-b14-b15-b16-b2-b3-b4-b5-b6-b7-b8-b9-c1-c2-c3-c4-c5-c6-c7-c8-d1-d2-d3-d4-e1-e2-f": (
- f,
- (
+ "f": (f, "e1", "e2"),
+ }
+ )
+ assert fuse(d, ave_width=4.7, rename_keys=False) == expected
+ assert fuse(d, ave_width=7.4, rename_keys=False) == expected
+ expected = with_deps(
+ {
+ "a": 1,
+ "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e1": (
f,
(f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
(f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
),
- (
+ "b10-b11-b12-b13-b14-b15-b16-b9-c5-c6-c7-c8-d3-d4-e2": (
f,
(f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
(f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
),
- ),
- "f": "b1-b10-b11-b12-b13-b14-b15-b16-b2-b3-b4-b5-b6-b7-b8-b9-c1-c2-c3-c4-c5-c6-c7-c8-d1-d2-d3-d4-e1-e2-f",
- })
+ "f": (f, "e1", "e2"),
+ "e1": "b1-b2-b3-b4-b5-b6-b7-b8-c1-c2-c3-c4-d1-d2-e1",
+ "e2": "b10-b11-b12-b13-b14-b15-b16-b9-c5-c6-c7-c8-d3-d4-e2",
+ }
+ )
+ assert fuse(d, ave_width=4.7, rename_keys=True) == expected
+ assert fuse(d, ave_width=7.4, rename_keys=True) == expected
+ assert fuse(d, ave_width=7.5, rename_keys=False) == with_deps(
+ {
+ "a": 1,
+ "f": (
+ f,
+ (
+ f,
+ (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ ),
+ (
+ f,
+ (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ ),
+ ),
+ }
+ )
+ assert fuse(d, ave_width=7.5, rename_keys=True) == with_deps(
+ {
+ "a": 1,
+ "b1-b10-b11-b12-b13-b14-b15-b16-b2-b3-b4-b5-b6-b7-b8-b9-c1-c2-c3-c4-c5-c6-c7-c8-d1-d2-d3-d4-e1-e2-f": (
+ f,
+ (
+ f,
+ (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ ),
+ (
+ f,
+ (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ (f, (f, (f, "a"), (f, "a")), (f, (f, "a"), (f, "a"))),
+ ),
+ ),
+ "f": "b1-b10-b11-b12-b13-b14-b15-b16-b2-b3-b4-b5-b6-b7-b8-b9-c1-c2-c3-c4-c5-c6-c7-c8-d1-d2-d3-d4-e1-e2-f",
+ }
+ )
d = {"a": 1, "b": (f, "a")}
assert fuse(d, ave_width=1, rename_keys=False) == with_deps({"b": (f, 1)})
@@ -710,11 +743,9 @@ def f(*args):
d = {"a": 1, "b": (f, "a"), "c": (f, "a", "b"), "d": (f, "a", "c")}
assert fuse(d, ave_width=1, rename_keys=False) == with_deps({"a": 1, "d": (f, "a", (f, "a", (f, "a")))})
- assert fuse(d, ave_width=1, rename_keys=True) == with_deps({
- "a": 1,
- "b-c-d": (f, "a", (f, "a", (f, "a"))),
- "d": "b-c-d"
- })
+ assert fuse(d, ave_width=1, rename_keys=True) == with_deps(
+ {"a": 1, "b-c-d": (f, "a", (f, "a", (f, "a"))), "d": "b-c-d"}
+ )
d = {
"a": 1,
@@ -728,21 +759,25 @@ def f(*args):
expected = with_deps({"a": 1, "b2": (f, "a"), "e1": (f, (f, (f, (f, "a")))), "f": (f, "e1", "b2")})
assert fuse(d, ave_width=1, rename_keys=False) == expected
assert fuse(d, ave_width=1.9, rename_keys=False) == expected
- expected = with_deps({
- "a": 1,
- "b2": (f, "a"),
- "b1-c1-d1-e1": (f, (f, (f, (f, "a")))),
- "f": (f, "e1", "b2"),
- "e1": "b1-c1-d1-e1",
- })
+ expected = with_deps(
+ {
+ "a": 1,
+ "b2": (f, "a"),
+ "b1-c1-d1-e1": (f, (f, (f, (f, "a")))),
+ "f": (f, "e1", "b2"),
+ "e1": "b1-c1-d1-e1",
+ }
+ )
assert fuse(d, ave_width=1, rename_keys=True) == expected
assert fuse(d, ave_width=1.9, rename_keys=True) == expected
assert fuse(d, ave_width=2, rename_keys=False) == with_deps({"a": 1, "f": (f, (f, (f, (f, (f, "a")))), (f, "a"))})
- assert fuse(d, ave_width=2, rename_keys=True) == with_deps({
- "a": 1,
- "b1-b2-c1-d1-e1-f": (f, (f, (f, (f, (f, "a")))), (f, "a")),
- "f": "b1-b2-c1-d1-e1-f",
- })
+ assert fuse(d, ave_width=2, rename_keys=True) == with_deps(
+ {
+ "a": 1,
+ "b1-b2-c1-d1-e1-f": (f, (f, (f, (f, (f, "a")))), (f, "a")),
+ "f": "b1-b2-c1-d1-e1-f",
+ }
+ )
d = {
"a": 1,
@@ -753,37 +788,42 @@ def f(*args):
"e1": (f, "a", "d1"),
"f": (f, "a", "e1", "b2"),
}
- expected = with_deps({
- "a": 1,
- "b2": (f, "a"),
- "e1": (f, "a", (f, "a", (f, "a", (f, "a")))),
- "f": (f, "a", "e1", "b2"),
- })
+ expected = with_deps(
+ {
+ "a": 1,
+ "b2": (f, "a"),
+ "e1": (f, "a", (f, "a", (f, "a", (f, "a")))),
+ "f": (f, "a", "e1", "b2"),
+ }
+ )
assert fuse(d, ave_width=1, rename_keys=False) == expected
assert fuse(d, ave_width=1.9, rename_keys=False) == expected
- expected = with_deps({
- "a": 1,
- "b2": (f, "a"),
- "b1-c1-d1-e1": (f, "a", (f, "a", (f, "a", (f, "a")))),
- "f": (f, "a", "e1", "b2"),
- "e1": "b1-c1-d1-e1",
- })
+ expected = with_deps(
+ {
+ "a": 1,
+ "b2": (f, "a"),
+ "b1-c1-d1-e1": (f, "a", (f, "a", (f, "a", (f, "a")))),
+ "f": (f, "a", "e1", "b2"),
+ "e1": "b1-c1-d1-e1",
+ }
+ )
assert fuse(d, ave_width=1, rename_keys=True) == expected
assert fuse(d, ave_width=1.9, rename_keys=True) == expected
- assert fuse(d, ave_width=2, rename_keys=False) == with_deps({
- "a": 1,
- "f": (f, "a", (f, "a", (f, "a", (f, "a", (f, "a")))), (f, "a"))
- })
- assert fuse(d, ave_width=2, rename_keys=True) == with_deps({
- "a": 1,
- "b1-b2-c1-d1-e1-f": (
- f,
- "a",
- (f, "a", (f, "a", (f, "a", (f, "a")))),
- (f, "a"),
- ),
- "f": "b1-b2-c1-d1-e1-f",
- })
+ assert fuse(d, ave_width=2, rename_keys=False) == with_deps(
+ {"a": 1, "f": (f, "a", (f, "a", (f, "a", (f, "a", (f, "a")))), (f, "a"))}
+ )
+ assert fuse(d, ave_width=2, rename_keys=True) == with_deps(
+ {
+ "a": 1,
+ "b1-b2-c1-d1-e1-f": (
+ f,
+ "a",
+ (f, "a", (f, "a", (f, "a", (f, "a")))),
+ (f, "a"),
+ ),
+ "f": "b1-b2-c1-d1-e1-f",
+ }
+ )
d = {
"a": 1,
@@ -800,24 +840,28 @@ def f(*args):
"f": (f, "e"),
"g": (f, "f"),
}
- assert fuse(d, ave_width=1, rename_keys=False) == with_deps({
- "a": 1,
- "d1": (f, (f, (f, "a"))),
- "d2": (f, (f, (f, "a"))),
- "d3": (f, (f, (f, "a"))),
- "g": (f, (f, (f, "d1", "d2", "d3"))),
- })
- assert fuse(d, ave_width=1, rename_keys=True) == with_deps({
- "a": 1,
- "b1-c1-d1": (f, (f, (f, "a"))),
- "b2-c2-d2": (f, (f, (f, "a"))),
- "b3-c3-d3": (f, (f, (f, "a"))),
- "e-f-g": (f, (f, (f, "d1", "d2", "d3"))),
- "d1": "b1-c1-d1",
- "d2": "b2-c2-d2",
- "d3": "b3-c3-d3",
- "g": "e-f-g",
- })
+ assert fuse(d, ave_width=1, rename_keys=False) == with_deps(
+ {
+ "a": 1,
+ "d1": (f, (f, (f, "a"))),
+ "d2": (f, (f, (f, "a"))),
+ "d3": (f, (f, (f, "a"))),
+ "g": (f, (f, (f, "d1", "d2", "d3"))),
+ }
+ )
+ assert fuse(d, ave_width=1, rename_keys=True) == with_deps(
+ {
+ "a": 1,
+ "b1-c1-d1": (f, (f, (f, "a"))),
+ "b2-c2-d2": (f, (f, (f, "a"))),
+ "b3-c3-d3": (f, (f, (f, "a"))),
+ "e-f-g": (f, (f, (f, "d1", "d2", "d3"))),
+ "d1": "b1-c1-d1",
+ "d2": "b2-c2-d2",
+ "d3": "b3-c3-d3",
+ "g": "e-f-g",
+ }
+ )
d = {
"a": 1,
@@ -828,23 +872,22 @@ def f(*args):
"f": (f, "e"),
"g": (f, "d", "f"),
}
- assert fuse(d, ave_width=1, rename_keys=False) == with_deps({
- "b": (f, 1),
- "d": (f, "b", (f, "b")),
- "g": (f, "d", (f, (f, "d")))
- })
- assert fuse(d, ave_width=1, rename_keys=True) == with_deps({
- "a-b": (f, 1),
- "c-d": (f, "b", (f, "b")),
- "e-f-g": (f, "d", (f, (f, "d"))),
- "b": "a-b",
- "d": "c-d",
- "g": "e-f-g",
- })
+ assert fuse(d, ave_width=1, rename_keys=False) == with_deps(
+ {"b": (f, 1), "d": (f, "b", (f, "b")), "g": (f, "d", (f, (f, "d")))}
+ )
+ assert fuse(d, ave_width=1, rename_keys=True) == with_deps(
+ {
+ "a-b": (f, 1),
+ "c-d": (f, "b", (f, "b")),
+ "e-f-g": (f, "d", (f, (f, "d"))),
+ "b": "a-b",
+ "d": "c-d",
+ "g": "e-f-g",
+ }
+ )
def test_fuse_stressed():
-
def f(*args):
return args
@@ -917,7 +960,6 @@ def f(*args):
def test_fuse_reductions_multiple_input():
-
def f(*args):
return args
@@ -925,12 +967,9 @@ def f(*args):
assert fuse(d, ave_width=2, rename_keys=False) == with_deps({"c": (f, (f, 1, 2))})
assert fuse(d, ave_width=2, rename_keys=True) == with_deps({"a1-a2-b-c": (f, (f, 1, 2)), "c": "a1-a2-b-c"})
assert fuse(d, ave_width=1, rename_keys=False) == with_deps({"a1": 1, "a2": 2, "c": (f, (f, "a1", "a2"))})
- assert fuse(d, ave_width=1, rename_keys=True) == with_deps({
- "a1": 1,
- "a2": 2,
- "b-c": (f, (f, "a1", "a2")),
- "c": "b-c"
- })
+ assert fuse(d, ave_width=1, rename_keys=True) == with_deps(
+ {"a1": 1, "a2": 2, "b-c": (f, (f, "a1", "a2")), "c": "b-c"}
+ )
d = {
"a1": 1,
@@ -945,17 +984,17 @@ def f(*args):
assert fuse(d, ave_width=2.9, rename_keys=False) == expected
assert fuse(d, ave_width=1, rename_keys=True) == expected
assert fuse(d, ave_width=2.9, rename_keys=True) == expected
- assert fuse(d, ave_width=3, rename_keys=False) == with_deps({
- "a1": 1,
- "a2": 2,
- "c": (f, (f, "a1"), (f, "a1", "a2"), (f, "a2"))
- })
- assert fuse(d, ave_width=3, rename_keys=True) == with_deps({
- "a1": 1,
- "a2": 2,
- "b1-b2-b3-c": (f, (f, "a1"), (f, "a1", "a2"), (f, "a2")),
- "c": "b1-b2-b3-c",
- })
+ assert fuse(d, ave_width=3, rename_keys=False) == with_deps(
+ {"a1": 1, "a2": 2, "c": (f, (f, "a1"), (f, "a1", "a2"), (f, "a2"))}
+ )
+ assert fuse(d, ave_width=3, rename_keys=True) == with_deps(
+ {
+ "a1": 1,
+ "a2": 2,
+ "b1-b2-b3-c": (f, (f, "a1"), (f, "a1", "a2"), (f, "a2")),
+ "c": "b1-b2-b3-c",
+ }
+ )
d = {
"a1": 1,
@@ -968,22 +1007,26 @@ def f(*args):
}
assert fuse(d, ave_width=1, rename_keys=False) == with_deps(d)
assert fuse(d, ave_width=1, rename_keys=True) == with_deps(d)
- assert fuse(d, ave_width=2, rename_keys=False) == with_deps({
- "a1": 1,
- "a2": 2,
- "b2": (f, "a1", "a2"),
- "c1": (f, (f, "a1"), "b2"),
- "c2": (f, "b2", (f, "a2")),
- })
- assert fuse(d, ave_width=2, rename_keys=True) == with_deps({
- "a1": 1,
- "a2": 2,
- "b2": (f, "a1", "a2"),
- "b1-c1": (f, (f, "a1"), "b2"),
- "b3-c2": (f, "b2", (f, "a2")),
- "c1": "b1-c1",
- "c2": "b3-c2",
- })
+ assert fuse(d, ave_width=2, rename_keys=False) == with_deps(
+ {
+ "a1": 1,
+ "a2": 2,
+ "b2": (f, "a1", "a2"),
+ "c1": (f, (f, "a1"), "b2"),
+ "c2": (f, "b2", (f, "a2")),
+ }
+ )
+ assert fuse(d, ave_width=2, rename_keys=True) == with_deps(
+ {
+ "a1": 1,
+ "a2": 2,
+ "b2": (f, "a1", "a2"),
+ "b1-c1": (f, (f, "a1"), "b2"),
+ "b3-c2": (f, "b2", (f, "a2")),
+ "c1": "b1-c1",
+ "c2": "b3-c2",
+ }
+ )
d = {
"a1": 1,
@@ -1000,19 +1043,23 @@ def f(*args):
# A more aggressive heuristic could do this at `ave_width=2`. Perhaps
# we can improve this. Nevertheless, this is behaving as intended.
- assert fuse(d, ave_width=3, rename_keys=False) == with_deps({
- "a1": 1,
- "a2": 2,
- "b2": (f, "a1", "a2"),
- "d": (f, (f, (f, "a1"), "b2"), (f, "b2", (f, "a2"))),
- })
- assert fuse(d, ave_width=3, rename_keys=True) == with_deps({
- "a1": 1,
- "a2": 2,
- "b2": (f, "a1", "a2"),
- "b1-b3-c1-c2-d": (f, (f, (f, "a1"), "b2"), (f, "b2", (f, "a2"))),
- "d": "b1-b3-c1-c2-d",
- })
+ assert fuse(d, ave_width=3, rename_keys=False) == with_deps(
+ {
+ "a1": 1,
+ "a2": 2,
+ "b2": (f, "a1", "a2"),
+ "d": (f, (f, (f, "a1"), "b2"), (f, "b2", (f, "a2"))),
+ }
+ )
+ assert fuse(d, ave_width=3, rename_keys=True) == with_deps(
+ {
+ "a1": 1,
+ "a2": 2,
+ "b2": (f, "a1", "a2"),
+ "b1-b3-c1-c2-d": (f, (f, (f, "a1"), "b2"), (f, "b2", (f, "a2"))),
+ "d": "b1-b3-c1-c2-d",
+ }
+ )
def func_with_kwargs(a, b, c=2):
@@ -1028,20 +1075,13 @@ def test_SubgraphCallable():
apply,
partial_by_order,
["in2"],
- {
- "function": func_with_kwargs,
- "other": [(1, 20)],
- "c": 4
- },
+ {"function": func_with_kwargs, "other": [(1, 20)], "c": 4},
),
"c": (
apply,
partial_by_order,
["in2", "in1"],
- {
- "function": func_with_kwargs,
- "other": [(1, 20)]
- },
+ {"function": func_with_kwargs, "other": [(1, 20)]},
),
"d": (inc, "a"),
"e": (add, "c", "d"),
@@ -1105,54 +1145,60 @@ def test_fuse_subgraphs():
}
res = fuse(dsk, "inc-6", fuse_subgraphs=True)
- sol = with_deps({
- "inc-6": "add-inc-x-1",
- "add-inc-x-1": (
- SubgraphCallable(
- {
- "x-1": 1,
- "add-1": (add, "x-1", (inc, (inc, "x-1"))),
- "inc-6": (inc, (inc, (add, "add-1", (inc, (inc, "add-1"))))),
- },
- "inc-6",
- (),
+ sol = with_deps(
+ {
+ "inc-6": "add-inc-x-1",
+ "add-inc-x-1": (
+ SubgraphCallable(
+ {
+ "x-1": 1,
+ "add-1": (add, "x-1", (inc, (inc, "x-1"))),
+ "inc-6": (inc, (inc, (add, "add-1", (inc, (inc, "add-1"))))),
+ },
+ "inc-6",
+ (),
+ ),
),
- ),
- })
+ }
+ )
assert res == sol
res = fuse(dsk, "inc-6", fuse_subgraphs=True, rename_keys=False)
- sol = with_deps({
- "inc-6": (
- SubgraphCallable(
- {
- "x-1": 1,
- "add-1": (add, "x-1", (inc, (inc, "x-1"))),
- "inc-6": (inc, (inc, (add, "add-1", (inc, (inc, "add-1"))))),
- },
- "inc-6",
- (),
- ),
- )
- })
+ sol = with_deps(
+ {
+ "inc-6": (
+ SubgraphCallable(
+ {
+ "x-1": 1,
+ "add-1": (add, "x-1", (inc, (inc, "x-1"))),
+ "inc-6": (inc, (inc, (add, "add-1", (inc, (inc, "add-1"))))),
+ },
+ "inc-6",
+ (),
+ ),
+ )
+ }
+ )
assert res == sol
res = fuse(dsk, "add-2", fuse_subgraphs=True)
- sol = with_deps({
- "add-inc-x-1": (
- SubgraphCallable(
- {
- "x-1": 1,
- "add-1": (add, "x-1", (inc, (inc, "x-1"))),
- "add-2": (add, "add-1", (inc, (inc, "add-1"))),
- },
- "add-2",
- (),
+ sol = with_deps(
+ {
+ "add-inc-x-1": (
+ SubgraphCallable(
+ {
+ "x-1": 1,
+ "add-1": (add, "x-1", (inc, (inc, "x-1"))),
+ "add-2": (add, "add-1", (inc, (inc, "add-1"))),
+ },
+ "add-2",
+ (),
+ ),
),
- ),
- "add-2": "add-inc-x-1",
- "inc-6": (inc, (inc, "add-2")),
- })
+ "add-2": "add-inc-x-1",
+ "inc-6": (inc, (inc, "add-2")),
+ }
+ )
assert res == sol
res = fuse(dsk, "inc-2", fuse_subgraphs=True)
@@ -1160,24 +1206,27 @@ def test_fuse_subgraphs():
sols = []
for inkeys in itertools.permutations(("x-1", "inc-2")):
sols.append(
- with_deps({
- "x-1": 1,
- "inc-2": (inc, (inc, "x-1")),
- "inc-6": "inc-add-1",
- "inc-add-1": (
- SubgraphCallable(
- {
- "add-1": (add, "x-1", "inc-2"),
- "inc-6": (
- inc,
- (inc, (add, "add-1", (inc, (inc, "add-1")))),
- ),
- },
- "inc-6",
- inkeys,
- ),
- ) + inkeys,
- })
+ with_deps(
+ {
+ "x-1": 1,
+ "inc-2": (inc, (inc, "x-1")),
+ "inc-6": "inc-add-1",
+ "inc-add-1": (
+ SubgraphCallable(
+ {
+ "add-1": (add, "x-1", "inc-2"),
+ "inc-6": (
+ inc,
+ (inc, (add, "add-1", (inc, (inc, "add-1")))),
+ ),
+ },
+ "inc-6",
+ inkeys,
+ ),
+ )
+ + inkeys,
+ }
+ )
)
assert res in sols
@@ -1186,22 +1235,25 @@ def test_fuse_subgraphs():
sols = []
for inkeys in itertools.permutations(("x-1", "inc-2")):
sols.append(
- with_deps({
- "x-1": 1,
- "inc-2": (inc, (inc, "x-1")),
- "inc-add-1": (
- SubgraphCallable(
- {
- "add-1": (add, "x-1", "inc-2"),
- "add-2": (add, "add-1", (inc, (inc, "add-1"))),
- },
- "add-2",
- inkeys,
- ),
- ) + inkeys,
- "add-2": "inc-add-1",
- "inc-6": (inc, (inc, "add-2")),
- })
+ with_deps(
+ {
+ "x-1": 1,
+ "inc-2": (inc, (inc, "x-1")),
+ "inc-add-1": (
+ SubgraphCallable(
+ {
+ "add-1": (add, "x-1", "inc-2"),
+ "add-2": (add, "add-1", (inc, (inc, "add-1"))),
+ },
+ "add-2",
+ inkeys,
+ ),
+ )
+ + inkeys,
+ "add-2": "inc-add-1",
+ "inc-6": (inc, (inc, "add-2")),
+ }
+ )
)
assert res in sols
@@ -1217,31 +1269,30 @@ def test_fuse_subgraphs_linear_chains_of_duplicate_deps():
}
res = fuse(dsk, "add-5", fuse_subgraphs=True)
- sol = with_deps({
- "add-x-1": (
- SubgraphCallable(
- {
- "x-1": 1,
- "add-1": (add, "x-1", "x-1"),
- "add-2": (add, "add-1", "add-1"),
- "add-3": (add, "add-2", "add-2"),
- "add-4": (add, "add-3", "add-3"),
- "add-5": (add, "add-4", "add-4"),
- },
- "add-5",
- (),
+ sol = with_deps(
+ {
+ "add-x-1": (
+ SubgraphCallable(
+ {
+ "x-1": 1,
+ "add-1": (add, "x-1", "x-1"),
+ "add-2": (add, "add-1", "add-1"),
+ "add-3": (add, "add-2", "add-2"),
+ "add-4": (add, "add-3", "add-3"),
+ "add-5": (add, "add-4", "add-4"),
+ },
+ "add-5",
+ (),
+ ),
),
- ),
- "add-5": "add-x-1",
- })
+ "add-5": "add-x-1",
+ }
+ )
assert res == sol
def test_dont_fuse_numpy_arrays():
- """
- Some types should stay in the graph bare
- This helps with things like serialization
- """
+ """Some types should stay in the graph bare This helps with things like serialization."""
np = pytest.importorskip("numpy")
dsk = {"x": np.arange(5), "y": (inc, "x")}
diff --git a/tests/core/serve/test_dag/test_order.py b/tests/core/serve/test_dag/test_order.py
index c332eb4860..50cfebdb67 100644
--- a/tests/core/serve/test_dag/test_order.py
+++ b/tests/core/serve/test_dag/test_order.py
@@ -20,14 +20,14 @@ def f(*args):
def test_ordering_keeps_groups_together(abcde):
a, b, c, d, e = abcde
- d = dict(((a, i), (f, )) for i in range(4))
+ d = {(a, i): (f,) for i in range(4)}
d.update({(b, 0): (f, (a, 0), (a, 1)), (b, 1): (f, (a, 2), (a, 3))})
o = order(d)
assert abs(o[(a, 0)] - o[(a, 1)]) == 1
assert abs(o[(a, 2)] - o[(a, 3)]) == 1
- d = dict(((a, i), (f, )) for i in range(4))
+ d = {(a, i): (f,) for i in range(4)}
d.update({(b, 0): (f, (a, 0), (a, 2)), (b, 1): (f, (a, 1), (a, 3))})
o = order(d)
@@ -46,8 +46,8 @@ def test_avoid_broker_nodes(abcde):
"""
a, b, c, d, e = abcde
dsk = {
- (a, 0): (f, ),
- (a, 1): (f, ),
+ (a, 0): (f,),
+ (a, 1): (f,),
(b, 0): (f, (a, 0)),
(b, 1): (f, (a, 1)),
(b, 2): (f, (a, 1)),
@@ -57,8 +57,8 @@ def test_avoid_broker_nodes(abcde):
# Switch name of 0, 1 to ensure that this isn't due to string comparison
dsk = {
- (a, 1): (f, ),
- (a, 0): (f, ),
+ (a, 1): (f,),
+ (a, 0): (f,),
(b, 0): (f, (a, 1)),
(b, 1): (f, (a, 0)),
(b, 2): (f, (a, 0)),
@@ -68,8 +68,8 @@ def test_avoid_broker_nodes(abcde):
# Switch name of 0, 1 for "b"s too
dsk = {
- (a, 0): (f, ),
- (a, 1): (f, ),
+ (a, 0): (f,),
+ (a, 1): (f,),
(b, 1): (f, (a, 0)),
(b, 0): (f, (a, 1)),
(b, 2): (f, (a, 1)),
@@ -161,10 +161,10 @@ def test_avoid_upwards_branching_complex(abcde):
(a, 2): (f, (a, 3)),
(a, 3): (f, (b, 1), (c, 1)),
(b, 1): (f, (b, 2)),
- (b, 2): (f, ),
+ (b, 2): (f,),
(c, 1): (f, (c, 2)),
(c, 2): (f, (c, 3)),
- (c, 3): (f, ),
+ (c, 3): (f,),
(d, 1): (f, (c, 1)),
(d, 2): (f, (d, 1)),
(d, 3): (f, (d, 1)),
@@ -220,7 +220,7 @@ def test_prefer_deep(abcde):
def test_stacklimit(abcde):
- dsk = dict(("x%s" % (i + 1), (inc, "x%s" % i)) for i in range(10000))
+ dsk = {"x%s" % (i + 1): (inc, "x%s" % i) for i in range(10000)}
dependencies, dependents = get_deps(dsk)
ndependencies(dependencies, dependents)
@@ -261,7 +261,7 @@ def test_prefer_short_dependents(abcde):
during the long computations.
"""
a, b, c, d, e = abcde
- dsk = {c: (f, ), d: (f, c), e: (f, c), b: (f, c), a: (f, b)}
+ dsk = {c: (f,), d: (f, c), e: (f, c), b: (f, c), a: (f, b)}
o = order(dsk)
assert o[d] < o[b]
@@ -280,24 +280,23 @@ def test_run_smaller_sections(abcde):
Prefer to run acb first because then we can get that out of the way
"""
a, b, c, d, e = abcde
- aa, bb, cc, dd = [x * 2 for x in [a, b, c, d]]
+ aa, bb, cc, dd = (x * 2 for x in [a, b, c, d])
expected = [a, c, b, e, d, cc, bb, aa, dd]
log = []
def f(x):
-
def _(*args):
log.append(x)
return _
dsk = {
- a: (f(a), ),
- c: (f(c), ),
- e: (f(e), ),
- cc: (f(cc), ),
+ a: (f(a),),
+ c: (f(c),),
+ e: (f(e),),
+ cc: (f(cc),),
b: (f(b), a, c),
d: (f(d), c, e),
bb: (f(bb), cc),
@@ -326,29 +325,28 @@ def test_local_parents_of_reduction(abcde):
Prefer to finish a1 stack before proceeding to b2
"""
a, b, c, d, e = abcde
- a1, a2, a3 = [a + i for i in "123"]
- b1, b2, b3 = [b + i for i in "123"]
- c1, c2, c3 = [c + i for i in "123"]
+ a1, a2, a3 = (a + i for i in "123")
+ b1, b2, b3 = (b + i for i in "123")
+ c1, c2, c3 = (c + i for i in "123")
expected = [a3, a2, a1, b3, b2, b1, c3, c2, c1]
log = []
def f(x):
-
def _(*args):
log.append(x)
return _
dsk = {
- a3: (f(a3), ),
+ a3: (f(a3),),
a2: (f(a2), a3),
a1: (f(a1), a2),
- b3: (f(b3), ),
+ b3: (f(b3),),
b2: (f(b2), b3, a2),
b1: (f(b1), b2),
- c3: (f(c3), ),
+ c3: (f(c3),),
c2: (f(c2), c3, b2),
c1: (f(c1), c2),
}
@@ -370,14 +368,14 @@ def test_nearest_neighbor(abcde):
This is difficult because all groups are connected.
"""
a, b, c, _, _ = abcde
- a1, a2, a3, a4, a5, a6, a7, a8, a9 = [a + i for i in "123456789"]
- b1, b2, b3, b4 = [b + i for i in "1234"]
+ a1, a2, a3, a4, a5, a6, a7, a8, a9 = (a + i for i in "123456789")
+ b1, b2, b3, b4 = (b + i for i in "1234")
dsk = {
- b1: (f, ),
- b2: (f, ),
- b3: (f, ),
- b4: (f, ),
+ b1: (f,),
+ b2: (f,),
+ b3: (f,),
+ b4: (f,),
a1: (f, b1),
a2: (f, b1),
a3: (f, b1, b2),
@@ -397,15 +395,15 @@ def test_nearest_neighbor(abcde):
def test_string_ordering():
- """ Prefer ordering tasks by name first """
- dsk = {("a", 1): (f, ), ("a", 2): (f, ), ("a", 3): (f, )}
+ """Prefer ordering tasks by name first."""
+ dsk = {("a", 1): (f,), ("a", 2): (f,), ("a", 3): (f,)}
o = order(dsk)
assert o == {("a", 1): 0, ("a", 2): 1, ("a", 3): 2}
def test_string_ordering_dependents():
- """ Prefer ordering tasks by name first even when in dependencies """
- dsk = {("a", 1): (f, "b"), ("a", 2): (f, "b"), ("a", 3): (f, "b"), "b": (f, )}
+ """Prefer ordering tasks by name first even when in dependencies."""
+ dsk = {("a", 1): (f, "b"), ("a", 2): (f, "b"), ("a", 3): (f, "b"), "b": (f,)}
o = order(dsk)
assert o == {"b": 0, ("a", 1): 1, ("a", 2): 2, ("a", 3): 3}
@@ -502,19 +500,19 @@ def test_map_overlap(abcde):
"""
a, b, c, d, e = abcde
dsk = {
- (e, 1): (f, ),
+ (e, 1): (f,),
(d, 1): (f, (e, 1)),
(c, 1): (f, (d, 1)),
(b, 1): (f, (c, 1), (c, 2)),
- (d, 2): (f, ),
+ (d, 2): (f,),
(c, 2): (f, (d, 1), (d, 2), (d, 3)),
- (e, 3): (f, ),
+ (e, 3): (f,),
(d, 3): (f, (e, 3)),
(c, 3): (f, (d, 3)),
(b, 3): (f, (c, 2), (c, 3), (c, 4)),
- (d, 4): (f, ),
+ (d, 4): (f,),
(c, 4): (f, (d, 3), (d, 4), (d, 5)),
- (e, 5): (f, ),
+ (e, 5): (f,),
(d, 5): (f, (e, 5)),
(c, 5): (f, (d, 5)),
(b, 5): (f, (c, 4), (c, 5)),
@@ -526,22 +524,22 @@ def test_map_overlap(abcde):
def test_use_structure_not_keys(abcde):
- """See https://github.com/dask/dask/issues/5584#issuecomment-554963958
+ """See https://github.com/dask/dask/issues/5584#issuecomment-554963958.
We were using key names to infer structure, which could result in funny behavior.
"""
a, b, _, _, _ = abcde
dsk = {
- (a, 0): (f, ),
- (a, 1): (f, ),
- (a, 2): (f, ),
- (a, 3): (f, ),
- (a, 4): (f, ),
- (a, 5): (f, ),
- (a, 6): (f, ),
- (a, 7): (f, ),
- (a, 8): (f, ),
- (a, 9): (f, ),
+ (a, 0): (f,),
+ (a, 1): (f,),
+ (a, 2): (f,),
+ (a, 3): (f,),
+ (a, 4): (f,),
+ (a, 5): (f,),
+ (a, 6): (f,),
+ (a, 7): (f,),
+ (a, 8): (f,),
+ (a, 9): (f,),
(b, 5): (f, (a, 2)),
(b, 7): (f, (a, 0), (a, 2)),
(b, 9): (f, (a, 7), (a, 0), (a, 2)),
@@ -566,7 +564,7 @@ def test_use_structure_not_keys(abcde):
def test_dont_run_all_dependents_too_early(abcde):
- """ From https://github.com/dask/dask-ml/issues/206#issuecomment-395873372 """
+ """From https://github.com/dask/dask-ml/issues/206#issuecomment-395873372."""
a, b, c, d, e = abcde
depth = 10
dsk = {(a, 0): 0, (b, 0): 1, (c, 0): 2, (d, 0): (f, (a, 0), (b, 0), (c, 0))}
@@ -581,13 +579,10 @@ def test_dont_run_all_dependents_too_early(abcde):
def test_many_branches_use_ndependencies(abcde):
- """From https://github.com/dask/dask/pull/5646#issuecomment-562700533
-
- Sometimes we need larger or wider DAGs to test behavior. This test
- ensures we choose the branch with more work twice in successtion.
- This is important, because ``order`` may search along dependencies
- and then along dependents.
+ """From https://github.com/dask/dask/pull/5646#issuecomment-562700533.
+ Sometimes we need larger or wider DAGs to test behavior. This test ensures we choose the branch with more work
+ twice in successtion. This is important, because ``order`` may search along dependencies and then along dependents.
"""
a, b, c, d, e = abcde
dd = d + d
@@ -694,32 +689,35 @@ def test_switching_dependents(abcde):
def test_order_with_equal_dependents(abcde):
- """From https://github.com/dask/dask/issues/5859#issuecomment-608422198
+ """From https://github.com/dask/dask/issues/5859#issuecomment-608422198.
See the visualization of `(maxima, argmax)` example from the above comment.
This DAG has enough structure to exercise more parts of `order`
-
"""
a, b, c, d, e = abcde
dsk = {}
abc = [a, b, c, d]
for x in abc:
- dsk.update({
- (x, 0): 0,
- (x, 1): (f, (x, 0)),
- (x, 2, 0): (f, (x, 0)),
- (x, 2, 1): (f, (x, 1)),
- })
+ dsk.update(
+ {
+ (x, 0): 0,
+ (x, 1): (f, (x, 0)),
+ (x, 2, 0): (f, (x, 0)),
+ (x, 2, 1): (f, (x, 1)),
+ }
+ )
for i, y in enumerate(abc):
- dsk.update({
- (x, 3, i): (f, (x, 2, 0), (y, 2, 1)), # cross x and y
- (x, 4, i): (f, (x, 3, i)),
- (x, 5, i, 0): (f, (x, 4, i)),
- (x, 5, i, 1): (f, (x, 4, i)),
- (x, 6, i, 0): (f, (x, 5, i, 0)),
- (x, 6, i, 1): (f, (x, 5, i, 1)),
- })
+ dsk.update(
+ {
+ (x, 3, i): (f, (x, 2, 0), (y, 2, 1)), # cross x and y
+ (x, 4, i): (f, (x, 3, i)),
+ (x, 5, i, 0): (f, (x, 4, i)),
+ (x, 5, i, 1): (f, (x, 4, i)),
+ (x, 6, i, 0): (f, (x, 5, i, 0)),
+ (x, 6, i, 1): (f, (x, 5, i, 1)),
+ }
+ )
o = order(dsk)
total = 0
for x in abc:
diff --git a/tests/core/serve/test_dag/test_rewrite.py b/tests/core/serve/test_dag/test_rewrite.py
index 64055f7211..97fbaf25f3 100644
--- a/tests/core/serve/test_dag/test_rewrite.py
+++ b/tests/core/serve/test_dag/test_rewrite.py
@@ -21,7 +21,7 @@ def test_head():
def test_args():
- assert args((inc, 1)) == (1, )
+ assert args((inc, 1)) == (1,)
assert args((add, 1, 2)) == (1, 2)
assert args(1) == ()
assert args([1, 2, 3]) == [1, 2, 3]
@@ -65,16 +65,16 @@ def repl_list(sd):
return (list, x)
-rule6 = RewriteRule((list, "x"), repl_list, ("x", ))
+rule6 = RewriteRule((list, "x"), repl_list, ("x",))
def test_RewriteRule():
# Test extraneous vars are removed, varlist is correct
- assert rule1.vars == ("a", )
+ assert rule1.vars == ("a",)
assert rule1._varlist == ["a"]
- assert rule2.vars == ("a", )
+ assert rule2.vars == ("a",)
assert rule2._varlist == ["a", "a"]
- assert rule3.vars == ("a", )
+ assert rule3.vars == ("a",)
assert rule3._varlist == ["a", "a"]
assert rule4.vars == ("a", "b")
assert rule4._varlist == ["b", "a"]
@@ -97,32 +97,13 @@ def test_RuleSet():
{
add: (
{
- VAR: ({
- VAR: ({}, [1]),
- 1: ({}, [0])
- }, []),
- inc: ({
- VAR: ({
- inc: ({
- VAR: ({}, [2, 3])
- }, [])
- }, [])
- }, []),
+ VAR: ({VAR: ({}, [1]), 1: ({}, [0])}, []),
+ inc: ({VAR: ({inc: ({VAR: ({}, [2, 3])}, [])}, [])}, []),
},
[],
),
- list: ({
- VAR: ({}, [5])
- }, []),
- sum: ({
- list: ({
- VAR: ({
- VAR: ({
- VAR: ({}, [4])
- }, [])
- }, [])
- }, [])
- }, []),
+ list: ({VAR: ({}, [5])}, []),
+ sum: ({list: ({VAR: ({VAR: ({VAR: ({}, [4])}, [])}, [])}, [])}, []),
},
[],
)
diff --git a/tests/core/serve/test_dag/test_task.py b/tests/core/serve/test_dag/test_task.py
index cd7479f5d5..260bc72d0b 100644
--- a/tests/core/serve/test_dag/test_task.py
+++ b/tests/core/serve/test_dag/test_task.py
@@ -52,7 +52,7 @@ def test_get_dependencies_nested():
def test_get_dependencies_empty():
- dsk = {"x": (inc, )}
+ dsk = {"x": (inc,)}
assert get_dependencies(dsk, "x") == set()
assert get_dependencies(dsk, "x", as_list=True) == []
@@ -181,7 +181,6 @@ class MyException(Exception):
pass
class F:
-
def __eq__(self, other):
raise MyException()
@@ -200,9 +199,7 @@ def test_subs_with_surprisingly_friendly_eq():
def test_subs_unexpected_hashable_key():
-
class UnexpectedButHashable:
-
def __init__(self):
self.name = "a"
diff --git a/tests/core/serve/test_dag/test_utils.py b/tests/core/serve/test_dag/test_utils.py
index 17315b5f29..7ce379d006 100644
--- a/tests/core/serve/test_dag/test_utils.py
+++ b/tests/core/serve/test_dag/test_utils.py
@@ -12,7 +12,6 @@
def test_funcname_long():
-
def a_long_function_name_11111111111111111111111111111111111111111111111():
pass
@@ -23,7 +22,6 @@ def a_long_function_name_11111111111111111111111111111111111111111111111():
@pytest.mark.skipif(not _CYTOOLZ_AVAILABLE, reason="the library `cytoolz` is not installed.")
def test_funcname_cytoolz():
-
@curry
def foo(a, b, c):
pass
@@ -45,14 +43,13 @@ def test_partial_by_order():
def test_funcname():
assert funcname(np.floor_divide) == "floor_divide"
assert funcname(partial(bool)) == "bool"
- assert (funcname(operator.methodcaller("__getitem__")) == "operator.methodcaller('__getitem__')")
+ assert funcname(operator.methodcaller("__getitem__")) == "operator.methodcaller('__getitem__')"
assert funcname(lambda x: x) == "lambda"
def test_numpy_vectorize_funcname():
-
def myfunc(a, b):
- "Return a-b if a>b, otherwise return a+b"
+ """Return a-b if a>b, otherwise return a+b."""
if a > b:
return a - b
return a + b
diff --git a/tests/core/serve/test_gridbase_validations.py b/tests/core/serve/test_gridbase_validations.py
index 29c61aa688..17e094dd83 100644
--- a/tests/core/serve/test_gridbase_validations.py
+++ b/tests/core/serve/test_gridbase_validations.py
@@ -12,7 +12,6 @@ def test_metaclass_raises_if_expose_decorator_not_applied_to_method():
with pytest.raises(SyntaxError, match=r"expose.* decorator"):
class FailedNoExposed(ModelComponent):
-
def __init__(self, model):
pass
@@ -23,7 +22,6 @@ def test_metaclass_raises_if_more_than_one_expose_decorator_applied():
with pytest.raises(SyntaxError, match=r"decorator must be applied to one"):
class FailedTwoExposed(ModelComponent):
-
def __init__(self, model):
pass
@@ -44,7 +42,6 @@ def test_metaclass_raises_if_first_arg_in_init_is_not_model():
with pytest.raises(SyntaxError, match="__init__ must set 'model' as first"):
class FailedModelArg(ModelComponent):
-
def __init__(self, foo):
pass
@@ -60,7 +57,6 @@ def test_metaclass_raises_if_second_arg_is_not_config():
with pytest.raises(SyntaxError, match="__init__ can only set 'config'"):
class FailedConfig(ModelComponent):
-
def __init__(self, model, OTHER):
pass
@@ -76,7 +72,6 @@ def test_metaclass_raises_if_random_parameters_in_init():
with pytest.raises(SyntaxError, match="__init__ can only have 1 or 2 parameters"):
class FailedInit(ModelComponent):
-
def __init__(self, model, config, FOO):
pass
@@ -93,7 +88,6 @@ def test_metaclass_raises_uses_restricted_method_name():
with pytest.raises(TypeError, match="bound methods/attrs named"):
class FailedMethod_Inputs(ModelComponent):
-
def __init__(self, model):
pass
@@ -109,7 +103,6 @@ def inputs(self):
with pytest.raises(TypeError, match="bound methods/attrs named"):
class FailedMethod_Outputs(ModelComponent):
-
def __init__(self, model):
pass
@@ -125,7 +118,6 @@ def outputs(self):
with pytest.raises(TypeError, match="bound methods/attrs named"):
class FailedMethod_Name(ModelComponent):
-
def __init__(self, model):
pass
@@ -136,11 +128,12 @@ def predict(param):
@property
def uid(self):
- return f'{self.uid}_SHOULD_NOT_RETURN'
+ return f"{self.uid}_SHOULD_NOT_RETURN"
# Ensure that if we add more restricted names in the future,
# there is a test for them as well.
from flash.core.serve.component import _FLASH_SERVE_RESERVED_NAMES
+
assert set(_FLASH_SERVE_RESERVED_NAMES).difference({"inputs", "outputs", "uid"}) == set()
@@ -149,7 +142,6 @@ def test_metaclass_raises_if_argument_values_of_expose_arent_subclasses_of_baset
with pytest.raises(TypeError, match="must be subclass of"):
class FailedExposedDecoratorInputs(ModelComponent):
-
def __init__(self, model):
self.model = model
@@ -162,7 +154,6 @@ def predict(param):
with pytest.raises(TypeError, match="must be subclass of"):
class FailedExposedDecoratorOutputs(ModelComponent):
-
def __init__(self, model):
self.model = model
@@ -175,7 +166,6 @@ def predict(param):
with pytest.raises(TypeError, match="must be subclass of"):
class FailedExposedDecoratorClass(ModelComponent):
-
def __init__(self, model):
self.model = model
@@ -191,13 +181,12 @@ def test_ModelComponent_raises_if_exposed_input_keys_differ_from_decorated_metho
):
"""This occurs when the instance is being initialized.
- This is noted because it differes from some of the other metaclass validations
- which will raise an exception at class defiition time.
+ This is noted because it differes from some of the other metaclass validations which will raise an exception at
+ class defiition time.
"""
from tests.core.serve.models import ClassificationInference
class FailedExposedDecorator(ModelComponent):
-
def __init__(self, model):
self.model = model
@@ -215,12 +204,11 @@ def predict(self, param):
def test_ModelComponent_raises_if_config_is_empty_dict(lightning_squeezenet1_1_obj):
"""This occurs when the instance is being initialized.
- This is noted because it differes from some of the other metaclass validations
- which will raise an exception at class defiition time.
+ This is noted because it differes from some of the other metaclass validations which will raise an exception at
+ class defiition time.
"""
class ConfigComponent(ModelComponent):
-
def __init__(self, model, config):
pass
@@ -236,12 +224,11 @@ def predict(self, param):
def test_ModelComponent_raises_if_model_is_empty_iterable():
"""This occurs when the instance is being initialized.
- This is noted because it differes from some of the other metaclass validations
- which will raise an exception at class defiition time.
+ This is noted because it differes from some of the other metaclass validations which will raise an exception at
+ class defiition time.
"""
class ConfigComponent(ModelComponent):
-
def __init__(self, model):
pass
diff --git a/tests/core/serve/test_integration.py b/tests/core/serve/test_integration.py
index 2d3cebef27..4efafb548c 100644
--- a/tests/core/serve/test_integration.py
+++ b/tests/core/serve/test_integration.py
@@ -89,35 +89,21 @@ def test_serving_single_component_and_endpoint_no_composition(session_global_dat
assert meta.json() == {
"definitions": {
"Ep_Ep_In_Image": {
- "properties": {
- "data": {
- "title": "Data",
- "type": "string"
- }
- },
+ "properties": {"data": {"title": "Data", "type": "string"}},
"required": ["data"],
"title": "Ep_Ep_In_Image",
"type": "object",
},
"Ep_Payload": {
- "properties": {
- "ep_in_image": {
- "$ref": "#/definitions/Ep_Ep_In_Image"
- }
- },
+ "properties": {"ep_in_image": {"$ref": "#/definitions/Ep_Ep_In_Image"}},
"required": ["ep_in_image"],
"title": "Ep_Payload",
"type": "object",
},
},
"properties": {
- "payload": {
- "$ref": "#/definitions/Ep_Payload"
- },
- "session": {
- "title": "Session",
- "type": "string"
- },
+ "payload": {"$ref": "#/definitions/Ep_Payload"},
+ "session": {"title": "Session", "type": "string"},
},
"required": ["payload"],
"title": "Ep_RequestModel",
@@ -134,9 +120,7 @@ def test_serving_single_component_and_endpoint_no_composition(session_global_dat
assert "result" in success.json()
expected = {
"session": "UUID",
- "result": {
- "ep_out_prediction": "goldfish, Carassius auratus"
- },
+ "result": {"ep_out_prediction": "goldfish, Carassius auratus"},
}
assert expected == success.json()
@@ -209,26 +193,15 @@ def test_serving_composed(session_global_datadir, lightning_squeezenet1_1_obj):
body = {
"session": "UUID",
"payload": {
- "image": {
- "data": imgstr
- },
- "section": {
- "num": 10
- },
- "isle": {
- "num": 4
- },
- "row": {
- "num": 53
- },
+ "image": {"data": imgstr},
+ "section": {"num": 10},
+ "isle": {"num": 4},
+ "row": {"num": 53},
},
}
success = tc.post("http://127.0.0.1:8000/predict_seat", json=body)
assert success.json() == {
- "result": {
- "seat_number": 4799680,
- "team": "buffalo bills, the ralph"
- },
+ "result": {"seat_number": 4799680, "team": "buffalo bills, the ralph"},
"session": "UUID",
}
resp = tc.get("http://127.0.0.1:8000/predict_seat/dag")
@@ -295,26 +268,15 @@ def test_composed_does_not_eliminate_endpoint_serialization(session_global_datad
body = {
"session": "UUID",
"payload": {
- "image": {
- "data": imgstr
- },
- "section": {
- "num": 10
- },
- "isle": {
- "num": 4
- },
- "row": {
- "num": 53
- },
+ "image": {"data": imgstr},
+ "section": {"num": 10},
+ "isle": {"num": 4},
+ "row": {"num": 53},
},
}
success = tc.post("http://127.0.0.1:8000/predict_seat", json=body)
assert success.json() == {
- "result": {
- "seat_number_out": 4799680,
- "team_out": "buffalo bills, the ralph"
- },
+ "result": {"seat_number_out": 4799680, "team_out": "buffalo bills, the ralph"},
"session": "UUID",
}
resp = tc.get("http://127.0.0.1:8000/predict_seat/dag")
@@ -339,10 +301,7 @@ def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squ
"section": seat_comp.inputs.section,
"row": seat_comp.inputs.row,
},
- outputs={
- "seat_number": seat_comp.outputs.seat_number,
- "team": seat_comp.outputs.team
- },
+ outputs={"seat_number": seat_comp.outputs.seat_number, "team": seat_comp.outputs.team},
)
ep2 = Endpoint(
route="/predict_seat_img",
@@ -366,10 +325,7 @@ def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squ
"section": seat_comp.inputs.section,
"row": seat_comp.inputs.row,
},
- outputs={
- "seat_number": seat_comp.outputs.seat_number,
- "team": seat_comp.outputs.team
- },
+ outputs={"seat_number": seat_comp.outputs.seat_number, "team": seat_comp.outputs.team},
)
composit = Composition(
@@ -402,26 +358,15 @@ def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squ
body = {
"session": "UUID",
"payload": {
- "image": {
- "data": imgstr
- },
- "section": {
- "num": 10
- },
- "isle": {
- "num": 4
- },
- "row": {
- "num": 53
- },
+ "image": {"data": imgstr},
+ "section": {"num": 10},
+ "isle": {"num": 4},
+ "row": {"num": 53},
},
}
success = tc.post("http://127.0.0.1:8000/predict_seat", json=body)
assert success.json() == {
- "result": {
- "seat_number": 4799680,
- "team": "buffalo bills, the ralph"
- },
+ "result": {"seat_number": 4799680, "team": "buffalo bills, the ralph"},
"session": "UUID",
}
@@ -438,26 +383,15 @@ def test_endpoint_overwrite_connection_dag(session_global_datadir, lightning_squ
body = {
"session": "UUID",
"payload": {
- "stadium": {
- "label": "buffalo bills, the ralph"
- },
- "section": {
- "num": 10
- },
- "isle": {
- "num": 4
- },
- "row": {
- "num": 53
- },
+ "stadium": {"label": "buffalo bills, the ralph"},
+ "section": {"num": 10},
+ "isle": {"num": 4},
+ "row": {"num": 53},
},
}
success = tc.post("http://127.0.0.1:8000/predict_seat_img_two", json=body)
assert success.json() == {
- "result": {
- "seat_number": 16960000,
- "team": "buffalo bills, the ralph"
- },
+ "result": {"seat_number": 16960000, "team": "buffalo bills, the ralph"},
"session": "UUID",
}
@@ -476,6 +410,7 @@ def test_cycle_in_connection_fails(session_global_datadir, lightning_squeezenet1
def test_composition_from_url_torchscript_servable(tmp_path):
from flash.core.serve import expose, ModelComponent, Servable
from flash.core.serve.types import Number
+
"""
# Tensor x Tensor
class MyModule(torch.nn.Module):
@@ -494,7 +429,6 @@ def forward(self, a, b):
TORCHSCRIPT_DOWNLOAD_URL = "https://github.com/pytorch/pytorch/raw/95489b590f00801bdee7f41783f30874883cf6bb/test/jit/fixtures/test_versioned_div_tensor_inplace_v3.pt" # noqa E501
class ComponentTwoModels(ModelComponent):
-
def __init__(self, model):
self.encoder = model["encoder"]
self.decoder = model["decoder"]
@@ -523,15 +457,11 @@ def do_my_predict(self, inp):
body = {
"session": "UUID",
"payload": {
- "ep_in": {
- "num": 10
- },
+ "ep_in": {"num": 10},
},
}
success = tc.post("http://127.0.0.1:8000/predictr", json=body)
assert success.json() == {
- "result": {
- "ep_out": 1.0
- },
+ "result": {"ep_out": 1.0},
"session": "UUID",
}
diff --git a/tests/core/serve/test_types/test_bbox.py b/tests/core/serve/test_types/test_bbox.py
index fb4fbe26c0..ca58a8f2a9 100644
--- a/tests/core/serve/test_types/test_bbox.py
+++ b/tests/core/serve/test_types/test_bbox.py
@@ -6,7 +6,7 @@
def test_deserialize():
bbox = BBox()
- assert torch.allclose(bbox.deserialize((0, 0, 0, 0)), torch.zeros((4, )))
+ assert torch.allclose(bbox.deserialize((0, 0, 0, 0)), torch.zeros((4,)))
assert bbox.deserialize((0, 0, 0, 0)).shape == torch.Size([4])
with pytest.raises(ValueError):
# only three elements, need four
@@ -19,15 +19,17 @@ def test_deserialize():
bbox.deserialize({1: 1, 2: 2, 3: 3, 4: 4})
with pytest.raises(ValueError):
# tuple instead of float
- bbox.deserialize((
+ bbox.deserialize(
(
- 0,
- 0,
- ),
- (0, 0),
- (0, 0),
- (0, 0),
- ))
+ (
+ 0,
+ 0,
+ ),
+ (0, 0),
+ (0, 0),
+ (0, 0),
+ )
+ )
def test_serialize():
diff --git a/tests/core/serve/test_types/test_repeated.py b/tests/core/serve/test_types/test_repeated.py
index b8fa64ef7e..2038dd29ec 100644
--- a/tests/core/serve/test_types/test_repeated.py
+++ b/tests/core/serve/test_types/test_repeated.py
@@ -12,11 +12,7 @@ def test_repeated_deserialize():
def test_repeated_serialize(session_global_datadir):
repeated = Repeated(dtype=Label(path=str(session_global_datadir / "imagenet_labels.txt")))
- assert repeated.deserialize(*({
- "label": "chickadee"
- }, {
- "label": "stingray"
- })) == (
+ assert repeated.deserialize(*({"label": "chickadee"}, {"label": "stingray"})) == (
torch.tensor(19),
torch.tensor(6),
)
@@ -29,11 +25,7 @@ def test_repeated_max_len():
with pytest.raises(ValueError):
repeated.deserialize(*({"label": "classA"}, {"label": "classA"}, {"label": "classB"}))
- assert repeated.deserialize(*({
- "label": "classA"
- }, {
- "label": "classB"
- })) == (
+ assert repeated.deserialize(*({"label": "classA"}, {"label": "classB"})) == (
torch.tensor(0),
torch.tensor(1),
)
@@ -52,7 +44,6 @@ def test_repeated_max_len():
def test_repeated_non_serve_dtype():
-
class NonServeDtype:
pass
diff --git a/tests/core/serve/test_types/test_table.py b/tests/core/serve/test_types/test_table.py
index c1da29b703..5bccc64892 100644
--- a/tests/core/serve/test_types/test_table.py
+++ b/tests/core/serve/test_types/test_table.py
@@ -65,14 +65,7 @@ def test_deserialize():
with pytest.raises(RuntimeError):
table.deserialize({"title1": {0: 100}, "title2": {0: 200}})
assert torch.allclose(
- table.deserialize({
- "t1": {
- 0: 100.0
- },
- "t2": {
- 1: 200.0
- }
- }),
+ table.deserialize({"t1": {0: 100.0}, "t2": {1: 200.0}}),
torch.tensor([[100.0, float("nan")], [float("nan"), 200.0]], dtype=torch.float64),
equal_nan=True,
)
diff --git a/tests/core/test_classification.py b/tests/core/test_classification.py
index 9281c36ab4..6cfa7a2c50 100644
--- a/tests/core/test_classification.py
+++ b/tests/core/test_classification.py
@@ -16,22 +16,22 @@
from flash.core.classification import Classes, FiftyOneLabels, Labels, Logits, Probabilities
from flash.core.data.data_source import DefaultDataKeys
-from flash.core.utilities.imports import _FIFTYONE_AVAILABLE
+from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE
def test_classification_serializers():
example_output = torch.tensor([-0.1, 0.2, 0.3]) # 3 classes
- labels = ['class_1', 'class_2', 'class_3']
+ labels = ["class_1", "class_2", "class_3"]
assert torch.allclose(torch.tensor(Logits().serialize(example_output)), example_output)
assert torch.allclose(torch.tensor(Probabilities().serialize(example_output)), torch.softmax(example_output, -1))
assert Classes().serialize(example_output) == 2
- assert Labels(labels).serialize(example_output) == 'class_3'
+ assert Labels(labels).serialize(example_output) == "class_3"
def test_classification_serializers_multi_label():
example_output = torch.tensor([-0.1, 0.2, 0.3]) # 3 classes
- labels = ['class_1', 'class_2', 'class_3']
+ labels = ["class_1", "class_2", "class_3"]
assert torch.allclose(torch.tensor(Logits(multi_label=True).serialize(example_output)), example_output)
assert torch.allclose(
@@ -39,32 +39,33 @@ def test_classification_serializers_multi_label():
torch.sigmoid(example_output),
)
assert Classes(multi_label=True).serialize(example_output) == [1, 2]
- assert Labels(labels, multi_label=True).serialize(example_output) == ['class_2', 'class_3']
+ assert Labels(labels, multi_label=True).serialize(example_output) == ["class_2", "class_3"]
+@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing")
def test_classification_serializers_fiftyone():
logits = torch.tensor([-0.1, 0.2, 0.3])
example_output = {DefaultDataKeys.PREDS: logits, DefaultDataKeys.METADATA: {"filepath": "something"}} # 3 classes
- labels = ['class_1', 'class_2', 'class_3']
+ labels = ["class_1", "class_2", "class_3"]
predictions = FiftyOneLabels(return_filepath=True).serialize(example_output)
- assert predictions["predictions"].label == '2'
+ assert predictions["predictions"].label == "2"
assert predictions["filepath"] == "something"
predictions = FiftyOneLabels(labels, return_filepath=True).serialize(example_output)
- assert predictions["predictions"].label == 'class_3'
+ assert predictions["predictions"].label == "class_3"
assert predictions["filepath"] == "something"
predictions = FiftyOneLabels(store_logits=True).serialize(example_output)
assert torch.allclose(torch.tensor(predictions.logits), logits)
assert torch.allclose(torch.tensor(predictions.confidence), torch.softmax(logits, -1)[-1])
- assert predictions.label == '2'
+ assert predictions.label == "2"
predictions = FiftyOneLabels(labels, store_logits=True).serialize(example_output)
- assert predictions.label == 'class_3'
+ assert predictions.label == "class_3"
predictions = FiftyOneLabels(store_logits=True, multi_label=True).serialize(example_output)
assert torch.allclose(torch.tensor(predictions.logits), logits)
- assert [c.label for c in predictions.classifications] == ['1', '2']
+ assert [c.label for c in predictions.classifications] == ["1", "2"]
predictions = FiftyOneLabels(labels, multi_label=True).serialize(example_output)
- assert [c.label for c in predictions.classifications] == ['class_2', 'class_3']
+ assert [c.label for c in predictions.classifications] == ["class_2", "class_3"]
diff --git a/tests/core/test_data.py b/tests/core/test_data.py
index a51d8756e2..156669a657 100644
--- a/tests/core/test_data.py
+++ b/tests/core/test_data.py
@@ -21,9 +21,8 @@
class DummyDataset(torch.utils.data.Dataset):
-
def __getitem__(self, index):
- return torch.rand(1, 28, 28), torch.randint(10, size=(1, )).item()
+ return torch.rand(1, 28, 28), torch.randint(10, size=(1,)).item()
def __len__(self) -> int:
return 10
@@ -49,7 +48,7 @@ def test_dataloaders():
dm.test_dataloader(),
]:
x, y = next(iter(dl))
- assert x.shape == (1, 1, 28, 28)
+ assert x.shape == (4, 1, 28, 28)
def test_cpu_count_none():
diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py
index ad44cc7dbf..809bfb41ab 100644
--- a/tests/core/test_finetuning.py
+++ b/tests/core/test_finetuning.py
@@ -24,9 +24,8 @@
class DummyDataset(torch.utils.data.Dataset):
-
def __getitem__(self, index: int) -> Any:
- return {"input": torch.rand(3, 64, 64), "target": torch.randint(10, size=(1, )).item()}
+ return {"input": torch.rand(3, 64, 64), "target": torch.randint(10, size=(1,)).item()}
def __len__(self) -> int:
return 100
@@ -34,7 +33,7 @@ def __len__(self) -> int:
@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.parametrize(
- "strategy", ['no_freeze', 'freeze', 'freeze_unfreeze', 'unfreeze_milestones', None, 'cls', 'chocolat']
+ "strategy", ["no_freeze", "freeze", "freeze_unfreeze", "unfreeze_milestones", None, "cls", "chocolat"]
)
def test_finetuning(tmpdir: str, strategy):
train_dl = torch.utils.data.DataLoader(DummyDataset())
@@ -43,7 +42,7 @@ def test_finetuning(tmpdir: str, strategy):
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
if strategy == "cls":
strategy = NoFreeze()
- if strategy == 'chocolat' or strategy is None:
+ if strategy == "chocolat" or strategy is None:
with pytest.raises(MisconfigurationException, match="strategy should be provided"):
trainer.finetune(task, train_dl, val_dl, strategy=strategy)
else:
diff --git a/tests/core/test_model.py b/tests/core/test_model.py
index 6336bdfb06..3d3b53b111 100644
--- a/tests/core/test_model.py
+++ b/tests/core/test_model.py
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import math
+from itertools import chain
from numbers import Number
from pathlib import Path
from typing import Any, Tuple
@@ -20,15 +22,17 @@
import pytest
import pytorch_lightning as pl
import torch
+from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn, Tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader
import flash
+from flash.core.adapter import Adapter
from flash.core.classification import ClassificationTask
from flash.core.data.process import DefaultPreprocess, Postprocess
-from flash.core.utilities.imports import _PIL_AVAILABLE, _TABULAR_AVAILABLE, _TEXT_AVAILABLE
+from flash.core.utilities.imports import _TABULAR_AVAILABLE, _TEXT_AVAILABLE, Image
from flash.image import ImageClassificationData, ImageClassifier
from tests.helpers.utils import _IMAGE_TESTING, _TABULAR_TESTING
@@ -37,27 +41,24 @@
else:
TabularClassifier = None
-if _PIL_AVAILABLE:
- from PIL import Image
-else:
-
- class Image:
- Image = None
-
# ======== Mock functions ========
class DummyDataset(torch.utils.data.Dataset):
+ def __init__(self, num_samples: int = 9):
+ self.num_samples = num_samples
def __getitem__(self, index: int) -> Tuple[Tensor, Number]:
- return torch.rand(1, 28, 28), torch.randint(10, size=(1, )).item()
+ return torch.rand(1, 28, 28), torch.randint(10, size=(1,)).item()
def __len__(self) -> int:
- return 9
+ return self.num_samples
class PredictDummyDataset(DummyDataset):
+ def __init__(self, num_samples: int):
+ super().__init__(num_samples)
def __getitem__(self, index: int) -> Tensor:
return torch.rand(1, 28, 28)
@@ -68,6 +69,80 @@ class DummyPostprocess(Postprocess):
pass
+class FixedDataset(torch.utils.data.Dataset):
+ def __init__(self, targets):
+ super().__init__()
+
+ self.targets = targets
+
+ def __getitem__(self, index: int) -> Tuple[Tensor, Number]:
+ return torch.rand(1), self.targets[index]
+
+ def __len__(self) -> int:
+ return len(self.targets)
+
+
+class OnesModel(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ self.layer = nn.Linear(1, 2)
+ self.register_buffer("zeros", torch.zeros(2))
+ self.register_buffer("zero_one", torch.tensor([0.0, 1.0]))
+
+ def forward(self, x):
+ x = self.layer(x)
+ return x * self.zeros + self.zero_one
+
+
+class Parent(ClassificationTask):
+ def __init__(self, child):
+ super().__init__()
+
+ self.child = child
+
+ def training_step(self, batch, batch_idx):
+ return self.child.training_step(batch, batch_idx)
+
+ def validation_step(self, batch, batch_idx):
+ return self.child.validation_step(batch, batch_idx)
+
+ def test_step(self, batch, batch_idx):
+ return self.child.test_step(batch, batch_idx)
+
+ def forward(self, x):
+ return self.child(x)
+
+
+class GrandParent(Parent):
+ def __init__(self, child):
+ super().__init__(Parent(child))
+
+
+class BasicAdapter(Adapter):
+ def __init__(self, child):
+ super().__init__()
+
+ self.child = child
+
+ def training_step(self, batch, batch_idx):
+ return self.child.training_step(batch, batch_idx)
+
+ def validation_step(self, batch, batch_idx):
+ return self.child.validation_step(batch, batch_idx)
+
+ def test_step(self, batch, batch_idx):
+ return self.child.test_step(batch, batch_idx)
+
+ def forward(self, x):
+ return self.child(x)
+
+
+class AdapterParent(Parent):
+ def __init__(self, child):
+ super().__init__(BasicAdapter(child))
+
+
# ================================
@@ -83,6 +158,21 @@ def test_classificationtask_train(tmpdir: str, metrics: Any):
assert "test_nll_loss" in result[0]
+@pytest.mark.parametrize("task", [Parent, GrandParent, AdapterParent])
+def test_nested_tasks(tmpdir, task):
+ model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax())
+ train_dl = torch.utils.data.DataLoader(DummyDataset())
+ val_dl = torch.utils.data.DataLoader(DummyDataset())
+ child_task = ClassificationTask(model, loss_fn=F.nll_loss)
+
+ parent_task = task(child_task)
+
+ trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
+ trainer.fit(parent_task, train_dl, val_dl)
+ result = trainer.test(parent_task, val_dl)
+ assert "test_nll_loss" in result[0]
+
+
def test_classificationtask_task_predict():
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax())
task = ClassificationTask(model, preprocess=DefaultPreprocess())
@@ -121,15 +211,12 @@ def _rand_image():
def test_classification_task_trainer_predict(tmpdir):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
task = ClassificationTask(model)
- ds = PredictDummyDataset()
- batch_size = 3
- predict_dl = torch.utils.data.DataLoader(ds, batch_size=batch_size)
+ ds = PredictDummyDataset(10)
+ batch_size = 6
+ predict_dl = task.process_predict_dataset(ds, batch_size=batch_size)
trainer = pl.Trainer(default_root_dir=tmpdir)
predictions = trainer.predict(task, predict_dl)
- assert len(predictions) == len(ds) // batch_size
- for batch_pred in predictions:
- assert len(batch_pred) == batch_size
- assert all(y < 10 for y in batch_pred)
+ assert len(list(chain.from_iterable(predictions))) == 10
def test_task_datapipeline_save(tmpdir):
@@ -158,24 +245,27 @@ def test_task_datapipeline_save(tmpdir):
assert task.postprocess.test
-@pytest.mark.parametrize(["cls", "filename"], [
- pytest.param(
- ImageClassifier,
- "image_classification_model.pt",
- marks=pytest.mark.skipif(
- not _IMAGE_TESTING,
- reason="image packages aren't installed",
- )
- ),
- pytest.param(
- TabularClassifier,
- "tabular_classification_model.pt",
- marks=pytest.mark.skipif(
- not _TABULAR_TESTING,
- reason="tabular packages aren't installed",
- )
- ),
-])
+@pytest.mark.parametrize(
+ ["cls", "filename"],
+ [
+ pytest.param(
+ ImageClassifier,
+ "image_classification_model.pt",
+ marks=pytest.mark.skipif(
+ not _IMAGE_TESTING,
+ reason="image packages aren't installed",
+ ),
+ ),
+ pytest.param(
+ TabularClassifier,
+ "tabular_classification_model.pt",
+ marks=pytest.mark.skipif(
+ not _TABULAR_TESTING,
+ reason="tabular packages aren't installed",
+ ),
+ ),
+ ],
+)
def test_model_download(tmpdir, cls, filename):
url = "https://flash-weights.s3.amazonaws.com/"
with tmpdir.as_cwd():
@@ -191,7 +281,7 @@ def test_available_backbones():
class Foo(ImageClassifier):
backbones = None
- assert Foo.available_backbones() == []
+ assert Foo.available_backbones() == {}
def test_optimization(tmpdir):
@@ -212,7 +302,7 @@ def test_optimization(tmpdir):
model,
optimizer=torch.optim.Adadelta,
scheduler=torch.optim.lr_scheduler.StepLR,
- scheduler_kwargs={"step_size": 1}
+ scheduler_kwargs={"step_size": 1},
)
optimizer, scheduler = task.configure_optimizers()
assert isinstance(optimizer[0], torch.optim.Adadelta)
@@ -241,11 +331,26 @@ def test_optimization(tmpdir):
scheduler_kwargs={"num_warmup_steps": 0.1},
loss_fn=F.nll_loss,
)
- trainer = flash.Trainer(max_epochs=1, limit_train_batches=2)
+ trainer = flash.Trainer(max_epochs=1, limit_train_batches=2, gpus=torch.cuda.device_count())
ds = DummyDataset()
trainer.fit(task, train_dataloader=DataLoader(ds))
optimizer, scheduler = task.configure_optimizers()
assert isinstance(optimizer[0], torch.optim.Adadelta)
assert isinstance(scheduler[0], torch.optim.lr_scheduler.LambdaLR)
expected = get_linear_schedule_with_warmup.__name__
- assert scheduler[0].lr_lambdas[0].__qualname__.split('.')[0] == expected
+ assert scheduler[0].lr_lambdas[0].__qualname__.split(".")[0] == expected
+
+
+def test_classification_task_metrics():
+ train_dataset = FixedDataset([0, 1])
+ val_dataset = FixedDataset([1, 1])
+
+ model = OnesModel()
+
+ class CheckAccuracy(Callback):
+ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
+ assert math.isclose(trainer.callback_metrics["train_accuracy_epoch"], 0.5)
+
+ task = ClassificationTask(model)
+ trainer = flash.Trainer(max_epochs=1, callbacks=CheckAccuracy(), gpus=torch.cuda.device_count())
+ trainer.fit(task, train_dataloader=DataLoader(train_dataset), val_dataloaders=DataLoader(val_dataset))
diff --git a/tests/core/test_registry.py b/tests/core/test_registry.py
index 061c6f4504..a230b869c0 100644
--- a/tests/core/test_registry.py
+++ b/tests/core/test_registry.py
@@ -27,20 +27,20 @@ def test_registry_raises():
def my_model(nc_input=5, nc_output=6):
return nn.Linear(nc_input, nc_output), nc_input, nc_output
- with pytest.raises(MisconfigurationException, match="You can only register a function, found: Linear"):
- backbones(nn.Linear(1, 1), name="cho")
+ with pytest.raises(MisconfigurationException, match="You can only register a callable, found: 3"):
+ backbones(3, name="foo")
- backbones(my_model, name="cho", override=True)
+ backbones(my_model, name="foo", override=True)
- with pytest.raises(MisconfigurationException, match="Function with name: cho and metadata: {}"):
- backbones(my_model, name="cho", override=False)
+ with pytest.raises(MisconfigurationException, match="Function with name: foo and metadata: {}"):
+ backbones(my_model, name="foo", override=False)
with pytest.raises(KeyError, match="Found no matches"):
- backbones.get("cho", foo="bar")
+ backbones.get("foo", baz="bar")
- backbones.remove("cho")
- with pytest.raises(KeyError, match="Key: cho is not in FlashRegistry"):
- backbones.get("cho")
+ backbones.remove("foo")
+ with pytest.raises(KeyError, match="Key: foo is not in FlashRegistry"):
+ backbones.get("foo")
with pytest.raises(TypeError, match="name` must be a str"):
backbones(name=float) # noqa
@@ -59,30 +59,30 @@ def my_model(nc_input=5, nc_output=6):
assert mlp.weight.shape == (7, 5)
# basic get
- backbones(my_model, name="cho")
- assert backbones.get("cho")
+ backbones(my_model, name="foo")
+ assert backbones.get("foo")
# test override
- backbones(my_model, name="cho", override=True)
- functions = backbones.get("cho", strict=False)
+ backbones(my_model, name="foo", override=True)
+ functions = backbones.get("foo", strict=False)
assert len(functions) == 1
# test metadata filtering
- backbones(my_model, name="cho", namespace="timm", type="resnet")
- backbones(my_model, name="cho", namespace="torchvision", type="resnet")
- backbones(my_model, name="cho", namespace="timm", type="densenet")
- backbones(my_model, name="cho", namespace="timm", type="alexnet")
- function = backbones.get("cho", with_metadata=True, type="resnet", namespace="timm")
- assert function["name"] == "cho"
+ backbones(my_model, name="foo", namespace="timm", type="resnet")
+ backbones(my_model, name="foo", namespace="torchvision", type="resnet")
+ backbones(my_model, name="foo", namespace="timm", type="densenet")
+ backbones(my_model, name="foo", namespace="timm", type="alexnet")
+ function = backbones.get("foo", with_metadata=True, type="resnet", namespace="timm")
+ assert function["name"] == "foo"
assert function["metadata"] == {"namespace": "timm", "type": "resnet"}
# test strict=False and with_metadata=False
- functions = backbones.get("cho", namespace="timm", strict=False)
+ functions = backbones.get("foo", namespace="timm", strict=False)
assert len(functions) == 3
assert all(callable(f) for f in functions)
# test available keys
- assert backbones.available_keys() == ['cho', 'cho', 'cho', 'cho', 'cho', 'my_model']
+ assert backbones.available_keys() == ["foo", "foo", "foo", "foo", "foo", "my_model"]
# todo (tchaton) Debug this test.
@@ -100,8 +100,8 @@ def my_model():
assert caplog.messages == [
"Registering: my_model function with name: bar and metadata: {'foobar': True}",
- 'Registering: my_model function with name: foo and metadata: {}',
- 'Registering: my_model function with name: my_model and metadata: {}'
+ "Registering: my_model function with name: foo and metadata: {}",
+ "Registering: my_model function with name: my_model and metadata: {}",
]
assert len(backbones) == 3
diff --git a/tests/core/test_trainer.py b/tests/core/test_trainer.py
index 7bd330d83a..436bb48a2e 100644
--- a/tests/core/test_trainer.py
+++ b/tests/core/test_trainer.py
@@ -27,7 +27,6 @@
class DummyDataset(torch.utils.data.Dataset):
-
def __init__(self, predict: bool = False):
self._predict = predict
@@ -35,14 +34,13 @@ def __getitem__(self, index: int) -> Any:
sample = torch.rand(1, 28, 28)
if self._predict:
return sample
- return sample, torch.randint(10, size=(1, )).item()
+ return sample, torch.randint(10, size=(1,)).item()
def __len__(self) -> int:
return 100
class DummyClassifier(nn.Module):
-
def __init__(self):
super().__init__()
self.backbone = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
@@ -85,7 +83,6 @@ def test_resolve_callbacks_invalid_strategy(tmpdir):
class MultiFinetuneClassificationTask(ClassificationTask):
-
def configure_finetune_callback(self):
return [NoFreeze(), NoFreeze()]
@@ -99,7 +96,6 @@ def test_resolve_callbacks_multi_error(tmpdir):
class FinetuneClassificationTask(ClassificationTask):
-
def configure_finetune_callback(self):
return [NoFreeze()]
@@ -115,14 +111,14 @@ def test_resolve_callbacks_override_warning(tmpdir):
def test_add_argparse_args():
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
- args = parser.parse_args(['--gpus=1'])
+ args = parser.parse_args(["--gpus=1"])
assert args.gpus == 1
def test_from_argparse_args():
parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
- args = parser.parse_args(['--max_epochs=200'])
+ args = parser.parse_args(["--max_epochs=200"])
trainer = Trainer.from_argparse_args(args)
assert trainer.max_epochs == 200
assert isinstance(trainer, Trainer)
diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py
index 250aba1122..49d24bf7ab 100644
--- a/tests/core/test_utils.py
+++ b/tests/core/test_utils.py
@@ -20,7 +20,6 @@
class A:
-
def __call__(self, x):
return True
@@ -54,4 +53,4 @@ def test_get_callable_dict():
def test_download_data(tmpdir):
path = os.path.join(tmpdir, "data")
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", path)
- assert set(os.listdir(path)) == {'titanic', 'titanic.zip'}
+ assert set(os.listdir(path)) == {"titanic", "titanic.zip"}
diff --git a/tests/core/utilities/__init__.py b/tests/core/utilities/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/core/utilities/test_lightning_cli.py b/tests/core/utilities/test_lightning_cli.py
new file mode 100644
index 0000000000..1b664a02e5
--- /dev/null
+++ b/tests/core/utilities/test_lightning_cli.py
@@ -0,0 +1,721 @@
+# Adapted from the Lightning CLI:
+# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/utilities/test_cli.py
+import inspect
+import json
+import os
+import pickle
+import sys
+from argparse import Namespace
+from contextlib import redirect_stdout
+from io import StringIO
+from typing import List, Optional, Union
+from unittest import mock
+
+import pytest
+import torch
+import yaml
+from packaging import version
+from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
+from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
+from pytorch_lightning.plugins.environments import SLURMEnvironment
+
+from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
+from flash.core.utilities.lightning_cli import (
+ instantiate_class,
+ LightningArgumentParser,
+ LightningCLI,
+ SaveConfigCallback,
+)
+from tests.helpers.boring_model import BoringDataModule, BoringModel
+
+torchvision_version = version.parse("0")
+if _TORCHVISION_AVAILABLE:
+ torchvision_version = version.parse(__import__("torchvision").__version__)
+
+
+@mock.patch("argparse.ArgumentParser.parse_args")
+def test_default_args(mock_argparse, tmpdir):
+ """Tests default argument parser for Trainer."""
+ mock_argparse.return_value = Namespace(**Trainer.default_attributes())
+
+ parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
+ args = parser.parse_args([])
+
+ args.max_epochs = 5
+ trainer = Trainer.from_argparse_args(args)
+
+ assert isinstance(trainer, Trainer)
+ assert trainer.max_epochs == 5
+
+
+@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], ["--weights_save_path=./"], []])
+def test_add_argparse_args_redefined(cli_args):
+ """Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness."""
+ parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
+ parser.add_lightning_class_args(Trainer, None)
+
+ args = parser.parse_args(cli_args)
+
+ # make sure we can pickle args
+ pickle.dumps(args)
+
+ # Check few deprecated args are not in namespace:
+ for depr_name in ("gradient_clip", "nb_gpu_nodes", "max_nb_epochs"):
+ assert depr_name not in args
+
+ trainer = Trainer.from_argparse_args(args=args)
+ pickle.dumps(trainer)
+
+ assert isinstance(trainer, Trainer)
+
+
+@pytest.mark.parametrize(
+ ["cli_args", "expected"],
+ [
+ ("--auto_lr_find=True --auto_scale_batch_size=power", dict(auto_lr_find=True, auto_scale_batch_size="power")),
+ (
+ "--auto_lr_find any_string --auto_scale_batch_size ON",
+ dict(auto_lr_find="any_string", auto_scale_batch_size=True),
+ ),
+ ("--auto_lr_find=Yes --auto_scale_batch_size=On", dict(auto_lr_find=True, auto_scale_batch_size=True)),
+ ("--auto_lr_find Off --auto_scale_batch_size No", dict(auto_lr_find=False, auto_scale_batch_size=False)),
+ ("--auto_lr_find TRUE --auto_scale_batch_size FALSE", dict(auto_lr_find=True, auto_scale_batch_size=False)),
+ ("--limit_train_batches=100", dict(limit_train_batches=100)),
+ ("--limit_train_batches 0.8", dict(limit_train_batches=0.8)),
+ ("--weights_summary=null", dict(weights_summary=None)),
+ (
+ "",
+ dict(
+ # These parameters are marked as Optional[...] in Trainer.__init__,
+ # with None as default. They should not be changed by the argparse
+ # interface.
+ min_steps=None,
+ max_steps=None,
+ log_gpu_memory=None,
+ distributed_backend=None,
+ weights_save_path=None,
+ truncated_bptt_steps=None,
+ resume_from_checkpoint=None,
+ profiler=None,
+ ),
+ ),
+ ],
+)
+def test_parse_args_parsing(cli_args, expected):
+ """Test parsing simple types and None optionals not modified."""
+ cli_args = cli_args.split(" ") if cli_args else []
+ parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
+ parser.add_lightning_class_args(Trainer, None)
+ with mock.patch("sys.argv", ["any.py"] + cli_args):
+ args = parser.parse_args()
+
+ for k, v in expected.items():
+ assert getattr(args, k) == v
+ assert Trainer.from_argparse_args(args)
+
+
+@pytest.mark.parametrize(
+ ["cli_args", "expected", "instantiate"],
+ [
+ (["--gpus", "[0, 2]"], dict(gpus=[0, 2]), False),
+ (["--tpu_cores=[1,3]"], dict(tpu_cores=[1, 3]), False),
+ (['--accumulate_grad_batches={"5":3,"10":20}'], dict(accumulate_grad_batches={5: 3, 10: 20}), True),
+ ],
+)
+def test_parse_args_parsing_complex_types(cli_args, expected, instantiate):
+ """Test parsing complex types."""
+ parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
+ parser.add_lightning_class_args(Trainer, None)
+ with mock.patch("sys.argv", ["any.py"] + cli_args):
+ args = parser.parse_args()
+
+ for k, v in expected.items():
+ assert getattr(args, k) == v
+ if instantiate:
+ assert Trainer.from_argparse_args(args)
+
+
+@pytest.mark.parametrize(
+ ["cli_args", "expected_gpu"],
+ [
+ ("--gpus 1", [0]),
+ ("--gpus 0,", [0]),
+ ("--gpus 0,1", [0, 1]),
+ ],
+)
+def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu):
+ """Test parsing of gpus and instantiation of Trainer."""
+ monkeypatch.setattr("torch.cuda.device_count", lambda: 2)
+ cli_args = cli_args.split(" ") if cli_args else []
+ parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
+ parser.add_lightning_class_args(Trainer, None)
+ with mock.patch("sys.argv", ["any.py"] + cli_args):
+ args = parser.parse_args()
+
+ trainer = Trainer.from_argparse_args(args)
+ assert trainer.data_parallel_device_ids == expected_gpu
+
+
+@pytest.mark.skipif(
+ sys.version_info < (3, 7),
+ reason="signature inspection while mocking is not working in Python < 3.7 despite autospec",
+)
+@pytest.mark.parametrize(
+ ["cli_args", "extra_args"],
+ [
+ ({}, {}),
+ (dict(logger=False), {}),
+ (dict(logger=False), dict(logger=True)),
+ (dict(logger=False), dict(checkpoint_callback=True)),
+ ],
+)
+def test_init_from_argparse_args(cli_args, extra_args):
+ unknown_args = dict(unknown_arg=0)
+
+ # unkown args in the argparser/namespace should be ignored
+ with mock.patch("pytorch_lightning.Trainer.__init__", autospec=True, return_value=None) as init:
+ trainer = Trainer.from_argparse_args(Namespace(**cli_args, **unknown_args), **extra_args)
+ expected = dict(cli_args)
+ expected.update(extra_args) # extra args should override any cli arg
+ init.assert_called_with(trainer, **expected)
+
+ # passing in unknown manual args should throw an error
+ with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'unknown_arg'"):
+ Trainer.from_argparse_args(Namespace(**cli_args), **extra_args, **unknown_args)
+
+
+class Model(LightningModule):
+ def __init__(self, model_param: int):
+ super().__init__()
+ self.model_param = model_param
+
+
+def model_builder(model_param: int) -> Model:
+ return Model(model_param)
+
+
+def trainer_builder(
+ limit_train_batches: int, fast_dev_run: bool = False, callbacks: Optional[Union[List[Callback], Callback]] = None
+) -> Trainer:
+ return Trainer(limit_train_batches=limit_train_batches, fast_dev_run=fast_dev_run, callbacks=callbacks)
+
+
+@pytest.mark.parametrize(["trainer_class", "model_class"], [(Trainer, Model), (trainer_builder, model_builder)])
+def test_lightning_cli(trainer_class, model_class, monkeypatch):
+ """Test that LightningCLI correctly instantiates model, trainer and calls fit."""
+
+ expected_model = dict(model_param=7)
+ expected_trainer = dict(limit_train_batches=100)
+
+ def fit(trainer, model):
+ for k, v in expected_model.items():
+ assert getattr(model, k) == v
+ for k, v in expected_trainer.items():
+ assert getattr(trainer, k) == v
+ save_callback = [x for x in trainer.callbacks if isinstance(x, SaveConfigCallback)]
+ assert len(save_callback) == 1
+ save_callback[0].on_train_start(trainer, model)
+
+ def on_train_start(callback, trainer, _):
+ config_dump = callback.parser.dump(callback.config, skip_none=False)
+ for k, v in expected_model.items():
+ assert f" {k}: {v}" in config_dump
+ for k, v in expected_trainer.items():
+ assert f" {k}: {v}" in config_dump
+ trainer.ran_asserts = True
+
+ monkeypatch.setattr(Trainer, "fit", fit)
+ monkeypatch.setattr(SaveConfigCallback, "on_train_start", on_train_start)
+
+ with mock.patch("sys.argv", ["any.py", "--model.model_param=7", "--trainer.limit_train_batches=100"]):
+ cli = LightningCLI(model_class, trainer_class=trainer_class, save_config_callback=SaveConfigCallback)
+ assert hasattr(cli.trainer, "ran_asserts") and cli.trainer.ran_asserts
+
+
+def test_lightning_cli_args_callbacks(tmpdir):
+
+ callbacks = [
+ dict(
+ class_path="pytorch_lightning.callbacks.LearningRateMonitor",
+ init_args=dict(logging_interval="epoch", log_momentum=True),
+ ),
+ dict(class_path="pytorch_lightning.callbacks.ModelCheckpoint", init_args=dict(monitor="NAME")),
+ ]
+
+ class TestModel(BoringModel):
+ def on_fit_start(self):
+ callback = [c for c in self.trainer.callbacks if isinstance(c, LearningRateMonitor)]
+ assert len(callback) == 1
+ assert callback[0].logging_interval == "epoch"
+ assert callback[0].log_momentum is True
+ callback = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]
+ assert len(callback) == 1
+ assert callback[0].monitor == "NAME"
+ self.trainer.ran_asserts = True
+
+ with mock.patch("sys.argv", ["any.py", f"--trainer.callbacks={json.dumps(callbacks)}"]):
+ cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True))
+
+ assert cli.trainer.ran_asserts
+
+
+def test_lightning_cli_configurable_callbacks(tmpdir):
+ class MyLightningCLI(LightningCLI):
+ def add_arguments_to_parser(self, parser):
+ parser.add_lightning_class_args(LearningRateMonitor, "learning_rate_monitor")
+
+ cli_args = [
+ f"--trainer.default_root_dir={tmpdir}",
+ "--trainer.max_epochs=1",
+ "--learning_rate_monitor.logging_interval=epoch",
+ ]
+
+ with mock.patch("sys.argv", ["any.py"] + cli_args):
+ cli = MyLightningCLI(BoringModel)
+
+ callback = [c for c in cli.trainer.callbacks if isinstance(c, LearningRateMonitor)]
+ assert len(callback) == 1
+ assert callback[0].logging_interval == "epoch"
+
+
+def test_lightning_cli_args_cluster_environments(tmpdir):
+ plugins = [dict(class_path="pytorch_lightning.plugins.environments.SLURMEnvironment")]
+
+ class TestModel(BoringModel):
+ def on_fit_start(self):
+ # Ensure SLURMEnvironment is set, instead of default LightningEnvironment
+ assert isinstance(self.trainer.accelerator_connector._cluster_environment, SLURMEnvironment)
+ self.trainer.ran_asserts = True
+
+ with mock.patch("sys.argv", ["any.py", f"--trainer.plugins={json.dumps(plugins)}"]):
+ cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True))
+
+ assert cli.trainer.ran_asserts
+
+
+def test_lightning_cli_args(tmpdir):
+
+ cli_args = [
+ f"--data.data_dir={tmpdir}",
+ f"--trainer.default_root_dir={tmpdir}",
+ "--trainer.max_epochs=1",
+ "--trainer.weights_summary=null",
+ "--seed_everything=1234",
+ ]
+
+ with mock.patch("sys.argv", ["any.py"] + cli_args):
+ cli = LightningCLI(BoringModel, BoringDataModule, trainer_defaults={"callbacks": [LearningRateMonitor()]})
+
+ assert cli.config["seed_everything"] == 1234
+ config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml"
+ assert os.path.isfile(config_path)
+ with open(config_path) as f:
+ config = yaml.safe_load(f.read())
+ assert "model" not in config and "model" not in cli.config # no arguments to include
+ assert config["data"] == cli.config["data"]
+ assert config["trainer"] == cli.config["trainer"]
+
+
+def test_lightning_cli_save_config_cases(tmpdir):
+
+ config_path = tmpdir / "config.yaml"
+ cli_args = [
+ f"--trainer.default_root_dir={tmpdir}",
+ "--trainer.logger=False",
+ "--trainer.fast_dev_run=1",
+ ]
+
+ # With fast_dev_run!=False config should not be saved
+ with mock.patch("sys.argv", ["any.py"] + cli_args):
+ LightningCLI(BoringModel)
+ assert not os.path.isfile(config_path)
+
+ # With fast_dev_run==False config should be saved
+ cli_args[-1] = "--trainer.max_epochs=1"
+ with mock.patch("sys.argv", ["any.py"] + cli_args):
+ LightningCLI(BoringModel)
+ assert os.path.isfile(config_path)
+
+ # If run again on same directory exception should be raised since config file already exists
+ with mock.patch("sys.argv", ["any.py"] + cli_args), pytest.raises(RuntimeError):
+ LightningCLI(BoringModel)
+
+
+def test_lightning_cli_config_and_subclass_mode(tmpdir):
+
+ config = dict(
+ model=dict(class_path="tests.helpers.boring_model.BoringModel"),
+ data=dict(class_path="tests.helpers.boring_model.BoringDataModule", init_args=dict(data_dir=str(tmpdir))),
+ trainer=dict(default_root_dir=str(tmpdir), max_epochs=1, weights_summary=None),
+ )
+ config_path = tmpdir / "config.yaml"
+ with open(config_path, "w") as f:
+ f.write(yaml.dump(config))
+
+ with mock.patch("sys.argv", ["any.py", "--config", str(config_path)]):
+ cli = LightningCLI(
+ BoringModel,
+ BoringDataModule,
+ subclass_mode_model=True,
+ subclass_mode_data=True,
+ trainer_defaults={"callbacks": LearningRateMonitor()},
+ )
+
+ config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml"
+ assert os.path.isfile(config_path)
+ with open(config_path) as f:
+ config = yaml.safe_load(f.read())
+ assert config["model"] == cli.config["model"]
+ assert config["data"] == cli.config["data"]
+ assert config["trainer"] == cli.config["trainer"]
+
+
+def any_model_any_data_cli():
+ LightningCLI(
+ LightningModule,
+ LightningDataModule,
+ subclass_mode_model=True,
+ subclass_mode_data=True,
+ )
+
+
+def test_lightning_cli_help():
+
+ cli_args = ["any.py", "--help"]
+ out = StringIO()
+ with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit):
+ any_model_any_data_cli()
+
+ assert "--print_config" in out.getvalue()
+ assert "--config" in out.getvalue()
+ assert "--seed_everything" in out.getvalue()
+ assert "--model.help" in out.getvalue()
+ assert "--data.help" in out.getvalue()
+
+ skip_params = {"self"}
+ for param in inspect.signature(Trainer.__init__).parameters.keys():
+ if param not in skip_params:
+ assert f"--trainer.{param}" in out.getvalue()
+
+ cli_args = ["any.py", "--data.help=tests.helpers.boring_model.BoringDataModule"]
+ out = StringIO()
+ with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit):
+ any_model_any_data_cli()
+
+ assert "--data.init_args.data_dir" in out.getvalue()
+
+
+def test_lightning_cli_print_config():
+
+ cli_args = [
+ "any.py",
+ "--seed_everything=1234",
+ "--model=tests.helpers.boring_model.BoringModel",
+ "--data=tests.helpers.boring_model.BoringDataModule",
+ "--print_config",
+ ]
+
+ out = StringIO()
+ with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit):
+ any_model_any_data_cli()
+
+ outval = yaml.safe_load(out.getvalue())
+ assert outval["seed_everything"] == 1234
+ assert outval["model"]["class_path"] == "tests.helpers.boring_model.BoringModel"
+ assert outval["data"]["class_path"] == "tests.helpers.boring_model.BoringDataModule"
+
+
+def test_lightning_cli_submodules(tmpdir):
+ class MainModule(BoringModel):
+ def __init__(
+ self,
+ submodule1: LightningModule,
+ submodule2: LightningModule,
+ main_param: int = 1,
+ ):
+ super().__init__()
+ self.submodule1 = submodule1
+ self.submodule2 = submodule2
+
+ config = """model:
+ main_param: 2
+ submodule1:
+ class_path: tests.helpers.boring_model.BoringModel
+ submodule2:
+ class_path: tests.helpers.boring_model.BoringModel
+ """
+ config_path = tmpdir / "config.yaml"
+ with open(config_path, "w") as f:
+ f.write(config)
+
+ cli_args = [
+ f"--trainer.default_root_dir={tmpdir}",
+ "--trainer.max_epochs=1",
+ f"--config={str(config_path)}",
+ ]
+
+ with mock.patch("sys.argv", ["any.py"] + cli_args):
+ cli = LightningCLI(MainModule)
+
+ assert cli.config["model"]["main_param"] == 2
+ assert isinstance(cli.model.submodule1, BoringModel)
+ assert isinstance(cli.model.submodule2, BoringModel)
+
+
+@pytest.mark.skipif(torchvision_version < version.parse("0.8.0"), reason="torchvision>=0.8.0 is required")
+def test_lightning_cli_torch_modules(tmpdir):
+ class TestModule(BoringModel):
+ def __init__(
+ self,
+ activation: torch.nn.Module = None,
+ transform: Optional[List[torch.nn.Module]] = None,
+ ):
+ super().__init__()
+ self.activation = activation
+ self.transform = transform
+
+ config = """model:
+ activation:
+ class_path: torch.nn.LeakyReLU
+ init_args:
+ negative_slope: 0.2
+ transform:
+ - class_path: torchvision.transforms.Resize
+ init_args:
+ size: 64
+ - class_path: torchvision.transforms.CenterCrop
+ init_args:
+ size: 64
+ """
+ config_path = tmpdir / "config.yaml"
+ with open(config_path, "w") as f:
+ f.write(config)
+
+ cli_args = [
+ f"--trainer.default_root_dir={tmpdir}",
+ "--trainer.max_epochs=1",
+ f"--config={str(config_path)}",
+ ]
+
+ with mock.patch("sys.argv", ["any.py"] + cli_args):
+ cli = LightningCLI(TestModule)
+
+ assert isinstance(cli.model.activation, torch.nn.LeakyReLU)
+ assert cli.model.activation.negative_slope == 0.2
+ assert len(cli.model.transform) == 2
+ assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform)
+
+
+class BoringModelRequiredClasses(BoringModel):
+ def __init__(
+ self,
+ num_classes: int,
+ batch_size: int = 8,
+ ):
+ super().__init__()
+ self.num_classes = num_classes
+ self.batch_size = batch_size
+
+
+class BoringDataModuleBatchSizeAndClasses(BoringDataModule):
+ def __init__(
+ self,
+ batch_size: int = 8,
+ ):
+ super().__init__()
+ self.batch_size = batch_size
+ self.num_classes = 5 # only available after instantiation
+
+
+def test_lightning_cli_link_arguments(tmpdir):
+ class MyLightningCLI(LightningCLI):
+ def add_arguments_to_parser(self, parser):
+ parser.link_arguments("data.batch_size", "model.batch_size")
+ parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate")
+
+ cli_args = [
+ f"--trainer.default_root_dir={tmpdir}",
+ "--trainer.max_epochs=1",
+ "--data.batch_size=12",
+ ]
+
+ with mock.patch("sys.argv", ["any.py"] + cli_args):
+ cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses)
+
+ assert cli.model.batch_size == 12
+ assert cli.model.num_classes == 5
+
+ class MyLightningCLI(LightningCLI):
+ def add_arguments_to_parser(self, parser):
+ parser.link_arguments("data.batch_size", "model.init_args.batch_size")
+ parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate")
+
+ cli_args[-1] = "--model=tests.core.utilities.test_lightning_cli.BoringModelRequiredClasses"
+
+ with mock.patch("sys.argv", ["any.py"] + cli_args):
+ cli = MyLightningCLI(
+ BoringModelRequiredClasses,
+ BoringDataModuleBatchSizeAndClasses,
+ subclass_mode_model=True,
+ )
+
+ assert cli.model.batch_size == 8
+ assert cli.model.num_classes == 5
+
+
+class EarlyExitTestModel(BoringModel):
+ def on_fit_start(self):
+ raise KeyboardInterrupt()
+
+
+@pytest.mark.parametrize("logger", (False, True))
+@pytest.mark.parametrize(
+ "trainer_kwargs",
+ (
+ dict(accelerator="ddp_cpu"),
+ dict(accelerator="ddp_cpu", plugins="ddp_find_unused_parameters_false"),
+ ),
+)
+def test_cli_ddp_spawn_save_config_callback(tmpdir, logger, trainer_kwargs):
+ with mock.patch("sys.argv", ["any.py"]), pytest.raises(KeyboardInterrupt):
+ LightningCLI(
+ EarlyExitTestModel,
+ trainer_defaults={
+ "default_root_dir": str(tmpdir),
+ "logger": logger,
+ "max_steps": 1,
+ "max_epochs": 1,
+ **trainer_kwargs,
+ },
+ )
+ if logger:
+ config_dir = tmpdir / "lightning_logs"
+ # no more version dirs should get created
+ assert os.listdir(config_dir) == ["version_0"]
+ config_path = config_dir / "version_0" / "config.yaml"
+ else:
+ config_path = tmpdir / "config.yaml"
+ assert os.path.isfile(config_path)
+
+
+def test_cli_config_overwrite(tmpdir):
+ trainer_defaults = {"default_root_dir": str(tmpdir), "logger": False, "max_steps": 1, "max_epochs": 1}
+
+ with mock.patch("sys.argv", ["any.py"]):
+ LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
+ with mock.patch("sys.argv", ["any.py"]), pytest.raises(RuntimeError, match="Aborting to avoid overwriting"):
+ LightningCLI(BoringModel, trainer_defaults=trainer_defaults)
+ with mock.patch("sys.argv", ["any.py"]):
+ LightningCLI(BoringModel, save_config_overwrite=True, trainer_defaults=trainer_defaults)
+
+
+def test_lightning_cli_optimizer(tmpdir):
+ class MyLightningCLI(LightningCLI):
+ def add_arguments_to_parser(self, parser):
+ parser.add_optimizer_args(torch.optim.Adam)
+
+ cli_args = [
+ f"--trainer.default_root_dir={tmpdir}",
+ "--trainer.max_epochs=1",
+ ]
+
+ match = (
+ "BoringModel.configure_optimizers` will be overridden by "
+ "`MyLightningCLI.add_configure_optimizers_method_to_model`"
+ )
+ with mock.patch("sys.argv", ["any.py"] + cli_args), pytest.warns(UserWarning, match=match):
+ cli = MyLightningCLI(BoringModel)
+
+ assert cli.model.configure_optimizers is not BoringModel.configure_optimizers
+ assert len(cli.trainer.optimizers) == 1
+ assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam)
+ assert len(cli.trainer.lr_schedulers) == 0
+
+
+def test_lightning_cli_optimizer_and_lr_scheduler(tmpdir):
+ class MyLightningCLI(LightningCLI):
+ def add_arguments_to_parser(self, parser):
+ parser.add_optimizer_args(torch.optim.Adam)
+ parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR)
+
+ cli_args = [
+ f"--trainer.default_root_dir={tmpdir}",
+ "--trainer.max_epochs=1",
+ "--lr_scheduler.gamma=0.8",
+ ]
+
+ with mock.patch("sys.argv", ["any.py"] + cli_args):
+ cli = MyLightningCLI(BoringModel)
+
+ assert cli.model.configure_optimizers is not BoringModel.configure_optimizers
+ assert len(cli.trainer.optimizers) == 1
+ assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam)
+ assert len(cli.trainer.lr_schedulers) == 1
+ assert isinstance(cli.trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.ExponentialLR)
+ assert cli.trainer.lr_schedulers[0]["scheduler"].gamma == 0.8
+
+
+def test_lightning_cli_optimizer_and_lr_scheduler_subclasses(tmpdir):
+ class MyLightningCLI(LightningCLI):
+ def add_arguments_to_parser(self, parser):
+ parser.add_optimizer_args((torch.optim.SGD, torch.optim.Adam))
+ parser.add_lr_scheduler_args((torch.optim.lr_scheduler.StepLR, torch.optim.lr_scheduler.ExponentialLR))
+
+ optimizer_arg = dict(
+ class_path="torch.optim.Adam",
+ init_args=dict(lr=0.01),
+ )
+ lr_scheduler_arg = dict(
+ class_path="torch.optim.lr_scheduler.StepLR",
+ init_args=dict(step_size=50),
+ )
+ cli_args = [
+ f"--trainer.default_root_dir={tmpdir}",
+ "--trainer.max_epochs=1",
+ f"--optimizer={json.dumps(optimizer_arg)}",
+ f"--lr_scheduler={json.dumps(lr_scheduler_arg)}",
+ ]
+
+ with mock.patch("sys.argv", ["any.py"] + cli_args):
+ cli = MyLightningCLI(BoringModel)
+
+ assert len(cli.trainer.optimizers) == 1
+ assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam)
+ assert len(cli.trainer.lr_schedulers) == 1
+ assert isinstance(cli.trainer.lr_schedulers[0]["scheduler"], torch.optim.lr_scheduler.StepLR)
+ assert cli.trainer.lr_schedulers[0]["scheduler"].step_size == 50
+
+
+def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(tmpdir):
+ class MyLightningCLI(LightningCLI):
+ def add_arguments_to_parser(self, parser):
+ parser.add_optimizer_args(torch.optim.Adam, nested_key="optim1", link_to="model.optim1")
+ parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2")
+ parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR, link_to="model.scheduler")
+
+ class TestModel(BoringModel):
+ def __init__(
+ self,
+ optim1: dict,
+ optim2: dict,
+ scheduler: dict,
+ ):
+ super().__init__()
+ self.optim1 = instantiate_class(self.parameters(), optim1)
+ self.optim2 = instantiate_class(self.parameters(), optim2)
+ self.scheduler = instantiate_class(self.optim1, scheduler)
+
+ cli_args = [
+ f"--trainer.default_root_dir={tmpdir}",
+ "--trainer.max_epochs=1",
+ "--optim2.class_path=torch.optim.SGD",
+ "--optim2.init_args.lr=0.01",
+ "--lr_scheduler.gamma=0.2",
+ ]
+
+ with mock.patch("sys.argv", ["any.py"] + cli_args):
+ cli = MyLightningCLI(TestModel)
+
+ assert isinstance(cli.model.optim1, torch.optim.Adam)
+ assert isinstance(cli.model.optim2, torch.optim.SGD)
+ assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)
diff --git a/tests/examples/test_integrations.py b/tests/examples/test_integrations.py
index 3ba73cc309..5fe061c678 100644
--- a/tests/examples/test_integrations.py
+++ b/tests/examples/test_integrations.py
@@ -17,21 +17,24 @@
import pytest
+from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE
from tests.examples.utils import run_test
-from tests.helpers.utils import _IMAGE_TESTING
root = Path(__file__).parent.parent.parent
@mock.patch.dict(os.environ, {"FLASH_TESTING": "1"})
@pytest.mark.parametrize(
- "folder, file", [
+ "folder, file",
+ [
pytest.param(
"fiftyone",
"image_classification.py",
- marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="fiftyone library isn't installed")
+ marks=pytest.mark.skipif(
+ not (_IMAGE_AVAILABLE and _FIFTYONE_AVAILABLE), reason="fiftyone library isn't installed"
+ ),
),
- ]
+ ],
)
def test_integrations(tmpdir, folder, file):
run_test(str(root / "flash_examples" / "integrations" / folder / file))
diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py
index 1decf2943b..75a5d7cd5f 100644
--- a/tests/examples/test_scripts.py
+++ b/tests/examples/test_scripts.py
@@ -20,7 +20,15 @@
import flash
from flash.core.utilities.imports import _SKLEARN_AVAILABLE
from tests.examples.utils import run_test
-from tests.helpers.utils import _IMAGE_TESTING, _TABULAR_TESTING, _TEXT_TESTING, _VIDEO_TESTING
+from tests.helpers.utils import (
+ _AUDIO_TESTING,
+ _GRAPH_TESTING,
+ _IMAGE_TESTING,
+ _POINTCLOUD_TESTING,
+ _TABULAR_TESTING,
+ _TEXT_TESTING,
+ _VIDEO_TESTING,
+)
@mock.patch.dict(os.environ, {"FLASH_TESTING": "1"})
@@ -30,47 +38,80 @@
pytest.param(
"custom_task.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
),
+ pytest.param(
+ "audio_classification.py",
+ marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed"),
+ ),
+ pytest.param(
+ "speech_recognition.py",
+ marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed"),
+ ),
pytest.param(
"image_classification.py",
- marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed")
+ marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"),
),
pytest.param(
"image_classification_multi_label.py",
- marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed")
+ marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"),
),
# pytest.param("finetuning", "object_detection.py"), # TODO: takes too long.
pytest.param(
"semantic_segmentation.py",
- marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed")
+ marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed"),
),
pytest.param(
- "style_transfer.py",
- marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed")
+ "style_transfer.py", marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed")
),
pytest.param(
"summarization.py", marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed")
),
pytest.param(
"tabular_classification.py",
- marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed")
+ marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed"),
),
pytest.param("template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")),
pytest.param(
"text_classification.py",
- marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed")
- ),
- pytest.param(
- "text_classification_multi_label.py",
- marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed")
+ marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed"),
),
+ # pytest.param(
+ # "text_classification_multi_label.py",
+ # marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed")
+ # ),
pytest.param(
"translation.py", marks=pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed")
),
pytest.param(
"video_classification.py",
- marks=pytest.mark.skipif(not _VIDEO_TESTING, reason="video libraries aren't installed")
+ marks=pytest.mark.skipif(not _VIDEO_TESTING, reason="video libraries aren't installed"),
+ ),
+ pytest.param(
+ "pointcloud_segmentation.py",
+ marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed"),
),
- ]
+ pytest.param(
+ "pointcloud_detection.py",
+ marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed"),
+ ),
+ pytest.param(
+ "graph_classification.py",
+ marks=pytest.mark.skipif(not _GRAPH_TESTING, reason="graph libraries aren't installed"),
+ ),
+ ],
)
def test_example(tmpdir, file):
run_test(str(Path(flash.PROJECT_ROOT) / "flash_examples" / file))
+
+
+@mock.patch.dict(os.environ, {"FLASH_TESTING": "1"})
+@pytest.mark.parametrize(
+ "file",
+ [
+ pytest.param(
+ "pointcloud_detection.py",
+ marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed"),
+ ),
+ ],
+)
+def test_example_2(tmpdir, file):
+ run_test(str(Path(flash.PROJECT_ROOT) / "flash_examples" / file))
diff --git a/tests/examples/utils.py b/tests/examples/utils.py
index aeeacacd0d..cf713fcbd1 100644
--- a/tests/examples/utils.py
+++ b/tests/examples/utils.py
@@ -19,12 +19,12 @@
def call_script(
filepath: str,
args: Optional[List[str]] = None,
- timeout: Optional[int] = 60 * 5,
+ timeout: Optional[int] = 60 * 10,
) -> Tuple[int, str, str]:
- with open(filepath, 'r') as original:
+ with open(filepath) as original:
data = original.read()
- with open(filepath, 'w') as modified:
+ with open(filepath, "w") as modified:
modified.write("import pytorch_lightning as pl\npl.seed_everything(42)\n" + data)
if args is None:
@@ -41,7 +41,7 @@ def call_script(
stdout = stdout.decode("utf-8")
stderr = stderr.decode("utf-8")
- with open(filepath, 'w') as modified:
+ with open(filepath, "w") as modified:
modified.write(data)
return p.returncode, stdout, stderr
diff --git a/tests/graph/__init__.py b/tests/graph/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/graph/classification/__init__.py b/tests/graph/classification/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/graph/classification/test_data.py b/tests/graph/classification/test_data.py
new file mode 100644
index 0000000000..de4d08ff72
--- /dev/null
+++ b/tests/graph/classification/test_data.py
@@ -0,0 +1,132 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+
+from flash.core.data.transforms import merge_transforms
+from flash.core.utilities.imports import _TORCH_GEOMETRIC_AVAILABLE
+from flash.graph.classification.data import GraphClassificationData, GraphClassificationPreprocess
+from tests.helpers.utils import _GRAPH_TESTING
+
+if _TORCH_GEOMETRIC_AVAILABLE:
+ from torch_geometric.datasets import TUDataset
+ from torch_geometric.transforms import OneHotDegree
+
+
+@pytest.mark.skipif(not _GRAPH_TESTING, reason="graph libraries aren't installed.")
+class TestGraphClassificationPreprocess:
+ """Tests ``GraphClassificationPreprocess``."""
+
+ def test_smoke(self):
+ """A simple test that the class can be instantiated."""
+ prep = GraphClassificationPreprocess()
+ assert prep is not None
+
+
+@pytest.mark.skipif(not _GRAPH_TESTING, reason="graph libraries aren't installed.")
+class TestGraphClassificationData:
+ """Tests ``GraphClassificationData``."""
+
+ def test_smoke(self):
+ dm = GraphClassificationData()
+ assert dm is not None
+
+ def test_from_datasets(self, tmpdir):
+ tudataset = TUDataset(root=tmpdir, name="KKI")
+ train_dataset = tudataset
+ val_dataset = tudataset
+ test_dataset = tudataset
+ predict_dataset = tudataset
+
+ # instantiate the data module
+ dm = GraphClassificationData.from_datasets(
+ train_dataset=train_dataset,
+ val_dataset=val_dataset,
+ test_dataset=test_dataset,
+ predict_dataset=predict_dataset,
+ train_transform=None,
+ val_transform=None,
+ test_transform=None,
+ predict_transform=None,
+ batch_size=2,
+ )
+ assert dm is not None
+ assert dm.train_dataloader() is not None
+ assert dm.val_dataloader() is not None
+ assert dm.test_dataloader() is not None
+
+ # check training data
+ data = next(iter(dm.train_dataloader()))
+ assert list(data.x.size())[1] == tudataset.num_features
+ assert list(data.y.size()) == [2]
+
+ # check val data
+ data = next(iter(dm.val_dataloader()))
+ assert list(data.x.size())[1] == tudataset.num_features
+ assert list(data.y.size()) == [2]
+
+ # check test data
+ data = next(iter(dm.test_dataloader()))
+ assert list(data.x.size())[1] == tudataset.num_features
+ assert list(data.y.size()) == [2]
+
+ def test_transforms(self, tmpdir):
+ tudataset = TUDataset(root=tmpdir, name="KKI")
+ train_dataset = tudataset
+ val_dataset = tudataset
+ test_dataset = tudataset
+ predict_dataset = tudataset
+
+ # instantiate the data module
+ dm = GraphClassificationData.from_datasets(
+ train_dataset=train_dataset,
+ val_dataset=val_dataset,
+ test_dataset=test_dataset,
+ predict_dataset=predict_dataset,
+ train_transform=merge_transforms(
+ GraphClassificationPreprocess.default_transforms(),
+ {"pre_tensor_transform": OneHotDegree(tudataset.num_features - 1)},
+ ),
+ val_transform=merge_transforms(
+ GraphClassificationPreprocess.default_transforms(),
+ {"pre_tensor_transform": OneHotDegree(tudataset.num_features - 1)},
+ ),
+ test_transform=merge_transforms(
+ GraphClassificationPreprocess.default_transforms(),
+ {"pre_tensor_transform": OneHotDegree(tudataset.num_features - 1)},
+ ),
+ predict_transform=merge_transforms(
+ GraphClassificationPreprocess.default_transforms(),
+ {"pre_tensor_transform": OneHotDegree(tudataset.num_features - 1)},
+ ),
+ batch_size=2,
+ )
+ assert dm is not None
+ assert dm.train_dataloader() is not None
+ assert dm.val_dataloader() is not None
+ assert dm.test_dataloader() is not None
+
+ # check training data
+ data = next(iter(dm.train_dataloader()))
+ assert list(data.x.size())[1] == tudataset.num_features * 2
+ assert list(data.y.size()) == [2]
+
+ # check val data
+ data = next(iter(dm.val_dataloader()))
+ assert list(data.x.size())[1] == tudataset.num_features * 2
+ assert list(data.y.size()) == [2]
+
+ # check test data
+ data = next(iter(dm.test_dataloader()))
+ assert list(data.x.size())[1] == tudataset.num_features * 2
+ assert list(data.y.size()) == [2]
diff --git a/tests/graph/classification/test_model.py b/tests/graph/classification/test_model.py
new file mode 100644
index 0000000000..0813a6fb3a
--- /dev/null
+++ b/tests/graph/classification/test_model.py
@@ -0,0 +1,88 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest import mock
+
+import pytest
+import torch
+
+from flash import Trainer
+from flash.__main__ import main
+from flash.core.data.data_pipeline import DataPipeline
+from flash.core.utilities.imports import _TORCH_GEOMETRIC_AVAILABLE
+from flash.graph.classification import GraphClassifier
+from flash.graph.classification.data import GraphClassificationPreprocess
+from tests.helpers.utils import _GRAPH_TESTING
+
+if _TORCH_GEOMETRIC_AVAILABLE:
+ from torch_geometric import datasets
+
+
+@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed")
+def test_smoke():
+ """A simple test that the class can be instantiated."""
+ model = GraphClassifier(num_features=1, num_classes=1)
+ assert model is not None
+
+
+@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed")
+def test_train(tmpdir):
+ """Tests that the model can be trained on a pytorch geometric dataset."""
+ tudataset = datasets.TUDataset(root=tmpdir, name="KKI")
+ model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes)
+ model.data_pipeline = DataPipeline(preprocess=GraphClassificationPreprocess())
+ train_dl = torch.utils.data.DataLoader(tudataset, batch_size=4)
+ trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
+ trainer.fit(model, train_dl)
+
+
+@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed")
+def test_val(tmpdir):
+ """Tests that the model can be validated on a pytorch geometric dataset."""
+ tudataset = datasets.TUDataset(root=tmpdir, name="KKI")
+ model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes)
+ model.data_pipeline = DataPipeline(preprocess=GraphClassificationPreprocess())
+ val_dl = torch.utils.data.DataLoader(tudataset, batch_size=4)
+ trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
+ trainer.validate(model, val_dl)
+
+
+@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed")
+def test_test(tmpdir):
+ """Tests that the model can be tested on a pytorch geometric dataset."""
+ tudataset = datasets.TUDataset(root=tmpdir, name="KKI")
+ model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes)
+ model.data_pipeline = DataPipeline(preprocess=GraphClassificationPreprocess())
+ test_dl = torch.utils.data.DataLoader(tudataset, batch_size=4)
+ trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
+ trainer.test(model, test_dl)
+
+
+@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed")
+def test_predict_dataset(tmpdir):
+ """Tests that we can generate predictions from a pytorch geometric dataset."""
+ tudataset = datasets.TUDataset(root=tmpdir, name="KKI")
+ model = GraphClassifier(num_features=tudataset.num_features, num_classes=tudataset.num_classes)
+ data_pipe = DataPipeline(preprocess=GraphClassificationPreprocess())
+ out = model.predict(tudataset, data_source="datasets", data_pipeline=data_pipe)
+ assert isinstance(out[0], int)
+
+
+@pytest.mark.skipif(not _GRAPH_TESTING, reason="pytorch geometric isn't installed")
+def test_cli():
+ cli_args = ["flash", "graph_classification", "--trainer.fast_dev_run", "True"]
+ with mock.patch("sys.argv", cli_args):
+ try:
+ main()
+ except SystemExit:
+ pass
diff --git a/tests/helpers/boring_model.py b/tests/helpers/boring_model.py
new file mode 100644
index 0000000000..e7ece2c0b8
--- /dev/null
+++ b/tests/helpers/boring_model.py
@@ -0,0 +1,135 @@
+# Adapted from:
+# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/boring_model.py
+from typing import Optional
+
+import torch
+from pytorch_lightning import LightningDataModule, LightningModule
+from torch.utils.data import DataLoader, Dataset, Subset
+
+
+class RandomDataset(Dataset):
+ def __init__(self, size, length):
+ self.len = length
+ self.data = torch.randn(length, size)
+
+ def __getitem__(self, index):
+ return self.data[index]
+
+ def __len__(self):
+ return self.len
+
+
+class BoringModel(LightningModule):
+ def __init__(self):
+ """Testing PL Module.
+
+ Use as follows:
+ - subclass
+ - modify the behavior for what you want
+
+ class TestModel(BaseTestModel):
+ def training_step(...):
+ # do your own thing
+
+ or:
+
+ model = BaseTestModel()
+ model.training_epoch_end = None
+ """
+ super().__init__()
+ self.layer = torch.nn.Linear(32, 2)
+
+ def forward(self, x):
+ return self.layer(x)
+
+ def loss(self, batch, prediction):
+ # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
+ return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
+
+ def step(self, x):
+ x = self(x)
+ out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
+ return out
+
+ def training_step(self, batch, batch_idx):
+ output = self(batch)
+ loss = self.loss(batch, output)
+ return {"loss": loss}
+
+ def training_step_end(self, training_step_outputs):
+ return training_step_outputs
+
+ def training_epoch_end(self, outputs) -> None:
+ torch.stack([x["loss"] for x in outputs]).mean()
+
+ def validation_step(self, batch, batch_idx):
+ output = self(batch)
+ loss = self.loss(batch, output)
+ return {"x": loss}
+
+ def validation_epoch_end(self, outputs) -> None:
+ torch.stack([x["x"] for x in outputs]).mean()
+
+ def test_step(self, batch, batch_idx):
+ output = self(batch)
+ loss = self.loss(batch, output)
+ return {"y": loss}
+
+ def test_epoch_end(self, outputs) -> None:
+ torch.stack([x["y"] for x in outputs]).mean()
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
+ return [optimizer], [lr_scheduler]
+
+ def train_dataloader(self):
+ return DataLoader(RandomDataset(32, 64))
+
+ def val_dataloader(self):
+ return DataLoader(RandomDataset(32, 64))
+
+ def test_dataloader(self):
+ return DataLoader(RandomDataset(32, 64))
+
+ def predict_dataloader(self):
+ return DataLoader(RandomDataset(32, 64))
+
+
+class BoringDataModule(LightningDataModule):
+ def __init__(self, data_dir: str = "./"):
+ super().__init__()
+ self.data_dir = data_dir
+ self.non_picklable = None
+ self.checkpoint_state: Optional[str] = None
+
+ def prepare_data(self):
+ self.random_full = RandomDataset(32, 64 * 4)
+
+ def setup(self, stage: Optional[str] = None):
+ if stage == "fit" or stage is None:
+ self.random_train = Subset(self.random_full, indices=range(64))
+ self.dims = self.random_train[0].shape
+
+ if stage in ("fit", "validate") or stage is None:
+ self.random_val = Subset(self.random_full, indices=range(64, 64 * 2))
+
+ if stage == "test" or stage is None:
+ self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3))
+ self.dims = getattr(self, "dims", self.random_test[0].shape)
+
+ if stage == "predict" or stage is None:
+ self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4))
+ self.dims = getattr(self, "dims", self.random_predict[0].shape)
+
+ def train_dataloader(self):
+ return DataLoader(self.random_train)
+
+ def val_dataloader(self):
+ return DataLoader(self.random_val)
+
+ def test_dataloader(self):
+ return DataLoader(self.random_test)
+
+ def predict_dataloader(self):
+ return DataLoader(self.random_predict)
diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py
index 0fa1815db8..bd57cf570d 100644
--- a/tests/helpers/utils.py
+++ b/tests/helpers/utils.py
@@ -14,7 +14,10 @@
import os
from flash.core.utilities.imports import (
+ _AUDIO_AVAILABLE,
+ _GRAPH_AVAILABLE,
_IMAGE_AVAILABLE,
+ _POINTCLOUD_AVAILABLE,
_SERVE_AVAILABLE,
_TABULAR_AVAILABLE,
_TEXT_AVAILABLE,
@@ -26,6 +29,9 @@
_TABULAR_TESTING = _TABULAR_AVAILABLE
_TEXT_TESTING = _TEXT_AVAILABLE
_SERVE_TESTING = _SERVE_AVAILABLE
+_POINTCLOUD_TESTING = _POINTCLOUD_AVAILABLE
+_GRAPH_TESTING = _GRAPH_AVAILABLE
+_AUDIO_TESTING = _AUDIO_AVAILABLE
if "FLASH_TEST_TOPIC" in os.environ:
topic = os.environ["FLASH_TEST_TOPIC"]
@@ -34,3 +40,6 @@
_TABULAR_TESTING = topic == "tabular"
_TEXT_TESTING = topic == "text"
_SERVE_TESTING = topic == "serve"
+ _POINTCLOUD_TESTING = topic == "pointcloud"
+ _GRAPH_TESTING = topic == "graph"
+ _AUDIO_TESTING = topic == "audio"
diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py
index 183f3427a4..99bf240646 100644
--- a/tests/image/classification/test_data.py
+++ b/tests/image/classification/test_data.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import csv
from pathlib import Path
from typing import Any, List, Tuple
@@ -21,7 +22,13 @@
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.transforms import ApplyToKeys
-from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
+from flash.core.utilities.imports import (
+ _FIFTYONE_AVAILABLE,
+ _IMAGE_AVAILABLE,
+ _MATPLOTLIB_AVAILABLE,
+ _PIL_AVAILABLE,
+ _TORCHVISION_AVAILABLE,
+)
from flash.image import ImageClassificationData
from tests.helpers.utils import _IMAGE_TESTING
@@ -72,9 +79,9 @@ def test_from_filepaths_smoke(tmpdir):
assert img_data.test_dataloader() is None
data = next(iter(img_data.train_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, )
+ assert labels.shape == (2,)
assert sorted(list(labels.numpy())) == [1, 2]
@@ -104,28 +111,29 @@ def test_from_filepaths_list_image_paths(tmpdir):
# check training data
data = next(iter(img_data.train_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, )
+ assert labels.shape == (2,)
assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here
assert labels.numpy()[1] in [0, 3, 6] # data comes shuffled here
# check validation data
data = next(iter(img_data.val_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, )
+ assert labels.shape == (2,)
assert list(labels.numpy()) == [1, 4]
# check test data
data = next(iter(img_data.test_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, )
+ assert labels.shape == (2,)
assert list(labels.numpy()) == [2, 5]
-@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
+@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.")
def test_from_filepaths_visualise(tmpdir):
tmpdir = Path(tmpdir)
@@ -160,7 +168,8 @@ def test_from_filepaths_visualise(tmpdir):
dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"])
-@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
+@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.")
def test_from_filepaths_visualise_multilabel(tmpdir):
tmpdir = Path(tmpdir)
@@ -207,7 +216,7 @@ def test_from_filepaths_splits(tmpdir):
_rand_image(img_size).save(tmpdir / "s.png")
num_samples: int = 10
- val_split: float = .3
+ val_split: float = 0.3
train_filepaths: List[str] = [str(tmpdir / "s.png") for _ in range(num_samples)]
@@ -218,7 +227,7 @@ def test_from_filepaths_splits(tmpdir):
_to_tensor = {
"to_tensor_transform": nn.Sequential(
ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
- ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor)
+ ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
),
}
@@ -234,9 +243,9 @@ def run(transform: Any = None):
image_size=img_size,
)
data = next(iter(dm.train_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (B, 3, H, W)
- assert labels.shape == (B, )
+ assert labels.shape == (B,)
run(_to_tensor)
@@ -257,9 +266,9 @@ def test_from_folders_only_train(tmpdir):
img_data = ImageClassificationData.from_folders(train_dir, train_transform=None, batch_size=1)
data = next(iter(img_data.train_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (1, 3, 196, 196)
- assert labels.shape == (1, )
+ assert labels.shape == (1,)
assert img_data.val_dataloader() is None
assert img_data.test_dataloader() is None
@@ -287,20 +296,20 @@ def test_from_folders_train_val(tmpdir):
)
data = next(iter(img_data.train_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, )
+ assert labels.shape == (2,)
data = next(iter(img_data.val_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, )
+ assert labels.shape == (2,)
assert list(labels.numpy()) == [0, 0]
data = next(iter(img_data.test_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, )
+ assert labels.shape == (2,)
assert list(labels.numpy()) == [0, 0]
@@ -329,18 +338,18 @@ def test_from_filepaths_multilabel(tmpdir):
)
data = next(iter(dm.train_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, 4)
data = next(iter(dm.val_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, 4)
torch.testing.assert_allclose(labels, torch.tensor(valid_labels))
data = next(iter(dm.test_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 196, 196)
assert labels.shape == (2, 4)
torch.testing.assert_allclose(labels, torch.tensor(test_labels))
@@ -368,28 +377,28 @@ def test_from_data(data, from_function):
# check training data
data = next(iter(img_data.train_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, )
+ assert labels.shape == (2,)
assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here
assert labels.numpy()[1] in [0, 3, 6] # data comes shuffled here
# check validation data
data = next(iter(img_data.val_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, )
+ assert labels.shape == (2,)
assert list(labels.numpy()) == [1, 4]
# check test data
data = next(iter(img_data.test_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, )
+ assert labels.shape == (2,)
assert list(labels.numpy()) == [2, 5]
-@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone isn't installed.")
def test_from_fiftyone(tmpdir):
tmpdir = Path(tmpdir)
@@ -426,23 +435,23 @@ def test_from_fiftyone(tmpdir):
# check train data
data = next(iter(img_data.train_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, )
+ assert labels.shape == (2,)
assert sorted(list(labels.numpy())) == [0, 1]
# check val data
data = next(iter(img_data.val_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, )
+ assert labels.shape == (2,)
assert sorted(list(labels.numpy())) == [0, 1]
# check test data
data = next(iter(img_data.test_dataloader()))
- imgs, labels = data['input'], data['target']
+ imgs, labels = data["input"], data["target"]
assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, )
+ assert labels.shape == (2,)
assert sorted(list(labels.numpy())) == [0, 1]
@@ -460,16 +469,103 @@ def test_from_datasets():
data = next(iter(img_data.train_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, )
+ assert labels.shape == (2,)
# check validation data
data = next(iter(img_data.val_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, )
+ assert labels.shape == (2,)
# check test data
data = next(iter(img_data.test_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, )
+ assert labels.shape == (2,)
+
+
+@pytest.fixture
+def image_tmpdir(tmpdir):
+ (tmpdir / "train").mkdir()
+ Image.new("RGB", (128, 128)).save(str(tmpdir / "train" / "image_1.png"))
+ Image.new("RGB", (128, 128)).save(str(tmpdir / "train" / "image_2.png"))
+ return tmpdir / "train"
+
+
+@pytest.fixture
+def single_target_csv(image_tmpdir):
+ with open(image_tmpdir / "metadata.csv", "w") as csvfile:
+ fieldnames = ["image", "target"]
+ writer = csv.DictWriter(csvfile, fieldnames)
+ writer.writeheader()
+ writer.writerow({"image": "image_1", "target": "Ants"})
+ writer.writerow({"image": "image_2", "target": "Bees"})
+ return str(image_tmpdir / "metadata.csv")
+
+
+@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+def test_from_csv_single_target(single_target_csv):
+ img_data = ImageClassificationData.from_csv(
+ "image",
+ "target",
+ train_file=single_target_csv,
+ batch_size=2,
+ num_workers=0,
+ )
+
+ # check training data
+ data = next(iter(img_data.train_dataloader()))
+ imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2,)
+
+
+@pytest.fixture
+def multi_target_csv(image_tmpdir):
+ with open(image_tmpdir / "metadata.csv", "w") as csvfile:
+ fieldnames = ["image", "target_1", "target_2"]
+ writer = csv.DictWriter(csvfile, fieldnames)
+ writer.writeheader()
+ writer.writerow({"image": "image_1", "target_1": 1, "target_2": 0})
+ writer.writerow({"image": "image_2", "target_1": 1, "target_2": 1})
+ return str(image_tmpdir / "metadata.csv")
+
+
+@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+def test_from_csv_multi_target(multi_target_csv):
+ img_data = ImageClassificationData.from_csv(
+ "image",
+ ["target_1", "target_2"],
+ train_file=multi_target_csv,
+ batch_size=2,
+ num_workers=0,
+ )
+
+ # check training data
+ data = next(iter(img_data.train_dataloader()))
+ imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, 2)
+
+
+@pytest.fixture
+def bad_csv_no_image(image_tmpdir):
+ with open(image_tmpdir / "metadata.csv", "w") as csvfile:
+ fieldnames = ["image", "target"]
+ writer = csv.DictWriter(csvfile, fieldnames)
+ writer.writeheader()
+ writer.writerow({"image": "image_3", "target": "Ants"})
+ return str(image_tmpdir / "metadata.csv")
+
+
+@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+def test_from_bad_csv_no_image(bad_csv_no_image):
+ with pytest.raises(ValueError, match="Found no matches"):
+ img_data = ImageClassificationData.from_csv(
+ "image",
+ ["target"],
+ train_file=bad_csv_no_image,
+ batch_size=1,
+ num_workers=0,
+ )
+ _ = next(iter(img_data.train_dataloader()))
diff --git a/tests/image/classification/test_data_model_integration.py b/tests/image/classification/test_data_model_integration.py
index c15aca96ea..ba53d68637 100644
--- a/tests/image/classification/test_data_model_integration.py
+++ b/tests/image/classification/test_data_model_integration.py
@@ -18,7 +18,7 @@
import torch
from flash import Trainer
-from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PIL_AVAILABLE
+from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _PIL_AVAILABLE
from flash.image import ImageClassificationData, ImageClassifier
from tests.helpers.utils import _IMAGE_TESTING
@@ -62,7 +62,7 @@ def test_classification(tmpdir):
trainer.finetune(model, datamodule=data, strategy="freeze")
-@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone isn't installed.")
def test_classification_fiftyone(tmpdir):
tmpdir = Path(tmpdir)
diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py
index 1cbaf589e2..7dc49a3abc 100644
--- a/tests/image/classification/test_model.py
+++ b/tests/image/classification/test_model.py
@@ -19,6 +19,7 @@
import torch
from flash import Trainer
+from flash.__main__ import main
from flash.core.classification import Probabilities
from flash.core.data.data_source import DefaultDataKeys
from flash.core.utilities.imports import _IMAGE_AVAILABLE
@@ -30,11 +31,10 @@
class DummyDataset(torch.utils.data.Dataset):
-
def __getitem__(self, index):
return {
DefaultDataKeys.INPUT: torch.rand(3, 224, 224),
- DefaultDataKeys.TARGET: torch.randint(10, size=(1, )).item(),
+ DefaultDataKeys.TARGET: torch.randint(10, size=(1,)).item(),
}
def __len__(self) -> int:
@@ -42,14 +42,13 @@ def __len__(self) -> int:
class DummyMultiLabelDataset(torch.utils.data.Dataset):
-
def __init__(self, num_classes: int):
self.num_classes = num_classes
def __getitem__(self, index):
return {
DefaultDataKeys.INPUT: torch.rand(3, 224, 224),
- DefaultDataKeys.TARGET: torch.randint(0, 2, (self.num_classes, )),
+ DefaultDataKeys.TARGET: torch.randint(0, 2, (self.num_classes,)),
}
def __len__(self) -> int:
@@ -61,17 +60,18 @@ def __len__(self) -> int:
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
@pytest.mark.parametrize(
- "backbone",
+ "backbone,metrics",
[
- "resnet18",
+ ("resnet18", None),
+ ("resnet18", []),
# "resnet34",
# "resnet50",
# "resnet101",
# "resnet152",
],
)
-def test_init_train(tmpdir, backbone):
- model = ImageClassifier(10, backbone=backbone)
+def test_init_train(tmpdir, backbone, metrics):
+ model = ImageClassifier(10, backbone=backbone, metrics=metrics)
train_dl = torch.utils.data.DataLoader(DummyDataset())
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.finetune(model, train_dl, strategy="freeze_unfreeze")
@@ -117,7 +117,7 @@ def test_multilabel(tmpdir):
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
-@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))])
+@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32),))])
def test_jit(tmpdir, jitter, args):
path = os.path.join(tmpdir, "test.pt")
@@ -148,3 +148,13 @@ def test_serve():
def test_load_from_checkpoint_dependency_error():
with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[image]'")):
ImageClassifier.load_from_checkpoint("not_a_real_checkpoint.pt")
+
+
+@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+def test_cli():
+ cli_args = ["flash", "image_classification", "--trainer.fast_dev_run", "True"]
+ with mock.patch("sys.argv", cli_args):
+ try:
+ main()
+ except SystemExit:
+ pass
diff --git a/tests/image/detection/test_data.py b/tests/image/detection/test_data.py
index 18e2efa1da..50ce9fb196 100644
--- a/tests/image/detection/test_data.py
+++ b/tests/image/detection/test_data.py
@@ -1,3 +1,16 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import json
import os
from pathlib import Path
@@ -5,9 +18,8 @@
import pytest
from flash.core.data.data_source import DefaultDataKeys
-from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PIL_AVAILABLE
+from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _PIL_AVAILABLE
from flash.image.detection.data import ObjectDetectionData
-from tests.helpers.utils import _IMAGE_TESTING
if _PIL_AVAILABLE:
from PIL import Image
@@ -19,44 +31,53 @@
def _create_dummy_coco_json(dummy_json_path):
dummy_json = {
- "images": [{
- "id": 0,
- 'width': 1920,
- 'height': 1080,
- 'file_name': 'sample_one.png',
- }, {
- "id": 1,
- "width": 1920,
- "height": 1080,
- "file_name": "sample_two.png",
- }],
- "annotations": [{
- "id": 1,
- "image_id": 0,
- "category_id": 0,
- "area": 150,
- "bbox": [30, 40, 20, 20],
- "iscrowd": 0,
- }, {
- "id": 2,
- "image_id": 1,
- "category_id": 0,
- "area": 240,
- "bbox": [50, 100, 280, 15],
- "iscrowd": 0,
- }, {
- "id": 3,
- "image_id": 1,
- "category_id": 0,
- "area": 170,
- "bbox": [230, 130, 90, 180],
- "iscrowd": 0,
- }],
- "categories": [{
- "id": 0,
- "name": "person",
- "supercategory": "person",
- }]
+ "images": [
+ {
+ "id": 0,
+ "width": 1920,
+ "height": 1080,
+ "file_name": "sample_one.png",
+ },
+ {
+ "id": 1,
+ "width": 1920,
+ "height": 1080,
+ "file_name": "sample_two.png",
+ },
+ ],
+ "annotations": [
+ {
+ "id": 1,
+ "image_id": 0,
+ "category_id": 0,
+ "area": 150,
+ "bbox": [30, 40, 20, 20],
+ "iscrowd": 0,
+ },
+ {
+ "id": 2,
+ "image_id": 1,
+ "category_id": 0,
+ "area": 240,
+ "bbox": [50, 100, 280, 15],
+ "iscrowd": 0,
+ },
+ {
+ "id": 3,
+ "image_id": 1,
+ "category_id": 0,
+ "area": 170,
+ "bbox": [230, 130, 90, 180],
+ "iscrowd": 0,
+ },
+ ],
+ "categories": [
+ {
+ "id": 0,
+ "name": "person",
+ "supercategory": "person",
+ }
+ ],
}
with open(dummy_json_path, "w") as fp:
@@ -68,8 +89,8 @@ def _create_synth_coco_dataset(tmpdir):
train_dir.mkdir()
(train_dir / "images").mkdir()
- Image.new('RGB', (1920, 1080)).save(train_dir / "images" / "sample_one.png")
- Image.new('RGB', (1920, 1080)).save(train_dir / "images" / "sample_two.png")
+ Image.new("RGB", (1920, 1080)).save(train_dir / "images" / "sample_one.png")
+ Image.new("RGB", (1920, 1080)).save(train_dir / "images" / "sample_two.png")
(train_dir / "annotations").mkdir()
dummy_json = train_dir / "annotations" / "sample.json"
@@ -85,8 +106,8 @@ def _create_synth_fiftyone_dataset(tmpdir):
img_dir = Path(tmpdir / "fo_imgs")
img_dir.mkdir()
- Image.new('RGB', (1920, 1080)).save(img_dir / "sample_one.png")
- Image.new('RGB', (1920, 1080)).save(img_dir / "sample_two.png")
+ Image.new("RGB", (1920, 1080)).save(img_dir / "sample_one.png")
+ Image.new("RGB", (1920, 1080)).save(img_dir / "sample_two.png")
dataset = fo.Dataset.from_dir(
img_dir,
@@ -121,20 +142,19 @@ def _create_synth_fiftyone_dataset(tmpdir):
return dataset
-@pytest.mark.skipif(not _IMAGE_TESTING, reason="pycocotools is not installed for testing")
+@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
+@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing")
def test_image_detector_data_from_coco(tmpdir):
train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir)
- datamodule = ObjectDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1)
+ datamodule = ObjectDetectionData.from_coco(
+ train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1, image_size=128
+ )
data = next(iter(datamodule.train_dataloader()))
- imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
-
- assert len(imgs) == 1
- assert imgs[0].shape == (3, 1080, 1920)
- assert len(labels) == 1
- assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd']
+ sample = data[0]
+ assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3)
assert datamodule.val_dataloader() is None
assert datamodule.test_dataloader() is None
@@ -148,40 +168,30 @@ def test_image_detector_data_from_coco(tmpdir):
test_ann_file=coco_ann_path,
batch_size=1,
num_workers=0,
+ image_size=128,
)
data = next(iter(datamodule.val_dataloader()))
- imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
- assert len(imgs) == 1
- assert imgs[0].shape == (3, 1080, 1920)
- assert len(labels) == 1
- assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd']
+ sample = data[0]
+ assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3)
data = next(iter(datamodule.test_dataloader()))
- imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
-
- assert len(imgs) == 1
- assert imgs[0].shape == (3, 1080, 1920)
- assert len(labels) == 1
- assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd']
+ sample = data[0]
+ assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3)
-@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed")
+@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing")
def test_image_detector_data_from_fiftyone(tmpdir):
train_dataset = _create_synth_fiftyone_dataset(tmpdir)
- datamodule = ObjectDetectionData.from_fiftyone(train_dataset=train_dataset, batch_size=1)
+ datamodule = ObjectDetectionData.from_fiftyone(train_dataset=train_dataset, batch_size=1, image_size=128)
data = next(iter(datamodule.train_dataloader()))
- imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
-
- assert len(imgs) == 1
- assert imgs[0].shape == (3, 1080, 1920)
- assert len(labels) == 1
- assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd']
+ sample = data[0]
+ assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3)
assert datamodule.val_dataloader() is None
assert datamodule.test_dataloader() is None
@@ -192,20 +202,13 @@ def test_image_detector_data_from_fiftyone(tmpdir):
test_dataset=train_dataset,
batch_size=1,
num_workers=0,
+ image_size=128,
)
data = next(iter(datamodule.val_dataloader()))
- imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
-
- assert len(imgs) == 1
- assert imgs[0].shape == (3, 1080, 1920)
- assert len(labels) == 1
- assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd']
+ sample = data[0]
+ assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3)
data = next(iter(datamodule.test_dataloader()))
- imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
-
- assert len(imgs) == 1
- assert imgs[0].shape == (3, 1080, 1920)
- assert len(labels) == 1
- assert list(labels[0].keys()) == ['boxes', 'labels', 'image_id', 'area', 'iscrowd']
+ sample = data[0]
+ assert sample[DefaultDataKeys.INPUT].shape == (128, 128, 3)
diff --git a/tests/image/detection/test_data_model_integration.py b/tests/image/detection/test_data_model_integration.py
index 4c9ce93209..1a9d47b9f0 100644
--- a/tests/image/detection/test_data_model_integration.py
+++ b/tests/image/detection/test_data_model_integration.py
@@ -14,9 +14,10 @@
import os
import pytest
+import torch
import flash
-from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _PIL_AVAILABLE
+from flash.core.utilities.imports import _COCO_AVAILABLE, _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _PIL_AVAILABLE
from flash.image import ObjectDetector
from flash.image.detection import ObjectDetectionData
from tests.helpers.utils import _IMAGE_TESTING
@@ -33,49 +34,48 @@
from tests.image.detection.test_data import _create_synth_fiftyone_dataset
-@pytest.mark.skipif(not _IMAGE_TESTING, reason="pycocotools is not installed for testing")
-@pytest.mark.skipif(not _COCO_AVAILABLE, reason="pycocotools is not installed for testing")
-@pytest.mark.parametrize(["model", "backbone"], [("fasterrcnn", "resnet18")])
-def test_detection(tmpdir, model, backbone):
+@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+@pytest.mark.parametrize(["head", "backbone"], [("retinanet", "resnet18_fpn")])
+def test_detection(tmpdir, head, backbone):
train_folder, coco_ann_path = _create_synth_coco_dataset(tmpdir)
data = ObjectDetectionData.from_coco(train_folder=train_folder, train_ann_file=coco_ann_path, batch_size=1)
- model = ObjectDetector(model=model, backbone=backbone, num_classes=data.num_classes)
+ model = ObjectDetector(head=head, backbone=backbone, num_classes=data.num_classes)
- trainer = flash.Trainer(fast_dev_run=True)
+ trainer = flash.Trainer(fast_dev_run=True, gpus=torch.cuda.device_count())
- trainer.finetune(model, data)
+ trainer.finetune(model, data, strategy="freeze")
test_image_one = os.fspath(tmpdir / "test_one.png")
test_image_two = os.fspath(tmpdir / "test_two.png")
- Image.new('RGB', (512, 512)).save(test_image_one)
- Image.new('RGB', (512, 512)).save(test_image_two)
+ Image.new("RGB", (512, 512)).save(test_image_one)
+ Image.new("RGB", (512, 512)).save(test_image_two)
test_images = [str(test_image_one), str(test_image_two)]
model.predict(test_images)
-@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed for testing")
+@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing")
-@pytest.mark.parametrize(["model", "backbone"], [("fasterrcnn", "resnet18")])
-def test_detection_fiftyone(tmpdir, model, backbone):
+@pytest.mark.parametrize(["head", "backbone"], [("retinanet", "resnet18_fpn")])
+def test_detection_fiftyone(tmpdir, head, backbone):
train_dataset = _create_synth_fiftyone_dataset(tmpdir)
data = ObjectDetectionData.from_fiftyone(train_dataset=train_dataset, batch_size=1)
- model = ObjectDetector(model=model, backbone=backbone, num_classes=data.num_classes)
+ model = ObjectDetector(head=head, backbone=backbone, num_classes=data.num_classes)
- trainer = flash.Trainer(fast_dev_run=True)
+ trainer = flash.Trainer(fast_dev_run=True, gpus=torch.cuda.device_count())
- trainer.finetune(model, data)
+ trainer.finetune(model, data, strategy="freeze")
test_image_one = os.fspath(tmpdir / "test_one.png")
test_image_two = os.fspath(tmpdir / "test_two.png")
- Image.new('RGB', (512, 512)).save(test_image_one)
- Image.new('RGB', (512, 512)).save(test_image_two)
+ Image.new("RGB", (512, 512)).save(test_image_one)
+ Image.new("RGB", (512, 512)).save(test_image_two)
test_images = [str(test_image_one), str(test_image_two)]
model.predict(test_images)
diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py
index a610122783..f5fd1fba85 100644
--- a/tests/image/detection/test_model.py
+++ b/tests/image/detection/test_model.py
@@ -11,26 +11,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
+import random
import re
+from unittest import mock
+import numpy as np
import pytest
import torch
from pytorch_lightning import Trainer
-from torch.utils.data import DataLoader, Dataset
+from torch.utils.data import Dataset
+from flash.__main__ import main
from flash.core.data.data_source import DefaultDataKeys
-from flash.core.utilities.imports import _IMAGE_AVAILABLE
+from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE
from flash.image import ObjectDetector
from tests.helpers.utils import _IMAGE_TESTING
+if _ICEVISION_AVAILABLE:
+ from icevision.data import Prediction
+
def collate_fn(samples):
return {key: [sample[key] for sample in samples] for key in samples[0]}
class DummyDetectionDataset(Dataset):
-
def __init__(self, img_shape, num_boxes, num_classes, length):
super().__init__()
self.img_shape = img_shape
@@ -43,15 +48,27 @@ def __len__(self) -> int:
def _random_bbox(self):
c, h, w = self.img_shape
- xs = torch.randint(w - 1, (2, ))
- ys = torch.randint(h - 1, (2, ))
- return [min(xs), min(ys), max(xs) + 1, max(ys) + 1]
+ xs = torch.randint(w - 1, (2,))
+ ys = torch.randint(h - 1, (2,))
+ return {"xmin": min(xs), "ymin": min(ys), "width": max(xs) - min(xs) + 1, "height": max(ys) - min(ys) + 1}
def __getitem__(self, idx):
- img = torch.rand(self.img_shape)
- boxes = torch.tensor([self._random_bbox() for _ in range(self.num_boxes)])
- labels = torch.randint(self.num_classes, (self.num_boxes, ))
- return {DefaultDataKeys.INPUT: img, DefaultDataKeys.TARGET: {"boxes": boxes, "labels": labels}}
+ sample = {}
+
+ img = np.random.rand(*self.img_shape).astype(np.float32)
+
+ sample[DefaultDataKeys.INPUT] = img
+
+ sample[DefaultDataKeys.TARGET] = {
+ "bboxes": [],
+ "labels": [],
+ }
+
+ for i in range(self.num_boxes):
+ sample[DefaultDataKeys.TARGET]["bboxes"].append(self._random_bbox())
+ sample[DefaultDataKeys.TARGET]["labels"].append(random.randint(0, self.num_classes - 1))
+
+ return sample
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
@@ -60,48 +77,58 @@ def test_init():
model.eval()
batch_size = 2
- ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10)
- dl = DataLoader(ds, collate_fn=collate_fn, batch_size=batch_size)
+ ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10)
+ dl = model.process_predict_dataset(ds, batch_size=batch_size)
data = next(iter(dl))
- img = data[DefaultDataKeys.INPUT]
- out = model(img)
+ out = model(data)
assert len(out) == batch_size
- assert {"boxes", "labels", "scores"} <= out[0].keys()
+ assert all(isinstance(res, Prediction) for res in out)
-@pytest.mark.parametrize("model", ["fasterrcnn", "retinanet"])
+@pytest.mark.parametrize("head", ["faster_rcnn", "retinanet"])
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
-def test_training(tmpdir, model):
- model = ObjectDetector(num_classes=2, model=model, pretrained=False, pretrained_backbone=False)
- ds = DummyDetectionDataset((3, 224, 224), 1, 2, 10)
- dl = DataLoader(ds, collate_fn=collate_fn)
+def test_training(tmpdir, head):
+ model = ObjectDetector(num_classes=2, head=head, pretrained=False)
+ ds = DummyDetectionDataset((128, 128, 3), 1, 2, 10)
+ dl = model.process_train_dataset(ds, 2, 0, False, None)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model, dl)
-@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
-def test_jit(tmpdir):
- path = os.path.join(tmpdir, "test.pt")
-
- model = ObjectDetector(2)
- model.eval()
-
- model = torch.jit.script(model) # torch.jit.trace doesn't work with torchvision RCNN
-
- torch.jit.save(model, path)
- model = torch.jit.load(path)
-
- out = model([torch.rand(3, 32, 32)])
-
- # torchvision RCNN always returns a (Losses, Detections) tuple in scripting
- out = out[1]
-
- assert {"boxes", "labels", "scores"} <= out[0].keys()
+# TODO: resolve JIT issues
+# @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+# def test_jit(tmpdir):
+# path = os.path.join(tmpdir, "test.pt")
+#
+# model = ObjectDetector(2)
+# model.eval()
+#
+# model = torch.jit.script(model) # torch.jit.trace doesn't work with torchvision RCNN
+#
+# torch.jit.save(model, path)
+# model = torch.jit.load(path)
+#
+# out = model([torch.rand(3, 32, 32)])
+#
+# # torchvision RCNN always returns a (Losses, Detections) tuple in scripting
+# out = out[1]
+#
+# assert {"boxes", "labels", "scores"} <= out[0].keys()
@pytest.mark.skipif(_IMAGE_AVAILABLE, reason="image libraries are installed.")
def test_load_from_checkpoint_dependency_error():
with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[image]'")):
ObjectDetector.load_from_checkpoint("not_a_real_checkpoint.pt")
+
+
+@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+def test_cli():
+ cli_args = ["flash", "object_detection", "--trainer.fast_dev_run", "True"]
+ with mock.patch("sys.argv", cli_args):
+ try:
+ main()
+ except SystemExit:
+ pass
diff --git a/tests/image/detection/test_serialization.py b/tests/image/detection/test_serialization.py
index 93b6a3756b..8f707a229a 100644
--- a/tests/image/detection/test_serialization.py
+++ b/tests/image/detection/test_serialization.py
@@ -2,13 +2,13 @@
import torch
from flash.core.data.data_source import DefaultDataKeys
-from flash.core.utilities.imports import _FIFTYONE_AVAILABLE
+from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE
from flash.image.detection.serialization import FiftyOneDetectionLabels
+@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing")
class TestFiftyOneDetectionLabels:
-
@staticmethod
def test_smoke():
serial = FiftyOneDetectionLabels()
@@ -16,7 +16,7 @@ def test_smoke():
@staticmethod
def test_serialize_fiftyone():
- labels = ['class_1', 'class_2', 'class_3']
+ labels = ["class_1", "class_2", "class_3"]
serial = FiftyOneDetectionLabels()
filepath_serial = FiftyOneDetectionLabels(return_filepath=True)
threshold_serial = FiftyOneDetectionLabels(threshold=0.9)
@@ -25,8 +25,7 @@ def test_serialize_fiftyone():
sample = {
DefaultDataKeys.PREDS: [
{
- "boxes": [torch.tensor(20), torch.tensor(30),
- torch.tensor(40), torch.tensor(50)],
+ "boxes": [torch.tensor(20), torch.tensor(30), torch.tensor(40), torch.tensor(50)],
"labels": torch.tensor(0),
"scores": torch.tensor(0.5),
},
diff --git a/tests/image/embedding/test_model.py b/tests/image/embedding/test_model.py
index 2700c3a37e..e823212ef7 100644
--- a/tests/image/embedding/test_model.py
+++ b/tests/image/embedding/test_model.py
@@ -23,7 +23,7 @@
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
-@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))])
+@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32),))])
def test_jit(tmpdir, jitter, args):
path = os.path.join(tmpdir, "test.pt")
diff --git a/tests/image/instance_segmentation/__init__.py b/tests/image/instance_segmentation/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/text/seq2seq/translation/test_metric.py b/tests/image/instance_segmentation/test_model.py
similarity index 55%
rename from tests/text/seq2seq/translation/test_metric.py
rename to tests/image/instance_segmentation/test_model.py
index 86b5784745..8f54742d24 100644
--- a/tests/text/seq2seq/translation/test_metric.py
+++ b/tests/image/instance_segmentation/test_model.py
@@ -11,15 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from unittest import mock
+
import pytest
-import torch
-from flash.text.seq2seq.translation.metric import BLEUScore
+from flash.__main__ import main
+from tests.helpers.utils import _IMAGE_TESTING
-@pytest.mark.parametrize("smooth, expected", [(False, 0.7598), (True, 0.8091)])
-def test_bleu_score(smooth, expected):
- translate_corpus = ['the cat is on the mat'.split()]
- reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]]
- metric = BLEUScore(smooth=smooth)
- assert torch.allclose(metric(translate_corpus, reference_corpus), torch.tensor(expected), 1e-4)
+@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+def test_cli():
+ cli_args = ["flash", "instance_segmentation", "--trainer.fast_dev_run", "True"]
+ with mock.patch("sys.argv", cli_args):
+ try:
+ main()
+ except SystemExit:
+ pass
diff --git a/tests/image/keypoint_detection/__init__.py b/tests/image/keypoint_detection/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/image/keypoint_detection/test_model.py b/tests/image/keypoint_detection/test_model.py
new file mode 100644
index 0000000000..215ea9a71f
--- /dev/null
+++ b/tests/image/keypoint_detection/test_model.py
@@ -0,0 +1,29 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest import mock
+
+import pytest
+
+from flash.__main__ import main
+from tests.helpers.utils import _IMAGE_TESTING
+
+
+@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+def test_cli():
+ cli_args = ["flash", "keypoint_detection", "--trainer.fast_dev_run", "True"]
+ with mock.patch("sys.argv", cli_args):
+ try:
+ main()
+ except SystemExit:
+ pass
diff --git a/tests/image/segmentation/test_backbones.py b/tests/image/segmentation/test_backbones.py
index 0b2b452e17..4b8fb7a7a7 100644
--- a/tests/image/segmentation/test_backbones.py
+++ b/tests/image/segmentation/test_backbones.py
@@ -12,19 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
-import torch
-from pytorch_lightning.utilities import _TORCHVISION_AVAILABLE
+from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE
from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES
-@pytest.mark.parametrize(["backbone"], [
- pytest.param("resnet50", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
- pytest.param("mobilenet_v3_large", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
-])
+@pytest.mark.parametrize(
+ ["backbone"],
+ [
+ pytest.param("resnet50", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")),
+ pytest.param("dpn131", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")),
+ ],
+)
def test_semantic_segmentation_backbones_registry(backbone):
- img = torch.rand(1, 3, 32, 32)
- backbone = SEMANTIC_SEGMENTATION_BACKBONES.get(backbone)(pretrained=False)
+ backbone = SEMANTIC_SEGMENTATION_BACKBONES.get(backbone)()
assert backbone
- backbone.eval()
- assert backbone(img) is not None
+ assert isinstance(backbone, str)
diff --git a/tests/image/segmentation/test_data.py b/tests/image/segmentation/test_data.py
index be898bdff3..b44a68da0d 100644
--- a/tests/image/segmentation/test_data.py
+++ b/tests/image/segmentation/test_data.py
@@ -9,7 +9,7 @@
from flash import Trainer
from flash.core.data.data_source import DefaultDataKeys
-from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PIL_AVAILABLE
+from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE, _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE
from flash.image import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess
from tests.helpers.utils import _IMAGE_TESTING
@@ -22,8 +22,8 @@
def build_checkboard(n, m, k=8):
x = np.zeros((n, m))
- x[k::k * 2, ::k] = 1
- x[::k * 2, k::k * 2] = 1
+ x[k :: k * 2, ::k] = 1
+ x[:: k * 2, k :: k * 2] = 1
return x
@@ -48,23 +48,22 @@ def create_random_data(image_files: List[str], label_files: List[str], size: Tup
class TestSemanticSegmentationPreprocess:
-
- @pytest.mark.xfail(reaspn="parameters are marked as optional but it returns Misconficg error.")
@staticmethod
+ @pytest.mark.xfail(reaspn="parameters are marked as optional but it returns Misconficg error.")
def test_smoke():
prep = SemanticSegmentationPreprocess(num_classes=1)
assert prep is not None
-@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
class TestSemanticSegmentationData:
-
@staticmethod
+ @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_smoke():
dm = SemanticSegmentationData()
assert dm is not None
@staticmethod
+ @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_from_folders(tmpdir):
tmp_dir = Path(tmpdir)
@@ -86,7 +85,7 @@ def test_from_folders(tmpdir):
]
num_classes: int = 2
- img_size: Tuple[int, int] = (196, 196)
+ img_size: Tuple[int, int] = (128, 128)
create_random_data(images, targets, img_size, num_classes)
# instantiate the data module
@@ -110,22 +109,23 @@ def test_from_folders(tmpdir):
# check training data
data = next(iter(dm.train_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
- assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, 196, 196)
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2, 128, 128)
# check val data
data = next(iter(dm.val_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
- assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, 196, 196)
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2, 128, 128)
# check test data
data = next(iter(dm.test_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
- assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, 196, 196)
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2, 128, 128)
@staticmethod
+ @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_from_folders_warning(tmpdir):
tmp_dir = Path(tmpdir)
@@ -145,7 +145,7 @@ def test_from_folders_warning(tmpdir):
]
num_classes: int = 2
- img_size: Tuple[int, int] = (196, 196)
+ img_size: Tuple[int, int] = (128, 128)
create_random_data(images, targets, img_size, num_classes)
# instantiate the data module
@@ -164,10 +164,11 @@ def test_from_folders_warning(tmpdir):
# check training data
data = next(iter(dm.train_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
- assert imgs.shape == (1, 3, 196, 196)
- assert labels.shape == (1, 196, 196)
+ assert imgs.shape == (1, 3, 128, 128)
+ assert labels.shape == (1, 128, 128)
@staticmethod
+ @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_from_files(tmpdir):
tmp_dir = Path(tmpdir)
@@ -186,7 +187,7 @@ def test_from_files(tmpdir):
]
num_classes: int = 2
- img_size: Tuple[int, int] = (196, 196)
+ img_size: Tuple[int, int] = (128, 128)
create_random_data(images, targets, img_size, num_classes)
# instantiate the data module
@@ -200,7 +201,7 @@ def test_from_files(tmpdir):
test_targets=targets,
batch_size=2,
num_workers=0,
- num_classes=num_classes
+ num_classes=num_classes,
)
assert dm is not None
assert dm.train_dataloader() is not None
@@ -210,22 +211,23 @@ def test_from_files(tmpdir):
# check training data
data = next(iter(dm.train_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
- assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, 196, 196)
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2, 128, 128)
# check val data
data = next(iter(dm.val_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
- assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, 196, 196)
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2, 128, 128)
# check test data
data = next(iter(dm.test_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
- assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, 196, 196)
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2, 128, 128)
@staticmethod
+ @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_from_files_warning(tmpdir):
tmp_dir = Path(tmpdir)
@@ -244,7 +246,7 @@ def test_from_files_warning(tmpdir):
]
num_classes: int = 2
- img_size: Tuple[int, int] = (196, 196)
+ img_size: Tuple[int, int] = (128, 128)
create_random_data(images, targets, img_size, num_classes)
# instantiate the data module
@@ -255,11 +257,12 @@ def test_from_files_warning(tmpdir):
train_targets=targets + [str(tmp_dir / "labels_img4.png")],
batch_size=2,
num_workers=0,
- num_classes=num_classes
+ num_classes=num_classes,
)
- @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing")
@staticmethod
+ @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
+ @pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing")
def test_from_fiftyone(tmpdir):
tmp_dir = Path(tmpdir)
@@ -272,7 +275,7 @@ def test_from_fiftyone(tmpdir):
]
num_classes: int = 2
- img_size: Tuple[int, int] = (196, 196)
+ img_size: Tuple[int, int] = (128, 128)
for img_file in images:
_rand_image(img_size).save(img_file)
@@ -307,27 +310,29 @@ def test_from_fiftyone(tmpdir):
# check training data
data = next(iter(dm.train_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
- assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, 196, 196)
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2, 128, 128)
# check val data
data = next(iter(dm.val_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
- assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, 196, 196)
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2, 128, 128)
# check test data
data = next(iter(dm.test_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
- assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, 196, 196)
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2, 128, 128)
# check predict data
data = next(iter(dm.predict_dataloader()))
imgs = data[DefaultDataKeys.INPUT]
- assert imgs.shape == (2, 3, 196, 196)
+ assert imgs.shape == (2, 3, 128, 128)
@staticmethod
+ @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
+ @pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.")
def test_map_labels(tmpdir):
tmp_dir = Path(tmpdir)
@@ -351,7 +356,7 @@ def test_map_labels(tmpdir):
}
num_classes: int = len(labels_map.keys())
- img_size: Tuple[int, int] = (196, 196)
+ img_size: Tuple[int, int] = (128, 128)
create_random_data(images, targets, img_size, num_classes)
# instantiate the data module
@@ -363,7 +368,7 @@ def test_map_labels(tmpdir):
val_targets=targets,
batch_size=2,
num_workers=0,
- num_classes=num_classes
+ num_classes=num_classes,
)
assert dm is not None
assert dm.train_dataloader() is not None
@@ -379,13 +384,13 @@ def test_map_labels(tmpdir):
# check training data
data = next(iter(dm.train_dataloader()))
imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
- assert imgs.shape == (2, 3, 196, 196)
- assert labels.shape == (2, 196, 196)
+ assert imgs.shape == (2, 3, 128, 128)
+ assert labels.shape == (2, 128, 128)
assert labels.min().item() == 0
assert labels.max().item() == 1
assert labels.dtype == torch.int64
# now train with `fast_dev_run`
- model = SemanticSegmentation(num_classes=2, backbone="resnet50", head="fcn")
+ model = SemanticSegmentation(num_classes=2, backbone="resnet50", head="fpn")
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.finetune(model, dm, strategy="freeze_unfreeze")
diff --git a/tests/image/segmentation/test_heads.py b/tests/image/segmentation/test_heads.py
index ec90b03670..dbc4b3b38e 100644
--- a/tests/image/segmentation/test_heads.py
+++ b/tests/image/segmentation/test_heads.py
@@ -11,26 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import unittest.mock
+
import pytest
import torch
-from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE
+from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE
+from flash.image.segmentation import SemanticSegmentation
from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES
from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS
+from tests.helpers.utils import _IMAGE_TESTING
@pytest.mark.parametrize(
- "head", [
- pytest.param("fcn", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
- pytest.param("deeplabv3", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
- pytest.param("lraspp", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
- pytest.param("unet", marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")),
- ]
+ "head",
+ [
+ pytest.param("fpn", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")),
+ pytest.param("deeplabv3", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")),
+ pytest.param("unet", marks=pytest.mark.skipif(not _SEGMENTATION_MODELS_AVAILABLE, reason="No SMP")),
+ ],
)
def test_semantic_segmentation_heads_registry(head):
img = torch.rand(1, 3, 32, 32)
backbone = SEMANTIC_SEGMENTATION_BACKBONES.get("resnet50")(pretrained=False)
- head = SEMANTIC_SEGMENTATION_HEADS.get(head)(backbone, 10)
+ head = SEMANTIC_SEGMENTATION_HEADS.get(head)(backbone=backbone, num_classes=10)
assert backbone
assert head
head.eval()
@@ -38,3 +42,26 @@ def test_semantic_segmentation_heads_registry(head):
if isinstance(res, dict):
res = res["out"]
assert res.shape[1] == 10
+
+
+@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+@unittest.mock.patch("flash.image.segmentation.heads.smp")
+def test_pretrained_weights(mock_smp):
+ mock_smp.create_model = unittest.mock.MagicMock()
+ available_weights = SemanticSegmentation.available_pretrained_weights("resnet18")
+ backbone = SEMANTIC_SEGMENTATION_BACKBONES.get("resnet18")()
+ SEMANTIC_SEGMENTATION_HEADS.get("unet")(backbone=backbone, num_classes=10, pretrained=True)
+
+ kwargs = {
+ "arch": "unet",
+ "classes": 10,
+ "encoder_name": "resnet18",
+ "in_channels": 3,
+ "encoder_weights": "imagenet",
+ }
+ mock_smp.create_model.assert_called_with(**kwargs)
+
+ for weight in available_weights:
+ SEMANTIC_SEGMENTATION_HEADS.get("unet")(backbone=backbone, num_classes=10, pretrained=weight)
+ kwargs["encoder_weights"] = weight
+ mock_smp.create_model.assert_called_with(**kwargs)
diff --git a/tests/image/segmentation/test_model.py b/tests/image/segmentation/test_model.py
index c16b54b951..6715ebfc50 100644
--- a/tests/image/segmentation/test_model.py
+++ b/tests/image/segmentation/test_model.py
@@ -21,6 +21,7 @@
import torch
from flash import Trainer
+from flash.__main__ import main
from flash.core.data.data_pipeline import DataPipeline
from flash.core.data.data_source import DefaultDataKeys
from flash.core.utilities.imports import _IMAGE_AVAILABLE
@@ -56,12 +57,12 @@ def test_smoke():
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
@pytest.mark.parametrize("num_classes", [8, 256])
-@pytest.mark.parametrize("img_shape", [(1, 3, 224, 192), (2, 3, 127, 212)])
+@pytest.mark.parametrize("img_shape", [(1, 3, 224, 192), (2, 3, 128, 256)])
def test_forward(num_classes, img_shape):
model = SemanticSegmentation(
num_classes=num_classes,
backbone="resnet50",
- head="fcn",
+ head="fpn",
)
B, C, H, W = img_shape
@@ -103,28 +104,28 @@ def test_unfreeze():
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_predict_tensor():
- img = torch.rand(1, 3, 10, 20)
- model = SemanticSegmentation(2)
+ img = torch.rand(1, 3, 64, 64)
+ model = SemanticSegmentation(2, backbone="mobilenetv3_large_100")
data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1))
out = model.predict(img, data_source="tensors", data_pipeline=data_pipe)
assert isinstance(out[0], list)
- assert len(out[0]) == 10
- assert len(out[0][0]) == 20
+ assert len(out[0]) == 64
+ assert len(out[0][0]) == 64
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_predict_numpy():
- img = np.ones((1, 3, 10, 20))
- model = SemanticSegmentation(2)
+ img = np.ones((1, 3, 64, 64))
+ model = SemanticSegmentation(2, backbone="mobilenetv3_large_100")
data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1))
out = model.predict(img, data_source="numpy", data_pipeline=data_pipe)
assert isinstance(out[0], list)
- assert len(out[0]) == 10
- assert len(out[0][0]) == 20
+ assert len(out[0]) == 64
+ assert len(out[0][0]) == 64
@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
-@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 3, 32, 32), ))])
+@pytest.mark.parametrize("jitter, args", [(torch.jit.trace, (torch.rand(1, 3, 32, 32),))])
def test_jit(tmpdir, jitter, args):
path = os.path.join(tmpdir, "test.pt")
@@ -155,3 +156,18 @@ def test_serve():
def test_load_from_checkpoint_dependency_error():
with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[image]'")):
SemanticSegmentation.load_from_checkpoint("not_a_real_checkpoint.pt")
+
+
+@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+def test_available_pretrained_weights():
+ assert SemanticSegmentation.available_pretrained_weights("resnet18") == ["imagenet", "ssl", "swsl"]
+
+
+@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+def test_cli():
+ cli_args = ["flash", "semantic_segmentation", "--trainer.fast_dev_run", "True"]
+ with mock.patch("sys.argv", cli_args):
+ try:
+ main()
+ except SystemExit:
+ pass
diff --git a/tests/image/segmentation/test_serialization.py b/tests/image/segmentation/test_serialization.py
index 09a03ad75c..0e7477348a 100644
--- a/tests/image/segmentation/test_serialization.py
+++ b/tests/image/segmentation/test_serialization.py
@@ -1,13 +1,27 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import pytest
import torch
from flash.core.data.data_source import DefaultDataKeys
-from flash.core.utilities.imports import _FIFTYONE_AVAILABLE
+from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_AVAILABLE
from flash.image.segmentation.serialization import FiftyOneSegmentationLabels, SegmentationLabels
+from tests.helpers.utils import _IMAGE_TESTING
class TestSemanticSegmentationLabels:
-
+ @pytest.mark.skipif(not _IMAGE_TESTING, "image libraries aren't installed.")
@staticmethod
def test_smoke():
serial = SegmentationLabels()
@@ -15,6 +29,7 @@ def test_smoke():
assert serial.labels_map is None
assert serial.visualize is False
+ @pytest.mark.skipif(not _IMAGE_TESTING, "image libraries aren't installed.")
@staticmethod
def test_exception():
serial = SegmentationLabels()
@@ -27,6 +42,7 @@ def test_exception():
sample = torch.zeros(2, 3)
serial.serialize(sample)
+ @pytest.mark.skipif(not _IMAGE_TESTING, "image libraries aren't installed.")
@staticmethod
def test_serialize():
serial = SegmentationLabels()
@@ -39,6 +55,7 @@ def test_serialize():
assert torch.tensor(classes)[1, 2] == 1
assert torch.tensor(classes)[0, 1] == 3
+ @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone is not installed for testing")
@staticmethod
def test_serialize_fiftyone():
@@ -51,9 +68,7 @@ def test_serialize_fiftyone():
sample = {
DefaultDataKeys.PREDS: preds,
- DefaultDataKeys.METADATA: {
- "filepath": "something"
- },
+ DefaultDataKeys.METADATA: {"filepath": "something"},
}
segmentation = serial.serialize(sample)
diff --git a/tests/image/style_transfer/test_model.py b/tests/image/style_transfer/test_model.py
index d054986978..8573b70784 100644
--- a/tests/image/style_transfer/test_model.py
+++ b/tests/image/style_transfer/test_model.py
@@ -1,9 +1,24 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
import os
import re
+from unittest import mock
import pytest
import torch
+from flash.__main__ import main
from flash.core.utilities.imports import _IMAGE_AVAILABLE
from flash.image.style_transfer import StyleTransfer
from tests.helpers.utils import _IMAGE_TESTING
@@ -34,6 +49,9 @@ def test_jit(tmpdir):
model = StyleTransfer()
model.eval()
+ model.loss_fn = None
+ model.perceptual_loss = None # TODO: Document this
+
model = torch.jit.trace(model, torch.rand(1, 3, 32, 32)) # torch.jit.script doesn't work with pystiche
torch.jit.save(model, path)
@@ -48,3 +66,13 @@ def test_jit(tmpdir):
def test_load_from_checkpoint_dependency_error():
with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[image]'")):
StyleTransfer.load_from_checkpoint("not_a_real_checkpoint.pt")
+
+
+@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+def test_cli():
+ cli_args = ["flash", "style_transfer", "--trainer.fast_dev_run", "True"]
+ with mock.patch("sys.argv", cli_args):
+ try:
+ main()
+ except SystemExit:
+ pass
diff --git a/tests/image/test_backbones.py b/tests/image/test_backbones.py
index 6036927555..c751426c76 100644
--- a/tests/image/test_backbones.py
+++ b/tests/image/test_backbones.py
@@ -14,19 +14,20 @@
import urllib.error
import pytest
-from pytorch_lightning.utilities import _TORCHVISION_AVAILABLE
-from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TIMM_AVAILABLE
-from flash.image.backbones import catch_url_error, IMAGE_CLASSIFIER_BACKBONES
+from flash.core.utilities.url_error import catch_url_error
+from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES
+from tests.helpers.utils import _IMAGE_TESTING
-@pytest.mark.parametrize(["backbone", "expected_num_features"], [
- pytest.param("resnet34", 512, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
- pytest.param("mobilenetv2_100", 1280, marks=pytest.mark.skipif(not _TIMM_AVAILABLE, reason="No timm")),
- pytest.param("simclr-imagenet", 2048, marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")),
- pytest.param("swav-imagenet", 2048, marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")),
- pytest.param("mobilenet_v2", 1280, marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
-])
+@pytest.mark.parametrize(
+ ["backbone", "expected_num_features"],
+ [
+ pytest.param("resnet34", 512, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision")),
+ pytest.param("mobilenetv2_100", 1280, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No timm")),
+ pytest.param("mobilenet_v2", 1280, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision")),
+ ],
+)
def test_image_classifier_backbones_registry(backbone, expected_num_features):
backbone_fn = IMAGE_CLASSIFIER_BACKBONES.get(backbone)
backbone_model, num_features = backbone_fn(pretrained=False)
@@ -34,11 +35,41 @@ def test_image_classifier_backbones_registry(backbone, expected_num_features):
assert num_features == expected_num_features
-def test_pretrained_backbones_catch_url_error():
+@pytest.mark.parametrize(
+ ["backbone", "pretrained", "expected_num_features"],
+ [
+ pytest.param(
+ "resnet50",
+ "supervised",
+ 2048,
+ marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision"),
+ ),
+ pytest.param("resnet50", "simclr", 2048, marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="No torchvision")),
+ ],
+)
+def test_pretrained_weights_registry(backbone, pretrained, expected_num_features):
+ backbone_fn = IMAGE_CLASSIFIER_BACKBONES.get(backbone)
+ backbone_model, num_features = backbone_fn(pretrained=pretrained)
+ assert backbone_model
+ assert num_features == expected_num_features
+
+@pytest.mark.parametrize(
+ ["backbone", "pretrained"],
+ [
+ pytest.param("resnet50w2", True),
+ pytest.param("resnet50w4", "supervised"),
+ ],
+)
+def test_wide_resnets(backbone, pretrained):
+ with pytest.raises(KeyError, match=f"Supervised pretrained weights not available for {backbone}"):
+ IMAGE_CLASSIFIER_BACKBONES.get(backbone)(pretrained=pretrained)
+
+
+def test_pretrained_backbones_catch_url_error():
def raise_error_if_pretrained(pretrained=False):
if pretrained:
- raise urllib.error.URLError('Test error')
+ raise urllib.error.URLError("Test error")
with pytest.warns(UserWarning, match="Failed to download pretrained weights"):
catch_url_error(raise_error_if_pretrained)(pretrained=True)
diff --git a/tests/pointcloud/__init__.py b/tests/pointcloud/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/pointcloud/detection/__init__.py b/tests/pointcloud/detection/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/pointcloud/detection/test_data.py b/tests/pointcloud/detection/test_data.py
new file mode 100644
index 0000000000..b337fa28da
--- /dev/null
+++ b/tests/pointcloud/detection/test_data.py
@@ -0,0 +1,58 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from os.path import join
+
+import pytest
+import torch
+from pytorch_lightning import seed_everything
+
+from flash import Trainer
+from flash.core.data.data_source import DefaultDataKeys
+from flash.core.data.utils import download_data
+from flash.pointcloud.detection import PointCloudObjectDetector, PointCloudObjectDetectorData
+from tests.helpers.utils import _POINTCLOUD_TESTING
+
+if _POINTCLOUD_TESTING:
+ from flash.pointcloud.detection.open3d_ml.backbones import ObjectDetectBatchCollator
+
+
+@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed")
+def test_pointcloud_object_detection_data(tmpdir):
+
+ seed_everything(52)
+
+ download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_micro.zip", tmpdir)
+
+ dm = PointCloudObjectDetectorData.from_folders(train_folder=join(tmpdir, "KITTI_Micro", "Kitti", "train"))
+
+ class MockModel(PointCloudObjectDetector):
+ def training_step(self, batch, batch_idx: int):
+ assert isinstance(batch, ObjectDetectBatchCollator)
+ assert len(batch.point) == 2
+ assert batch.point[0][1].shape == torch.Size([4])
+ assert len(batch.bboxes) > 1
+ assert batch.attr[0]["name"] in ("000000.bin", "000001.bin")
+ assert batch.attr[1]["name"] in ("000000.bin", "000001.bin")
+
+ num_classes = 19
+ model = MockModel(backbone="pointpillars_kitti", num_classes=num_classes)
+ trainer = Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=0)
+ trainer.fit(model, dm)
+
+ predict_path = join(tmpdir, "KITTI_Micro", "Kitti", "predict")
+ model.eval()
+
+ predictions = model.predict([join(predict_path, "scans/000000.bin")])
+ assert torch.stack(predictions[0][DefaultDataKeys.INPUT]).shape[1] == 4
+ assert len(predictions[0][DefaultDataKeys.PREDS]) == 158
diff --git a/tests/pointcloud/detection/test_model.py b/tests/pointcloud/detection/test_model.py
new file mode 100644
index 0000000000..deafc06faf
--- /dev/null
+++ b/tests/pointcloud/detection/test_model.py
@@ -0,0 +1,24 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+
+from flash.pointcloud.detection import PointCloudObjectDetector
+from tests.helpers.utils import _POINTCLOUD_TESTING
+
+
+@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed")
+def test_backbones():
+
+ backbones = PointCloudObjectDetector.available_backbones()
+ assert backbones == ["pointpillars", "pointpillars_kitti"]
diff --git a/tests/pointcloud/segmentation/__init__.py b/tests/pointcloud/segmentation/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/pointcloud/segmentation/test_data.py b/tests/pointcloud/segmentation/test_data.py
new file mode 100644
index 0000000000..a4c808fff2
--- /dev/null
+++ b/tests/pointcloud/segmentation/test_data.py
@@ -0,0 +1,56 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from os.path import join
+
+import pytest
+import torch
+from pytorch_lightning import seed_everything
+
+from flash import Trainer
+from flash.core.data.data_source import DefaultDataKeys
+from flash.core.data.utils import download_data
+from flash.pointcloud.segmentation import PointCloudSegmentation, PointCloudSegmentationData
+from tests.helpers.utils import _POINTCLOUD_TESTING
+
+
+@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed")
+def test_pointcloud_segmentation_data(tmpdir):
+
+ seed_everything(52)
+
+ download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiMicro.zip", tmpdir)
+
+ dm = PointCloudSegmentationData.from_folders(train_folder=join(tmpdir, "SemanticKittiMicro", "train"))
+
+ class MockModel(PointCloudSegmentation):
+ def training_step(self, batch, batch_idx: int):
+ assert batch[DefaultDataKeys.INPUT]["xyz"][0].shape == torch.Size([2, 45056, 3])
+ assert batch[DefaultDataKeys.INPUT]["xyz"][1].shape == torch.Size([2, 11264, 3])
+ assert batch[DefaultDataKeys.INPUT]["xyz"][2].shape == torch.Size([2, 2816, 3])
+ assert batch[DefaultDataKeys.INPUT]["xyz"][3].shape == torch.Size([2, 704, 3])
+ assert batch[DefaultDataKeys.INPUT]["labels"].shape == torch.Size([2, 45056])
+ assert batch[DefaultDataKeys.INPUT]["labels"].max() == 19
+ assert batch[DefaultDataKeys.INPUT]["labels"].min() == 0
+ assert batch[DefaultDataKeys.METADATA][0]["name"] in ("00_000000", "00_000001")
+ assert batch[DefaultDataKeys.METADATA][1]["name"] in ("00_000000", "00_000001")
+
+ num_classes = 19
+ model = MockModel(backbone="randlanet", num_classes=num_classes)
+ trainer = Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=0)
+ trainer.fit(model, dm)
+
+ predictions = model.predict(join(tmpdir, "SemanticKittiMicro", "predict"))
+ assert torch.stack(predictions[0][DefaultDataKeys.INPUT]).shape == torch.Size([45056, 3])
+ assert torch.stack(predictions[0][DefaultDataKeys.PREDS]).shape == torch.Size([45056, 19])
+ assert torch.stack(predictions[0][DefaultDataKeys.TARGET]).shape == torch.Size([45056])
diff --git a/tests/pointcloud/segmentation/test_datasets.py b/tests/pointcloud/segmentation/test_datasets.py
new file mode 100644
index 0000000000..fa36606a26
--- /dev/null
+++ b/tests/pointcloud/segmentation/test_datasets.py
@@ -0,0 +1,37 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import patch
+
+import pytest
+
+from flash.pointcloud.segmentation.datasets import LyftDataset, SemanticKITTIDataset
+from tests.helpers.utils import _POINTCLOUD_TESTING
+
+
+@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed")
+@patch("flash.pointcloud.segmentation.datasets.os.system")
+def test_datasets(mock_system):
+
+ LyftDataset("data")
+ assert mock_system.call_count == 2
+ assert "lyft" in mock_system.call_args_list[0][0][0]
+ assert "data" in mock_system.call_args_list[0][0][0]
+ assert "lyft" in mock_system.call_args_list[1][0][0]
+ assert "data" in mock_system.call_args_list[1][0][0]
+
+ mock_system.reset_mock()
+ SemanticKITTIDataset("data")
+ assert mock_system.call_count == 1
+ assert "semantickitti" in mock_system.call_args_list[0][0][0]
+ assert "data" in mock_system.call_args_list[0][0][0]
diff --git a/tests/pointcloud/segmentation/test_model.py b/tests/pointcloud/segmentation/test_model.py
new file mode 100644
index 0000000000..234f867e64
--- /dev/null
+++ b/tests/pointcloud/segmentation/test_model.py
@@ -0,0 +1,41 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+import torch
+
+from flash.pointcloud.segmentation import PointCloudSegmentation
+from tests.helpers.utils import _POINTCLOUD_TESTING
+
+
+@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed")
+def test_backbones():
+
+ backbones = PointCloudSegmentation.available_backbones()
+ assert backbones == ["randlanet", "randlanet_s3dis", "randlanet_semantic_kitti", "randlanet_toronto3d"]
+
+
+@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed")
+@pytest.mark.parametrize(
+ "backbone",
+ [
+ "randlanet",
+ "randlanet_s3dis",
+ "randlanet_toronto3d",
+ "randlanet_semantic_kitti",
+ ],
+)
+def test_models(backbone):
+ num_classes = 13
+ model = PointCloudSegmentation(backbone=backbone, num_classes=num_classes)
+ assert model.head.weight.shape == torch.Size([13, 32])
diff --git a/tests/tabular/classification/test_data.py b/tests/tabular/classification/test_data.py
index baa87b3451..b1e9ef3f25 100644
--- a/tests/tabular/classification/test_data.py
+++ b/tests/tabular/classification/test_data.py
@@ -23,7 +23,7 @@
if _PANDAS_AVAILABLE:
import pandas as pd
- from flash.tabular import TabularData
+ from flash.tabular import TabularClassificationData
from flash.tabular.classification.utils import _categorize, _normalize
TEST_DF_1 = pd.DataFrame(
@@ -68,24 +68,24 @@ def test_normalize():
@pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required")
-def test_emb_sizes():
+def test_embedding_sizes():
self = Mock()
self.codes = {"category": [None, "a", "b", "c"]}
self.cat_cols = ["category"]
# use __get__ to test property with mocked self
- es = TabularData.emb_sizes.__get__(self) # pylint: disable=E1101
+ es = TabularClassificationData.embedding_sizes.__get__(self) # pylint: disable=E1101
assert es == [(4, 16)]
self.codes = {}
self.cat_cols = []
# use __get__ to test property with mocked self
- es = TabularData.emb_sizes.__get__(self) # pylint: disable=E1101
+ es = TabularClassificationData.embedding_sizes.__get__(self) # pylint: disable=E1101
assert es == []
self.codes = {"large": ["a"] * 100_000, "larger": ["b"] * 1_000_000}
self.cat_cols = ["large", "larger"]
# use __get__ to test property with mocked self
- es = TabularData.emb_sizes.__get__(self) # pylint: disable=E1101
+ es = TabularClassificationData.embedding_sizes.__get__(self) # pylint: disable=E1101
assert es == [(100_000, 17), (1_000_000, 31)]
@@ -94,7 +94,7 @@ def test_tabular_data(tmpdir):
train_data_frame = TEST_DF_1.copy()
val_data_frame = TEST_DF_2.copy()
test_data_frame = TEST_DF_2.copy()
- dm = TabularData.from_data_frame(
+ dm = TabularClassificationData.from_data_frame(
categorical_fields=["category"],
numerical_fields=["scalar_a", "scalar_b"],
target_fields="label",
@@ -110,7 +110,7 @@ def test_tabular_data(tmpdir):
target = data[DefaultDataKeys.TARGET]
assert cat.shape == (1, 1)
assert num.shape == (1, 2)
- assert target.shape == (1, )
+ assert target.shape == (1,)
@pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required")
@@ -122,7 +122,7 @@ def test_categorical_target(tmpdir):
# change int label to string
df["label"] = df["label"].astype(str)
- dm = TabularData.from_data_frame(
+ dm = TabularClassificationData.from_data_frame(
categorical_fields=["category"],
numerical_fields=["scalar_a", "scalar_b"],
target_fields="label",
@@ -138,7 +138,7 @@ def test_categorical_target(tmpdir):
target = data[DefaultDataKeys.TARGET]
assert cat.shape == (1, 1)
assert num.shape == (1, 2)
- assert target.shape == (1, )
+ assert target.shape == (1,)
@pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required")
@@ -146,7 +146,7 @@ def test_from_data_frame(tmpdir):
train_data_frame = TEST_DF_1.copy()
val_data_frame = TEST_DF_2.copy()
test_data_frame = TEST_DF_2.copy()
- dm = TabularData.from_data_frame(
+ dm = TabularClassificationData.from_data_frame(
categorical_fields=["category"],
numerical_fields=["scalar_a", "scalar_b"],
target_fields="label",
@@ -154,7 +154,7 @@ def test_from_data_frame(tmpdir):
val_data_frame=val_data_frame,
test_data_frame=test_data_frame,
num_workers=0,
- batch_size=1
+ batch_size=1,
)
for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]:
data = next(iter(dl))
@@ -162,7 +162,7 @@ def test_from_data_frame(tmpdir):
target = data[DefaultDataKeys.TARGET]
assert cat.shape == (1, 1)
assert num.shape == (1, 2)
- assert target.shape == (1, )
+ assert target.shape == (1,)
@pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required")
@@ -173,7 +173,7 @@ def test_from_csv(tmpdir):
TEST_DF_2.to_csv(val_csv)
TEST_DF_2.to_csv(test_csv)
- dm = TabularData.from_csv(
+ dm = TabularClassificationData.from_csv(
categorical_fields=["category"],
numerical_fields=["scalar_a", "scalar_b"],
target_fields="label",
@@ -181,7 +181,7 @@ def test_from_csv(tmpdir):
val_file=str(val_csv),
test_file=str(test_csv),
num_workers=0,
- batch_size=1
+ batch_size=1,
)
for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]:
data = next(iter(dl))
@@ -189,14 +189,14 @@ def test_from_csv(tmpdir):
target = data[DefaultDataKeys.TARGET]
assert cat.shape == (1, 1)
assert num.shape == (1, 2)
- assert target.shape == (1, )
+ assert target.shape == (1,)
@pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required")
def test_empty_inputs():
train_data_frame = TEST_DF_1.copy()
with pytest.raises(RuntimeError):
- TabularData.from_data_frame(
+ TabularClassificationData.from_data_frame(
numerical_fields=None,
categorical_fields=None,
target_fields="label",
diff --git a/tests/tabular/classification/test_data_model_integration.py b/tests/tabular/classification/test_data_model_integration.py
index 349aeeaaba..3d4875f1dd 100644
--- a/tests/tabular/classification/test_data_model_integration.py
+++ b/tests/tabular/classification/test_data_model_integration.py
@@ -15,7 +15,7 @@
import pytorch_lightning as pl
from flash.core.utilities.imports import _TABULAR_AVAILABLE
-from flash.tabular import TabularClassifier, TabularData
+from flash.tabular import TabularClassificationData, TabularClassifier
from tests.helpers.utils import _TABULAR_TESTING
if _TABULAR_AVAILABLE:
@@ -37,7 +37,7 @@ def test_classification(tmpdir):
train_data_frame = TEST_DF_1.copy()
val_data_frame = TEST_DF_1.copy()
test_data_frame = TEST_DF_1.copy()
- data = TabularData.from_data_frame(
+ data = TabularClassificationData.from_data_frame(
categorical_fields=["category"],
numerical_fields=["scalar_a", "scalar_b"],
target_fields="label",
@@ -47,6 +47,6 @@ def test_classification(tmpdir):
num_workers=0,
batch_size=2,
)
- model = TabularClassifier(num_features=3, num_classes=2, embedding_sizes=data.emb_sizes)
+ model = TabularClassifier(num_features=3, num_classes=2, embedding_sizes=data.embedding_sizes)
trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir)
trainer.fit(model, data)
diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py
index d3cc3db332..e7ee5e9f5d 100644
--- a/tests/tabular/classification/test_model.py
+++ b/tests/tabular/classification/test_model.py
@@ -21,23 +21,21 @@
from flash.core.data.data_source import DefaultDataKeys
from flash.core.utilities.imports import _TABULAR_AVAILABLE
-from flash.tabular import TabularClassifier
-from flash.tabular.classification.data import TabularData
+from flash.tabular import TabularClassificationData, TabularClassifier
from tests.helpers.utils import _SERVE_TESTING, _TABULAR_TESTING
# ======== Mock functions ========
class DummyDataset(torch.utils.data.Dataset):
-
def __init__(self, num_num=16, num_cat=16):
super().__init__()
self.num_num = num_num
self.num_cat = num_cat
def __getitem__(self, index):
- target = torch.randint(0, 10, size=(1, )).item()
- cat_vars = torch.randint(0, 10, size=(self.num_cat, ))
+ target = torch.randint(0, 10, size=(1,)).item()
+ cat_vars = torch.randint(0, 10, size=(self.num_cat,))
num_vars = torch.rand(self.num_num)
return {DefaultDataKeys.INPUT: (cat_vars, num_vars), DefaultDataKeys.TARGET: target}
@@ -84,7 +82,7 @@ def test_jit(tmpdir):
model.eval()
# torch.jit.script doesn't work with tabnet
- model = torch.jit.trace(model, ((torch.randint(0, 10, size=(1, 4)), torch.rand(1, 4)), ))
+ model = torch.jit.trace(model, ((torch.randint(0, 10, size=(1, 4)), torch.rand(1, 4)),))
# TODO: torch.jit.save doesn't work with tabnet
# path = os.path.join(tmpdir, "test.pt")
@@ -100,7 +98,7 @@ def test_jit(tmpdir):
@mock.patch("flash._IS_TESTING", True)
def test_serve():
train_data = {"num_col": [1.4, 2.5], "cat_col": ["positive", "negative"], "target": [1, 2]}
- datamodule = TabularData.from_data_frame(
+ datamodule = TabularClassificationData.from_data_frame(
"cat_col",
"num_col",
"target",
diff --git a/tests/template/__init__.py b/tests/template/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/template/classification/test_data.py b/tests/template/classification/test_data.py
index 6bdec2f2ef..b793849e08 100644
--- a/tests/template/classification/test_data.py
+++ b/tests/template/classification/test_data.py
@@ -49,7 +49,7 @@ def test_smoke():
def test_from_numpy(self):
"""Tests that ``TemplateData`` is properly created when using the ``from_numpy`` method."""
data = np.random.rand(10, self.num_features)
- targets = np.random.randint(0, self.num_classes, (10, ))
+ targets = np.random.randint(0, self.num_classes, (10,))
# instantiate the data module
dm = TemplateData.from_numpy(
@@ -71,19 +71,19 @@ def test_from_numpy(self):
data = next(iter(dm.train_dataloader()))
rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert rows.shape == (2, self.num_features)
- assert targets.shape == (2, )
+ assert targets.shape == (2,)
# check val data
data = next(iter(dm.val_dataloader()))
rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert rows.shape == (2, self.num_features)
- assert targets.shape == (2, )
+ assert targets.shape == (2,)
# check test data
data = next(iter(dm.test_dataloader()))
rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert rows.shape == (2, self.num_features)
- assert targets.shape == (2, )
+ assert targets.shape == (2,)
@staticmethod
def test_from_sklearn():
@@ -107,16 +107,16 @@ def test_from_sklearn():
data = next(iter(dm.train_dataloader()))
rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert rows.shape == (2, dm.num_features)
- assert targets.shape == (2, )
+ assert targets.shape == (2,)
# check val data
data = next(iter(dm.val_dataloader()))
rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert rows.shape == (2, dm.num_features)
- assert targets.shape == (2, )
+ assert targets.shape == (2,)
# check test data
data = next(iter(dm.test_dataloader()))
rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
assert rows.shape == (2, dm.num_features)
- assert targets.shape == (2, )
+ assert targets.shape == (2,)
diff --git a/tests/template/classification/test_model.py b/tests/template/classification/test_model.py
index 9fa57b80b9..cfd0f77f39 100644
--- a/tests/template/classification/test_model.py
+++ b/tests/template/classification/test_model.py
@@ -39,7 +39,7 @@ class DummyDataset(torch.utils.data.Dataset):
def __getitem__(self, index):
return {
DefaultDataKeys.INPUT: torch.randn(self.num_features),
- DefaultDataKeys.TARGET: torch.randint(self.num_classes - 1, (1, ))[0],
+ DefaultDataKeys.TARGET: torch.randint(self.num_classes - 1, (1,))[0],
}
def __len__(self) -> int:
@@ -121,7 +121,7 @@ def test_predict_sklearn():
@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
-@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 16), ))])
+@pytest.mark.parametrize("jitter, args", [(torch.jit.script, ()), (torch.jit.trace, (torch.rand(1, 16),))])
def test_jit(tmpdir, jitter, args):
path = os.path.join(tmpdir, "test.pt")
diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py
index d5a3b680f9..4c42909b35 100644
--- a/tests/text/classification/test_data.py
+++ b/tests/text/classification/test_data.py
@@ -44,6 +44,12 @@
{"sentence": "this is a sentence three","lab":0}
"""
+TEST_JSON_DATA_FIELD = """{"data": [
+{"sentence": "this is a sentence one","lab":0},
+{"sentence": "this is a sentence two","lab":1},
+{"sentence": "this is a sentence three","lab":0}]}
+"""
+
def csv_data(tmpdir):
path = Path(tmpdir) / "data.csv"
@@ -57,6 +63,12 @@ def json_data(tmpdir):
return path
+def json_data_with_field(tmpdir):
+ path = Path(tmpdir) / "data.json"
+ path.write_text(TEST_JSON_DATA_FIELD)
+ return path
+
+
@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_csv(tmpdir):
@@ -78,7 +90,7 @@ def test_test_valid(tmpdir):
train_file=csv_path,
val_file=csv_path,
test_file=csv_path,
- batch_size=1
+ batch_size=1,
)
batch = next(iter(dm.val_dataloader()))
assert batch["labels"].item() in [0, 1]
@@ -99,6 +111,18 @@ def test_from_json(tmpdir):
assert "input_ids" in batch
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
+def test_from_json_with_field(tmpdir):
+ json_path = json_data_with_field(tmpdir)
+ dm = TextClassificationData.from_json(
+ "sentence", "lab", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data"
+ )
+ batch = next(iter(dm.train_dataloader()))
+ assert batch["labels"].item() in [0, 1]
+ assert "input_ids" in batch
+
+
@pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.")
def test_text_module_not_found_error():
with pytest.raises(ModuleNotFoundError, match="[text]"):
@@ -111,9 +135,7 @@ def test_text_module_not_found_error():
"cls, kwargs",
[
(TextDataSource, {}),
- (TextFileDataSource, {
- "filetype": "csv"
- }),
+ (TextFileDataSource, {"filetype": "csv"}),
(TextCSVDataSource, {}),
(TextJSONDataSource, {}),
(TextSentencesDataSource, {}),
diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py
index 431b8f4cb8..7ca20d92c7 100644
--- a/tests/text/classification/test_model.py
+++ b/tests/text/classification/test_model.py
@@ -19,6 +19,7 @@
import torch
from flash import Trainer
+from flash.__main__ import main
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.text import TextClassifier
from flash.text.classification.data import TextClassificationPostprocess, TextClassificationPreprocess
@@ -28,11 +29,10 @@
class DummyDataset(torch.utils.data.Dataset):
-
def __getitem__(self, index):
return {
- "input_ids": torch.randint(1000, size=(100, )),
- "labels": torch.randint(2, size=(1, )).item(),
+ "input_ids": torch.randint(1000, size=(100,)),
+ "labels": torch.randint(2, size=(1,)).item(),
}
def __len__(self) -> int:
@@ -87,3 +87,19 @@ def test_serve():
def test_load_from_checkpoint_dependency_error():
with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[text]'")):
TextClassifier.load_from_checkpoint("not_a_real_checkpoint.pt")
+
+
+@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
+@pytest.mark.parametrize(
+ "cli_args",
+ (
+ ["flash", "text_classification", "--trainer.fast_dev_run", "True"],
+ ["flash", "text_classification", "--trainer.fast_dev_run", "True", "from_toxic"],
+ ),
+)
+def test_cli(cli_args):
+ with mock.patch("sys.argv", cli_args):
+ try:
+ main()
+ except SystemExit:
+ pass
diff --git a/tests/text/classification/test_ort.py b/tests/text/classification/test_ort.py
new file mode 100644
index 0000000000..01d987e092
--- /dev/null
+++ b/tests/text/classification/test_ort.py
@@ -0,0 +1,62 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+import pytest
+import torch
+from pytorch_lightning import Callback
+from pytorch_lightning.core.lightning import LightningModule
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+
+from flash import Trainer
+from flash.core.utilities.imports import _TORCH_ORT_AVAILABLE
+from flash.text import TextClassifier
+from flash.text.ort_callback import ORTCallback
+from tests.helpers.boring_model import BoringModel
+from tests.helpers.utils import _TEXT_TESTING
+from tests.text.classification.test_model import DummyDataset, TEST_BACKBONE
+
+if _TORCH_ORT_AVAILABLE:
+ from torch_ort import ORTModule
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
+@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.")
+def test_init_train_enable_ort(tmpdir):
+ class TestCallback(Callback):
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
+ assert isinstance(pl_module.model, ORTModule)
+
+ model = TextClassifier(2, TEST_BACKBONE, enable_ort=True)
+ trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=TestCallback())
+ trainer.fit(
+ model,
+ train_dataloader=torch.utils.data.DataLoader(DummyDataset()),
+ val_dataloaders=torch.utils.data.DataLoader(DummyDataset()),
+ )
+ trainer.test(model, test_dataloaders=torch.utils.data.DataLoader(DummyDataset()))
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _TORCH_ORT_AVAILABLE, reason="ORT Module aren't installed.")
+def test_ort_callback_fails_no_model(tmpdir):
+ model = BoringModel()
+ trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, callbacks=ORTCallback())
+ with pytest.raises(MisconfigurationException, match="Torch ORT requires to wrap a single model"):
+ trainer.fit(
+ model,
+ train_dataloader=torch.utils.data.DataLoader(DummyDataset()),
+ val_dataloaders=torch.utils.data.DataLoader(DummyDataset()),
+ )
diff --git a/tests/text/seq2seq/__init__.py b/tests/text/seq2seq/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/text/seq2seq/core/test_data.py b/tests/text/seq2seq/core/test_data.py
index 4f2144aa90..d52bd9132a 100644
--- a/tests/text/seq2seq/core/test_data.py
+++ b/tests/text/seq2seq/core/test_data.py
@@ -36,22 +36,11 @@
@pytest.mark.parametrize(
"cls, kwargs",
[
- (Seq2SeqDataSource, {
- "backbone": "sshleifer/tiny-mbart"
- }),
- (Seq2SeqFileDataSource, {
- "backbone": "sshleifer/tiny-mbart",
- "filetype": "csv"
- }),
- (Seq2SeqCSVDataSource, {
- "backbone": "sshleifer/tiny-mbart"
- }),
- (Seq2SeqJSONDataSource, {
- "backbone": "sshleifer/tiny-mbart"
- }),
- (Seq2SeqSentencesDataSource, {
- "backbone": "sshleifer/tiny-mbart"
- }),
+ (Seq2SeqDataSource, {"backbone": "sshleifer/tiny-mbart"}),
+ (Seq2SeqFileDataSource, {"backbone": "sshleifer/tiny-mbart", "filetype": "csv"}),
+ (Seq2SeqCSVDataSource, {"backbone": "sshleifer/tiny-mbart"}),
+ (Seq2SeqJSONDataSource, {"backbone": "sshleifer/tiny-mbart"}),
+ (Seq2SeqSentencesDataSource, {"backbone": "sshleifer/tiny-mbart"}),
(Seq2SeqPostprocess, {}),
],
)
diff --git a/tests/text/seq2seq/summarization/test_metric.py b/tests/text/seq2seq/core/test_metrics.py
similarity index 67%
rename from tests/text/seq2seq/summarization/test_metric.py
rename to tests/text/seq2seq/core/test_metrics.py
index 9f17397b02..c16f828c37 100644
--- a/tests/text/seq2seq/summarization/test_metric.py
+++ b/tests/text/seq2seq/core/test_metrics.py
@@ -14,7 +14,7 @@
import pytest
import torch
-from flash.text.seq2seq.summarization.metric import RougeMetric
+from flash.text.seq2seq.core.metrics import BLEUScore, RougeMetric
from tests.helpers.utils import _TEXT_TESTING
@@ -24,3 +24,11 @@ def test_rouge():
target = "Is your name John".split()
metric = RougeMetric()
assert torch.allclose(torch.tensor(metric(preds, target)["rouge1_recall"]).float(), torch.tensor(0.25), 1e-4)
+
+
+@pytest.mark.parametrize("smooth, expected", [(False, 0.7598), (True, 0.8091)])
+def test_bleu_score(smooth, expected):
+ translate_corpus = ["the cat is on the mat".split()]
+ reference_corpus = [["there is a cat on the mat".split(), "a cat is on the mat".split()]]
+ metric = BLEUScore(smooth=smooth)
+ assert torch.allclose(metric(translate_corpus, reference_corpus), torch.tensor(expected), 1e-4)
diff --git a/tests/text/seq2seq/question_answering/__init__.py b/tests/text/seq2seq/question_answering/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/text/seq2seq/question_answering/test_data.py b/tests/text/seq2seq/question_answering/test_data.py
new file mode 100644
index 0000000000..8879282bba
--- /dev/null
+++ b/tests/text/seq2seq/question_answering/test_data.py
@@ -0,0 +1,131 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from pathlib import Path
+
+import pytest
+
+from flash.text import QuestionAnsweringData
+from tests.helpers.utils import _TEXT_TESTING
+
+TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing
+
+TEST_CSV_DATA = """input,target
+this is a question one,this is an answer one
+this is a question two,this is an answer two
+this is a question three,this is an answer three
+"""
+
+TEST_JSON_DATA = """
+{"input": "this is a question one","target":"this is an answer one"}
+{"input": "this is a question two","target":"this is an answer two"}
+{"input": "this is a question three","target":"this is an answer three"}
+"""
+
+TEST_JSON_DATA_FIELD = """{"data": [
+{"input": "this is a question one","target":"this is an answer one"},
+{"input": "this is a question two","target":"this is an answer two"},
+{"input": "this is a question three","target":"this is an answer three"}]}
+"""
+
+
+def csv_data(tmpdir):
+ path = Path(tmpdir) / "data.csv"
+ path.write_text(TEST_CSV_DATA)
+ return path
+
+
+def json_data(tmpdir):
+ path = Path(tmpdir) / "data.json"
+ path.write_text(TEST_JSON_DATA)
+ return path
+
+
+def json_data_with_field(tmpdir):
+ path = Path(tmpdir) / "data.json"
+ path.write_text(TEST_JSON_DATA_FIELD)
+ return path
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
+def test_from_csv(tmpdir):
+ csv_path = csv_data(tmpdir)
+ dm = QuestionAnsweringData.from_csv("input", "target", backbone=TEST_BACKBONE, train_file=csv_path, batch_size=1)
+ batch = next(iter(dm.train_dataloader()))
+ assert "labels" in batch
+ assert "input_ids" in batch
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
+def test_from_files(tmpdir):
+ csv_path = csv_data(tmpdir)
+ dm = QuestionAnsweringData.from_csv(
+ "input",
+ "target",
+ backbone=TEST_BACKBONE,
+ train_file=csv_path,
+ val_file=csv_path,
+ test_file=csv_path,
+ batch_size=1,
+ )
+ batch = next(iter(dm.val_dataloader()))
+ assert "labels" in batch
+ assert "input_ids" in batch
+
+ batch = next(iter(dm.test_dataloader()))
+ assert "labels" in batch
+ assert "input_ids" in batch
+
+
+@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
+def test_postprocess_tokenizer(tmpdir):
+ """Tests that the tokenizer property in ``SummarizationPostprocess`` resolves correctly when a different
+ backbone is used."""
+ backbone = "sshleifer/bart-tiny-random"
+ csv_path = csv_data(tmpdir)
+ dm = QuestionAnsweringData.from_csv(
+ "input",
+ "target",
+ backbone=backbone,
+ train_file=csv_path,
+ batch_size=1,
+ )
+ pipeline = dm.data_pipeline
+ pipeline.initialize()
+ assert pipeline._postprocess_pipeline.backbone == backbone
+ assert pipeline._postprocess_pipeline.tokenizer is not None
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
+def test_from_json(tmpdir):
+ json_path = json_data(tmpdir)
+ dm = QuestionAnsweringData.from_json("input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1)
+ batch = next(iter(dm.train_dataloader()))
+ assert "labels" in batch
+ assert "input_ids" in batch
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
+def test_from_json_with_field(tmpdir):
+ json_path = json_data_with_field(tmpdir)
+ dm = QuestionAnsweringData.from_json(
+ "input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data"
+ )
+ batch = next(iter(dm.train_dataloader()))
+ assert "labels" in batch
+ assert "input_ids" in batch
diff --git a/tests/text/seq2seq/question_answering/test_model.py b/tests/text/seq2seq/question_answering/test_model.py
new file mode 100644
index 0000000000..ad4389b768
--- /dev/null
+++ b/tests/text/seq2seq/question_answering/test_model.py
@@ -0,0 +1,91 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import re
+from unittest import mock
+
+import pytest
+import torch
+
+from flash import Trainer
+from flash.core.utilities.imports import _TEXT_AVAILABLE
+from flash.text import QuestionAnsweringTask
+from flash.text.seq2seq.core.data import Seq2SeqPostprocess
+from flash.text.seq2seq.question_answering.data import QuestionAnsweringPreprocess
+from tests.helpers.utils import _SERVE_TESTING, _TEXT_TESTING
+
+# ======== Mock functions ========
+
+
+class DummyDataset(torch.utils.data.Dataset):
+ def __getitem__(self, index):
+ return {
+ "input_ids": torch.randint(1000, size=(128,)),
+ "labels": torch.randint(1000, size=(128,)),
+ }
+
+ def __len__(self) -> int:
+ return 100
+
+
+# ==============================
+
+TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
+def test_init_train(tmpdir):
+ model = QuestionAnsweringTask(TEST_BACKBONE)
+ train_dl = torch.utils.data.DataLoader(DummyDataset())
+ trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
+ trainer.fit(model, train_dl)
+
+
+@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
+def test_jit(tmpdir):
+ sample_input = {
+ "input_ids": torch.randint(1000, size=(1, 32)),
+ "attention_mask": torch.randint(1, size=(1, 32)),
+ }
+ path = os.path.join(tmpdir, "test.pt")
+
+ model = QuestionAnsweringTask(TEST_BACKBONE)
+ model.eval()
+
+ # Huggingface only supports `torch.jit.trace`
+ model = torch.jit.trace(model, [sample_input])
+
+ torch.jit.save(model, path)
+ model = torch.jit.load(path)
+
+ out = model(sample_input)
+ assert isinstance(out, torch.Tensor)
+
+
+@pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.")
+@mock.patch("flash._IS_TESTING", True)
+def test_serve():
+ model = QuestionAnsweringTask(TEST_BACKBONE)
+ # TODO: Currently only servable once a preprocess and postprocess have been attached
+ model._preprocess = QuestionAnsweringPreprocess(backbone=TEST_BACKBONE)
+ model._postprocess = Seq2SeqPostprocess()
+ model.eval()
+ model.serve()
+
+
+@pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.")
+def test_load_from_checkpoint_dependency_error():
+ with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[text]'")):
+ QuestionAnsweringTask.load_from_checkpoint("not_a_real_checkpoint.pt")
diff --git a/tests/text/seq2seq/summarization/test_data.py b/tests/text/seq2seq/summarization/test_data.py
index 2ab09f3636..ff359dcdf0 100644
--- a/tests/text/seq2seq/summarization/test_data.py
+++ b/tests/text/seq2seq/summarization/test_data.py
@@ -22,15 +22,21 @@
TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing
TEST_CSV_DATA = """input,target
-this is a sentence one,this is a translated sentence one
-this is a sentence two,this is a translated sentence two
-this is a sentence three,this is a translated sentence three
+this is a sentence one,this is a summarized sentence one
+this is a sentence two,this is a summarized sentence two
+this is a sentence three,this is a summarized sentence three
"""
TEST_JSON_DATA = """
-{"input": "this is a sentence one","target":"this is a translated sentence one"}
-{"input": "this is a sentence two","target":"this is a translated sentence two"}
-{"input": "this is a sentence three","target":"this is a translated sentence three"}
+{"input": "this is a sentence one","target":"this is a summarized sentence one"}
+{"input": "this is a sentence two","target":"this is a summarized sentence two"}
+{"input": "this is a sentence three","target":"this is a summarized sentence three"}
+"""
+
+TEST_JSON_DATA_FIELD = """{"data": [
+{"input": "this is a sentence one","target":"this is a summarized sentence one"},
+{"input": "this is a sentence two","target":"this is a summarized sentence two"},
+{"input": "this is a sentence three","target":"this is a summarized sentence three"}]}
"""
@@ -46,6 +52,12 @@ def json_data(tmpdir):
return path
+def json_data_with_field(tmpdir):
+ path = Path(tmpdir) / "data.json"
+ path.write_text(TEST_JSON_DATA_FIELD)
+ return path
+
+
@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_csv(tmpdir):
@@ -80,9 +92,8 @@ def test_from_files(tmpdir):
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_postprocess_tokenizer(tmpdir):
- """Tests that the tokenizer property in ``SummarizationPostprocess`` resolves correctly when a different backbone is
- used.
- """
+ """Tests that the tokenizer property in ``SummarizationPostprocess`` resolves correctly when a different
+ backbone is used."""
backbone = "sshleifer/bart-tiny-random"
csv_path = csv_data(tmpdir)
dm = SummarizationData.from_csv(
@@ -106,3 +117,15 @@ def test_from_json(tmpdir):
batch = next(iter(dm.train_dataloader()))
assert "labels" in batch
assert "input_ids" in batch
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
+def test_from_json_with_field(tmpdir):
+ json_path = json_data_with_field(tmpdir)
+ dm = SummarizationData.from_json(
+ "input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data"
+ )
+ batch = next(iter(dm.train_dataloader()))
+ assert "labels" in batch
+ assert "input_ids" in batch
diff --git a/tests/text/seq2seq/summarization/test_model.py b/tests/text/seq2seq/summarization/test_model.py
index ccff5e6d85..c6adf69fdc 100644
--- a/tests/text/seq2seq/summarization/test_model.py
+++ b/tests/text/seq2seq/summarization/test_model.py
@@ -29,11 +29,10 @@
class DummyDataset(torch.utils.data.Dataset):
-
def __getitem__(self, index):
return {
- "input_ids": torch.randint(1000, size=(128, )),
- "labels": torch.randint(1000, size=(128, )),
+ "input_ids": torch.randint(1000, size=(128,)),
+ "labels": torch.randint(1000, size=(128,)),
}
def __len__(self) -> int:
diff --git a/tests/text/seq2seq/translation/test_data.py b/tests/text/seq2seq/translation/test_data.py
index 244cb27d4a..f87a51fdcd 100644
--- a/tests/text/seq2seq/translation/test_data.py
+++ b/tests/text/seq2seq/translation/test_data.py
@@ -33,6 +33,12 @@
{"input": "this is a sentence three","target":"this is a translated sentence three"}
"""
+TEST_JSON_DATA_FIELD = """{"data": [
+{"input": "this is a sentence one","target":"this is a translated sentence one"},
+{"input": "this is a sentence two","target":"this is a translated sentence two"},
+{"input": "this is a sentence three","target":"this is a translated sentence three"}]}
+"""
+
def csv_data(tmpdir):
path = Path(tmpdir) / "data.csv"
@@ -46,6 +52,12 @@ def json_data(tmpdir):
return path
+def json_data_with_field(tmpdir):
+ path = Path(tmpdir) / "data.json"
+ path.write_text(TEST_JSON_DATA_FIELD)
+ return path
+
+
@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_csv(tmpdir):
@@ -67,7 +79,7 @@ def test_from_files(tmpdir):
train_file=csv_path,
val_file=csv_path,
test_file=csv_path,
- batch_size=1
+ batch_size=1,
)
batch = next(iter(dm.val_dataloader()))
assert "labels" in batch
@@ -86,3 +98,15 @@ def test_from_json(tmpdir):
batch = next(iter(dm.train_dataloader()))
assert "labels" in batch
assert "input_ids" in batch
+
+
+@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
+@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
+def test_from_json_with_field(tmpdir):
+ json_path = json_data_with_field(tmpdir)
+ dm = TranslationData.from_json(
+ "input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data"
+ )
+ batch = next(iter(dm.train_dataloader()))
+ assert "labels" in batch
+ assert "input_ids" in batch
diff --git a/tests/text/seq2seq/translation/test_model.py b/tests/text/seq2seq/translation/test_model.py
index c49ccd4c24..237fa3bb5a 100644
--- a/tests/text/seq2seq/translation/test_model.py
+++ b/tests/text/seq2seq/translation/test_model.py
@@ -29,11 +29,10 @@
class DummyDataset(torch.utils.data.Dataset):
-
def __getitem__(self, index):
return {
- "input_ids": torch.randint(1000, size=(128, )),
- "labels": torch.randint(1000, size=(128, )),
+ "input_ids": torch.randint(1000, size=(128,)),
+ "labels": torch.randint(1000, size=(128,)),
}
def __len__(self) -> int:
diff --git a/tests/video/classification/test_model.py b/tests/video/classification/test_model.py
index 27ad049411..d7d45aa69f 100644
--- a/tests/video/classification/test_model.py
+++ b/tests/video/classification/test_model.py
@@ -16,12 +16,14 @@
import re
import tempfile
from pathlib import Path
+from unittest import mock
import pytest
import torch
from torch.utils.data import SequentialSampler
import flash
+from flash.__main__ import main
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _VIDEO_AVAILABLE
from flash.video import VideoClassificationData, VideoClassifier
from tests.helpers.utils import _VIDEO_TESTING
@@ -43,7 +45,7 @@ def create_dummy_video_frames(num_frames: int, height: int, width: int):
for i in range(num_frames):
xc = float(i) / num_frames
yc = 1 - float(i) / (2 * num_frames)
- d = torch.exp(-((x - xc)**2 + (y - yc)**2) / 2) * 255
+ d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255
data.append(d.unsqueeze(2).repeat(1, 1, 3).byte())
return torch.stack(data, 0)
@@ -51,9 +53,9 @@ def create_dummy_video_frames(num_frames: int, height: int, width: int):
# https://github.com/facebookresearch/pytorchvideo/blob/4feccb607d7a16933d485495f91d067f177dd8db/tests/utils.py#L33
@contextlib.contextmanager
def temp_encoded_video(num_frames: int, fps: int, height=10, width=10, prefix=None, directory=None):
- """
- Creates a temporary lossless, mp4 video with synthetic content. Uses a context which
- deletes the video after exit.
+ """Creates a temporary lossless, mp4 video with synthetic content.
+
+ Uses a context which deletes the video after exit.
"""
# Lossless options.
video_codec = "libx264rgb"
@@ -101,8 +103,8 @@ def mock_encoded_video_dataset_file():
@contextlib.contextmanager
def mock_encoded_video_dataset_folder(tmpdir):
- """
- Creates a temporary mock encoded video directory tree with 2 videos labeled 1, 2.
+ """Creates a temporary mock encoded video directory tree with 2 videos labeled 1, 2.
+
Returns a directory that to this mock encoded video dataset and the video duration in seconds.
"""
num_frames = 10
@@ -150,28 +152,34 @@ def test_video_classifier_finetune(tmpdir):
assert len(VideoClassifier.available_backbones()) > 5
train_transform = {
- "post_tensor_transform": Compose([
- ApplyTransformToKey(
- key="video",
- transform=Compose([
- UniformTemporalSubsample(8),
- RandomShortSideScale(min_size=256, max_size=320),
- RandomCrop(244),
- RandomHorizontalFlip(p=0.5),
- ]),
- ),
- ]),
- "per_batch_transform_on_device": Compose([
- ApplyTransformToKey(
- key="video",
- transform=K.VideoSequential(
- K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])),
- K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
- data_format="BCTHW",
- same_on_frame=False
- )
- ),
- ]),
+ "post_tensor_transform": Compose(
+ [
+ ApplyTransformToKey(
+ key="video",
+ transform=Compose(
+ [
+ UniformTemporalSubsample(8),
+ RandomShortSideScale(min_size=256, max_size=320),
+ RandomCrop(244),
+ RandomHorizontalFlip(p=0.5),
+ ]
+ ),
+ ),
+ ]
+ ),
+ "per_batch_transform_on_device": Compose(
+ [
+ ApplyTransformToKey(
+ key="video",
+ transform=K.VideoSequential(
+ K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])),
+ K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
+ data_format="BCTHW",
+ same_on_frame=False,
+ ),
+ ),
+ ]
+ ),
}
datamodule = VideoClassificationData.from_folders(
@@ -180,17 +188,17 @@ def test_video_classifier_finetune(tmpdir):
clip_duration=half_duration,
video_sampler=SequentialSampler,
decode_audio=False,
- train_transform=train_transform
+ train_transform=train_transform,
)
- model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False)
+ model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False, backbone="slow_r50")
- trainer = flash.Trainer(fast_dev_run=True)
+ trainer = flash.Trainer(fast_dev_run=True, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule)
-@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.")
+@pytest.mark.skipif(not _VIDEO_AVAILABLE, reason="PyTorchVideo isn't installed.")
@pytest.mark.skipif(not _FIFTYONE_AVAILABLE, reason="fiftyone isn't installed.")
def test_video_classifier_finetune_fiftyone(tmpdir):
@@ -220,28 +228,34 @@ def test_video_classifier_finetune_fiftyone(tmpdir):
assert len(VideoClassifier.available_backbones()) > 5
train_transform = {
- "post_tensor_transform": Compose([
- ApplyTransformToKey(
- key="video",
- transform=Compose([
- UniformTemporalSubsample(8),
- RandomShortSideScale(min_size=256, max_size=320),
- RandomCrop(244),
- RandomHorizontalFlip(p=0.5),
- ]),
- ),
- ]),
- "per_batch_transform_on_device": Compose([
- ApplyTransformToKey(
- key="video",
- transform=K.VideoSequential(
- K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])),
- K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
- data_format="BCTHW",
- same_on_frame=False
- )
- ),
- ]),
+ "post_tensor_transform": Compose(
+ [
+ ApplyTransformToKey(
+ key="video",
+ transform=Compose(
+ [
+ UniformTemporalSubsample(8),
+ RandomShortSideScale(min_size=256, max_size=320),
+ RandomCrop(244),
+ RandomHorizontalFlip(p=0.5),
+ ]
+ ),
+ ),
+ ]
+ ),
+ "per_batch_transform_on_device": Compose(
+ [
+ ApplyTransformToKey(
+ key="video",
+ transform=K.VideoSequential(
+ K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])),
+ K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
+ data_format="BCTHW",
+ same_on_frame=False,
+ ),
+ ),
+ ]
+ ),
}
datamodule = VideoClassificationData.from_fiftyone(
@@ -250,12 +264,12 @@ def test_video_classifier_finetune_fiftyone(tmpdir):
clip_duration=half_duration,
video_sampler=SequentialSampler,
decode_audio=False,
- train_transform=train_transform
+ train_transform=train_transform,
)
- model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False)
+ model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False, backbone="slow_r50")
- trainer = flash.Trainer(fast_dev_run=True)
+ trainer = flash.Trainer(fast_dev_run=True, gpus=torch.cuda.device_count())
trainer.finetune(model, datamodule=datamodule)
@@ -265,7 +279,7 @@ def test_jit(tmpdir):
sample_input = torch.rand(1, 3, 32, 256, 256)
path = os.path.join(tmpdir, "test.pt")
- model = VideoClassifier(2, pretrained=False)
+ model = VideoClassifier(2, pretrained=False, backbone="slow_r50")
model.eval()
# pytorchvideo only works with `torch.jit.trace`
@@ -283,3 +297,13 @@ def test_jit(tmpdir):
def test_load_from_checkpoint_dependency_error():
with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[video]'")):
VideoClassifier.load_from_checkpoint("not_a_real_checkpoint.pt")
+
+
+@pytest.mark.skipif(not _VIDEO_TESTING, reason="PyTorchVideo isn't installed.")
+def test_cli():
+ cli_args = ["flash", "video_classification", "--trainer.fast_dev_run", "True", "num_workers", "0"]
+ with mock.patch("sys.argv", cli_args):
+ try:
+ main()
+ except SystemExit:
+ pass