Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[feat] Add multi label support (#230)
Browse files Browse the repository at this point in the history
* add multilabel

* change types

* add a check

* resolve a bug

* update on comments

* update
  • Loading branch information
tchaton authored Apr 20, 2021
1 parent 781fa98 commit 42f8db4
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 28 deletions.
23 changes: 17 additions & 6 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,36 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Any, List, Optional

import torch
import torch.nn.functional as F

from flash.core.model import Task
from flash.data.process import Postprocess
from flash.data.process import Postprocess, Preprocess


class ClassificationPostprocess(Postprocess):

def per_sample_transform(self, samples: Any) -> Any:
return torch.argmax(samples, -1).tolist()
def __init__(self, multi_label: bool = False, save_path: Optional[str] = None):
super().__init__(save_path=save_path)
self.multi_label = multi_label

def per_sample_transform(self, samples: Any) -> List[Any]:
if self.multi_label:
return F.sigmoid(samples).tolist()
else:
return torch.argmax(samples, -1).tolist()


class ClassificationTask(Task):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, default_postprocess=ClassificationPostprocess(), **kwargs)
postprocess_cls = ClassificationPostprocess

def __init__(self, *args, postprocess: Optional[Preprocess] = None, **kwargs):
super().__init__(*args, postprocess=postprocess or self.postprocess_cls(), **kwargs)

def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
if getattr(self.hparams, "multi_label", False):
return F.sigmoid(x)
return F.softmax(x, -1)
34 changes: 19 additions & 15 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class Task(LightningModule):
optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
metrics: Metrics to compute for training and evaluation.
learning_rate: Learning rate to use for training, defaults to `5e-5`.
default_preprocess: :class:`.Preprocess` to use as the default for this task.
default_postprocess: :class:`.Postprocess` to use as the default for this task.
preprocess: :class:`~flash.data.process.Preprocess` to use as the default for this task.
postprocess: :class:`~flash.data.process.Postprocess` to use as the default for this task.
"""

def __init__(
Expand All @@ -71,8 +71,8 @@ def __init__(
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
default_preprocess: Preprocess = None,
default_postprocess: Postprocess = None,
preprocess: Preprocess = None,
postprocess: Postprocess = None,
):
super().__init__()
if model is not None:
Expand All @@ -84,8 +84,8 @@ def __init__(
# TODO: should we save more? Bug on some regarding yaml if we save metrics
self.save_hyperparameters("learning_rate", "optimizer")

self._preprocess = default_preprocess
self._postprocess = default_postprocess
self._preprocess = preprocess
self._postprocess = postprocess

def step(self, batch: Any, batch_idx: int) -> Any:
"""
Expand Down Expand Up @@ -181,17 +181,19 @@ def _resolve(
new_preprocess: Optional[Preprocess],
new_postprocess: Optional[Postprocess],
) -> Tuple[Optional[Preprocess], Optional[Postprocess]]:
"""Resolves the correct :class:`.Preprocess` and :class:`.Postprocess` to use, choosing ``new_*`` if it is not
None or a base class (:class:`.Preprocess` or :class:`.Postprocess`) and ``old_*`` otherwise.
"""Resolves the correct :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` to use,
choosing ``new_*`` if it is not None or a base class
(:class:`~flash.data.process.Preprocess` or :class:`~flash.data.process.Postprocess`)
and ``old_*`` otherwise.
Args:
old_preprocess: :class:`.Preprocess` to be overridden.
old_postprocess: :class:`.Postprocess` to be overridden.
new_preprocess: :class:`.Preprocess` to override with.
new_postprocess: :class:`.Postprocess` to override with.
old_preprocess: :class:`~flash.data.process.Preprocess` to be overridden.
old_postprocess: :class:`~flash.data.process.Postprocess` to be overridden.
new_preprocess: :class:`~flash.data.process.Preprocess` to override with.
new_postprocess: :class:`~flash.data.process.Postprocess` to override with.
Returns:
The resolved :class:`.Preprocess` and :class:`.Postprocess`.
The resolved :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`.
"""
preprocess = old_preprocess
if new_preprocess is not None and type(new_preprocess) != Preprocess:
Expand All @@ -204,7 +206,8 @@ def _resolve(
return preprocess, postprocess

def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> Optional[DataPipeline]:
"""Build a :class:`.DataPipeline` incorporating available :class:`.Preprocess` and :class:`.Postprocess`
"""Build a :class:`.DataPipeline` incorporating available
:class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`
objects. These will be overridden in the following resolution order (lowest priority first):
- Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`.
Expand All @@ -213,7 +216,8 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O
- :class:`.DataPipeline` passed to this method.
Args:
data_pipeline: Optional highest priority source of :class:`.Preprocess` and :class:`.Postprocess`.
data_pipeline: Optional highest priority source of
:class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`.
Returns:
The fully resolved :class:`.DataPipeline`.
Expand Down
2 changes: 1 addition & 1 deletion flash/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _create_uncollate_postprocessors(self, stage: RunningStage) -> _PostProcesso
# since postprocessing is exclusive for prediction, we don't have to check the resolution hierarchy here.
if postprocess._save_path:
save_per_sample: bool = self._is_overriden_recursive(
"save_sample", postprocess, object_type=Postprocess, prefix=_STAGES_PREFIX[stage]
"save_sample", postprocess, Postprocess, prefix=_STAGES_PREFIX[stage]
)

if save_per_sample:
Expand Down
23 changes: 19 additions & 4 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES


def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Calls BCE with logits and cast the target one_hot (y) encoding to floating point precision."""
return F.binary_cross_entropy_with_logits(x, y.float())


class ImageClassifier(ClassificationTask):
"""Task that classifies images.
Expand Down Expand Up @@ -57,6 +62,7 @@ class ImageClassifier(ClassificationTask):
metrics: Metrics to compute for training and evaluation,
defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
multi_label: Whether the labels are multi labels or not.
"""

backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES
Expand All @@ -68,17 +74,26 @@ def __init__(
backbone_kwargs: Optional[Dict] = None,
head: Optional[Union[FunctionType, nn.Module]] = None,
pretrained: bool = True,
loss_fn: Callable = F.cross_entropy,
loss_fn: Optional[Callable] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD,
metrics: Union[Callable, Mapping, Sequence, None] = Accuracy(),
metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None,
learning_rate: float = 1e-3,
multi_label: bool = False,
):

if metrics is None:
metrics = Accuracy(subset_accuracy=multi_label)

if loss_fn is None:
loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy

super().__init__(
model=None,
loss_fn=loss_fn,
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
postprocess=self.postprocess_cls(multi_label)
)

self.save_hyperparameters()
Expand All @@ -98,6 +113,6 @@ def __init__(
nn.Linear(num_features, num_classes),
)

def forward(self, x) -> Any:
def forward(self, x) -> torch.Tensor:
x = self.backbone(x)
return torch.softmax(self.head(x), -1)
return self.head(x)
4 changes: 2 additions & 2 deletions flash/vision/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
default_preprocess=ImageClassificationPreprocess(
preprocess=ImageClassificationPreprocess(
predict_transform=ImageClassificationData.default_val_transforms(),
)
)
Expand Down Expand Up @@ -98,7 +98,7 @@ def apply_pool(self, x):
x = self.pooling_fn(x, dim=-1)
return x

def forward(self, x) -> Any:
def forward(self, x) -> torch.Tensor:
x = self.backbone(x)

# bolts ssl models return lists
Expand Down
28 changes: 28 additions & 0 deletions tests/vision/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ def __len__(self) -> int:
return 100


class DummyMultiLabelDataset(torch.utils.data.Dataset):

def __init__(self, num_classes: int):
self.num_classes = num_classes

def __getitem__(self, index):
return torch.rand(3, 224, 224), torch.randint(0, 2, (self.num_classes, ))

def __len__(self) -> int:
return 100


# ==============================


Expand Down Expand Up @@ -67,3 +79,19 @@ def test_unfreeze():
model.unfreeze()
for p in model.backbone.parameters():
assert p.requires_grad is True


def test_multilabel(tmpdir):

num_classes = 4
ds = DummyMultiLabelDataset(num_classes)
model = ImageClassifier(num_classes, multi_label=True)
train_dl = torch.utils.data.DataLoader(ds, batch_size=2)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.finetune(model, train_dl, strategy="freeze_unfreeze")
image, label = ds[0]
predictions = model.predict(image.unsqueeze(0))
assert (torch.tensor(predictions) > 1).sum() == 0
assert (torch.tensor(predictions) < 0).sum() == 0
assert len(predictions[0]) == num_classes == len(label)
assert len(torch.unique(label)) <= 2

0 comments on commit 42f8db4

Please sign in to comment.