Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[feat] Add backbone API #204

Merged
merged 23 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 4 additions & 19 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,23 +259,8 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.data_pipeline = checkpoint['data_pipeline']

@classmethod
def available_backbones(cls, ) -> List:
def available_backbones(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "backbones", None)
return cls._available_models(registry=registry)

@staticmethod
def _available_models(registry: Optional[FlashRegistry] = None) -> List[str]:
if registry is not None:
return registry.available_keys()
return []

@staticmethod
def _register_function(
registry: Optional[FlashRegistry] = None,
fn: Optional[Callable] = None,
name: Optional[str] = None,
override: bool = False,
**metadata
) -> Optional[Callable]:
if registry is not None:
return registry(fn=fn, name=name, override=override, **metadata)
if registry is None:
return []
return registry.available_keys()
162 changes: 64 additions & 98 deletions flash/core/registry.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
import hashlib
from collections import defaultdict
from dataclasses import dataclass
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from types import FunctionType
from typing import Any, Callable, Dict, List, Optional, Set, Union
from typing import Any, Dict, List, Optional, Union

from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException

_REGISTERED_FUNCTION = Dict[str, Any]


class FlashRegistry:
"""
Expand All @@ -16,145 +29,102 @@ class FlashRegistry:

backbones = FlashRegistry("backbones")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

@backbones.register_function()
@backbones
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀🚀

def my_model(nc_input=5, nc_output=6):
return nn.Linear(nc_input, nc_output), nc_input, nc_output

mlp, nc_input, nc_output = backbones("my_model")(nc_output=7)

backbones.register_function(my_model, name="cho")
assert backbones("cho")
backbones(my_model, name="foo")
assert backbones("foo")

"""

def __init__(self, registry_name: str, verbose: bool = False) -> None:
self._registry_name = registry_name
self._registered_functions: List[Dict[str, Any]] = []
def __init__(self, name: str, verbose: bool = False) -> None:
self.name = name
self.functions: List[_REGISTERED_FUNCTION] = []
self._verbose = verbose

def __len__(self) -> int:
return len(self._registered_functions)
return len(self.functions)

def __contains__(self, key) -> bool:
return any(key == e["name"] for e in self._registered_functions)
return any(key == e["name"] for e in self.functions)

def __repr__(self) -> str:
format_str = self.__class__.__name__ + \
f'(name={self._registry_name}, ' \
f'registered_items={self._registered_functions})'
return format_str

@property
def name(self) -> str:
return self._registry_name

@property
def registered_funcs(self) -> Dict[str, Any]:
return self._registered_functions

def validate_matches(self, key: str, matches: Dict, with_metadata: bool, key_in: bool = False):
if len(matches) == 1:
registered_function = matches[0]
if with_metadata:
return registered_function
return registered_function["fn"]
elif len(matches) == 0:
if key_in:
raise MisconfigurationException(
f"Found {len(matches)} matches within {matches}. Add more metadata to filter them out."
)
raise MisconfigurationException(f"Key: {key} is not in {self.__repr__()}")
return f'{self.__class__.__name__}(name={self.name}, functions={self.functions})'

def get(self,
key: str,
with_metadata: bool = False,
strict: bool = True,
**metadata) -> Union[Callable, Dict[str, Any], List[Dict[str, Any]], List[Callable]]:
def get(
self,
key: str,
with_metadata: bool = False,
strict: bool = True,
**metadata: Dict[str, Any],
) -> Union[callable, _REGISTERED_FUNCTION, List[_REGISTERED_FUNCTION], List[callable]]:
"""
This function is used to gather matches from the registry:

Args:
key: Name of the registered function.
with_metadata: Whether to return associated metadata used during registration.
strict: Whether to return all matches if higher than 1.
metadata: All filtering metadata used for the registry.

with_metadata: Whether to include the associated metadata in the return value.
strict: Whether to return all matches or just one.
metadata: Metadata used to filter against existing registry item's metadata.
"""
matches = [e for e in self._registered_functions if key == e["name"]]
key_in = False
if len(matches) > 1:
key_in = True
matches = self._filter_matches_on_metadata(matches, with_metadata=with_metadata, metadata=metadata)
if not strict:
return [e if with_metadata else e["fn"] for e in matches]
if len(matches) > 1:
matches = [e for e in self.functions if key == e["name"]]
if not matches:
raise KeyError(f"Key: {key} is not in {repr(self)}")

if metadata:
matches = [m for m in matches if metadata.items() <= m["metadata"].items()]
if not matches:
raise MisconfigurationException(
f"Found {len(matches)} matches within {matches}. Add more metadata to filter them out."
"Found no matches that fit your metadata criteria. Try removing some metadata"
)
return self.validate_matches(key, matches, with_metadata, key_in=key_in)

def _filter_matches_on_metadata(self, matches, with_metadata: bool = False, **metadata) -> List[Dict[str, Any]]:
_matches = []
for item in matches:
if all(
self._extract_value_from_metadata(item["metadata"], k) == v for k, v in metadata["metadata"].items()
):
_matches.append(item)
return _matches
matches = [e if with_metadata else e["fn"] for e in matches]
return matches[0] if strict else matches

def remove(self, key: str) -> None:
_registered_functions = []
for item in self._registered_functions:
if item["name"] != key:
_registered_functions.append(item)
self._registered_functions = _registered_functions
self.functions = [f for f in self.functions if f["name"] != key]

def _register_function(
self, fn: Callable, name: Optional[str] = None, override: bool = False, metadata: Dict[str, Any] = None
self, fn: callable, name: Optional[str] = None, override: bool = False, metadata: Dict[str, Any] = None
):
if not isinstance(fn, FunctionType) and not isinstance(fn, partial):
raise MisconfigurationException("``register_function`` should be used with a function")
raise MisconfigurationException(f"You can only register a function, found: {fn}")

name = name or fn.__name__

item = {"fn": fn, "name": name, "metadata": metadata}
if self._verbose:
rank_zero_info(f"Registering: {fn.__name__} function with name: {name} and metadata: {metadata}")

matching_index = self._find_matching_index(item)
item = {"fn": fn, "name": name, "metadata": metadata or {}}

matching_index = self._find_matching_index(item)
if override and matching_index is not None:
self._registered_functions[matching_index] = item
self.functions[matching_index] = item
else:
if matching_index is not None:
raise MisconfigurationException(
f"Function with name: {name} and metadata: {metadata} is already present within {self}."
"HINT: Use `override=True`."
" HINT: Use `override=True`."
)
self._registered_functions.append(item)

@staticmethod
def _extract_value_from_metadata(metadata: Dict, key: str) -> Optional[Any]:
if key in metadata:
return metadata[key]

def _find_matching_index(self, item: Dict[str, Any]) -> Optional[int]:
for idx, _item in enumerate(self._registered_functions):
if (
_item["fn"] == item["fn"] and _item["name"] == item["name"] and
all(self._extract_value_from_metadata(_item["metadata"], k) == v for k, v in item["metadata"].items())
):
self.functions.append(item)

def _find_matching_index(self, item: _REGISTERED_FUNCTION) -> Optional[int]:
for idx, fn in enumerate(self.functions):
if fn["fn"] == item["fn"] and fn["name"] == item["name"] and fn["metadata"] == item["metadata"]:
return idx

def __call__(
self,
fn: Optional[Callable] = None,
fn: Optional[callable] = None,
name: Optional[str] = None,
override: bool = False,
**metadata
) -> Callable:
**metadata: Dict[str, Any]
) -> callable:
"""Register a callable"""
if fn is not None:
if self._verbose:
print(f"Registering: {fn.__name__} function with name: {name} and metadata: {metadata}")
self._register_function(fn=fn, name=name, override=override, metadata=metadata)
return fn

Expand All @@ -169,8 +139,4 @@ def _register(cls):
return _register

def available_keys(self) -> List[str]:
return sorted([v["name"] for v in self._registered_functions])


IMAGE_CLASSIFIER_BACKBONES = FlashRegistry("backbones")
OBJ_DETECTION_BACKBONES = FlashRegistry("backbones")
return sorted([v["name"] for v in self.functions])
2 changes: 1 addition & 1 deletion flash/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, ..., data, ...):
def __getitem__(self, index):
return Preprocess.load_sample(self.preprocessed_data[index])

def __len__(self) -> int:
def __len__(self):
return len(self.preprocessed_data)

2. Create a ``worker_collate_fn`` which is injected directly into the ``DataLoader``
Expand Down
1 change: 0 additions & 1 deletion flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from transformers.modeling_outputs import SequenceClassifierOutput

from flash.core.classification import ClassificationTask
from flash.text.classification.data import TextClassificationData


class TextClassifier(ClassificationTask):
Expand Down
35 changes: 21 additions & 14 deletions flash/vision/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch import nn as nn
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

from flash.core.registry import IMAGE_CLASSIFIER_BACKBONES, OBJ_DETECTION_BACKBONES
from flash.core.registry import FlashRegistry
from flash.utils.imports import _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE

if _TIMM_AVAILABLE:
Expand All @@ -40,20 +40,23 @@
TORCHVISION_MODELS = MOBILENET_MODELS + VGG_MODELS + RESNET_MODELS + DENSENET_MODELS
BOLTS_MODELS = ["simclr-imagenet", "swav-imagenet"]

IMAGE_CLASSIFIER_BACKBONES = FlashRegistry("backbones")
OBJ_DETECTION_BACKBONES = FlashRegistry("backbones")


@IMAGE_CLASSIFIER_BACKBONES(name="simclr-imagenet", namespace="vision", package="bolts")
def load_simclr_imagenet(
path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt", **__
):
def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt", **_):
simclr: LightningModule = SimCLR.load_from_checkpoint(path_or_url, strict=False)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(simclr.encoder.children())[:-2])
return backbone, 2048


@IMAGE_CLASSIFIER_BACKBONES(name="swav-imagenet", namespace="vision", package="bolts")
def load_swav_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar",
**__) -> Tuple[nn.Module, int]:
def load_swav_imagenet(
path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar",
**_,
) -> Tuple[nn.Module, int]:
swav: LightningModule = SwAV.load_from_checkpoint(path_or_url, strict=True)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(swav.model.children())[:-2])
Expand Down Expand Up @@ -96,10 +99,12 @@ def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int
type="resnet"
)

def _fn_resnet_fpn(model_name: str,
pretrained: bool = True,
trainable_layers: bool = True,
**kwargs) -> Tuple[nn.Module, int]:
def _fn_resnet_fpn(
model_name: str,
pretrained: bool = True,
trainable_layers: bool = True,
**kwargs,
) -> Tuple[nn.Module, int]:
backbone = resnet_fpn_backbone(
model_name, pretrained=pretrained, trainable_layers=trainable_layers, **kwargs
)
Expand Down Expand Up @@ -131,10 +136,12 @@ def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, i
if model_name in TORCHVISION_MODELS:
continue

def _fn_timm(model_name: str,
pretrained: bool = True,
num_classes: int = 0,
global_pool: str = '') -> Tuple[nn.Module, int]:
def _fn_timm(
model_name: str,
pretrained: bool = True,
num_classes: int = 0,
global_pool: str = '',
) -> Tuple[nn.Module, int]:
backbone = timm.create_model(
model_name, pretrained=pretrained, num_classes=num_classes, global_pool=global_pool
)
Expand Down
3 changes: 1 addition & 2 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from types import FunctionType
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union

import torch
from torch import nn
Expand All @@ -22,7 +22,6 @@
from flash.core.classification import ClassificationTask
from flash.core.registry import FlashRegistry
from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES
from flash.vision.classification.data import ImageClassificationData, ImageClassificationPreprocess


class ImageClassifier(ClassificationTask):
Expand Down
3 changes: 2 additions & 1 deletion flash_examples/finetuning/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from flash import Trainer
from flash.core.finetuning import FreezeUnfreeze
from flash.data.utils import download_data
from flash.vision import IMAGE_CLASSIFIER_BACKBONES, ImageClassificationData, ImageClassifier
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")
Expand All @@ -36,6 +36,7 @@
@ImageClassifier.backbones(name="username/resnet18")
def fn_resnet(pretrained: bool = True):
model = torchvision.models.resnet18(pretrained)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(model.children())[:-2])
tchaton marked this conversation as resolved.
Show resolved Hide resolved
num_features = model.fc.in_features
# backbones need to return the num_features to build the head
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,3 @@ sentencepiece>=0.1.95
filelock # comes with 3rd-party dependency
pycocotools>=2.0.2 ; python_version >= "3.7"
kornia>=0.5.0
pytest
Loading