Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Rename PreprocessTransform to InputTransform (#868)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Oct 14, 2021
1 parent 4b24b44 commit b22e786
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 180 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759))


- Changed `PreprocessTransform` to `InputTransform` ([#868](https://github.com/PyTorchLightning/lightning-flash/pull/868))


### Fixed

- Fixed a bug where additional kwargs (e.g. sampler) passed to tabular data would be ignored ([#792](https://github.com/PyTorchLightning/lightning-flash/pull/792))
Expand Down
6 changes: 3 additions & 3 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from flash.core.data.callback import FlashCallback
from flash.core.data.data_module import DataModule # noqa: E402
from flash.core.data.data_source import DataSource
from flash.core.data.datasets import FlashDataset, FlashIterableDataset # noqa: E402
from flash.core.data.preprocess_transform import PreprocessTransform # noqa: E402
from flash.core.data.datasets import FlashDataset, FlashIterableDataset
from flash.core.data.input_transform import InputTransform
from flash.core.data.process import Postprocess, Preprocess, Serializer
from flash.core.model import Task # noqa: E402
from flash.core.trainer import Trainer # noqa: E402
Expand All @@ -45,7 +45,7 @@
"FlashDataset",
"FlashIterableDataset",
"Preprocess",
"PreprocessTransform",
"InputTransform",
"Postprocess",
"Serializer",
"Task",
Expand Down
32 changes: 17 additions & 15 deletions flash/core/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import Dataset, IterableDataset

from flash.core.data.preprocess_transform import PREPROCESS_TRANSFORM_TYPE, PreprocessTransform
from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.data.properties import Properties
from flash.core.registry import FlashRegistry

Expand All @@ -35,8 +35,8 @@ class BaseDataset(Properties):

DATASET_KEY = "dataset"

transforms_registry: Optional[FlashRegistry] = FlashRegistry("transforms")
transform: Optional[PreprocessTransform] = None
input_transforms_registry: Optional[FlashRegistry] = FlashRegistry("transforms")
transform: Optional[InputTransform] = None

@abstractmethod
def load_data(self, data: Any) -> Union[Iterable, Mapping]:
Expand All @@ -49,12 +49,12 @@ def load_data(self, data: Any) -> Union[Iterable, Mapping]:
def load_sample(self, data: Any) -> Any:
"""The `load_sample` hook contains the logic to load a single sample."""

def __init__(self, running_stage: RunningStage, transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None) -> None:
def __init__(self, running_stage: RunningStage, transform: Optional[INPUT_TRANSFORM_TYPE] = None) -> None:
super().__init__()
self.running_stage = running_stage
if transform:
self.transform = PreprocessTransform.from_transform(
transform, running_stage=running_stage, transforms_registry=self.transforms_registry
self.transform = InputTransform.from_transform(
transform, running_stage=running_stage, input_transforms_registry=self.input_transforms_registry
)

def pass_args_to_load_data(
Expand Down Expand Up @@ -110,7 +110,7 @@ def from_data(
cls,
*load_data_args,
running_stage: Optional[RunningStage] = None,
transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
transform: Optional[INPUT_TRANSFORM_TYPE] = None,
**dataset_kwargs: Any,
) -> "BaseDataset":
if not running_stage:
Expand All @@ -126,7 +126,7 @@ def from_data(
def from_train_data(
cls,
*load_data_args,
transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
transform: Optional[INPUT_TRANSFORM_TYPE] = None,
**dataset_kwargs: Any,
) -> "BaseDataset":
return cls.from_data(
Expand All @@ -137,7 +137,7 @@ def from_train_data(
def from_val_data(
cls,
*load_data_args,
transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
transform: Optional[INPUT_TRANSFORM_TYPE] = None,
**dataset_kwargs: Any,
) -> "BaseDataset":
return cls.from_data(
Expand All @@ -148,7 +148,7 @@ def from_val_data(
def from_test_data(
cls,
*load_data_args,
transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
transform: Optional[INPUT_TRANSFORM_TYPE] = None,
**dataset_kwargs: Any,
) -> "BaseDataset":
return cls.from_data(*load_data_args, running_stage=RunningStage.TESTING, transform=transform, **dataset_kwargs)
Expand All @@ -157,20 +157,22 @@ def from_test_data(
def from_predict_data(
cls,
*load_data_args,
transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
transform: Optional[INPUT_TRANSFORM_TYPE] = None,
**dataset_kwargs: Any,
) -> "BaseDataset":
return cls.from_data(
*load_data_args, running_stage=RunningStage.PREDICTING, transform=transform, **dataset_kwargs
)

@classmethod
def register_transform(cls, enum: Union[LightningEnum, str], fn: Union[Type[PreprocessTransform], partial]) -> None:
if cls.transforms_registry is None:
def register_input_transform(
cls, enum: Union[LightningEnum, str], fn: Union[Type[InputTransform], partial]
) -> None:
if cls.input_transforms_registry is None:
raise MisconfigurationException(
"The class attribute `transforms_registry` should be set as a class attribute. "
"The class attribute `input_transforms_registry` should be set as a class attribute. "
)
cls.transforms_registry(fn=fn, name=enum)
cls.input_transforms_registry(fn=fn, name=enum)

def resolve_functions(self):
raise NotImplementedError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX
from flash.core.registry import FlashRegistry

PREPROCESS_TRANSFORM_TYPE = Optional[
Union["PreprocessTransform", Callable, Tuple[Union[LightningEnum, str], Dict[str, Any]], Union[LightningEnum, str]]
INPUT_TRANSFORM_TYPE = Optional[
Union["InputTransform", Callable, Tuple[Union[LightningEnum, str], Dict[str, Any]], Union[LightningEnum, str]]
]


class PreprocessTransformPlacement(LightningEnum):
class InputTransformPlacement(LightningEnum):

PER_SAMPLE_TRANSFORM = "per_sample_transform"
PER_BATCH_TRANSFORM = "per_batch_transform"
Expand All @@ -52,8 +52,8 @@ def wrapper(self, *args, **kwargs) -> Any:
return wrapper


class PreprocessTransform(Properties):
def configure_transforms(self, *args, **kwargs) -> Dict[PreprocessTransformPlacement, Callable]:
class InputTransform(Properties):
def configure_transforms(self, *args, **kwargs) -> Dict[InputTransformPlacement, Callable]:
"""The default transforms to use.
Will be overridden by transforms passed to the ``__init__``.
Expand Down Expand Up @@ -184,21 +184,21 @@ def per_batch_transform_on_device(self, batch: Any) -> Any:
@classmethod
def from_transform(
cls,
transform: PREPROCESS_TRANSFORM_TYPE,
transform: INPUT_TRANSFORM_TYPE,
running_stage: RunningStage,
transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["PreprocessTransform"]:
input_transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["InputTransform"]:

if isinstance(transform, PreprocessTransform):
if isinstance(transform, InputTransform):
transform.running_stage = running_stage
return transform

if isinstance(transform, Callable):
return cls(running_stage, {PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: transform})
return cls(running_stage, {InputTransformPlacement.PER_SAMPLE_TRANSFORM: transform})

if isinstance(transform, tuple) or isinstance(transform, (LightningEnum, str)):
enum, transform_kwargs = cls._sanitize_registry_transform(transform, transforms_registry)
transform_cls = transforms_registry.get(enum)
enum, transform_kwargs = cls._sanitize_registry_transform(transform, input_transforms_registry)
transform_cls = input_transforms_registry.get(enum)
return transform_cls(running_stage, transform=None, **transform_kwargs)

if not transform:
Expand All @@ -209,60 +209,66 @@ def from_transform(
@classmethod
def from_train_transform(
cls,
transform: PREPROCESS_TRANSFORM_TYPE,
transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["PreprocessTransform"]:
transform: INPUT_TRANSFORM_TYPE,
input_transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["InputTransform"]:
return cls.from_transform(
transform=transform, running_stage=RunningStage.TRAINING, transforms_registry=transforms_registry
transform=transform,
running_stage=RunningStage.TRAINING,
input_transforms_registry=input_transforms_registry,
)

@classmethod
def from_val_transform(
cls,
transform: PREPROCESS_TRANSFORM_TYPE,
transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["PreprocessTransform"]:
transform: INPUT_TRANSFORM_TYPE,
input_transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["InputTransform"]:
return cls.from_transform(
transform=transform, running_stage=RunningStage.VALIDATING, transforms_registry=transforms_registry
transform=transform,
running_stage=RunningStage.VALIDATING,
input_transforms_registry=input_transforms_registry,
)

@classmethod
def from_test_transform(
cls,
transform: PREPROCESS_TRANSFORM_TYPE,
transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["PreprocessTransform"]:
transform: INPUT_TRANSFORM_TYPE,
input_transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["InputTransform"]:
return cls.from_transform(
transform=transform, running_stage=RunningStage.TESTING, transforms_registry=transforms_registry
transform=transform, running_stage=RunningStage.TESTING, input_transforms_registry=input_transforms_registry
)

@classmethod
def from_predict_transform(
cls,
transform: PREPROCESS_TRANSFORM_TYPE,
transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["PreprocessTransform"]:
transform: INPUT_TRANSFORM_TYPE,
input_transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["InputTransform"]:
return cls.from_transform(
transform=transform, running_stage=RunningStage.PREDICTING, transforms_registry=transforms_registry
transform=transform,
running_stage=RunningStage.PREDICTING,
input_transforms_registry=input_transforms_registry,
)

def _resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str, Callable]]:
from flash.core.data.data_pipeline import DataPipeline

resolved_function = getattr(
self,
DataPipeline._resolve_function_hierarchy("configure_transforms", self, running_stage, PreprocessTransform),
DataPipeline._resolve_function_hierarchy("configure_transforms", self, running_stage, InputTransform),
)
params = inspect.signature(resolved_function).parameters
transforms_out: Optional[Dict[str, Callable]] = resolved_function(
**{k: v for k, v in self._transform_kwargs.items() if k in params}
)

transforms_out = transforms_out or {}
for placement in PreprocessTransformPlacement:
for placement in InputTransformPlacement:
transform_name = f"configure_{placement.value}"
resolved_function = getattr(
self, DataPipeline._resolve_function_hierarchy(transform_name, self, running_stage, PreprocessTransform)
self, DataPipeline._resolve_function_hierarchy(transform_name, self, running_stage, InputTransform)
)
params = inspect.signature(resolved_function).parameters
transforms: Optional[Dict[str, Callable]] = resolved_function(
Expand All @@ -278,7 +284,7 @@ def _check_transforms(
if transform is None:
return transform

keys_diff = set(transform.keys()).difference([v for v in PreprocessTransformPlacement])
keys_diff = set(transform.keys()).difference([v for v in InputTransformPlacement])

if len(keys_diff) > 0:
raise MisconfigurationException(
Expand Down Expand Up @@ -315,11 +321,11 @@ def _get_transform(self, transform: Dict[str, Callable]) -> Callable:

@classmethod
def _sanitize_registry_transform(
cls, transform: Tuple[Union[LightningEnum, str], Any], transforms_registry: Optional[FlashRegistry]
cls, transform: Tuple[Union[LightningEnum, str], Any], input_transforms_registry: Optional[FlashRegistry]
) -> Tuple[Union[LightningEnum, str], Dict]:
msg = "The transform should be provided as a tuple with the following types (LightningEnum, Dict[str, Any]) "
msg += "when requesting transform from the registry."
if not transforms_registry:
if not input_transforms_registry:
raise MisconfigurationException("You requested a transform from the registry, but it is empty.")
if isinstance(transform, tuple) and len(transform) > 2:
raise MisconfigurationException(msg)
Expand All @@ -337,7 +343,7 @@ def _sanitize_registry_transform(
def __repr__(self) -> str:
return f"{self.__class__.__name__}(running_stage={self.running_stage}, transform={self.transform})"

def __getitem__(self, placement: PreprocessTransformPlacement) -> Callable:
def __getitem__(self, placement: InputTransformPlacement) -> Callable:
return self.transform[placement]

def _make_collates(self, on_device: bool, collate: Callable) -> Tuple[Callable, Callable]:
Expand All @@ -359,18 +365,18 @@ def _create_collate_preprocessors(self) -> Tuple[Any]:
prefix: str = _STAGES_PREFIX[self.running_stage]

func_names: Dict[str, str] = {
k: DataPipeline._resolve_function_hierarchy(k, self, self.running_stage, PreprocessTransform)
for k in [v.value for v in PreprocessTransformPlacement]
k: DataPipeline._resolve_function_hierarchy(k, self, self.running_stage, InputTransform)
for k in [v.value for v in InputTransformPlacement]
}

collate_fn: Callable = getattr(self, func_names["collate"])

per_batch_transform_overriden: bool = DataPipeline._is_overriden_recursive(
"per_batch_transform", self, PreprocessTransform, prefix=prefix
"per_batch_transform", self, InputTransform, prefix=prefix
)

per_sample_transform_on_device_overriden: bool = DataPipeline._is_overriden_recursive(
"per_sample_transform_on_device", self, PreprocessTransform, prefix=prefix
"per_sample_transform_on_device", self, InputTransform, prefix=prefix
)

is_per_overriden = per_batch_transform_overriden and per_sample_transform_on_device_overriden
Expand Down
12 changes: 6 additions & 6 deletions flash/core/data/new_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import DefaultPreprocess, Postprocess
from flash.core.data.datasets import BaseDataset
from flash.core.data.preprocess_transform import PREPROCESS_TRANSFORM_TYPE, PreprocessTransform
from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE

Expand Down Expand Up @@ -267,10 +267,10 @@ def create_flash_datasets(
val_data: Optional[Any] = None,
test_data: Optional[Any] = None,
predict_data: Optional[Any] = None,
train_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
val_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
test_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
predict_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
train_transform: Optional[INPUT_TRANSFORM_TYPE] = None,
val_transform: Optional[INPUT_TRANSFORM_TYPE] = None,
test_transform: Optional[INPUT_TRANSFORM_TYPE] = None,
predict_transform: Optional[INPUT_TRANSFORM_TYPE] = None,
**flash_dataset_kwargs,
) -> Tuple[Optional[BaseDataset]]:
cls._verify_flash_dataset_enum(enum)
Expand Down Expand Up @@ -311,7 +311,7 @@ def _create_flash_dataset(
flash_dataset_cls,
*load_data_args,
running_stage: RunningStage,
transform: Optional[PreprocessTransform],
transform: Optional[InputTransform],
**kwargs,
) -> Optional[BaseDataset]:
if load_data_args[0] is not None:
Expand Down
Loading

0 comments on commit b22e786

Please sign in to comment.