Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[feat] Add multi label support #230

Merged
merged 6 commits into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, List, Optional

import torch
import torch.nn.functional as F

from flash.core.model import Task
from flash.data.process import Postprocess
from flash.data.process import Postprocess, Preprocess


class ClassificationPostprocess(Postprocess):

def per_sample_transform(self, samples: Any) -> Any:
return torch.argmax(samples, -1).tolist()
def __init__(self, multi_label: bool = False, save_path: Optional[str] = None):
super().__init__(save_path=save_path)
self.multi_label = multi_label

def per_sample_transform(self, samples: Any) -> List[Any]:
if self.multi_label:
return F.sigmoid(samples).tolist()
tchaton marked this conversation as resolved.
Show resolved Hide resolved
else:
return torch.argmax(samples, -1).tolist()


class ClassificationTask(Task):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, default_postprocess=ClassificationPostprocess(), **kwargs)
postprocess_cls = ClassificationPostprocess

def __init__(self, *args, postprocess: Optional[Preprocess] = None, **kwargs):
super().__init__(*args, postprocess=postprocess or self.postprocess_cls(), **kwargs)

def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
if getattr(self.hparams, "multi_label", False):
return F.sigmoid(x)
return F.softmax(x, -1)
34 changes: 19 additions & 15 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class Task(LightningModule):
optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
metrics: Metrics to compute for training and evaluation.
learning_rate: Learning rate to use for training, defaults to `5e-5`.
default_preprocess: :class:`.Preprocess` to use as the default for this task.
default_postprocess: :class:`.Postprocess` to use as the default for this task.
preprocess: :class:`~flash.data.process.Preprocess` to use as the default for this task.
postprocess: :class:`~flash.data.process.Postprocess` to use as the default for this task.
"""

def __init__(
Expand All @@ -71,8 +71,8 @@ def __init__(
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
default_preprocess: Preprocess = None,
default_postprocess: Postprocess = None,
preprocess: Preprocess = None,
postprocess: Postprocess = None,
):
super().__init__()
if model is not None:
Expand All @@ -84,8 +84,8 @@ def __init__(
# TODO: should we save more? Bug on some regarding yaml if we save metrics
self.save_hyperparameters("learning_rate", "optimizer")

self._preprocess = default_preprocess
self._postprocess = default_postprocess
self._preprocess = preprocess
self._postprocess = postprocess

def step(self, batch: Any, batch_idx: int) -> Any:
"""
Expand Down Expand Up @@ -181,17 +181,19 @@ def _resolve(
new_preprocess: Optional[Preprocess],
new_postprocess: Optional[Postprocess],
) -> Tuple[Optional[Preprocess], Optional[Postprocess]]:
"""Resolves the correct :class:`.Preprocess` and :class:`.Postprocess` to use, choosing ``new_*`` if it is not
None or a base class (:class:`.Preprocess` or :class:`.Postprocess`) and ``old_*`` otherwise.
"""Resolves the correct :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` to use,
choosing ``new_*`` if it is not None or a base class
(:class:`~flash.data.process.Preprocess` or :class:`~flash.data.process.Postprocess`)
and ``old_*`` otherwise.

Args:
old_preprocess: :class:`.Preprocess` to be overridden.
old_postprocess: :class:`.Postprocess` to be overridden.
new_preprocess: :class:`.Preprocess` to override with.
new_postprocess: :class:`.Postprocess` to override with.
old_preprocess: :class:`~flash.data.process.Preprocess` to be overridden.
old_postprocess: :class:`~flash.data.process.Postprocess` to be overridden.
new_preprocess: :class:`~flash.data.process.Preprocess` to override with.
new_postprocess: :class:`~flash.data.process.Postprocess` to override with.

Returns:
The resolved :class:`.Preprocess` and :class:`.Postprocess`.
The resolved :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`.
"""
preprocess = old_preprocess
if new_preprocess is not None and type(new_preprocess) != Preprocess:
Expand All @@ -204,7 +206,8 @@ def _resolve(
return preprocess, postprocess

def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> Optional[DataPipeline]:
"""Build a :class:`.DataPipeline` incorporating available :class:`.Preprocess` and :class:`.Postprocess`
"""Build a :class:`.DataPipeline` incorporating available
:class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`
objects. These will be overridden in the following resolution order (lowest priority first):

- Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`.
Expand All @@ -213,7 +216,8 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O
- :class:`.DataPipeline` passed to this method.

Args:
data_pipeline: Optional highest priority source of :class:`.Preprocess` and :class:`.Postprocess`.
data_pipeline: Optional highest priority source of
:class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`.

Returns:
The fully resolved :class:`.DataPipeline`.
Expand Down
2 changes: 1 addition & 1 deletion flash/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _create_uncollate_postprocessors(self, stage: RunningStage) -> _PostProcesso
# since postprocessing is exclusive for prediction, we don't have to check the resolution hierarchy here.
if postprocess._save_path:
save_per_sample: bool = self._is_overriden_recursive(
"save_sample", postprocess, object_type=Postprocess, prefix=_STAGES_PREFIX[stage]
"save_sample", postprocess, Postprocess, prefix=_STAGES_PREFIX[stage]
)

if save_per_sample:
Expand Down
23 changes: 19 additions & 4 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES


def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Calls BCE with logits and cast the target one_hot (y) encoding to floating point precision."""
return F.binary_cross_entropy_with_logits(x, y.float())


class ImageClassifier(ClassificationTask):
"""Task that classifies images.

Expand Down Expand Up @@ -57,6 +62,7 @@ class ImageClassifier(ClassificationTask):
metrics: Metrics to compute for training and evaluation,
defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
multi_label: Whether the labels are multi labels or not.
"""

backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES
Expand All @@ -68,17 +74,26 @@ def __init__(
backbone_kwargs: Optional[Dict] = None,
head: Optional[Union[FunctionType, nn.Module]] = None,
pretrained: bool = True,
loss_fn: Callable = F.cross_entropy,
loss_fn: Optional[Callable] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD,
metrics: Union[Callable, Mapping, Sequence, None] = Accuracy(),
metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None,
learning_rate: float = 1e-3,
multi_label: bool = False,
):

if metrics is None:
metrics = Accuracy(subset_accuracy=multi_label)

if loss_fn is None:
loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy

super().__init__(
model=None,
loss_fn=loss_fn,
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
postprocess=self.postprocess_cls(multi_label)
)

self.save_hyperparameters()
Expand All @@ -98,6 +113,6 @@ def __init__(
nn.Linear(num_features, num_classes),
)

def forward(self, x) -> Any:
def forward(self, x) -> torch.Tensor:
x = self.backbone(x)
return torch.softmax(self.head(x), -1)
return self.head(x)
4 changes: 2 additions & 2 deletions flash/vision/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
default_preprocess=ImageClassificationPreprocess(
preprocess=ImageClassificationPreprocess(
predict_transform=ImageClassificationData.default_val_transforms(),
)
)
Expand Down Expand Up @@ -98,7 +98,7 @@ def apply_pool(self, x):
x = self.pooling_fn(x, dim=-1)
return x

def forward(self, x) -> Any:
def forward(self, x) -> torch.Tensor:
x = self.backbone(x)

# bolts ssl models return lists
Expand Down
28 changes: 28 additions & 0 deletions tests/vision/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ def __len__(self) -> int:
return 100


class DummyMultiLabelDataset(torch.utils.data.Dataset):

def __init__(self, num_classes: int):
self.num_classes = num_classes

def __getitem__(self, index):
return torch.rand(3, 224, 224), torch.randint(0, 2, (self.num_classes, ))

def __len__(self) -> int:
return 100


# ==============================


Expand Down Expand Up @@ -67,3 +79,19 @@ def test_unfreeze():
model.unfreeze()
for p in model.backbone.parameters():
assert p.requires_grad is True


def test_multilabel(tmpdir):

num_classes = 4
ds = DummyMultiLabelDataset(num_classes)
model = ImageClassifier(num_classes, multi_label=True)
train_dl = torch.utils.data.DataLoader(ds, batch_size=2)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.finetune(model, train_dl, strategy="freeze_unfreeze")
image, label = ds[0]
predictions = model.predict(image.unsqueeze(0))
assert (torch.tensor(predictions) > 1).sum() == 0
assert (torch.tensor(predictions) < 0).sum() == 0
assert len(predictions[0]) == num_classes == len(label)
assert len(torch.unique(label)) <= 2