Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[feat] Add backbone API #204

Merged
merged 23 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class Task(LightningModule):
learning_rate: Learning rate to use for training, defaults to `5e-5`
"""

register = None
tchaton marked this conversation as resolved.
Show resolved Hide resolved
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
model: Optional[nn.Module] = None,
Expand Down Expand Up @@ -256,3 +258,20 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
super().on_load_checkpoint(checkpoint)
if 'data_pipeline' in checkpoint:
self.data_pipeline = checkpoint['data_pipeline']

@classmethod
def available_models(cls) -> List[str]:
if cls.register is not None:
return cls.register.available_keys()
return []
tchaton marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def register_function(
cls,
fn: Optional[Callable] = None,
name: Optional[str] = None,
override: bool = False,
**metadata
) -> Optional[Callable]:
if cls.register is not None:
return cls.register.register_function(fn=fn, name=name, override=override, **metadata)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
189 changes: 189 additions & 0 deletions flash/core/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import hashlib
from collections import defaultdict
from functools import partial
from types import FunctionType
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Union

from pytorch_lightning.utilities.exceptions import MisconfigurationException


class FlashRegistry(Dict):
"""
This class is used to register function or partial to a registry:

Example::

backbones = FlashRegistry("backbones")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

@backbones.register_function()
def my_model(nc_input=5, nc_output=6):
return nn.Linear(nc_input, nc_output), nc_input, nc_output

mlp, nc_input, nc_output = backbones.get("my_model")(nc_output=7)

backbones.register_function(my_model, name="cho")
assert backbones.get("cho")

"""

def __init__(self, registry_name: str, verbose: bool = False):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self._registry_name = registry_name
self._registered_functions: Mapping[str, Dict[str, Any]] = defaultdict()
self._registered_functions_mapping: Dict[str, str] = {}
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self._verbose = verbose

def __len__(self):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return len(self._registered_functions)

def __contains__(self, key):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return self._registered_functions.get(key, None)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
format_str = self.__class__.__name__ + \
f'(name={self._registry_name}, ' \
f'registered_items={dict(**self._registered_functions)})'
return format_str

@property
def name(self) -> str:
return self._registry_name

@property
def registered_funcs(self) -> Dict[str, Any]:
return self._registered_functions

def __getitem__(self, key: str) -> Callable:
return self.get(key)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def validate_matches(self, key: str, matches: Dict, with_metadata: bool, key_in: bool = False):
if len(matches) == 1:
registered_function = self._registered_functions[list(matches.keys())[0]]
if with_metadata:
return registered_function
return registered_function["fn"]
elif len(matches) == 0:
if key_in:
raise MisconfigurationException(
f"Found {len(matches)} matches within {matches}. Add more metadata to filter them out."
)
raise MisconfigurationException(f"Key: {key} is not in {self.__repr__()}")

def __call__(self,
key: str,
with_metadata: bool = False,
strict: bool = True,
**metadata) -> Union[Callable, Dict[str, Any], List[Dict[str, Any]], List[Callable]]:

return self.get(key, with_metadata=with_metadata, strict=strict, **metadata)

def get(self,
key: str,
with_metadata: bool = False,
strict: bool = True,
**metadata) -> Union[Callable, Dict[str, Any], List[Dict[str, Any]], List[Callable]]:
"""
This function is used to gather matches from the registry:

Args:
key: Name of the registered function.
with_metadata: Whether to return associated metadata used during registration.
strict: Whether to return all matches if higher than 1.
metadata: All filtering metadata used for the registry.

"""
matches = {_hash: name for _hash, name in self._registered_functions_mapping.items() if key == name}
key_in = False
if len(matches) > 1:
key_in = True
matches = self._filter_matches_on_metadata(matches, with_metadata=with_metadata, metadata=metadata)
if not strict:
_matches = []
for v in matches:
match = list(v.values())[0]
if with_metadata:
match = match["fn"]
_matches.append(match)
return _matches
if len(matches) > 1:
raise MisconfigurationException(
f"Found {len(matches)} matches within {matches}. Add more metadata to filter them out."
)
elif len(matches) == 1:
matches = matches[0]
return self.validate_matches(key, matches, with_metadata, key_in=key_in)

def remove(self, key: str) -> None:
matches = {hash for hash, _key in self._registered_functions_mapping.items() if key == _key}
for hash in matches:
del self._registered_functions_mapping[hash]
del self._registered_functions[hash]

def _register_function(
self, fn: Callable, name: Optional[str] = None, override: bool = False, metadata: Dict[str, Any] = None
):
if not isinstance(fn, FunctionType) and not isinstance(fn, partial):
raise MisconfigurationException("``register_function`` should be used with a function")

name = name or fn.__name__

registered_function = {"fn": fn, "name": name, "metadata": metadata}

hash_algo = hashlib.sha256()
hash_algo.update(str(name + str(metadata)).encode('utf-8'))
hash = hash_algo.hexdigest()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

if override:
self._registered_functions[hash] = registered_function
else:
if hash in self._registered_functions_mapping:
raise MisconfigurationException(
f"Function with name: {name} and metadata: {metadata} is already present within {self}"
)
self._registered_functions[hash] = registered_function
self._registered_functions_mapping.update({hash: name})

def register_function(
self,
fn: Optional[Callable] = None,
name: Optional[str] = None,
override: bool = False,
**metadata
) -> Callable:
"""Register a callable
"""
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if fn is not None:
if self._verbose:
print(f"Registering: {fn.__name__} function with name: {name} and metadata: {metadata}")
self._register_function(fn=fn, name=name, override=override, metadata=metadata)
return fn

# raise the error ahead of time
if not (name is None or isinstance(name, str)):
raise TypeError(f'name must be a str, but got {type(name)}')

def _register(cls):
self._register_function(fn=cls, name=name, override=override, metadata=metadata)
return cls

return _register

def _filter_matches_on_metadata(self, matches, with_metadata: bool = False, **metadata) -> List[Dict[str, Any]]:

def _extract_metadata(metadata: Dict, key: str) -> Optional[Any]:
if key in metadata:
return metadata[key]

_matches = []
for hash in matches.keys():
registered_function = self._registered_functions[hash]
_metadata = registered_function["metadata"]
if all(_extract_metadata(_metadata, k) == v for k, v in metadata["metadata"].items()):
_matches.append({hash: registered_function})
return _matches

def available_keys(self) -> List[str]:
return sorted([v["name"] for v in self._registered_functions.values()])


IMAGE_CLASSIFIER_BACKBONES = FlashRegistry("backbones")
OBJ_DETECTION_BACKBONES = FlashRegistry("backbones")
40 changes: 15 additions & 25 deletions flash/data/base_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,45 +24,35 @@ def __init__(self, enabled: bool = False):
self.enabled = enabled
self._preprocess = None

def _store(self, data: Any, fn_name: str, running_stage: RunningStage) -> None:
if self.enabled:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault(fn_name, [])
store[fn_name].append(data)

def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("load_sample", [])
store["load_sample"].append(sample)
self._store(sample, "load_sample", running_stage)

def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("pre_tensor_transform", [])
store["pre_tensor_transform"].append(sample)
self._store(sample, "pre_tensor_transform", running_stage)

def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("to_tensor_transform", [])
store["to_tensor_transform"].append(sample)
self._store(sample, "to_tensor_transform", running_stage)

def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("post_tensor_transform", [])
store["post_tensor_transform"].append(sample)
self._store(sample, "post_tensor_transform", running_stage)

def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("per_batch_transform", [])
store["per_batch_transform"].append(batch)
self._store(batch, "per_batch_transform", running_stage)

def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("collate", [])
store["collate"].append(batch)
self._store(batch, "collate", running_stage)

def on_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("per_sample_transform_on_device", [])
store["per_sample_transform_on_device"].append(samples)
self._store(samples, "per_sample_transform_on_device", running_stage)

def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("per_batch_transform_on_device", [])
store["per_batch_transform_on_device"].append(batch)
self._store(batch, "per_batch_transform_on_device", running_stage)

@contextmanager
def enable(self):
Expand All @@ -74,7 +64,7 @@ def attach_to_datamodule(self, datamodule) -> None:
datamodule.viz = self

def attach_to_preprocess(self, preprocess: Preprocess) -> None:
preprocess.callbacks = [self]
preprocess.add_callbacks([self])
self._preprocess = preprocess

def show(self, batch: Dict[str, Any], running_stage: RunningStage) -> None:
Expand Down
17 changes: 9 additions & 8 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ def viz(self, viz: BaseViz) -> None:
def configure_vis(*args, **kwargs) -> BaseViz:
return BaseViz()

def _reset_iterator(self, stage: RunningStage) -> Iterable[Any]:
iter_name = f"_{stage}_iter"
dataloader_fn = getattr(self, f"{stage}_dataloader")
iterator = iter(dataloader_fn())
setattr(self, iter_name, iterator)
return iterator

def show(self, batch: Dict[str, Any], stage: RunningStage) -> None:
"""
This function is a hook for users to override with their visualization on a batch.
Expand All @@ -114,21 +121,15 @@ def _show_batch(self, stage: RunningStage, reset: bool = True) -> None:
"""
iter_name = f"_{stage}_iter"

def _reset_iterator() -> Iterable[Any]:
dataloader_fn = getattr(self, f"{stage}_dataloader")
iterator = iter(dataloader_fn())
setattr(self, iter_name, iterator)
return iterator

if not hasattr(self, iter_name):
_reset_iterator()
self._reset_iterator(stage)

iter_dataloader = getattr(self, iter_name)
with self.viz.enable():
try:
_ = next(iter_dataloader)
except StopIteration:
iter_dataloader = _reset_iterator()
iter_dataloader = self._reset_iterator(stage)
_ = next(iter_dataloader)
self.show(self.viz.batches[stage], stage)
if reset:
Expand Down
3 changes: 3 additions & 0 deletions flash/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def callbacks(self) -> List['FlashCallback']:

@callbacks.setter
def callbacks(self, callbacks: List['FlashCallback']):
self._callbacks = callbacks

def add_callbacks(self, callbacks: List['FlashCallback']):
_callbacks = [c for c in callbacks if c not in self._callbacks]
self._callbacks.extend(_callbacks)

Expand Down
1 change: 1 addition & 0 deletions flash/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
_KORNIA_AVAILABLE = _module_available("kornia")
_COCO_AVAILABLE = _module_available("pycocotools")
_TIMM_AVAILABLE = _module_available("timm")
_TORCHVISION_AVAILABLE = _module_available("torchvision")
1 change: 1 addition & 0 deletions flash/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES, OBJ_DETECTION_BACKBONES
from flash.vision.classification import ImageClassificationData, ImageClassifier
from flash.vision.detection import ObjectDetectionData, ObjectDetector
from flash.vision.embedding import ImageEmbedder
Loading