Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add BaseViz Callback (2 / 2) (#201)
Browse files Browse the repository at this point in the history
* wip

* add base_viz + new features for DataPipeline

* update

* resolve flake8

* update

* resolve tests

* update

* wip

* update

* resolve doc

* resolve doc

* update doc

* update

* update

* update

* convert to staticmethod

* initial visualisation implementation

* implement test case using Kornia transforms

* update on comments

* resolve bug

* update

* update

* add test

* update

* resolve tests

* resolve flake8

* update

* update

* update

* resolve test

Co-authored-by: Edgar Riba <[email protected]>
  • Loading branch information
tchaton and edgarriba authored Apr 6, 2021
1 parent 6e548a4 commit 6a4948a
Show file tree
Hide file tree
Showing 17 changed files with 650 additions and 88 deletions.
11 changes: 0 additions & 11 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,22 +245,11 @@ def on_fit_end(self) -> None:
self.data_pipeline._detach_from_model(self)
super().on_fit_end()

@staticmethod
def _sanetize_funcs(obj: Any) -> Any:
if hasattr(obj, "__dict__"):
for k, v in obj.__dict__.items():
if isinstance(v, Callable):
obj.__dict__[k] = inspect.unwrap(v)
return obj

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# TODO: Is this the best way to do this? or should we also use some kind of hparams here?
# This may be an issue since here we create the same problems with pickle as in
# https://pytorch.org/docs/stable/notes/serialization.html
if self.data_pipeline is not None and 'data_pipeline' not in checkpoint:
self._preprocess = self._sanetize_funcs(self._preprocess)
checkpoint['data_pipeline'] = self.data_pipeline
# todo (tchaton) re-wrap visualization
super().on_save_checkpoint(checkpoint)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
Expand Down
14 changes: 13 additions & 1 deletion flash/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, Mapping, Sequence, Union
from typing import Any, Callable, Dict, Mapping, Sequence, Type, Union


def get_callable_name(fn_or_class: Union[Callable, object]) -> str:
Expand All @@ -25,3 +25,15 @@ def get_callable_dict(fn: Union[Callable, Mapping, Sequence]) -> Union[Dict, Map
return {get_callable_name(f): f for f in fn}
elif callable(fn):
return {get_callable_name(fn): fn}


def _is_overriden(method_name: str, instance: object, parent: Type[object]) -> bool:
"""
Cropped Version of
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/utilities/model_helpers.py
"""

if not hasattr(instance, method_name):
return False

return getattr(instance, method_name).__code__ != getattr(parent, method_name).__code__
12 changes: 11 additions & 1 deletion flash/data/auto_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pytorch_lightning.utilities.warning_utils import rank_zero_warn
from torch.utils.data import Dataset

from flash.data.callback import ControlFlow
from flash.data.process import Preprocess
from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES, CurrentRunningStageFuncContext

Expand Down Expand Up @@ -82,6 +83,12 @@ def preprocess(self) -> Optional[Preprocess]:
if self.data_pipeline is not None:
return self.data_pipeline._preprocess_pipeline

@property
def control_flow_callback(self) -> Optional[ControlFlow]:
preprocess = self.preprocess
if preprocess is not None:
return ControlFlow(preprocess.callbacks)

def _call_load_data(self, data: Any) -> Iterable:
parameters = signature(self.load_data).parameters
if len(parameters) > 1 and self.DATASET_KEY in parameters:
Expand Down Expand Up @@ -124,7 +131,10 @@ def __getitem__(self, index: int) -> Any:
raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.")
if self.load_sample:
with self._load_sample_context:
return self._call_load_sample(self.preprocessed_data[index])
data: Any = self._call_load_sample(self.preprocessed_data[index])
if self.control_flow_callback:
self.control_flow_callback.on_load_sample(data, self.running_stage)
return data
return self.preprocessed_data[index]

def __len__(self) -> int:
Expand Down
111 changes: 111 additions & 0 deletions flash/data/base_viz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from contextlib import contextmanager
from typing import Any, Dict, List, Sequence

from pytorch_lightning.trainer.states import RunningStage
from torch import Tensor

from flash.core.utils import _is_overriden
from flash.data.callback import FlashCallback
from flash.data.process import Preprocess
from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX


class BaseViz(FlashCallback):
"""
This class is used to profile ``Preprocess`` hook outputs and visualize the data transformations.
It is disabled by default.
batches: Dict = {"train": {"to_tensor_transform": [], ...}, ...}
"""

def __init__(self, enabled: bool = False):
self.batches = {k: {} for k in _STAGES_PREFIX.values()}
self.enabled = enabled
self._preprocess = None

def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("load_sample", [])
store["load_sample"].append(sample)

def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("pre_tensor_transform", [])
store["pre_tensor_transform"].append(sample)

def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("to_tensor_transform", [])
store["to_tensor_transform"].append(sample)

def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("post_tensor_transform", [])
store["post_tensor_transform"].append(sample)

def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("per_batch_transform", [])
store["per_batch_transform"].append(batch)

def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("collate", [])
store["collate"].append(batch)

def on_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("per_sample_transform_on_device", [])
store["per_sample_transform_on_device"].append(samples)

def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None:
store = self.batches[_STAGES_PREFIX[running_stage]]
store.setdefault("per_batch_transform_on_device", [])
store["per_batch_transform_on_device"].append(batch)

@contextmanager
def enable(self):
self.enabled = True
yield
self.enabled = False

def attach_to_datamodule(self, datamodule) -> None:
datamodule.viz = self

def attach_to_preprocess(self, preprocess: Preprocess) -> None:
preprocess.callbacks = [self]
self._preprocess = preprocess

def show(self, batch: Dict[str, Any], running_stage: RunningStage) -> None:
"""
This function is a hook for users to override with their visualization on a batch.
"""
for func_name in _PREPROCESS_FUNCS:
hook_name = f"show_{func_name}"
if _is_overriden(hook_name, self, BaseViz):
getattr(self, hook_name)(batch[func_name], running_stage)

def show_load_sample(self, samples: List[Any], running_stage: RunningStage):
pass

def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
pass

def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
pass

def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
pass

def show_collate(self, batch: Sequence, running_stage: RunningStage) -> None:
pass

def show_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None:
pass

def show_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None:
pass

def show_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None:
pass
29 changes: 23 additions & 6 deletions flash/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import Tensor

from flash.data.callback import ControlFlow
from flash.data.utils import _contains_any_tensor, convert_to_modules, CurrentFuncContext, CurrentRunningStageContext

if TYPE_CHECKING:
Expand All @@ -43,6 +44,7 @@ def __init__(
):
super().__init__()
self.preprocess = preprocess
self.callback = ControlFlow(self.preprocess.callbacks)
self.pre_tensor_transform = convert_to_modules(pre_tensor_transform)
self.to_tensor_transform = convert_to_modules(to_tensor_transform)
self.post_tensor_transform = convert_to_modules(post_tensor_transform)
Expand All @@ -58,9 +60,11 @@ def forward(self, sample: Any) -> Any:
with self._current_stage_context:
with self._pre_tensor_transform_context:
sample = self.pre_tensor_transform(sample)
self.callback.on_pre_tensor_transform(sample, self.stage)

with self._to_tensor_transform_context:
sample = self.to_tensor_transform(sample)
self.callback.on_to_tensor_transform(sample, self.stage)

if self.assert_contains_tensor:
if not _contains_any_tensor(sample):
Expand All @@ -71,6 +75,7 @@ def forward(self, sample: Any) -> Any:

with self._post_tensor_transform_context:
sample = self.post_tensor_transform(sample)
self.callback.on_post_tensor_transform(sample, self.stage)

return sample

Expand Down Expand Up @@ -112,36 +117,48 @@ def __init__(
per_batch_transform: Callable,
stage: RunningStage,
apply_per_sample_transform: bool = True,
on_device: bool = False
on_device: bool = False,
):
super().__init__()
self.preprocess = preprocess
self.callback = ControlFlow(self.preprocess.callbacks)
self.collate_fn = convert_to_modules(collate_fn)
self.per_sample_transform = convert_to_modules(per_sample_transform)
self.per_batch_transform = convert_to_modules(per_batch_transform)
self.apply_per_sample_transform = apply_per_sample_transform
self.stage = stage
self.on_device = on_device

extension = f"{'on_device' if self.on_device else ''}"
extension = f"{'_on_device' if self.on_device else ''}"
self._current_stage_context = CurrentRunningStageContext(stage, preprocess)
self._per_sample_transform_context = CurrentFuncContext(f"per_sample_transform_{extension}", preprocess)
self._per_sample_transform_context = CurrentFuncContext(f"per_sample_transform{extension}", preprocess)
self._collate_context = CurrentFuncContext("collate", preprocess)
self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform_{extension}", preprocess)
self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", preprocess)

def forward(self, samples: Sequence[Any]) -> Any:
with self._current_stage_context:

if self.apply_per_sample_transform:
with self._per_sample_transform_context:
samples = [self.per_sample_transform(sample) for sample in samples]
samples = type(samples)(samples)
_samples = []
for sample in samples:
sample = self.per_sample_transform(sample)
if self.on_device:
self.callback.on_per_sample_transform_on_device(sample, self.stage)
_samples.append(sample)

samples = type(_samples)(_samples)

with self._collate_context:
samples = self.collate_fn(samples)
self.callback.on_collate(samples, self.stage)

with self._per_batch_transform_context:
samples = self.per_batch_transform(samples)
if self.on_device:
self.callback.on_per_batch_transform_on_device(samples, self.stage)
else:
self.callback.on_per_batch_transform(samples, self.stage)
return samples

def __str__(self) -> str:
Expand Down
67 changes: 67 additions & 0 deletions flash/data/callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Any, List, Sequence

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer.states import RunningStage
from torch import Tensor


class FlashCallback(Callback):

def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None:
"""Called once a sample has been loaded using ``load_sample``."""

def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
"""Called once ``pre_tensor_transform`` have been applied to a sample."""

def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
"""Called once ``to_tensor_transform`` have been applied to a sample."""

def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None:
"""Called once ``post_tensor_transform`` have been applied to a sample."""

def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None:
"""Called once ``per_batch_transform`` have been applied to a batch."""

def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None:
"""Called once ``collate`` have been applied to a sequence of samples."""

def on_per_sample_transform_on_device(self, sample: Any, running_stage: RunningStage) -> None:
"""Called once ``per_sample_transform_on_device`` have been applied to a sample."""

def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None:
"""Called once ``per_batch_transform_on_device`` have been applied to a sample."""


class ControlFlow(FlashCallback):

def __init__(self, callbacks: List[FlashCallback]):
self._callbacks = callbacks

def run_for_all_callbacks(self, *args, method_name: str, **kwargs):
if self._callbacks:
for cb in self._callbacks:
getattr(cb, method_name)(*args, **kwargs)

def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(sample, running_stage, method_name="on_load_sample")

def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(sample, running_stage, method_name="on_pre_tensor_transform")

def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(sample, running_stage, method_name="on_to_tensor_transform")

def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(sample, running_stage, method_name="on_post_tensor_transform")

def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(batch, running_stage, method_name="on_per_batch_transform")

def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(batch, running_stage, method_name="on_collate")

def on_per_sample_transform_on_device(self, sample: Any, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(sample, running_stage, method_name="on_per_sample_transform_on_device")

def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(batch, running_stage, method_name="on_per_batch_transform_on_device")
Loading

0 comments on commit 6a4948a

Please sign in to comment.