Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[feat] Add multi label support #230

Merged
merged 6 commits into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, Optional

import torch
import torch.nn.functional as F

from flash.core.model import Task
from flash.data.process import Postprocess
from flash.data.process import Postprocess, Preprocess


class ClassificationPostprocess(Postprocess):

def __init__(self, multi_label: bool = False, save_path: Optional[str] = None):
super().__init__(save_path=save_path)
self.multi_label = multi_label

def per_sample_transform(self, samples: Any) -> Any:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return torch.argmax(samples, -1).tolist()
if self.multi_label:
return F.sigmoid(samples).tolist()
tchaton marked this conversation as resolved.
Show resolved Hide resolved
else:
return torch.argmax(samples, -1).tolist()


class ClassificationTask(Task):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, default_postprocess=ClassificationPostprocess(), **kwargs)
postprocess_cls = ClassificationPostprocess

def __init__(self, *args, postprocess: Optional[Preprocess] = None, **kwargs):
super().__init__(*args, postprocess=postprocess or self.postprocess_cls(), **kwargs)

def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
if getattr(self.hparams, "multi_label", False):
return F.sigmoid(x).int()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why int ? .round() might work better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, it was a bug.

tchaton marked this conversation as resolved.
Show resolved Hide resolved
return F.softmax(x, -1)
15 changes: 7 additions & 8 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class Task(LightningModule):
optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
metrics: Metrics to compute for training and evaluation.
learning_rate: Learning rate to use for training, defaults to `5e-5`.
default_preprocess: :class:`.Preprocess` to use as the default for this task.
default_postprocess: :class:`.Postprocess` to use as the default for this task.
preprocess: :class:`.Preprocess` to use as the default for this task.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
postprocess: :class:`.Postprocess` to use as the default for this task.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
Expand All @@ -71,8 +71,8 @@ def __init__(
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
default_preprocess: Preprocess = None,
default_postprocess: Postprocess = None,
preprocess: Preprocess = None,
postprocess: Postprocess = None,
):
super().__init__()
if model is not None:
Expand All @@ -84,8 +84,8 @@ def __init__(
# TODO: should we save more? Bug on some regarding yaml if we save metrics
self.save_hyperparameters("learning_rate", "optimizer")

self._preprocess = default_preprocess
self._postprocess = default_postprocess
self._preprocess = preprocess
self._postprocess = postprocess

def step(self, batch: Any, batch_idx: int) -> Any:
"""
Expand All @@ -99,8 +99,7 @@ def step(self, batch: Any, batch_idx: int) -> Any:
y_hat = self.to_metrics_format(y_hat)
for name, metric in self.metrics.items():
if isinstance(metric, torchmetrics.metric.Metric):
metric(y_hat, y)
logs[name] = metric # log the metric itself if it is of type Metric
logs[name] = metric(y_hat, y) # log the metric itself if it is of type Metric
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
else:
logs[name] = metric(y_hat, y)
logs.update(losses)
Expand Down
2 changes: 1 addition & 1 deletion flash/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _create_uncollate_postprocessors(self, stage: RunningStage) -> _PostProcesso
# since postprocessing is exclusive for prediction, we don't have to check the resolution hierarchy here.
if postprocess._save_path:
save_per_sample: bool = self._is_overriden_recursive(
"save_sample", postprocess, object_type=Postprocess, prefix=_STAGES_PREFIX[stage]
"save_sample", postprocess, Postprocess, prefix=_STAGES_PREFIX[stage]
)

if save_per_sample:
Expand Down
23 changes: 20 additions & 3 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES


def binary_cross_entropy_with_logits(x, y):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return F.binary_cross_entropy_with_logits(x, y.float())


class ImageClassifier(ClassificationTask):
"""Task that classifies images.

Expand Down Expand Up @@ -57,6 +61,7 @@ class ImageClassifier(ClassificationTask):
metrics: Metrics to compute for training and evaluation,
defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
multi_label: Whether the labels are multi labels or not.
"""

backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES
Expand All @@ -68,17 +73,26 @@ def __init__(
backbone_kwargs: Optional[Dict] = None,
head: Optional[Union[FunctionType, nn.Module]] = None,
pretrained: bool = True,
loss_fn: Callable = F.cross_entropy,
loss_fn: Optional[Callable] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD,
metrics: Union[Callable, Mapping, Sequence, None] = Accuracy(),
metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None,
learning_rate: float = 1e-3,
multi_label: bool = False,
):

if metrics is None:
metrics = Accuracy(subset_accuracy=multi_label)

if loss_fn is None:
loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy

super().__init__(
model=None,
loss_fn=loss_fn,
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
postprocess=self.postprocess_cls(multi_label)
)

self.save_hyperparameters()
Expand All @@ -100,4 +114,7 @@ def __init__(

def forward(self, x) -> Any:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
x = self.backbone(x)
return torch.softmax(self.head(x), -1)
if self.hparams.multi_label:
return self.head(x)
else:
return torch.softmax(self.head(x), -1)
2 changes: 1 addition & 1 deletion flash/vision/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
default_preprocess=ImageClassificationPreprocess(
preprocess=ImageClassificationPreprocess(
predict_transform=ImageClassificationData.default_val_transforms(),
)
)
Expand Down
25 changes: 25 additions & 0 deletions tests/vision/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ def __len__(self) -> int:
return 100


class DummyMultiLabelDataset(torch.utils.data.Dataset):

def __init__(self, num_classes: int):
self.num_classes = num_classes

def __getitem__(self, index):
return torch.rand(3, 224, 224), torch.randint(0, 2, (self.num_classes, ))

def __len__(self) -> int:
return 100


# ==============================


Expand Down Expand Up @@ -67,3 +79,16 @@ def test_unfreeze():
model.unfreeze()
for p in model.backbone.parameters():
assert p.requires_grad is True


def test_multilabel(tmpdir):

num_classes = 4
ds = DummyMultiLabelDataset(num_classes)
model = ImageClassifier(num_classes, multi_label=True)
train_dl = torch.utils.data.DataLoader(ds, batch_size=2)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.finetune(model, train_dl, strategy="freeze_unfreeze")
image, _ = ds[0]
tchaton marked this conversation as resolved.
Show resolved Hide resolved
predictions = model.predict(image.unsqueeze(0))
assert len(predictions[0]) == num_classes