diff --git a/flash/core/classification.py b/flash/core/classification.py index 970466dbf4..9650eadc34 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -11,25 +11,36 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, List, Optional import torch import torch.nn.functional as F from flash.core.model import Task -from flash.data.process import Postprocess +from flash.data.process import Postprocess, Preprocess class ClassificationPostprocess(Postprocess): - def per_sample_transform(self, samples: Any) -> Any: - return torch.argmax(samples, -1).tolist() + def __init__(self, multi_label: bool = False, save_path: Optional[str] = None): + super().__init__(save_path=save_path) + self.multi_label = multi_label + + def per_sample_transform(self, samples: Any) -> List[Any]: + if self.multi_label: + return F.sigmoid(samples).tolist() + else: + return torch.argmax(samples, -1).tolist() class ClassificationTask(Task): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, default_postprocess=ClassificationPostprocess(), **kwargs) + postprocess_cls = ClassificationPostprocess + + def __init__(self, *args, postprocess: Optional[Preprocess] = None, **kwargs): + super().__init__(*args, postprocess=postprocess or self.postprocess_cls(), **kwargs) def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: + if getattr(self.hparams, "multi_label", False): + return F.sigmoid(x) return F.softmax(x, -1) diff --git a/flash/core/model.py b/flash/core/model.py index eeaf268b75..9914b4cb61 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -60,8 +60,8 @@ class Task(LightningModule): optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. metrics: Metrics to compute for training and evaluation. learning_rate: Learning rate to use for training, defaults to `5e-5`. - default_preprocess: :class:`.Preprocess` to use as the default for this task. - default_postprocess: :class:`.Postprocess` to use as the default for this task. + preprocess: :class:`~flash.data.process.Preprocess` to use as the default for this task. + postprocess: :class:`~flash.data.process.Postprocess` to use as the default for this task. """ def __init__( @@ -71,8 +71,8 @@ def __init__( optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, learning_rate: float = 5e-5, - default_preprocess: Preprocess = None, - default_postprocess: Postprocess = None, + preprocess: Preprocess = None, + postprocess: Postprocess = None, ): super().__init__() if model is not None: @@ -84,8 +84,8 @@ def __init__( # TODO: should we save more? Bug on some regarding yaml if we save metrics self.save_hyperparameters("learning_rate", "optimizer") - self._preprocess = default_preprocess - self._postprocess = default_postprocess + self._preprocess = preprocess + self._postprocess = postprocess def step(self, batch: Any, batch_idx: int) -> Any: """ @@ -181,17 +181,19 @@ def _resolve( new_preprocess: Optional[Preprocess], new_postprocess: Optional[Postprocess], ) -> Tuple[Optional[Preprocess], Optional[Postprocess]]: - """Resolves the correct :class:`.Preprocess` and :class:`.Postprocess` to use, choosing ``new_*`` if it is not - None or a base class (:class:`.Preprocess` or :class:`.Postprocess`) and ``old_*`` otherwise. + """Resolves the correct :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` to use, + choosing ``new_*`` if it is not None or a base class + (:class:`~flash.data.process.Preprocess` or :class:`~flash.data.process.Postprocess`) + and ``old_*`` otherwise. Args: - old_preprocess: :class:`.Preprocess` to be overridden. - old_postprocess: :class:`.Postprocess` to be overridden. - new_preprocess: :class:`.Preprocess` to override with. - new_postprocess: :class:`.Postprocess` to override with. + old_preprocess: :class:`~flash.data.process.Preprocess` to be overridden. + old_postprocess: :class:`~flash.data.process.Postprocess` to be overridden. + new_preprocess: :class:`~flash.data.process.Preprocess` to override with. + new_postprocess: :class:`~flash.data.process.Postprocess` to override with. Returns: - The resolved :class:`.Preprocess` and :class:`.Postprocess`. + The resolved :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`. """ preprocess = old_preprocess if new_preprocess is not None and type(new_preprocess) != Preprocess: @@ -204,7 +206,8 @@ def _resolve( return preprocess, postprocess def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> Optional[DataPipeline]: - """Build a :class:`.DataPipeline` incorporating available :class:`.Preprocess` and :class:`.Postprocess` + """Build a :class:`.DataPipeline` incorporating available + :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` objects. These will be overridden in the following resolution order (lowest priority first): - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`. @@ -213,7 +216,8 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O - :class:`.DataPipeline` passed to this method. Args: - data_pipeline: Optional highest priority source of :class:`.Preprocess` and :class:`.Postprocess`. + data_pipeline: Optional highest priority source of + :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`. Returns: The fully resolved :class:`.DataPipeline`. diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 2125b79909..46be3c823c 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -359,7 +359,7 @@ def _create_uncollate_postprocessors(self, stage: RunningStage) -> _PostProcesso # since postprocessing is exclusive for prediction, we don't have to check the resolution hierarchy here. if postprocess._save_path: save_per_sample: bool = self._is_overriden_recursive( - "save_sample", postprocess, object_type=Postprocess, prefix=_STAGES_PREFIX[stage] + "save_sample", postprocess, Postprocess, prefix=_STAGES_PREFIX[stage] ) if save_per_sample: diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 7a046f52e5..5b6d9dca30 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -24,6 +24,11 @@ from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES +def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Calls BCE with logits and cast the target one_hot (y) encoding to floating point precision.""" + return F.binary_cross_entropy_with_logits(x, y.float()) + + class ImageClassifier(ClassificationTask): """Task that classifies images. @@ -57,6 +62,7 @@ class ImageClassifier(ClassificationTask): metrics: Metrics to compute for training and evaluation, defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training, defaults to ``1e-3``. + multi_label: Whether the labels are multi labels or not. """ backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES @@ -68,17 +74,26 @@ def __init__( backbone_kwargs: Optional[Dict] = None, head: Optional[Union[FunctionType, nn.Module]] = None, pretrained: bool = True, - loss_fn: Callable = F.cross_entropy, + loss_fn: Optional[Callable] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, - metrics: Union[Callable, Mapping, Sequence, None] = Accuracy(), + metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None, learning_rate: float = 1e-3, + multi_label: bool = False, ): + + if metrics is None: + metrics = Accuracy(subset_accuracy=multi_label) + + if loss_fn is None: + loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy + super().__init__( model=None, loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate, + postprocess=self.postprocess_cls(multi_label) ) self.save_hyperparameters() @@ -98,6 +113,6 @@ def __init__( nn.Linear(num_features, num_classes), ) - def forward(self, x) -> Any: + def forward(self, x) -> torch.Tensor: x = self.backbone(x) - return torch.softmax(self.head(x), -1) + return self.head(x) diff --git a/flash/vision/embedding/model.py b/flash/vision/embedding/model.py index d1ddce5c84..f9fc3d85e2 100644 --- a/flash/vision/embedding/model.py +++ b/flash/vision/embedding/model.py @@ -66,7 +66,7 @@ def __init__( optimizer=optimizer, metrics=metrics, learning_rate=learning_rate, - default_preprocess=ImageClassificationPreprocess( + preprocess=ImageClassificationPreprocess( predict_transform=ImageClassificationData.default_val_transforms(), ) ) @@ -98,7 +98,7 @@ def apply_pool(self, x): x = self.pooling_fn(x, dim=-1) return x - def forward(self, x) -> Any: + def forward(self, x) -> torch.Tensor: x = self.backbone(x) # bolts ssl models return lists diff --git a/tests/vision/classification/test_model.py b/tests/vision/classification/test_model.py index 5079c332d8..067fa994b8 100644 --- a/tests/vision/classification/test_model.py +++ b/tests/vision/classification/test_model.py @@ -30,6 +30,18 @@ def __len__(self) -> int: return 100 +class DummyMultiLabelDataset(torch.utils.data.Dataset): + + def __init__(self, num_classes: int): + self.num_classes = num_classes + + def __getitem__(self, index): + return torch.rand(3, 224, 224), torch.randint(0, 2, (self.num_classes, )) + + def __len__(self) -> int: + return 100 + + # ============================== @@ -67,3 +79,19 @@ def test_unfreeze(): model.unfreeze() for p in model.backbone.parameters(): assert p.requires_grad is True + + +def test_multilabel(tmpdir): + + num_classes = 4 + ds = DummyMultiLabelDataset(num_classes) + model = ImageClassifier(num_classes, multi_label=True) + train_dl = torch.utils.data.DataLoader(ds, batch_size=2) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.finetune(model, train_dl, strategy="freeze_unfreeze") + image, label = ds[0] + predictions = model.predict(image.unsqueeze(0)) + assert (torch.tensor(predictions) > 1).sum() == 0 + assert (torch.tensor(predictions) < 0).sum() == 0 + assert len(predictions[0]) == num_classes == len(label) + assert len(torch.unique(label)) <= 2