Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix bug when passing metrics as empty list (#660)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Aug 16, 2021
1 parent 9061d4b commit b766cc3
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where custom samplers would not be properly forwarded to the data loader ([#651](https://github.com/PyTorchLightning/lightning-flash/pull/651))

- Fixed a bug where it was not possible to pass no metrics to the `ImageClassifier` or `TestClassifier` ([#660](https://github.com/PyTorchLightning/lightning-flash/pull/660))

## [0.4.0] - 2021-06-22

### Added
Expand Down
3 changes: 2 additions & 1 deletion flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ class ClassificationTask(Task):
def __init__(
self,
*args,
num_classes: Optional[int] = None,
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
multi_label: bool = False,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
**kwargs,
) -> None:
if metrics is None:
metrics = torchmetrics.Accuracy(subset_accuracy=multi_label)
metrics = torchmetrics.F1(num_classes) if (multi_label and num_classes) else torchmetrics.Accuracy()

if loss_fn is None:
loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy
Expand Down
5 changes: 3 additions & 2 deletions flash/image/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from torch import nn
from torch.optim.lr_scheduler import _LRScheduler
from torchmetrics import Accuracy, F1, Metric
from torchmetrics import Metric

from flash.core.classification import ClassificationTask, Labels
from flash.core.data.data_source import DefaultDataKeys
Expand Down Expand Up @@ -86,13 +86,14 @@ def __init__(
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
):
super().__init__(
num_classes=num_classes,
model=None,
loss_fn=loss_fn,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
metrics=metrics or (F1(num_classes) if multi_label else Accuracy()),
metrics=metrics,
learning_rate=learning_rate,
multi_label=multi_label,
serializer=serializer or Labels(multi_label=multi_label),
Expand Down
5 changes: 3 additions & 2 deletions flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union

import torch
from torchmetrics import Accuracy, F1, Metric
from torchmetrics import Metric

from flash.core.classification import ClassificationTask, Labels
from flash.core.data.process import Serializer
Expand Down Expand Up @@ -67,10 +67,11 @@ def __init__(
os.environ["PYTHONWARNINGS"] = "ignore"

super().__init__(
num_classes=num_classes,
model=None,
loss_fn=loss_fn,
optimizer=optimizer,
metrics=metrics or (F1(num_classes) if multi_label else Accuracy()),
metrics=metrics,
learning_rate=learning_rate,
multi_label=multi_label,
serializer=serializer or Labels(multi_label=multi_label),
Expand Down
9 changes: 5 additions & 4 deletions tests/image/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,18 @@ def __len__(self) -> int:

@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
@pytest.mark.parametrize(
"backbone",
"backbone,metrics",
[
"resnet18",
("resnet18", None),
("resnet18", []),
# "resnet34",
# "resnet50",
# "resnet101",
# "resnet152",
],
)
def test_init_train(tmpdir, backbone):
model = ImageClassifier(10, backbone=backbone)
def test_init_train(tmpdir, backbone, metrics):
model = ImageClassifier(10, backbone=backbone, metrics=metrics)
train_dl = torch.utils.data.DataLoader(DummyDataset())
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.finetune(model, train_dl, strategy="freeze_unfreeze")
Expand Down

0 comments on commit b766cc3

Please sign in to comment.