Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Rename PreprocessTransform to InputTransform #868

Merged
merged 4 commits into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759))


- Changed `PreprocessTransform` to `InputTransform` ([#868](https://github.com/PyTorchLightning/lightning-flash/pull/868))


### Fixed

- Fixed a bug where additional kwargs (e.g. sampler) passed to tabular data would be ignored ([#792](https://github.com/PyTorchLightning/lightning-flash/pull/792))
Expand Down
6 changes: 3 additions & 3 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from flash.core.data.callback import FlashCallback
from flash.core.data.data_module import DataModule # noqa: E402
from flash.core.data.data_source import DataSource
from flash.core.data.datasets import FlashDataset, FlashIterableDataset # noqa: E402
from flash.core.data.preprocess_transform import PreprocessTransform # noqa: E402
from flash.core.data.datasets import FlashDataset, FlashIterableDataset
from flash.core.data.input_transform import InputTransform
from flash.core.data.process import Postprocess, Preprocess, Serializer
from flash.core.model import Task # noqa: E402
from flash.core.trainer import Trainer # noqa: E402
Expand All @@ -45,7 +45,7 @@
"FlashDataset",
"FlashIterableDataset",
"Preprocess",
"PreprocessTransform",
"InputTransform",
"Postprocess",
"Serializer",
"Task",
Expand Down
32 changes: 17 additions & 15 deletions flash/core/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import Dataset, IterableDataset

from flash.core.data.preprocess_transform import PREPROCESS_TRANSFORM_TYPE, PreprocessTransform
from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.data.properties import Properties
from flash.core.registry import FlashRegistry

Expand All @@ -35,8 +35,8 @@ class BaseDataset(Properties):

DATASET_KEY = "dataset"

transforms_registry: Optional[FlashRegistry] = FlashRegistry("transforms")
transform: Optional[PreprocessTransform] = None
input_transforms_registry: Optional[FlashRegistry] = FlashRegistry("transforms")
transform: Optional[InputTransform] = None

@abstractmethod
def load_data(self, data: Any) -> Union[Iterable, Mapping]:
Expand All @@ -49,12 +49,12 @@ def load_data(self, data: Any) -> Union[Iterable, Mapping]:
def load_sample(self, data: Any) -> Any:
"""The `load_sample` hook contains the logic to load a single sample."""

def __init__(self, running_stage: RunningStage, transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None) -> None:
def __init__(self, running_stage: RunningStage, transform: Optional[INPUT_TRANSFORM_TYPE] = None) -> None:
super().__init__()
self.running_stage = running_stage
if transform:
self.transform = PreprocessTransform.from_transform(
transform, running_stage=running_stage, transforms_registry=self.transforms_registry
self.transform = InputTransform.from_transform(
transform, running_stage=running_stage, input_transforms_registry=self.input_transforms_registry
)

def pass_args_to_load_data(
Expand Down Expand Up @@ -110,7 +110,7 @@ def from_data(
cls,
*load_data_args,
running_stage: Optional[RunningStage] = None,
transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
transform: Optional[INPUT_TRANSFORM_TYPE] = None,
**dataset_kwargs: Any,
) -> "BaseDataset":
if not running_stage:
Expand All @@ -126,7 +126,7 @@ def from_data(
def from_train_data(
cls,
*load_data_args,
transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
transform: Optional[INPUT_TRANSFORM_TYPE] = None,
**dataset_kwargs: Any,
) -> "BaseDataset":
return cls.from_data(
Expand All @@ -137,7 +137,7 @@ def from_train_data(
def from_val_data(
cls,
*load_data_args,
transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
transform: Optional[INPUT_TRANSFORM_TYPE] = None,
**dataset_kwargs: Any,
) -> "BaseDataset":
return cls.from_data(
Expand All @@ -148,7 +148,7 @@ def from_val_data(
def from_test_data(
cls,
*load_data_args,
transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
transform: Optional[INPUT_TRANSFORM_TYPE] = None,
**dataset_kwargs: Any,
) -> "BaseDataset":
return cls.from_data(*load_data_args, running_stage=RunningStage.TESTING, transform=transform, **dataset_kwargs)
Expand All @@ -157,20 +157,22 @@ def from_test_data(
def from_predict_data(
cls,
*load_data_args,
transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
transform: Optional[INPUT_TRANSFORM_TYPE] = None,
**dataset_kwargs: Any,
) -> "BaseDataset":
return cls.from_data(
*load_data_args, running_stage=RunningStage.PREDICTING, transform=transform, **dataset_kwargs
)

@classmethod
def register_transform(cls, enum: Union[LightningEnum, str], fn: Union[Type[PreprocessTransform], partial]) -> None:
if cls.transforms_registry is None:
def register_input_transform(
cls, enum: Union[LightningEnum, str], fn: Union[Type[InputTransform], partial]
) -> None:
if cls.input_transforms_registry is None:
raise MisconfigurationException(
"The class attribute `transforms_registry` should be set as a class attribute. "
"The class attribute `input_transforms_registry` should be set as a class attribute. "
)
cls.transforms_registry(fn=fn, name=enum)
cls.input_transforms_registry(fn=fn, name=enum)

def resolve_functions(self):
raise NotImplementedError
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX
from flash.core.registry import FlashRegistry

PREPROCESS_TRANSFORM_TYPE = Optional[
Union["PreprocessTransform", Callable, Tuple[Union[LightningEnum, str], Dict[str, Any]], Union[LightningEnum, str]]
INPUT_TRANSFORM_TYPE = Optional[
Union["InputTransform", Callable, Tuple[Union[LightningEnum, str], Dict[str, Any]], Union[LightningEnum, str]]
]


class PreprocessTransformPlacement(LightningEnum):
class InputTransformPlacement(LightningEnum):

PER_SAMPLE_TRANSFORM = "per_sample_transform"
PER_BATCH_TRANSFORM = "per_batch_transform"
Expand All @@ -52,8 +52,8 @@ def wrapper(self, *args, **kwargs) -> Any:
return wrapper


class PreprocessTransform(Properties):
def configure_transforms(self, *args, **kwargs) -> Dict[PreprocessTransformPlacement, Callable]:
class InputTransform(Properties):
def configure_transforms(self, *args, **kwargs) -> Dict[InputTransformPlacement, Callable]:
"""The default transforms to use.

Will be overridden by transforms passed to the ``__init__``.
Expand Down Expand Up @@ -184,21 +184,21 @@ def per_batch_transform_on_device(self, batch: Any) -> Any:
@classmethod
def from_transform(
cls,
transform: PREPROCESS_TRANSFORM_TYPE,
transform: INPUT_TRANSFORM_TYPE,
running_stage: RunningStage,
transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["PreprocessTransform"]:
input_transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["InputTransform"]:

if isinstance(transform, PreprocessTransform):
if isinstance(transform, InputTransform):
transform.running_stage = running_stage
return transform

if isinstance(transform, Callable):
return cls(running_stage, {PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: transform})
return cls(running_stage, {InputTransformPlacement.PER_SAMPLE_TRANSFORM: transform})

if isinstance(transform, tuple) or isinstance(transform, (LightningEnum, str)):
enum, transform_kwargs = cls._sanitize_registry_transform(transform, transforms_registry)
transform_cls = transforms_registry.get(enum)
enum, transform_kwargs = cls._sanitize_registry_transform(transform, input_transforms_registry)
transform_cls = input_transforms_registry.get(enum)
return transform_cls(running_stage, transform=None, **transform_kwargs)

if not transform:
Expand All @@ -209,60 +209,66 @@ def from_transform(
@classmethod
def from_train_transform(
cls,
transform: PREPROCESS_TRANSFORM_TYPE,
transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["PreprocessTransform"]:
transform: INPUT_TRANSFORM_TYPE,
input_transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["InputTransform"]:
return cls.from_transform(
transform=transform, running_stage=RunningStage.TRAINING, transforms_registry=transforms_registry
transform=transform,
running_stage=RunningStage.TRAINING,
input_transforms_registry=input_transforms_registry,
)

@classmethod
def from_val_transform(
cls,
transform: PREPROCESS_TRANSFORM_TYPE,
transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["PreprocessTransform"]:
transform: INPUT_TRANSFORM_TYPE,
input_transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["InputTransform"]:
return cls.from_transform(
transform=transform, running_stage=RunningStage.VALIDATING, transforms_registry=transforms_registry
transform=transform,
running_stage=RunningStage.VALIDATING,
input_transforms_registry=input_transforms_registry,
)

@classmethod
def from_test_transform(
cls,
transform: PREPROCESS_TRANSFORM_TYPE,
transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["PreprocessTransform"]:
transform: INPUT_TRANSFORM_TYPE,
input_transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["InputTransform"]:
return cls.from_transform(
transform=transform, running_stage=RunningStage.TESTING, transforms_registry=transforms_registry
transform=transform, running_stage=RunningStage.TESTING, input_transforms_registry=input_transforms_registry
)

@classmethod
def from_predict_transform(
cls,
transform: PREPROCESS_TRANSFORM_TYPE,
transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["PreprocessTransform"]:
transform: INPUT_TRANSFORM_TYPE,
input_transforms_registry: Optional[FlashRegistry] = None,
) -> Optional["InputTransform"]:
return cls.from_transform(
transform=transform, running_stage=RunningStage.PREDICTING, transforms_registry=transforms_registry
transform=transform,
running_stage=RunningStage.PREDICTING,
input_transforms_registry=input_transforms_registry,
)

def _resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str, Callable]]:
from flash.core.data.data_pipeline import DataPipeline

resolved_function = getattr(
self,
DataPipeline._resolve_function_hierarchy("configure_transforms", self, running_stage, PreprocessTransform),
DataPipeline._resolve_function_hierarchy("configure_transforms", self, running_stage, InputTransform),
)
params = inspect.signature(resolved_function).parameters
transforms_out: Optional[Dict[str, Callable]] = resolved_function(
**{k: v for k, v in self._transform_kwargs.items() if k in params}
)

transforms_out = transforms_out or {}
for placement in PreprocessTransformPlacement:
for placement in InputTransformPlacement:
transform_name = f"configure_{placement.value}"
resolved_function = getattr(
self, DataPipeline._resolve_function_hierarchy(transform_name, self, running_stage, PreprocessTransform)
self, DataPipeline._resolve_function_hierarchy(transform_name, self, running_stage, InputTransform)
)
params = inspect.signature(resolved_function).parameters
transforms: Optional[Dict[str, Callable]] = resolved_function(
Expand All @@ -278,7 +284,7 @@ def _check_transforms(
if transform is None:
return transform

keys_diff = set(transform.keys()).difference([v for v in PreprocessTransformPlacement])
keys_diff = set(transform.keys()).difference([v for v in InputTransformPlacement])

if len(keys_diff) > 0:
raise MisconfigurationException(
Expand Down Expand Up @@ -315,11 +321,11 @@ def _get_transform(self, transform: Dict[str, Callable]) -> Callable:

@classmethod
def _sanitize_registry_transform(
cls, transform: Tuple[Union[LightningEnum, str], Any], transforms_registry: Optional[FlashRegistry]
cls, transform: Tuple[Union[LightningEnum, str], Any], input_transforms_registry: Optional[FlashRegistry]
) -> Tuple[Union[LightningEnum, str], Dict]:
msg = "The transform should be provided as a tuple with the following types (LightningEnum, Dict[str, Any]) "
msg += "when requesting transform from the registry."
if not transforms_registry:
if not input_transforms_registry:
raise MisconfigurationException("You requested a transform from the registry, but it is empty.")
if isinstance(transform, tuple) and len(transform) > 2:
raise MisconfigurationException(msg)
Expand All @@ -337,7 +343,7 @@ def _sanitize_registry_transform(
def __repr__(self) -> str:
return f"{self.__class__.__name__}(running_stage={self.running_stage}, transform={self.transform})"

def __getitem__(self, placement: PreprocessTransformPlacement) -> Callable:
def __getitem__(self, placement: InputTransformPlacement) -> Callable:
return self.transform[placement]

def _make_collates(self, on_device: bool, collate: Callable) -> Tuple[Callable, Callable]:
Expand All @@ -359,18 +365,18 @@ def _create_collate_preprocessors(self) -> Tuple[Any]:
prefix: str = _STAGES_PREFIX[self.running_stage]

func_names: Dict[str, str] = {
k: DataPipeline._resolve_function_hierarchy(k, self, self.running_stage, PreprocessTransform)
for k in [v.value for v in PreprocessTransformPlacement]
k: DataPipeline._resolve_function_hierarchy(k, self, self.running_stage, InputTransform)
for k in [v.value for v in InputTransformPlacement]
}

collate_fn: Callable = getattr(self, func_names["collate"])

per_batch_transform_overriden: bool = DataPipeline._is_overriden_recursive(
"per_batch_transform", self, PreprocessTransform, prefix=prefix
"per_batch_transform", self, InputTransform, prefix=prefix
)

per_sample_transform_on_device_overriden: bool = DataPipeline._is_overriden_recursive(
"per_sample_transform_on_device", self, PreprocessTransform, prefix=prefix
"per_sample_transform_on_device", self, InputTransform, prefix=prefix
)

is_per_overriden = per_batch_transform_overriden and per_sample_transform_on_device_overriden
Expand Down
12 changes: 6 additions & 6 deletions flash/core/data/new_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import DefaultPreprocess, Postprocess
from flash.core.data.datasets import BaseDataset
from flash.core.data.preprocess_transform import PREPROCESS_TRANSFORM_TYPE, PreprocessTransform
from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE

Expand Down Expand Up @@ -267,10 +267,10 @@ def create_flash_datasets(
val_data: Optional[Any] = None,
test_data: Optional[Any] = None,
predict_data: Optional[Any] = None,
train_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
val_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
test_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
predict_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None,
train_transform: Optional[INPUT_TRANSFORM_TYPE] = None,
val_transform: Optional[INPUT_TRANSFORM_TYPE] = None,
test_transform: Optional[INPUT_TRANSFORM_TYPE] = None,
predict_transform: Optional[INPUT_TRANSFORM_TYPE] = None,
**flash_dataset_kwargs,
) -> Tuple[Optional[BaseDataset]]:
cls._verify_flash_dataset_enum(enum)
Expand Down Expand Up @@ -311,7 +311,7 @@ def _create_flash_dataset(
flash_dataset_cls,
*load_data_args,
running_stage: RunningStage,
transform: Optional[PreprocessTransform],
transform: Optional[InputTransform],
**kwargs,
) -> Optional[BaseDataset]:
if load_data_args[0] is not None:
Expand Down
Loading