From b766cc383612c75dc5543293628d2e4cc209e73e Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 16 Aug 2021 14:02:31 +0100 Subject: [PATCH] Fix bug when passing metrics as empty list (#660) --- CHANGELOG.md | 2 ++ flash/core/classification.py | 3 ++- flash/image/classification/model.py | 5 +++-- flash/text/classification/model.py | 5 +++-- tests/image/classification/test_model.py | 9 +++++---- 5 files changed, 15 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 812b64f5f5..a27635e797 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where custom samplers would not be properly forwarded to the data loader ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651)) +- Fixed a bug where it was not possible to pass no metrics to the `ImageClassifier` or `TestClassifier` ([#660](https://github.com/PyTorchLightning/lightning-flash/pull/660)) + ## [0.4.0] - 2021-06-22 ### Added diff --git a/flash/core/classification.py b/flash/core/classification.py index ba10162abc..b11e714528 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -41,6 +41,7 @@ class ClassificationTask(Task): def __init__( self, *args, + num_classes: Optional[int] = None, loss_fn: Optional[Callable] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, multi_label: bool = False, @@ -48,7 +49,7 @@ def __init__( **kwargs, ) -> None: if metrics is None: - metrics = torchmetrics.Accuracy(subset_accuracy=multi_label) + metrics = torchmetrics.F1(num_classes) if (multi_label and num_classes) else torchmetrics.Accuracy() if loss_fn is None: loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index 40ba82d5c9..ba70b6988c 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -17,7 +17,7 @@ import torch from torch import nn from torch.optim.lr_scheduler import _LRScheduler -from torchmetrics import Accuracy, F1, Metric +from torchmetrics import Metric from flash.core.classification import ClassificationTask, Labels from flash.core.data.data_source import DefaultDataKeys @@ -86,13 +86,14 @@ def __init__( serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, ): super().__init__( + num_classes=num_classes, model=None, loss_fn=loss_fn, optimizer=optimizer, optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, - metrics=metrics or (F1(num_classes) if multi_label else Accuracy()), + metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, serializer=serializer or Labels(multi_label=multi_label), diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index 3a0d78e1ff..c9ba5fa0a1 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -16,7 +16,7 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union import torch -from torchmetrics import Accuracy, F1, Metric +from torchmetrics import Metric from flash.core.classification import ClassificationTask, Labels from flash.core.data.process import Serializer @@ -67,10 +67,11 @@ def __init__( os.environ["PYTHONWARNINGS"] = "ignore" super().__init__( + num_classes=num_classes, model=None, loss_fn=loss_fn, optimizer=optimizer, - metrics=metrics or (F1(num_classes) if multi_label else Accuracy()), + metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, serializer=serializer or Labels(multi_label=multi_label), diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py index d9014464eb..7dc49a3abc 100644 --- a/tests/image/classification/test_model.py +++ b/tests/image/classification/test_model.py @@ -60,17 +60,18 @@ def __len__(self) -> int: @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @pytest.mark.parametrize( - "backbone", + "backbone,metrics", [ - "resnet18", + ("resnet18", None), + ("resnet18", []), # "resnet34", # "resnet50", # "resnet101", # "resnet152", ], ) -def test_init_train(tmpdir, backbone): - model = ImageClassifier(10, backbone=backbone) +def test_init_train(tmpdir, backbone, metrics): + model = ImageClassifier(10, backbone=backbone, metrics=metrics) train_dl = torch.utils.data.DataLoader(DummyDataset()) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.finetune(model, train_dl, strategy="freeze_unfreeze")