Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add pre_metric_transform hook #219

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added


- Added a `pre_metric_transform` hook to `Task` for converting model outputs to the correct format for metric computation ([#219](https://github.com/PyTorchLightning/lightning-flash/pull/219))

### Changed

Expand All @@ -20,6 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed classification softmax ([#169](https://github.com/PyTorchLightning/lightning-flash/pull/169))


- Fixed a bug where `ClassificationTask` would sometimes call `softmax` followed by `log_softmax` during training ([#219](https://github.com/PyTorchLightning/lightning-flash/pull/219))

### Removed


Expand Down
4 changes: 4 additions & 0 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any

import torch
from torch import Tensor

from flash.core.model import Task
from flash.data.process import Postprocess
Expand All @@ -29,3 +30,6 @@ class ClassificationTask(Task):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, default_postprocess=ClassificationPostprocess(), **kwargs)

def pre_metric_transform(self, y_hat: Tensor) -> Tensor:
return torch.softmax(y_hat, 1)
7 changes: 7 additions & 0 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union

import torch
from torch import Tensor

import torchmetrics
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import Callback
Expand Down Expand Up @@ -96,6 +98,7 @@ def step(self, batch: Any, batch_idx: int) -> Any:
output = {"y_hat": y_hat}
losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()}
logs = {}
y_hat = self.pre_metric_transform(y_hat)
for name, metric in self.metrics.items():
if isinstance(metric, torchmetrics.metric.Metric):
metric(y_hat, y)
Expand Down Expand Up @@ -256,6 +259,10 @@ def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None:
getattr(data_pipeline, '_postprocess_pipeline', None),
)

def pre_metric_transform(self, y_hat: Tensor) -> Tensor:
"""Transform to apply to the model output (``y_hat``) before forwarding to the metrics."""
return y_hat

def on_train_dataloader(self) -> None:
if self.data_pipeline is not None:
self.data_pipeline._detach_from_model(self, RunningStage.TRAINING)
Expand Down
2 changes: 1 addition & 1 deletion flash/tabular/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
def forward(self, x_in) -> torch.Tensor:
# TabNet takes single input, x_in is composed of (categorical, numerical)
x = torch.cat([x for x in x_in if x.numel()], dim=1)
return F.softmax(self.model(x)[0], -1)
return self.model(x)[0]

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self(batch)
Expand Down
2 changes: 1 addition & 1 deletion flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,6 @@ def step(self, batch, batch_idx) -> dict:
output["y_hat"] = logits
if isinstance(logits, SequenceClassifierOutput):
logits = logits.logits
probs = torch.softmax(logits, 1)
probs = self.pre_metric_transform(logits)
output["logs"] = {name: metric(probs, batch["labels"]) for name, metric in self.metrics.items()}
return output
2 changes: 1 addition & 1 deletion flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,4 @@ def __init__(

def forward(self, x) -> Any:
x = self.backbone(x)
return torch.softmax(self.head(x), -1)
return self.head(x)
1 change: 0 additions & 1 deletion flash_examples/generic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 10),
nn.Softmax(),
)

# 3. Load a dataset
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class DummyClassifier(nn.Module):
def __init__(self):
super().__init__()
self.backbone = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
self.head = nn.LogSoftmax()
self.head = nn.LogSoftmax(dim=1)

def forward(self, x):
return self.head(self.backbone(x))
Expand Down