Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix bug when passing metrics as empty list #660

Merged
merged 2 commits into from
Aug 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where custom samplers would not be properly forwarded to the data loader ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651))

- Fixed a bug where it was not possible to pass no metrics to the `ImageClassifier` or `TestClassifier` ([#660](https://github.com/PyTorchLightning/lightning-flash/pull/660))

## [0.4.0] - 2021-06-22

### Added
Expand Down
3 changes: 2 additions & 1 deletion flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ class ClassificationTask(Task):
def __init__(
self,
*args,
num_classes: Optional[int] = None,
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
**kwargs,
) -> None:
if metrics is None:
metrics = torchmetrics.Accuracy(subset_accuracy=multi_label)
metrics = torchmetrics.F1(num_classes) if (multi_label and num_classes) else torchmetrics.Accuracy()

if loss_fn is None:
loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy
Expand Down
5 changes: 3 additions & 2 deletions flash/image/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from torch import nn
from torch.optim.lr_scheduler import _LRScheduler
from torchmetrics import Accuracy, F1, Metric
from torchmetrics import Metric

from flash.core.classification import ClassificationTask, Labels
from flash.core.data.data_source import DefaultDataKeys
Expand Down Expand Up @@ -86,13 +86,14 @@ def __init__(
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
):
super().__init__(
num_classes=num_classes,
model=None,
loss_fn=loss_fn,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
metrics=metrics or (F1(num_classes) if multi_label else Accuracy()),
metrics=metrics,
learning_rate=learning_rate,
multi_label=multi_label,
serializer=serializer or Labels(multi_label=multi_label),
Expand Down
5 changes: 3 additions & 2 deletions flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union

import torch
from torchmetrics import Accuracy, F1, Metric
from torchmetrics import Metric

from flash.core.classification import ClassificationTask, Labels
from flash.core.data.process import Serializer
Expand Down Expand Up @@ -67,10 +67,11 @@ def __init__(
os.environ["PYTHONWARNINGS"] = "ignore"

super().__init__(
num_classes=num_classes,
model=None,
loss_fn=loss_fn,
optimizer=optimizer,
metrics=metrics or (F1(num_classes) if multi_label else Accuracy()),
metrics=metrics,
learning_rate=learning_rate,
multi_label=multi_label,
serializer=serializer or Labels(multi_label=multi_label),
Expand Down
9 changes: 5 additions & 4 deletions tests/image/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,18 @@ def __len__(self) -> int:

@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
@pytest.mark.parametrize(
"backbone",
"backbone,metrics",
[
"resnet18",
("resnet18", None),
("resnet18", []),
# "resnet34",
# "resnet50",
# "resnet101",
# "resnet152",
],
)
def test_init_train(tmpdir, backbone):
model = ImageClassifier(10, backbone=backbone)
def test_init_train(tmpdir, backbone, metrics):
model = ImageClassifier(10, backbone=backbone, metrics=metrics)
train_dl = torch.utils.data.DataLoader(DummyDataset())
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.finetune(model, train_dl, strategy="freeze_unfreeze")
Expand Down