Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[feat] Add backbone API (#204)
Browse files Browse the repository at this point in the history
Co-authored-by: Kaushik B <[email protected]>
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
3 people authored Apr 7, 2021
1 parent 6a4948a commit 7853efd
Show file tree
Hide file tree
Showing 27 changed files with 492 additions and 257 deletions.
8 changes: 8 additions & 0 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.trainer.states import RunningStage
from torch import nn

from flash.core.registry import FlashRegistry
from flash.core.utils import get_callable_dict
from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess

Expand Down Expand Up @@ -256,3 +257,10 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
super().on_load_checkpoint(checkpoint)
if 'data_pipeline' in checkpoint:
self.data_pipeline = checkpoint['data_pipeline']

@classmethod
def available_backbones(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "backbones", None)
if registry is None:
return []
return registry.available_keys()
147 changes: 147 additions & 0 deletions flash/core/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from types import FunctionType
from typing import Any, Dict, List, Optional, Union

from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException

_REGISTERED_FUNCTION = Dict[str, Any]


class FlashRegistry:
"""
This class is used to register function or partial to a registry:
Example::
backbones = FlashRegistry("backbones")
@backbones
def my_model(nc_input=5, nc_output=6):
return nn.Linear(nc_input, nc_output), nc_input, nc_output
mlp, nc_input, nc_output = backbones("my_model")(nc_output=7)
backbones(my_model, name="foo")
assert backbones("foo")
"""

def __init__(self, name: str, verbose: bool = False) -> None:
self.name = name
self.functions: List[_REGISTERED_FUNCTION] = []
self._verbose = verbose

def __len__(self) -> int:
return len(self.functions)

def __contains__(self, key) -> bool:
return any(key == e["name"] for e in self.functions)

def __repr__(self) -> str:
return f'{self.__class__.__name__}(name={self.name}, functions={self.functions})'

def get(
self,
key: str,
with_metadata: bool = False,
strict: bool = True,
**metadata,
) -> Union[callable, _REGISTERED_FUNCTION, List[_REGISTERED_FUNCTION], List[callable]]:
"""
This function is used to gather matches from the registry:
Args:
key: Name of the registered function.
with_metadata: Whether to include the associated metadata in the return value.
strict: Whether to return all matches or just one.
metadata: Metadata used to filter against existing registry item's metadata.
"""
matches = [e for e in self.functions if key == e["name"]]
if not matches:
raise KeyError(f"Key: {key} is not in {repr(self)}")

if metadata:
matches = [m for m in matches if metadata.items() <= m["metadata"].items()]
if not matches:
raise KeyError("Found no matches that fit your metadata criteria. Try removing some metadata")

matches = [e if with_metadata else e["fn"] for e in matches]
return matches[0] if strict else matches

def remove(self, key: str) -> None:
self.functions = [f for f in self.functions if f["name"] != key]

def _register_function(
self,
fn: callable,
name: Optional[str] = None,
override: bool = False,
metadata: Optional[Dict[str, Any]] = None
):
if not isinstance(fn, FunctionType) and not isinstance(fn, partial):
raise MisconfigurationException(f"You can only register a function, found: {fn}")

name = name or fn.__name__

if self._verbose:
rank_zero_info(f"Registering: {fn.__name__} function with name: {name} and metadata: {metadata}")

item = {"fn": fn, "name": name, "metadata": metadata or {}}

matching_index = self._find_matching_index(item)
if override and matching_index is not None:
self.functions[matching_index] = item
else:
if matching_index is not None:
raise MisconfigurationException(
f"Function with name: {name} and metadata: {metadata} is already present within {self}."
" HINT: Use `override=True`."
)
self.functions.append(item)

def _find_matching_index(self, item: _REGISTERED_FUNCTION) -> Optional[int]:
for idx, fn in enumerate(self.functions):
if (
fn["fn"] == item["fn"] and fn["name"] == item["name"]
and item["metadata"].items() <= fn["metadata"].items()
):
return idx

def __call__(
self,
fn: Optional[callable] = None,
name: Optional[str] = None,
override: bool = False,
**metadata
) -> callable:
"""Register a function"""
if fn is not None:
self._register_function(fn=fn, name=name, override=override, metadata=metadata)
return fn

# raise the error ahead of time
if not (name is None or isinstance(name, str)):
raise TypeError(f'`name` must be a str, found {name}')

def _register(cls):
self._register_function(fn=cls, name=name, override=override, metadata=metadata)
return cls

return _register

def available_keys(self) -> List[str]:
return sorted(v["name"] for v in self.functions)
40 changes: 15 additions & 25 deletions flash/data/base_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,45 +24,35 @@ def __init__(self, enabled: bool = False):
self.enabled = enabled
self._preprocess = None

def _store(self, data: Any, fn_name: str, running_stage: RunningStage) -> None:
if self.enabled:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault(fn_name, [])
store[fn_name].append(data)

def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("load_sample", [])
store["load_sample"].append(sample)
self._store(sample, "load_sample", running_stage)

def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("pre_tensor_transform", [])
store["pre_tensor_transform"].append(sample)
self._store(sample, "pre_tensor_transform", running_stage)

def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("to_tensor_transform", [])
store["to_tensor_transform"].append(sample)
self._store(sample, "to_tensor_transform", running_stage)

def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("post_tensor_transform", [])
store["post_tensor_transform"].append(sample)
self._store(sample, "post_tensor_transform", running_stage)

def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("per_batch_transform", [])
store["per_batch_transform"].append(batch)
self._store(batch, "per_batch_transform", running_stage)

def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("collate", [])
store["collate"].append(batch)
self._store(batch, "collate", running_stage)

def on_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("per_sample_transform_on_device", [])
store["per_sample_transform_on_device"].append(samples)
self._store(samples, "per_sample_transform_on_device", running_stage)

def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("per_batch_transform_on_device", [])
store["per_batch_transform_on_device"].append(batch)
self._store(batch, "per_batch_transform_on_device", running_stage)

@contextmanager
def enable(self):
Expand All @@ -74,7 +64,7 @@ def attach_to_datamodule(self, datamodule) -> None:
datamodule.viz = self

def attach_to_preprocess(self, preprocess: Preprocess) -> None:
preprocess.callbacks = [self]
preprocess.add_callbacks([self])
self._preprocess = preprocess

def show(self, batch: Dict[str, Any], running_stage: RunningStage) -> None:
Expand Down
17 changes: 9 additions & 8 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ def viz(self, viz: BaseViz) -> None:
def configure_vis(*args, **kwargs) -> BaseViz:
return BaseViz()

def _reset_iterator(self, stage: RunningStage) -> Iterable[Any]:
iter_name = f"_{stage}_iter"
dataloader_fn = getattr(self, f"{stage}_dataloader")
iterator = iter(dataloader_fn())
setattr(self, iter_name, iterator)
return iterator

def show(self, batch: Dict[str, Any], stage: RunningStage) -> None:
"""
This function is a hook for users to override with their visualization on a batch.
Expand All @@ -114,21 +121,15 @@ def _show_batch(self, stage: RunningStage, reset: bool = True) -> None:
"""
iter_name = f"_{stage}_iter"

def _reset_iterator() -> Iterable[Any]:
dataloader_fn = getattr(self, f"{stage}_dataloader")
iterator = iter(dataloader_fn())
setattr(self, iter_name, iterator)
return iterator

if not hasattr(self, iter_name):
_reset_iterator()
self._reset_iterator(stage)

iter_dataloader = getattr(self, iter_name)
with self.viz.enable():
try:
_ = next(iter_dataloader)
except StopIteration:
iter_dataloader = _reset_iterator()
iter_dataloader = self._reset_iterator(stage)
_ = next(iter_dataloader)
self.show(self.viz.batches[stage], stage)
if reset:
Expand Down
3 changes: 3 additions & 0 deletions flash/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def callbacks(self) -> List['FlashCallback']:

@callbacks.setter
def callbacks(self, callbacks: List['FlashCallback']):
self._callbacks = callbacks

def add_callbacks(self, callbacks: List['FlashCallback']):
_callbacks = [c for c in callbacks if c not in self._callbacks]
self._callbacks.extend(_callbacks)

Expand Down
1 change: 0 additions & 1 deletion flash/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import os.path
import zipfile
from contextlib import ContextDecorator, contextmanager
from typing import Any, Callable, Dict, Iterable, Mapping, Type

import requests
Expand Down
1 change: 0 additions & 1 deletion flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from transformers.modeling_outputs import SequenceClassifierOutput

from flash.core.classification import ClassificationTask
from flash.text.classification.data import TextClassificationData


class TextClassifier(ClassificationTask):
Expand Down
1 change: 1 addition & 0 deletions flash/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
_KORNIA_AVAILABLE = _module_available("kornia")
_COCO_AVAILABLE = _module_available("pycocotools")
_TIMM_AVAILABLE = _module_available("timm")
_TORCHVISION_AVAILABLE = _module_available("torchvision")
1 change: 1 addition & 0 deletions flash/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES, OBJ_DETECTION_BACKBONES
from flash.vision.classification import ImageClassificationData, ImageClassifier
from flash.vision.detection import ObjectDetectionData, ObjectDetector
from flash.vision.embedding import ImageEmbedder
Loading

0 comments on commit 7853efd

Please sign in to comment.