Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Data Pipeline Refactor: Improve new DataModule (#1029)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Dec 6, 2021
1 parent e73c420 commit d046de4
Show file tree
Hide file tree
Showing 8 changed files with 402 additions and 175 deletions.
2 changes: 2 additions & 0 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def _create_collate_input_transform_processors(
self._identity if is_serving else per_sample_transform,
getattr(input_transform, func_names["per_batch_transform"]),
stage,
callbacks=input_transform.callbacks,
)
worker_input_transform_processor._original_collate_fn = original_collate_fn
device_input_transform_processor = _InputTransformProcessor(
Expand All @@ -295,6 +296,7 @@ def _create_collate_input_transform_processors(
stage,
apply_per_sample_transform=device_collate_fn != self._identity,
on_device=True,
callbacks=input_transform.callbacks,
)
return deserialize_processor, worker_input_transform_processor, device_input_transform_processor

Expand Down
157 changes: 85 additions & 72 deletions flash/core/data/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
import inspect
from dataclasses import dataclass
from functools import partial, wraps
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union

from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data._utils.collate import default_collate

import flash
from flash.core.data.callback import FlashCallback
from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import _InputTransformProcessor
from flash.core.data.properties import Properties
Expand All @@ -31,7 +32,7 @@
from flash.core.utilities.stages import RunningStage

INPUT_TRANSFORM_TYPE = Optional[
Union["InputTransform", Callable, Tuple[Union[LightningEnum, str], Dict[str, Any]], Union[LightningEnum, str]]
Union[Type["InputTransform"], Callable, Tuple[Union[LightningEnum, str], Dict[str, Any]], Union[LightningEnum, str]]
]


Expand Down Expand Up @@ -86,20 +87,28 @@ def __repr__(self):
return format_string


class InputTransformState(dict):
pass


@dataclass
class InputTransform(Properties):

running_stage: RunningStage
data_pipeline_state: Optional["flash.core.data.data_pipeline.DataPipelineState"] = None

def __post_init__(self):
transform_kwargs = {
k: v for k, v in self.__dict__.items() if k not in ("_running_stage", "data_pipeline_state")
}
# used to keep track of provided transforms
self._collate_in_worker_from_transform: Optional[bool] = None
self._transform = None
self._transform = self._check_transforms(self._resolve_transforms(self.running_stage), self.running_stage)
self.callbacks = []
# Hack
Properties.__init__(self, data_pipeline_state=self.data_pipeline_state, running_stage=self.running_stage)
self.set_state(InputTransformState(**transform_kwargs))

@property
def current_transform(self) -> Callable:
Expand All @@ -115,16 +124,6 @@ def transforms(self) -> Dict[str, Optional[Dict[str, Callable]]]:
"transform": self._transform,
}

@property
def dataloader_collate_fn(self):
"""Generate the function to be injected within the DataLoader as the collate_fn."""
return self._create_collate_input_transform_processors()[0]

@property
def on_after_batch_transfer_fn(self):
"""Generate the function to be injected after the on_after_batch_transfer from the LightningModule."""
return self._create_collate_input_transform_processors()[1]

########################
# PER SAMPLE TRANSFORM #
########################
Expand Down Expand Up @@ -930,70 +929,16 @@ def _get_transform(self, transform: Dict[str, Callable]) -> Callable:
return transform[self.current_fn]
return self._identity

def __repr__(self) -> str:
return f"{self.__class__.__name__}(running_stage={self.running_stage}, transform={self._transform})"
def __str__(self) -> str:
state = self.get_state(InputTransformState)
return (
f"{self.__class__.__name__}("
+ f"running_stage={self.running_stage}, state: {state}, transform={self._transform})"
)

def __getitem__(self, placement: InputTransformPlacement) -> Callable:
return self._transform[placement]

def _make_collates(self, on_device: bool, collate: Callable) -> Tuple[Callable, Callable]:
if on_device:
return self._identity, collate
return collate, self._identity

def _create_collate_input_transform_processors(self) -> Tuple[Any]:
from flash.core.data.data_pipeline import DataPipeline

prefix: str = _STAGES_PREFIX[self.running_stage]

per_batch_transform_overridden: bool = DataPipeline._is_overridden_recursive(
"per_batch_transform", self, InputTransform, prefix=prefix
)

per_sample_transform_on_device_overridden: bool = DataPipeline._is_overridden_recursive(
"per_sample_transform_on_device", self, InputTransform, prefix=prefix
)

is_per_overridden = per_batch_transform_overridden and per_sample_transform_on_device_overridden
if self._collate_in_worker_from_transform is None and is_per_overridden:
raise MisconfigurationException(
f"{self.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` "
f"are mutually exclusive for stage {self.running_stage}"
)

if isinstance(self._collate_in_worker_from_transform, bool):
worker_collate_fn, device_collate_fn = self._make_collates(
not self._collate_in_worker_from_transform, self._collate
)
else:
worker_collate_fn, device_collate_fn = self._make_collates(
per_sample_transform_on_device_overridden, self._collate
)

worker_collate_fn = (
worker_collate_fn.collate_fn
if isinstance(worker_collate_fn, _InputTransformProcessor)
else worker_collate_fn
)

worker_input_transform_processor = _InputTransformProcessor(
self,
worker_collate_fn,
self._per_sample_transform,
self._per_batch_transform,
self.running_stage,
)
device_input_transform_processor = _InputTransformProcessor(
self,
device_collate_fn,
self._per_sample_transform_on_device,
self._per_batch_transform_on_device,
self.running_stage,
apply_per_sample_transform=device_collate_fn != self._identity,
on_device=True,
)
return worker_input_transform_processor, device_input_transform_processor


@dataclass
class LambdaInputTransform(InputTransform):
Expand Down Expand Up @@ -1037,6 +982,9 @@ def create_transform(
transform._data_pipeline_state = data_pipeline_state
return transform

if inspect.isclass(transform) and issubclass(transform, InputTransform):
return transform(running_stage=running_stage, data_pipeline_state=data_pipeline_state)

if isinstance(transform, Callable):
return LambdaInputTransform(
running_stage=running_stage, transform=transform, data_pipeline_state=data_pipeline_state
Expand All @@ -1051,3 +999,68 @@ def create_transform(
return None

raise MisconfigurationException(f"The format for the transform isn't correct. Found {transform}")


def _make_collates(input_transform: "InputTransform", on_device: bool, collate: Callable) -> Tuple[Callable, Callable]:
if on_device:
return input_transform._identity, collate
return collate, input_transform._identity


def _create_collate_input_transform_processors(
input_transform: "InputTransform", callbacks: List[FlashCallback]
) -> Tuple[_InputTransformProcessor, _InputTransformProcessor]:
"""This utility is used to create the 2 `_InputTransformProcessor` objects which contain the transforms used as
the DataLoader `collate_fn` and the DataModule `on_after_batch_transfer` hook."""

from flash.core.data.data_pipeline import DataPipeline

prefix: str = _STAGES_PREFIX[input_transform.running_stage]

per_batch_transform_overridden: bool = DataPipeline._is_overridden_recursive(
"per_batch_transform", input_transform, InputTransform, prefix=prefix
)

per_sample_transform_on_device_overridden: bool = DataPipeline._is_overridden_recursive(
"per_sample_transform_on_device", input_transform, InputTransform, prefix=prefix
)

is_per_overridden = per_batch_transform_overridden and per_sample_transform_on_device_overridden
if input_transform._collate_in_worker_from_transform is None and is_per_overridden:
raise MisconfigurationException(
f"{input_transform.__class__.__name__}: `per_batch_transform` and `per_sample_transform_on_device` "
f"are mutually exclusive for stage {input_transform.running_stage}"
)

if isinstance(input_transform._collate_in_worker_from_transform, bool):
worker_collate_fn, device_collate_fn = _make_collates(
input_transform, not input_transform._collate_in_worker_from_transform, input_transform._collate
)
else:
worker_collate_fn, device_collate_fn = _make_collates(
input_transform, per_sample_transform_on_device_overridden, input_transform._collate
)

worker_collate_fn = (
worker_collate_fn.collate_fn if isinstance(worker_collate_fn, _InputTransformProcessor) else worker_collate_fn
)

worker_input_transform_processor = _InputTransformProcessor(
input_transform,
worker_collate_fn,
input_transform._per_sample_transform,
input_transform._per_batch_transform,
input_transform.running_stage,
callbacks=callbacks,
)
device_input_transform_processor = _InputTransformProcessor(
input_transform,
device_collate_fn,
input_transform._per_sample_transform_on_device,
input_transform._per_batch_transform_on_device,
input_transform.running_stage,
apply_per_sample_transform=device_collate_fn != input_transform._identity,
on_device=True,
callbacks=callbacks,
)
return worker_input_transform_processor, device_input_transform_processor
27 changes: 15 additions & 12 deletions flash/core/data/io/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.utils.data import Dataset

import flash
from flash.core.data.callback import FlashCallback
from flash.core.data.properties import Properties
from flash.core.registry import FlashRegistry
from flash.core.utilities.stages import RunningStage
Expand Down Expand Up @@ -138,6 +139,20 @@ def __init__(
if len(args) >= 1 and args[0] is not None:
self.data = self._call_load_data(*args, **kwargs)

def _create_dataloader_collate_fn(self, callbacks: List[FlashCallback]) -> Optional[Callable]:
from flash.core.data.input_transform import _create_collate_input_transform_processors

if not self.transform:
return
return _create_collate_input_transform_processors(self.transform, callbacks)[0]

def _create_on_after_batch_transfer_fn(self, callbacks: List[FlashCallback]) -> Optional[Callable]:
from flash.core.data.input_transform import _create_collate_input_transform_processors

if not self.transform:
return
return _create_collate_input_transform_processors(self.transform, callbacks)[1]

def _call_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterable]:
from flash.core.data.data_pipeline import DataPipeline

Expand Down Expand Up @@ -236,18 +251,6 @@ def register_input_transform(
)
cls.input_transforms_registry(fn=fn, name=enum)

@property
def dataloader_collate_fn(self) -> Optional[Callable]:
if self.transform:
self.transform.running_stage = self.running_stage
return self.transform.dataloader_collate_fn

@property
def on_after_batch_transfer_fn(self) -> Optional[Callable]:
if self.transform:
self.transform.running_stage = self.running_stage
return self.transform.on_after_batch_transfer_fn


class Input(InputBase, Dataset):
def __getitem__(self, index: int) -> Any:
Expand Down
3 changes: 2 additions & 1 deletion flash/core/data/io/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,11 @@ def __init__(
stage: RunningStage,
apply_per_sample_transform: bool = True,
on_device: bool = False,
callbacks: Optional[List[FlashCallback]] = None,
):
super().__init__()
self.input_transform = input_transform
self.callback = ControlFlow(self.input_transform.callbacks)
self.callback = ControlFlow(callbacks or [])
self.collate_fn = convert_to_modules(collate_fn)
self.per_sample_transform = convert_to_modules(per_sample_transform)
self.per_batch_transform = convert_to_modules(per_batch_transform)
Expand Down
Loading

0 comments on commit d046de4

Please sign in to comment.