Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[feat] Add backbone API #204

Merged
merged 23 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.trainer.states import RunningStage
from torch import nn

from flash.core.registry import FlashRegistry
from flash.core.utils import get_callable_dict
from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess

Expand Down Expand Up @@ -256,3 +257,10 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
super().on_load_checkpoint(checkpoint)
if 'data_pipeline' in checkpoint:
self.data_pipeline = checkpoint['data_pipeline']

@classmethod
def available_backbones(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "backbones", None)
if registry is None:
return []
return registry.available_keys()
147 changes: 147 additions & 0 deletions flash/core/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from types import FunctionType
from typing import Any, Dict, List, Optional, Union

from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.exceptions import MisconfigurationException

_REGISTERED_FUNCTION = Dict[str, Any]


class FlashRegistry:
"""
This class is used to register function or partial to a registry:

Example::

backbones = FlashRegistry("backbones")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

@backbones
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀🚀

def my_model(nc_input=5, nc_output=6):
return nn.Linear(nc_input, nc_output), nc_input, nc_output

mlp, nc_input, nc_output = backbones("my_model")(nc_output=7)

backbones(my_model, name="foo")
assert backbones("foo")

"""

def __init__(self, name: str, verbose: bool = False) -> None:
self.name = name
self.functions: List[_REGISTERED_FUNCTION] = []
self._verbose = verbose

def __len__(self) -> int:
return len(self.functions)

def __contains__(self, key) -> bool:
return any(key == e["name"] for e in self.functions)

def __repr__(self) -> str:
return f'{self.__class__.__name__}(name={self.name}, functions={self.functions})'

def get(
self,
key: str,
with_metadata: bool = False,
strict: bool = True,
**metadata,
) -> Union[callable, _REGISTERED_FUNCTION, List[_REGISTERED_FUNCTION], List[callable]]:
"""
This function is used to gather matches from the registry:

Args:
key: Name of the registered function.
with_metadata: Whether to include the associated metadata in the return value.
strict: Whether to return all matches or just one.
metadata: Metadata used to filter against existing registry item's metadata.
"""
matches = [e for e in self.functions if key == e["name"]]
if not matches:
raise KeyError(f"Key: {key} is not in {repr(self)}")

if metadata:
matches = [m for m in matches if metadata.items() <= m["metadata"].items()]
if not matches:
raise KeyError("Found no matches that fit your metadata criteria. Try removing some metadata")

matches = [e if with_metadata else e["fn"] for e in matches]
return matches[0] if strict else matches

def remove(self, key: str) -> None:
self.functions = [f for f in self.functions if f["name"] != key]

def _register_function(
self,
fn: callable,
name: Optional[str] = None,
override: bool = False,
metadata: Optional[Dict[str, Any]] = None
):
if not isinstance(fn, FunctionType) and not isinstance(fn, partial):
raise MisconfigurationException(f"You can only register a function, found: {fn}")

name = name or fn.__name__

if self._verbose:
rank_zero_info(f"Registering: {fn.__name__} function with name: {name} and metadata: {metadata}")

item = {"fn": fn, "name": name, "metadata": metadata or {}}

matching_index = self._find_matching_index(item)
if override and matching_index is not None:
self.functions[matching_index] = item
else:
if matching_index is not None:
raise MisconfigurationException(
f"Function with name: {name} and metadata: {metadata} is already present within {self}."
" HINT: Use `override=True`."
)
self.functions.append(item)

def _find_matching_index(self, item: _REGISTERED_FUNCTION) -> Optional[int]:
for idx, fn in enumerate(self.functions):
if (
fn["fn"] == item["fn"] and fn["name"] == item["name"]
and item["metadata"].items() <= fn["metadata"].items()
):
return idx

def __call__(
self,
fn: Optional[callable] = None,
name: Optional[str] = None,
override: bool = False,
**metadata
) -> callable:
"""Register a function"""
if fn is not None:
self._register_function(fn=fn, name=name, override=override, metadata=metadata)
return fn

# raise the error ahead of time
if not (name is None or isinstance(name, str)):
raise TypeError(f'`name` must be a str, found {name}')

def _register(cls):
self._register_function(fn=cls, name=name, override=override, metadata=metadata)
return cls

return _register

def available_keys(self) -> List[str]:
return sorted(v["name"] for v in self.functions)
40 changes: 15 additions & 25 deletions flash/data/base_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,45 +24,35 @@ def __init__(self, enabled: bool = False):
self.enabled = enabled
self._preprocess = None

def _store(self, data: Any, fn_name: str, running_stage: RunningStage) -> None:
if self.enabled:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault(fn_name, [])
store[fn_name].append(data)

def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("load_sample", [])
store["load_sample"].append(sample)
self._store(sample, "load_sample", running_stage)

def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("pre_tensor_transform", [])
store["pre_tensor_transform"].append(sample)
self._store(sample, "pre_tensor_transform", running_stage)

def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("to_tensor_transform", [])
store["to_tensor_transform"].append(sample)
self._store(sample, "to_tensor_transform", running_stage)

def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("post_tensor_transform", [])
store["post_tensor_transform"].append(sample)
self._store(sample, "post_tensor_transform", running_stage)

def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("per_batch_transform", [])
store["per_batch_transform"].append(batch)
self._store(batch, "per_batch_transform", running_stage)

def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("collate", [])
store["collate"].append(batch)
self._store(batch, "collate", running_stage)

def on_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("per_sample_transform_on_device", [])
store["per_sample_transform_on_device"].append(samples)
self._store(samples, "per_sample_transform_on_device", running_stage)

def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("per_batch_transform_on_device", [])
store["per_batch_transform_on_device"].append(batch)
self._store(batch, "per_batch_transform_on_device", running_stage)

@contextmanager
def enable(self):
Expand All @@ -74,7 +64,7 @@ def attach_to_datamodule(self, datamodule) -> None:
datamodule.viz = self

def attach_to_preprocess(self, preprocess: Preprocess) -> None:
preprocess.callbacks = [self]
preprocess.add_callbacks([self])
self._preprocess = preprocess

def show(self, batch: Dict[str, Any], running_stage: RunningStage) -> None:
Expand Down
17 changes: 9 additions & 8 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ def viz(self, viz: BaseViz) -> None:
def configure_vis(*args, **kwargs) -> BaseViz:
return BaseViz()

def _reset_iterator(self, stage: RunningStage) -> Iterable[Any]:
iter_name = f"_{stage}_iter"
dataloader_fn = getattr(self, f"{stage}_dataloader")
iterator = iter(dataloader_fn())
setattr(self, iter_name, iterator)
return iterator

def show(self, batch: Dict[str, Any], stage: RunningStage) -> None:
"""
This function is a hook for users to override with their visualization on a batch.
Expand All @@ -114,21 +121,15 @@ def _show_batch(self, stage: RunningStage, reset: bool = True) -> None:
"""
iter_name = f"_{stage}_iter"

def _reset_iterator() -> Iterable[Any]:
dataloader_fn = getattr(self, f"{stage}_dataloader")
iterator = iter(dataloader_fn())
setattr(self, iter_name, iterator)
return iterator

if not hasattr(self, iter_name):
_reset_iterator()
self._reset_iterator(stage)

iter_dataloader = getattr(self, iter_name)
with self.viz.enable():
try:
_ = next(iter_dataloader)
except StopIteration:
iter_dataloader = _reset_iterator()
iter_dataloader = self._reset_iterator(stage)
_ = next(iter_dataloader)
self.show(self.viz.batches[stage], stage)
if reset:
Expand Down
3 changes: 3 additions & 0 deletions flash/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def callbacks(self) -> List['FlashCallback']:

@callbacks.setter
def callbacks(self, callbacks: List['FlashCallback']):
self._callbacks = callbacks

def add_callbacks(self, callbacks: List['FlashCallback']):
_callbacks = [c for c in callbacks if c not in self._callbacks]
self._callbacks.extend(_callbacks)

Expand Down
1 change: 0 additions & 1 deletion flash/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import os.path
import zipfile
from contextlib import ContextDecorator, contextmanager
from typing import Any, Callable, Dict, Iterable, Mapping, Type

import requests
Expand Down
1 change: 0 additions & 1 deletion flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from transformers.modeling_outputs import SequenceClassifierOutput

from flash.core.classification import ClassificationTask
from flash.text.classification.data import TextClassificationData


class TextClassifier(ClassificationTask):
Expand Down
1 change: 1 addition & 0 deletions flash/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
_KORNIA_AVAILABLE = _module_available("kornia")
_COCO_AVAILABLE = _module_available("pycocotools")
_TIMM_AVAILABLE = _module_available("timm")
_TORCHVISION_AVAILABLE = _module_available("torchvision")
1 change: 1 addition & 0 deletions flash/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES, OBJ_DETECTION_BACKBONES
from flash.vision.classification import ImageClassificationData, ImageClassifier
from flash.vision.detection import ObjectDetectionData, ObjectDetector
from flash.vision.embedding import ImageEmbedder
Loading