Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[feat] Add multi label support #230

Merged
merged 6 commits into from
Apr 20, 2021
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, Optional

import torch
import torch.nn.functional as F
@@ -22,14 +22,22 @@

class ClassificationPostprocess(Postprocess):

def __init__(self, multi_label: bool = False, save_path: Optional[str] = None):
super().__init__(save_path=save_path)
self.multi_label = multi_label

def per_sample_transform(self, samples: Any) -> Any:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return torch.argmax(samples, -1).tolist()
if self.multi_label:
return F.sigmoid(samples).tolist()
tchaton marked this conversation as resolved.
Show resolved Hide resolved
else:
return torch.argmax(samples, -1).tolist()


class ClassificationTask(Task):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, default_postprocess=ClassificationPostprocess(), **kwargs)
postprocess_cls = ClassificationPostprocess

def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
if getattr(self.hparams, "multi_label", False):
return F.sigmoid(x).int()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why int ? .round() might work better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, it was a bug.

tchaton marked this conversation as resolved.
Show resolved Hide resolved
return F.softmax(x, -1)
15 changes: 7 additions & 8 deletions flash/core/model.py
Original file line number Diff line number Diff line change
@@ -60,8 +60,8 @@ class Task(LightningModule):
optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
metrics: Metrics to compute for training and evaluation.
learning_rate: Learning rate to use for training, defaults to `5e-5`.
default_preprocess: :class:`.Preprocess` to use as the default for this task.
default_postprocess: :class:`.Postprocess` to use as the default for this task.
preprocess: :class:`.Preprocess` to use as the default for this task.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
postprocess: :class:`.Postprocess` to use as the default for this task.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
@@ -71,8 +71,8 @@ def __init__(
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
default_preprocess: Preprocess = None,
default_postprocess: Postprocess = None,
preprocess: Preprocess = None,
postprocess: Postprocess = None,
):
super().__init__()
if model is not None:
@@ -84,8 +84,8 @@ def __init__(
# TODO: should we save more? Bug on some regarding yaml if we save metrics
self.save_hyperparameters("learning_rate", "optimizer")

self._preprocess = default_preprocess
self._postprocess = default_postprocess
self._preprocess = preprocess
self._postprocess = postprocess

def step(self, batch: Any, batch_idx: int) -> Any:
"""
@@ -99,8 +99,7 @@ def step(self, batch: Any, batch_idx: int) -> Any:
y_hat = self.to_metrics_format(y_hat)
for name, metric in self.metrics.items():
if isinstance(metric, torchmetrics.metric.Metric):
metric(y_hat, y)
logs[name] = metric # log the metric itself if it is of type Metric
logs[name] = metric(y_hat, y) # log the metric itself if it is of type Metric
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
else:
logs[name] = metric(y_hat, y)
logs.update(losses)
2 changes: 1 addition & 1 deletion flash/data/data_pipeline.py
Original file line number Diff line number Diff line change
@@ -359,7 +359,7 @@ def _create_uncollate_postprocessors(self, stage: RunningStage) -> _PostProcesso
# since postprocessing is exclusive for prediction, we don't have to check the resolution hierarchy here.
if postprocess._save_path:
save_per_sample: bool = self._is_overriden_recursive(
"save_sample", postprocess, object_type=Postprocess, prefix=_STAGES_PREFIX[stage]
"save_sample", postprocess, Postprocess, prefix=_STAGES_PREFIX[stage]
)

if save_per_sample:
23 changes: 20 additions & 3 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
@@ -24,6 +24,10 @@
from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES


def binary_cross_entropy_with_logits(x, y):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return F.binary_cross_entropy_with_logits(x, y.float())


class ImageClassifier(ClassificationTask):
"""Task that classifies images.
@@ -57,6 +61,7 @@ class ImageClassifier(ClassificationTask):
metrics: Metrics to compute for training and evaluation,
defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
multi_label: Whether the labels are multi labels or not.
"""

backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES
@@ -68,17 +73,26 @@ def __init__(
backbone_kwargs: Optional[Dict] = None,
head: Optional[Union[FunctionType, nn.Module]] = None,
pretrained: bool = True,
loss_fn: Callable = F.cross_entropy,
loss_fn: Optional[Callable] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD,
metrics: Union[Callable, Mapping, Sequence, None] = Accuracy(),
metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None,
learning_rate: float = 1e-3,
multi_label: bool = False,
):

if metrics is None:
metrics = Accuracy(subset_accuracy=multi_label)

if loss_fn is None:
loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy

super().__init__(
model=None,
loss_fn=loss_fn,
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
postprocess=self.postprocess_cls(multi_label)
)

self.save_hyperparameters()
@@ -100,4 +114,7 @@ def __init__(

def forward(self, x) -> Any:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
x = self.backbone(x)
return torch.softmax(self.head(x), -1)
if self.hparams.multi_label:
return self.head(x)
else:
return torch.softmax(self.head(x), -1)
2 changes: 1 addition & 1 deletion flash/vision/embedding/model.py
Original file line number Diff line number Diff line change
@@ -66,7 +66,7 @@ def __init__(
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
default_preprocess=ImageClassificationPreprocess(
preprocess=ImageClassificationPreprocess(
predict_transform=ImageClassificationData.default_val_transforms(),
)
)
25 changes: 25 additions & 0 deletions tests/vision/classification/test_model.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,18 @@ def __len__(self) -> int:
return 100


class DummyMultiLabelDataset(torch.utils.data.Dataset):

def __init__(self, num_classes: int):
self.num_classes = num_classes

def __getitem__(self, index):
return torch.rand(3, 224, 224), torch.randint(0, 2, (self.num_classes, ))

def __len__(self) -> int:
return 100


# ==============================


@@ -67,3 +79,16 @@ def test_unfreeze():
model.unfreeze()
for p in model.backbone.parameters():
assert p.requires_grad is True


def test_multilabel(tmpdir):

num_classes = 4
ds = DummyMultiLabelDataset(num_classes)
model = ImageClassifier(num_classes, multi_label=True)
train_dl = torch.utils.data.DataLoader(ds, batch_size=2)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.finetune(model, train_dl, strategy="freeze_unfreeze")
image, _ = ds[0]
tchaton marked this conversation as resolved.
Show resolved Hide resolved
predictions = model.predict(image.unsqueeze(0))
assert len(predictions[0]) == num_classes