This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 211
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* wip * add base_viz + new features for DataPipeline * update * resolve flake8 * update * resolve tests * update * wip * update * resolve doc * resolve doc * update doc * update * update * update * convert to staticmethod * initial visualisation implementation * implement test case using Kornia transforms * update on comments * resolve bug * update * update * add test * update * resolve tests * resolve flake8 * update * update * update * resolve test Co-authored-by: Edgar Riba <[email protected]>
- Loading branch information
Showing
17 changed files
with
650 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
from contextlib import contextmanager | ||
from typing import Any, Dict, List, Sequence | ||
|
||
from pytorch_lightning.trainer.states import RunningStage | ||
from torch import Tensor | ||
|
||
from flash.core.utils import _is_overriden | ||
from flash.data.callback import FlashCallback | ||
from flash.data.process import Preprocess | ||
from flash.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX | ||
|
||
|
||
class BaseViz(FlashCallback): | ||
""" | ||
This class is used to profile ``Preprocess`` hook outputs and visualize the data transformations. | ||
It is disabled by default. | ||
batches: Dict = {"train": {"to_tensor_transform": [], ...}, ...} | ||
""" | ||
|
||
def __init__(self, enabled: bool = False): | ||
self.batches = {k: {} for k in _STAGES_PREFIX.values()} | ||
self.enabled = enabled | ||
self._preprocess = None | ||
|
||
def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: | ||
store = self.batches[_STAGES_PREFIX[running_stage]] | ||
store.setdefault("load_sample", []) | ||
store["load_sample"].append(sample) | ||
|
||
def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: | ||
store = self.batches[_STAGES_PREFIX[running_stage]] | ||
store.setdefault("pre_tensor_transform", []) | ||
store["pre_tensor_transform"].append(sample) | ||
|
||
def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: | ||
store = self.batches[_STAGES_PREFIX[running_stage]] | ||
store.setdefault("to_tensor_transform", []) | ||
store["to_tensor_transform"].append(sample) | ||
|
||
def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None: | ||
store = self.batches[_STAGES_PREFIX[running_stage]] | ||
store.setdefault("post_tensor_transform", []) | ||
store["post_tensor_transform"].append(sample) | ||
|
||
def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: | ||
store = self.batches[_STAGES_PREFIX[running_stage]] | ||
store.setdefault("per_batch_transform", []) | ||
store["per_batch_transform"].append(batch) | ||
|
||
def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None: | ||
store = self.batches[_STAGES_PREFIX[running_stage]] | ||
store.setdefault("collate", []) | ||
store["collate"].append(batch) | ||
|
||
def on_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None: | ||
store = self.batches[_STAGES_PREFIX[running_stage]] | ||
store.setdefault("per_sample_transform_on_device", []) | ||
store["per_sample_transform_on_device"].append(samples) | ||
|
||
def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: | ||
store = self.batches[_STAGES_PREFIX[running_stage]] | ||
store.setdefault("per_batch_transform_on_device", []) | ||
store["per_batch_transform_on_device"].append(batch) | ||
|
||
@contextmanager | ||
def enable(self): | ||
self.enabled = True | ||
yield | ||
self.enabled = False | ||
|
||
def attach_to_datamodule(self, datamodule) -> None: | ||
datamodule.viz = self | ||
|
||
def attach_to_preprocess(self, preprocess: Preprocess) -> None: | ||
preprocess.callbacks = [self] | ||
self._preprocess = preprocess | ||
|
||
def show(self, batch: Dict[str, Any], running_stage: RunningStage) -> None: | ||
""" | ||
This function is a hook for users to override with their visualization on a batch. | ||
""" | ||
for func_name in _PREPROCESS_FUNCS: | ||
hook_name = f"show_{func_name}" | ||
if _is_overriden(hook_name, self, BaseViz): | ||
getattr(self, hook_name)(batch[func_name], running_stage) | ||
|
||
def show_load_sample(self, samples: List[Any], running_stage: RunningStage): | ||
pass | ||
|
||
def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage): | ||
pass | ||
|
||
def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage): | ||
pass | ||
|
||
def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): | ||
pass | ||
|
||
def show_collate(self, batch: Sequence, running_stage: RunningStage) -> None: | ||
pass | ||
|
||
def show_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: | ||
pass | ||
|
||
def show_per_sample_transform_on_device(self, samples: Sequence, running_stage: RunningStage) -> None: | ||
pass | ||
|
||
def show_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from typing import Any, List, Sequence | ||
|
||
from pytorch_lightning.callbacks import Callback | ||
from pytorch_lightning.trainer.states import RunningStage | ||
from torch import Tensor | ||
|
||
|
||
class FlashCallback(Callback): | ||
|
||
def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: | ||
"""Called once a sample has been loaded using ``load_sample``.""" | ||
|
||
def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: | ||
"""Called once ``pre_tensor_transform`` have been applied to a sample.""" | ||
|
||
def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: | ||
"""Called once ``to_tensor_transform`` have been applied to a sample.""" | ||
|
||
def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None: | ||
"""Called once ``post_tensor_transform`` have been applied to a sample.""" | ||
|
||
def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: | ||
"""Called once ``per_batch_transform`` have been applied to a batch.""" | ||
|
||
def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None: | ||
"""Called once ``collate`` have been applied to a sequence of samples.""" | ||
|
||
def on_per_sample_transform_on_device(self, sample: Any, running_stage: RunningStage) -> None: | ||
"""Called once ``per_sample_transform_on_device`` have been applied to a sample.""" | ||
|
||
def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: | ||
"""Called once ``per_batch_transform_on_device`` have been applied to a sample.""" | ||
|
||
|
||
class ControlFlow(FlashCallback): | ||
|
||
def __init__(self, callbacks: List[FlashCallback]): | ||
self._callbacks = callbacks | ||
|
||
def run_for_all_callbacks(self, *args, method_name: str, **kwargs): | ||
if self._callbacks: | ||
for cb in self._callbacks: | ||
getattr(cb, method_name)(*args, **kwargs) | ||
|
||
def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: | ||
self.run_for_all_callbacks(sample, running_stage, method_name="on_load_sample") | ||
|
||
def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: | ||
self.run_for_all_callbacks(sample, running_stage, method_name="on_pre_tensor_transform") | ||
|
||
def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: | ||
self.run_for_all_callbacks(sample, running_stage, method_name="on_to_tensor_transform") | ||
|
||
def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None: | ||
self.run_for_all_callbacks(sample, running_stage, method_name="on_post_tensor_transform") | ||
|
||
def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: | ||
self.run_for_all_callbacks(batch, running_stage, method_name="on_per_batch_transform") | ||
|
||
def on_collate(self, batch: Sequence, running_stage: RunningStage) -> None: | ||
self.run_for_all_callbacks(batch, running_stage, method_name="on_collate") | ||
|
||
def on_per_sample_transform_on_device(self, sample: Any, running_stage: RunningStage) -> None: | ||
self.run_for_all_callbacks(sample, running_stage, method_name="on_per_sample_transform_on_device") | ||
|
||
def on_per_batch_transform_on_device(self, batch: Any, running_stage: RunningStage) -> None: | ||
self.run_for_all_callbacks(batch, running_stage, method_name="on_per_batch_transform_on_device") |
Oops, something went wrong.