Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add FlashRegistry of Available Heads for `flash.image.ImageClassifi…
Browse files Browse the repository at this point in the history
…er` (#1152)

Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
ajndkr and ethanwharris authored Feb 4, 2022
1 parent 8e4abf3 commit 2ca83bc
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for `SemanticSegmentationData.from_folders` where mask files have different extensions to the image files ([#1130](https://github.com/PyTorchLightning/lightning-flash/pull/1130))

- Added `FlashRegistry` of Available Heads for `flash.image.ImageClassifier` ([#1152](https://github.com/PyTorchLightning/lightning-flash/pull/1152))

### Changed

- Changed `Wav2Vec2Processor` to `AutoProcessor` and seperate it from backbone [optional] ([#1075](https://github.com/PyTorchLightning/lightning-flash/pull/1075))
Expand Down
10 changes: 7 additions & 3 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,10 +833,12 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self._data_pipeline_state = checkpoint["_data_pipeline_state"]

@classmethod
def available_backbones(cls, head: Optional[str] = None) -> Union[Dict[str, List[str]], List[str]]:
def available_backbones(
cls, head: Optional[str] = None
) -> Optional[Union[Dict[str, Optional[List[str]]], List[str]]]:
if head is None:
registry: Optional[FlashRegistry] = getattr(cls, "backbones", None)
if registry is not None:
if registry is not None and getattr(cls, "heads", None) is None:
return registry.available_keys()
heads = cls.available_heads()
else:
Expand All @@ -848,7 +850,9 @@ def available_backbones(cls, head: Optional[str] = None) -> Union[Dict[str, List
if "backbones" in metadata:
backbones = metadata["backbones"].available_keys()
else:
backbones = cls.available_backbones()
backbones = getattr(cls, "backbones", None)
if backbones is not None:
backbones = backbones.available_keys()
result[head] = backbones

if len(result) == 1:
Expand Down
40 changes: 40 additions & 0 deletions flash/image/classification/heads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

from torch import nn

from flash.core.registry import FlashRegistry # noqa: F401

# define ImageClassifier registry
IMAGE_CLASSIFIER_HEADS = FlashRegistry("classifier_heads")


def _load_linear_head(num_features: int, num_classes: int) -> nn.Module:
"""Loads a linear head.
Args:
num_features: Number of input features.
num_classes: Number of output classes.
Returns:
nn.Module: Linear head.
"""
return nn.Linear(num_features, num_classes)


IMAGE_CLASSIFIER_HEADS(
partial(_load_linear_head),
name="linear",
)
14 changes: 9 additions & 5 deletions flash/image/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from flash.image.classification.adapters import TRAINING_STRATEGIES
from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES
from flash.image.classification.heads import IMAGE_CLASSIFIER_HEADS
from flash.image.classification.input_transform import ImageClassificationInputTransform
from flash.image.data import ImageDeserializer

Expand Down Expand Up @@ -61,6 +62,8 @@ def fn_resnet(pretrained: bool = True):
Args:
num_classes: Number of classes to classify.
backbone: A string or (model, num_features) tuple to use to compute image features, defaults to ``"resnet18"``.
head: A string from ``ImageClassifier.available_heads()``, an ``nn.Module``, or a function of (``num_features``,
``num_classes``) which returns an ``nn.Module`` to use as the model head.
pretrained: A bool or string to specify the pretrained weights of the backbone, defaults to ``True``
which loads the default supervised pretrained weights.
loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`.
Expand All @@ -79,6 +82,7 @@ def fn_resnet(pretrained: bool = True):
"""

backbones: FlashRegistry = IMAGE_CLASSIFIER_BACKBONES
heads: FlashRegistry = IMAGE_CLASSIFIER_HEADS
training_strategies: FlashRegistry = TRAINING_STRATEGIES
required_extras: str = "image"

Expand All @@ -87,7 +91,7 @@ def __init__(
num_classes: Optional[int] = None,
backbone: Union[str, Tuple[nn.Module, int]] = "resnet18",
backbone_kwargs: Optional[Dict] = None,
head: Optional[Union[FunctionType, nn.Module]] = None,
head: Union[str, FunctionType, nn.Module] = "linear",
pretrained: Union[bool, str] = True,
loss_fn: LOSS_FN_TYPE = None,
optimizer: OPTIMIZER_TYPE = "Adam",
Expand Down Expand Up @@ -123,10 +127,10 @@ def __init__(
else:
backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs)

head = head(num_features, num_classes) if isinstance(head, FunctionType) else head
head = head or nn.Sequential(
nn.Linear(num_features, num_classes),
)
if isinstance(head, str):
head = self.heads.get(head)(num_features=num_features, num_classes=num_classes)
else:
head = head(num_features, num_classes) if isinstance(head, FunctionType) else head

adapter_from_class = self.training_strategies.get(training_strategy)
adapter = adapter_from_class(
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def test_available_backbones():
class Foo(ImageClassifier):
backbones = None

assert Foo.available_backbones() == {}
assert Foo.available_backbones() is None


@pytest.mark.skipif(_IMAGE_AVAILABLE, reason="image libraries are installed.")
Expand Down
9 changes: 9 additions & 0 deletions tests/image/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ def test_init_train(tmpdir, backbone, metrics):
trainer.finetune(model, train_dl, strategy="freeze")


@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
@pytest.mark.parametrize("head", ["linear", torch.nn.Linear(512, 10)])
def test_init_train_head(tmpdir, head):
model = ImageClassifier(10, backbone="resnet18", head=head, metrics=None)
train_dl = torch.utils.data.DataLoader(DummyDataset())
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.finetune(model, train_dl, strategy="freeze")


@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_non_existent_backbone():
with pytest.raises(KeyError):
Expand Down

0 comments on commit 2ca83bc

Please sign in to comment.