From 70c865d869debffc0d3f658749064a587e1f55c2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 20 Apr 2021 11:51:36 +0100 Subject: [PATCH 1/6] add multilabel --- flash/core/classification.py | 16 +++++++++++---- flash/core/model.py | 15 +++++++------- flash/data/data_pipeline.py | 2 +- flash/vision/classification/model.py | 23 ++++++++++++++++++--- flash/vision/embedding/model.py | 2 +- tests/vision/classification/test_model.py | 25 +++++++++++++++++++++++ 6 files changed, 66 insertions(+), 17 deletions(-) diff --git a/flash/core/classification.py b/flash/core/classification.py index 970466dbf4..17dcc946dc 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import torch import torch.nn.functional as F @@ -22,14 +22,22 @@ class ClassificationPostprocess(Postprocess): + def __init__(self, multi_label: bool = False, save_path: Optional[str] = None): + super().__init__(save_path=save_path) + self.multi_label = multi_label + def per_sample_transform(self, samples: Any) -> Any: - return torch.argmax(samples, -1).tolist() + if self.multi_label: + return F.sigmoid(samples).tolist() + else: + return torch.argmax(samples, -1).tolist() class ClassificationTask(Task): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, default_postprocess=ClassificationPostprocess(), **kwargs) + postprocess_cls = ClassificationPostprocess def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: + if self.hparams.multi_label: + return F.sigmoid(x).int() return F.softmax(x, -1) diff --git a/flash/core/model.py b/flash/core/model.py index eeaf268b75..5206babf48 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -60,8 +60,8 @@ class Task(LightningModule): optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. metrics: Metrics to compute for training and evaluation. learning_rate: Learning rate to use for training, defaults to `5e-5`. - default_preprocess: :class:`.Preprocess` to use as the default for this task. - default_postprocess: :class:`.Postprocess` to use as the default for this task. + preprocess: :class:`.Preprocess` to use as the default for this task. + postprocess: :class:`.Postprocess` to use as the default for this task. """ def __init__( @@ -71,8 +71,8 @@ def __init__( optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, learning_rate: float = 5e-5, - default_preprocess: Preprocess = None, - default_postprocess: Postprocess = None, + preprocess: Preprocess = None, + postprocess: Postprocess = None, ): super().__init__() if model is not None: @@ -84,8 +84,8 @@ def __init__( # TODO: should we save more? Bug on some regarding yaml if we save metrics self.save_hyperparameters("learning_rate", "optimizer") - self._preprocess = default_preprocess - self._postprocess = default_postprocess + self._preprocess = preprocess + self._postprocess = postprocess def step(self, batch: Any, batch_idx: int) -> Any: """ @@ -99,8 +99,7 @@ def step(self, batch: Any, batch_idx: int) -> Any: y_hat = self.to_metrics_format(y_hat) for name, metric in self.metrics.items(): if isinstance(metric, torchmetrics.metric.Metric): - metric(y_hat, y) - logs[name] = metric # log the metric itself if it is of type Metric + logs[name] = metric(y_hat, y) # log the metric itself if it is of type Metric else: logs[name] = metric(y_hat, y) logs.update(losses) diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 2125b79909..46be3c823c 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -359,7 +359,7 @@ def _create_uncollate_postprocessors(self, stage: RunningStage) -> _PostProcesso # since postprocessing is exclusive for prediction, we don't have to check the resolution hierarchy here. if postprocess._save_path: save_per_sample: bool = self._is_overriden_recursive( - "save_sample", postprocess, object_type=Postprocess, prefix=_STAGES_PREFIX[stage] + "save_sample", postprocess, Postprocess, prefix=_STAGES_PREFIX[stage] ) if save_per_sample: diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 7a046f52e5..1d7be4e9e6 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -24,6 +24,10 @@ from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES +def binary_cross_entropy_with_logits(x, y): + return F.binary_cross_entropy_with_logits(x, y.float()) + + class ImageClassifier(ClassificationTask): """Task that classifies images. @@ -57,6 +61,7 @@ class ImageClassifier(ClassificationTask): metrics: Metrics to compute for training and evaluation, defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training, defaults to ``1e-3``. + multi_label: Whether the labels are multi labels or not. """ backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES @@ -68,17 +73,26 @@ def __init__( backbone_kwargs: Optional[Dict] = None, head: Optional[Union[FunctionType, nn.Module]] = None, pretrained: bool = True, - loss_fn: Callable = F.cross_entropy, + loss_fn: Callable = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, - metrics: Union[Callable, Mapping, Sequence, None] = Accuracy(), + metrics: Union[Callable, Mapping, Sequence, None] = None, learning_rate: float = 1e-3, + multi_label: bool = False, ): + + if metrics is None: + metrics = Accuracy(subset_accuracy=multi_label) + + if loss_fn is None: + loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy + super().__init__( model=None, loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate, + postprocess=self.postprocess_cls(multi_label) ) self.save_hyperparameters() @@ -100,4 +114,7 @@ def __init__( def forward(self, x) -> Any: x = self.backbone(x) - return torch.softmax(self.head(x), -1) + if self.hparams.multi_label: + return self.head(x) + else: + return torch.softmax(self.head(x), -1) diff --git a/flash/vision/embedding/model.py b/flash/vision/embedding/model.py index d1ddce5c84..ee1019aa51 100644 --- a/flash/vision/embedding/model.py +++ b/flash/vision/embedding/model.py @@ -66,7 +66,7 @@ def __init__( optimizer=optimizer, metrics=metrics, learning_rate=learning_rate, - default_preprocess=ImageClassificationPreprocess( + preprocess=ImageClassificationPreprocess( predict_transform=ImageClassificationData.default_val_transforms(), ) ) diff --git a/tests/vision/classification/test_model.py b/tests/vision/classification/test_model.py index 5079c332d8..a31dd3434e 100644 --- a/tests/vision/classification/test_model.py +++ b/tests/vision/classification/test_model.py @@ -30,6 +30,18 @@ def __len__(self) -> int: return 100 +class DummyMultiLabelDataset(torch.utils.data.Dataset): + + def __init__(self, num_classes: int): + self.num_classes = num_classes + + def __getitem__(self, index): + return torch.rand(3, 224, 224), torch.randint(0, 2, (self.num_classes, )) + + def __len__(self) -> int: + return 100 + + # ============================== @@ -67,3 +79,16 @@ def test_unfreeze(): model.unfreeze() for p in model.backbone.parameters(): assert p.requires_grad is True + + +def test_multilabel(tmpdir): + + num_classes = 4 + ds = DummyMultiLabelDataset(num_classes) + model = ImageClassifier(num_classes, multi_label=True) + train_dl = torch.utils.data.DataLoader(ds, batch_size=2) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.finetune(model, train_dl, strategy="freeze_unfreeze") + image, _ = ds[0] + predictions = model.predict(image.unsqueeze(0)) + assert len(predictions[0]) == num_classes From ca8022e897723024f9062f9b68b7923e0cdcb990 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 20 Apr 2021 11:52:26 +0100 Subject: [PATCH 2/6] change types --- flash/vision/classification/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 1d7be4e9e6..7fecddbb8b 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -73,9 +73,9 @@ def __init__( backbone_kwargs: Optional[Dict] = None, head: Optional[Union[FunctionType, nn.Module]] = None, pretrained: bool = True, - loss_fn: Callable = None, + loss_fn: Optional[Callable] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, - metrics: Union[Callable, Mapping, Sequence, None] = None, + metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None, learning_rate: float = 1e-3, multi_label: bool = False, ): From 310c5f91734d48ae7b43c6e1e094e08a018c7f9d Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 20 Apr 2021 11:55:17 +0100 Subject: [PATCH 3/6] add a check --- flash/core/classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/core/classification.py b/flash/core/classification.py index 17dcc946dc..56551390b9 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -38,6 +38,6 @@ class ClassificationTask(Task): postprocess_cls = ClassificationPostprocess def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: - if self.hparams.multi_label: + if getattr(self.hparams, "multi_label", False): return F.sigmoid(x).int() return F.softmax(x, -1) From 3ac9f8e9c723c9eb4dc906ab452dbe9976107e0f Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 20 Apr 2021 12:13:29 +0100 Subject: [PATCH 4/6] resolve a bug --- flash/core/classification.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/flash/core/classification.py b/flash/core/classification.py index 56551390b9..ddc5ab1a52 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from flash.core.model import Task -from flash.data.process import Postprocess +from flash.data.process import Postprocess, Preprocess class ClassificationPostprocess(Postprocess): @@ -37,6 +37,9 @@ class ClassificationTask(Task): postprocess_cls = ClassificationPostprocess + def __init__(self, *args, postprocess: Optional[Preprocess] = None, **kwargs): + super().__init__(*args, postprocess=postprocess or self.postprocess_cls(), **kwargs) + def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: if getattr(self.hparams, "multi_label", False): return F.sigmoid(x).int() From 64b90e7b7db00aee5515a6019d96028a63f4ae05 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 20 Apr 2021 12:28:33 +0100 Subject: [PATCH 5/6] update on comments --- flash/core/classification.py | 6 +++--- flash/core/model.py | 25 ++++++++++++----------- flash/vision/classification/model.py | 5 +++-- flash/vision/embedding/model.py | 2 +- tests/vision/classification/test_model.py | 7 +++++-- 5 files changed, 25 insertions(+), 20 deletions(-) diff --git a/flash/core/classification.py b/flash/core/classification.py index ddc5ab1a52..9650eadc34 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, List, Optional import torch import torch.nn.functional as F @@ -26,7 +26,7 @@ def __init__(self, multi_label: bool = False, save_path: Optional[str] = None): super().__init__(save_path=save_path) self.multi_label = multi_label - def per_sample_transform(self, samples: Any) -> Any: + def per_sample_transform(self, samples: Any) -> List[Any]: if self.multi_label: return F.sigmoid(samples).tolist() else: @@ -42,5 +42,5 @@ def __init__(self, *args, postprocess: Optional[Preprocess] = None, **kwargs): def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: if getattr(self.hparams, "multi_label", False): - return F.sigmoid(x).int() + return F.sigmoid(x) return F.softmax(x, -1) diff --git a/flash/core/model.py b/flash/core/model.py index 5206babf48..eafaea4af6 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -60,8 +60,8 @@ class Task(LightningModule): optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. metrics: Metrics to compute for training and evaluation. learning_rate: Learning rate to use for training, defaults to `5e-5`. - preprocess: :class:`.Preprocess` to use as the default for this task. - postprocess: :class:`.Postprocess` to use as the default for this task. + preprocess: :class:`~flash.data.process.Preprocess` to use as the default for this task. + postprocess: :class:`~flash.data.process.Postprocess` to use as the default for this task. """ def __init__( @@ -99,7 +99,8 @@ def step(self, batch: Any, batch_idx: int) -> Any: y_hat = self.to_metrics_format(y_hat) for name, metric in self.metrics.items(): if isinstance(metric, torchmetrics.metric.Metric): - logs[name] = metric(y_hat, y) # log the metric itself if it is of type Metric + metric(y_hat, y) + logs[name] = metric # log the metric itself if it is of type Metric else: logs[name] = metric(y_hat, y) logs.update(losses) @@ -180,17 +181,17 @@ def _resolve( new_preprocess: Optional[Preprocess], new_postprocess: Optional[Postprocess], ) -> Tuple[Optional[Preprocess], Optional[Postprocess]]: - """Resolves the correct :class:`.Preprocess` and :class:`.Postprocess` to use, choosing ``new_*`` if it is not - None or a base class (:class:`.Preprocess` or :class:`.Postprocess`) and ``old_*`` otherwise. + """Resolves the correct :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` to use, choosing ``new_*`` if it is not + None or a base class (:class:`~flash.data.process.Preprocess` or :class:`~flash.data.process.Postprocess`) and ``old_*`` otherwise. Args: - old_preprocess: :class:`.Preprocess` to be overridden. - old_postprocess: :class:`.Postprocess` to be overridden. - new_preprocess: :class:`.Preprocess` to override with. - new_postprocess: :class:`.Postprocess` to override with. + old_preprocess: :class:`~flash.data.process.Preprocess` to be overridden. + old_postprocess: :class:`~flash.data.process.Postprocess` to be overridden. + new_preprocess: :class:`~flash.data.process.Preprocess` to override with. + new_postprocess: :class:`~flash.data.process.Postprocess` to override with. Returns: - The resolved :class:`.Preprocess` and :class:`.Postprocess`. + The resolved :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`. """ preprocess = old_preprocess if new_preprocess is not None and type(new_preprocess) != Preprocess: @@ -203,7 +204,7 @@ def _resolve( return preprocess, postprocess def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> Optional[DataPipeline]: - """Build a :class:`.DataPipeline` incorporating available :class:`.Preprocess` and :class:`.Postprocess` + """Build a :class:`.DataPipeline` incorporating available :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` objects. These will be overridden in the following resolution order (lowest priority first): - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`. @@ -212,7 +213,7 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O - :class:`.DataPipeline` passed to this method. Args: - data_pipeline: Optional highest priority source of :class:`.Preprocess` and :class:`.Postprocess`. + data_pipeline: Optional highest priority source of :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`. Returns: The fully resolved :class:`.DataPipeline`. diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 7fecddbb8b..528ce99063 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -24,7 +24,8 @@ from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES -def binary_cross_entropy_with_logits(x, y): +def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Calls BCE with logits and cast the target one_hot (y) encoding to floating point precision.""" return F.binary_cross_entropy_with_logits(x, y.float()) @@ -112,7 +113,7 @@ def __init__( nn.Linear(num_features, num_classes), ) - def forward(self, x) -> Any: + def forward(self, x) -> torch.Tensor: x = self.backbone(x) if self.hparams.multi_label: return self.head(x) diff --git a/flash/vision/embedding/model.py b/flash/vision/embedding/model.py index ee1019aa51..f9fc3d85e2 100644 --- a/flash/vision/embedding/model.py +++ b/flash/vision/embedding/model.py @@ -98,7 +98,7 @@ def apply_pool(self, x): x = self.pooling_fn(x, dim=-1) return x - def forward(self, x) -> Any: + def forward(self, x) -> torch.Tensor: x = self.backbone(x) # bolts ssl models return lists diff --git a/tests/vision/classification/test_model.py b/tests/vision/classification/test_model.py index a31dd3434e..067fa994b8 100644 --- a/tests/vision/classification/test_model.py +++ b/tests/vision/classification/test_model.py @@ -89,6 +89,9 @@ def test_multilabel(tmpdir): train_dl = torch.utils.data.DataLoader(ds, batch_size=2) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.finetune(model, train_dl, strategy="freeze_unfreeze") - image, _ = ds[0] + image, label = ds[0] predictions = model.predict(image.unsqueeze(0)) - assert len(predictions[0]) == num_classes + assert (torch.tensor(predictions) > 1).sum() == 0 + assert (torch.tensor(predictions) < 0).sum() == 0 + assert len(predictions[0]) == num_classes == len(label) + assert len(torch.unique(label)) <= 2 From 7a807ba714b7a4065876690f1faa1bca8cb799ce Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 20 Apr 2021 12:39:01 +0100 Subject: [PATCH 6/6] update --- flash/core/model.py | 12 ++++++++---- flash/vision/classification/model.py | 5 +---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index eafaea4af6..9914b4cb61 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -181,8 +181,10 @@ def _resolve( new_preprocess: Optional[Preprocess], new_postprocess: Optional[Postprocess], ) -> Tuple[Optional[Preprocess], Optional[Postprocess]]: - """Resolves the correct :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` to use, choosing ``new_*`` if it is not - None or a base class (:class:`~flash.data.process.Preprocess` or :class:`~flash.data.process.Postprocess`) and ``old_*`` otherwise. + """Resolves the correct :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` to use, + choosing ``new_*`` if it is not None or a base class + (:class:`~flash.data.process.Preprocess` or :class:`~flash.data.process.Postprocess`) + and ``old_*`` otherwise. Args: old_preprocess: :class:`~flash.data.process.Preprocess` to be overridden. @@ -204,7 +206,8 @@ def _resolve( return preprocess, postprocess def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> Optional[DataPipeline]: - """Build a :class:`.DataPipeline` incorporating available :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` + """Build a :class:`.DataPipeline` incorporating available + :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` objects. These will be overridden in the following resolution order (lowest priority first): - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`. @@ -213,7 +216,8 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O - :class:`.DataPipeline` passed to this method. Args: - data_pipeline: Optional highest priority source of :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`. + data_pipeline: Optional highest priority source of + :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`. Returns: The fully resolved :class:`.DataPipeline`. diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 528ce99063..5b6d9dca30 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -115,7 +115,4 @@ def __init__( def forward(self, x) -> torch.Tensor: x = self.backbone(x) - if self.hparams.multi_label: - return self.head(x) - else: - return torch.softmax(self.head(x), -1) + return self.head(x)