Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Nov 9, 2021
1 parent a49ee38 commit 63b4d3a
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 142 deletions.
4 changes: 2 additions & 2 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ Flash takes care of calling the right hooks for each stage.

Example::

# This will be wrapped into a :class:`~flash.core.data.io.input_transform.flash.core.data.io.input_transform._InputTransformPreprocessor`.
# This will be wrapped into a :class:`~flash.core.data.io.input_transform.flash.core.data.io.input_transform._InputTransformProcessor`.
def collate_fn(samples: Sequence[Any]) -> Any:

# This will be wrapped into a :class:`~flash.core.data.io.input_transform._InputTransformSequential`
Expand Down Expand Up @@ -367,7 +367,7 @@ Flash takes care of calling the right hooks for each stage.

Example::

# This will be wrapped into a :class:`~flash.core.data.io.input_transform._InputTransformPreprocessor`
# This will be wrapped into a :class:`~flash.core.data.io.input_transform._InputTransformProcessor`
def collate_fn(samples: Sequence[Any]) -> Any:

# if ``per_batch_transform`` hook is overridden, those functions below will be no-ops
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def set_running_stages(self):

def _resolve_collate_fn(self, dataset: Dataset, running_stage: RunningStage) -> Optional[Callable]:
if isinstance(dataset, (BaseAutoDataset, SplitDataset)):
return self.data_pipeline.worker_input_transform_preprocessor(running_stage)
return self.data_pipeline.worker_input_transform_processor(running_stage)

def _train_dataloader(self) -> DataLoader:
"""Configure the train dataloader of the datamodule."""
Expand Down
42 changes: 21 additions & 21 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from flash.core.data.batch import _DeserializeProcessor
from flash.core.data.data_source import DataSource
from flash.core.data.io.input_transform import (
_InputTransformPreprocessor,
_InputTransformProcessor,
_InputTransformSequential,
DefaultInputTransform,
InputTransform,
Expand Down Expand Up @@ -162,17 +162,17 @@ def _identity(samples: Sequence[Any]) -> Sequence[Any]:
return samples

def deserialize_processor(self) -> _DeserializeProcessor:
return self._create_collate_input_transform_preprocessors(RunningStage.PREDICTING)[0]
return self._create_collate_input_transform_processors(RunningStage.PREDICTING)[0]

def worker_input_transform_preprocessor(
def worker_input_transform_processor(
self, running_stage: RunningStage, collate_fn: Optional[Callable] = None, is_serving: bool = False
) -> _InputTransformPreprocessor:
return self._create_collate_input_transform_preprocessors(
) -> _InputTransformProcessor:
return self._create_collate_input_transform_processors(
running_stage, collate_fn=collate_fn, is_serving=is_serving
)[1]

def device_input_transform_preprocessor(self, running_stage: RunningStage) -> _InputTransformPreprocessor:
return self._create_collate_input_transform_preprocessors(running_stage)[2]
def device_input_transform_processor(self, running_stage: RunningStage) -> _InputTransformProcessor:
return self._create_collate_input_transform_processors(running_stage)[2]

def output_transform_processor(self, running_stage: RunningStage, is_serving=False) -> _OutputTransformProcessor:
return self._create_output_transform_processor(running_stage, is_serving=is_serving)
Expand Down Expand Up @@ -211,12 +211,12 @@ def _make_collates(self, on_device: bool, collate: Callable) -> Tuple[Callable,
return self._identity, collate
return collate, self._identity

def _create_collate_input_transform_preprocessors(
def _create_collate_input_transform_processors(
self,
stage: RunningStage,
collate_fn: Optional[Callable] = None,
is_serving: bool = False,
) -> Tuple[_DeserializeProcessor, _InputTransformPreprocessor, _InputTransformPreprocessor]:
) -> Tuple[_DeserializeProcessor, _InputTransformProcessor, _InputTransformProcessor]:

original_collate_fn = collate_fn

Expand Down Expand Up @@ -261,7 +261,7 @@ def _create_collate_input_transform_preprocessors(

worker_collate_fn = (
worker_collate_fn.collate_fn
if isinstance(worker_collate_fn, _InputTransformPreprocessor)
if isinstance(worker_collate_fn, _InputTransformProcessor)
else worker_collate_fn
)

Expand All @@ -275,7 +275,7 @@ def _create_collate_input_transform_preprocessors(
getattr(input_transform, func_names["pre_tensor_transform"]),
getattr(input_transform, func_names["to_tensor_transform"]),
)
worker_input_transform_preprocessor = _InputTransformPreprocessor(
worker_input_transform_processor = _InputTransformProcessor(
input_transform,
worker_collate_fn,
_InputTransformSequential(
Expand All @@ -289,8 +289,8 @@ def _create_collate_input_transform_preprocessors(
getattr(input_transform, func_names["per_batch_transform"]),
stage,
)
worker_input_transform_preprocessor._original_collate_fn = original_collate_fn
device_input_transform_preprocessor = _InputTransformPreprocessor(
worker_input_transform_processor._original_collate_fn = original_collate_fn
device_input_transform_processor = _InputTransformProcessor(
input_transform,
device_collate_fn,
getattr(input_transform, func_names["per_sample_transform_on_device"]),
Expand All @@ -299,11 +299,11 @@ def _create_collate_input_transform_preprocessors(
apply_per_sample_transform=device_collate_fn != self._identity,
on_device=True,
)
return deserialize_processor, worker_input_transform_preprocessor, device_input_transform_preprocessor
return deserialize_processor, worker_input_transform_processor, device_input_transform_processor

@staticmethod
def _model_transfer_to_device_wrapper(
func: Callable, input_transform: _InputTransformPreprocessor, model: "Task", stage: RunningStage
func: Callable, input_transform: _InputTransformProcessor, model: "Task", stage: RunningStage
) -> Callable:

if not isinstance(func, _StageOrchestrator):
Expand Down Expand Up @@ -380,7 +380,7 @@ def _set_loader(model: "Task", loader_name: str, new_loader: DataLoader) -> None
setattr(curr_attr, final_name, new_loader)
setattr(model, final_name, new_loader)

def _attach_preprocess_to_model(
def _attach_input_transform_to_model(
self,
model: "Task",
stage: Optional[RunningStage] = None,
Expand Down Expand Up @@ -421,7 +421,7 @@ def _attach_preprocess_to_model(
if isinstance(loader, DataLoader):
dl_args = {k: v for k, v in vars(loader).items() if not k.startswith("_")}

_, dl_args["collate_fn"], device_collate_fn = self._create_collate_input_transform_preprocessors(
_, dl_args["collate_fn"], device_collate_fn = self._create_collate_input_transform_processors(
stage=stage, collate_fn=dl_args["collate_fn"], is_serving=is_serving
)

Expand Down Expand Up @@ -486,18 +486,18 @@ def _attach_to_model(
is_serving: bool = False,
):
# not necessary to detach. preprocessing and postprocessing for stage will be overwritten.
self._attach_preprocess_to_model(model, stage)
self._attach_input_transform_to_model(model, stage)

if not stage or stage == RunningStage.PREDICTING:
self._attach_output_transform_to_model(model, RunningStage.PREDICTING, is_serving=is_serving)

def _detach_from_model(self, model: "Task", stage: Optional[RunningStage] = None):
self._detach_preprocessing_from_model(model, stage)
self._detach_input_transform_from_model(model, stage)

if not stage or stage == RunningStage.PREDICTING:
self._detach_output_transform_from_model(model)

def _detach_preprocessing_from_model(self, model: "Task", stage: Optional[RunningStage] = None):
def _detach_input_transform_from_model(self, model: "Task", stage: Optional[RunningStage] = None):
if not stage:
stages = [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING]
elif isinstance(stage, RunningStage):
Expand Down Expand Up @@ -542,7 +542,7 @@ def _detach_preprocessing_from_model(self, model: "Task", stage: Optional[Runnin
if default_collate:
dl_args["collate_fn"] = default_collate

if isinstance(dl_args["collate_fn"], _InputTransformPreprocessor):
if isinstance(dl_args["collate_fn"], _InputTransformProcessor):
dl_args["collate_fn"] = dl_args["collate_fn"]._original_collate_fn

if isinstance(dl_args["dataset"], (IterableAutoDataset, IterableDataset)):
Expand Down
16 changes: 8 additions & 8 deletions flash/core/data/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from flash.core.data.data_pipeline import DataPipeline
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.io.input_transform import _InputTransformPreprocessor
from flash.core.data.io.input_transform import _InputTransformProcessor
from flash.core.data.properties import Properties
from flash.core.data.states import CollateFn
from flash.core.data.utils import _INPUT_TRANSFORM_FUNCS, _STAGES_PREFIX
Expand Down Expand Up @@ -356,14 +356,14 @@ def _make_collates(self, on_device: bool, collate: Callable) -> Tuple[Callable,
@property
def dataloader_collate_fn(self):
"""Generate the function to be injected within the DataLoader as the collate_fn."""
return self._create_collate_input_transform_preprocessors()[0]
return self._create_collate_input_transform_processors()[0]

@property
def on_after_batch_transfer_fn(self):
"""Generate the function to be injected after the on_after_batch_transfer from the LightningModule."""
return self._create_collate_input_transform_preprocessors()[1]
return self._create_collate_input_transform_processors()[1]

def _create_collate_input_transform_preprocessors(self) -> Tuple[Any]:
def _create_collate_input_transform_processors(self) -> Tuple[Any]:
prefix: str = _STAGES_PREFIX[self.running_stage]

func_names: Dict[str, str] = {
Expand Down Expand Up @@ -399,18 +399,18 @@ def _create_collate_input_transform_preprocessors(self) -> Tuple[Any]:

worker_collate_fn = (
worker_collate_fn.collate_fn
if isinstance(worker_collate_fn, _InputTransformPreprocessor)
if isinstance(worker_collate_fn, _InputTransformProcessor)
else worker_collate_fn
)

worker_input_transform_preprocessor = _InputTransformPreprocessor(
worker_input_transform_processor = _InputTransformProcessor(
self,
worker_collate_fn,
getattr(self, func_names["per_sample_transform"]),
getattr(self, func_names["per_batch_transform"]),
self.running_stage,
)
device_input_transform_preprocessor = _InputTransformPreprocessor(
device_input_transform_processor = _InputTransformProcessor(
self,
device_collate_fn,
getattr(self, func_names["per_sample_transform_on_device"]),
Expand All @@ -419,4 +419,4 @@ def _create_collate_input_transform_preprocessors(self) -> Tuple[Any]:
apply_per_sample_transform=device_collate_fn != self._identity,
on_device=True,
)
return worker_input_transform_preprocessor, device_input_transform_preprocessor
return worker_input_transform_processor, device_input_transform_processor
8 changes: 4 additions & 4 deletions flash/core/data/io/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,8 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool):


class _InputTransformSequential(torch.nn.Module):
"""This class is used to chain 3 functions together for the _InputTransformPreprocessor
``per_sample_transform`` function.
"""This class is used to chain 3 functions together for the _InputTransformProcessor ``per_sample_transform``
function.
1. ``pre_tensor_transform``
2. ``to_tensor_transform``
Expand Down Expand Up @@ -604,7 +604,7 @@ def __str__(self) -> str:
)


class _InputTransformPreprocessor(torch.nn.Module):
class _InputTransformProcessor(torch.nn.Module):
"""
This class is used to encapsultate the following functions of a InputTransformInputTransform Object:
Inside a worker:
Expand Down Expand Up @@ -703,7 +703,7 @@ def forward(self, samples: Sequence[Any]) -> Any:
def __str__(self) -> str:
# todo: define repr function which would take object and string attributes to be shown
return (
"_InputTransformPreprocessor:\n"
"_InputTransformProcessor:\n"
f"\t(per_sample_transform): {str(self.per_sample_transform)}\n"
f"\t(collate_fn): {str(self.collate_fn)}\n"
f"\t(per_batch_transform): {str(self.per_batch_transform)}\n"
Expand Down
4 changes: 2 additions & 2 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,13 +492,13 @@ def predict(
dataset = data_pipeline.data_source.generate_dataset(x, running_stage)
dataloader = self.process_predict_dataset(dataset)
x = list(dataloader.dataset)
x = data_pipeline.worker_input_transform_preprocessor(running_stage, collate_fn=dataloader.collate_fn)(x)
x = data_pipeline.worker_input_transform_processor(running_stage, collate_fn=dataloader.collate_fn)(x)
# todo (tchaton): Remove this when sync with Lightning master.
if len(inspect.signature(self.transfer_batch_to_device).parameters) == 3:
x = self.transfer_batch_to_device(x, self.device, 0)
else:
x = self.transfer_batch_to_device(x, self.device)
x = data_pipeline.device_input_transform_preprocessor(running_stage)(x)
x = data_pipeline.device_input_transform_processor(running_stage)(x)
x = x[0] if isinstance(x, list) else x
predictions = self.predict_step(x, 0) # batch_idx is always 0 when running with `model.predict`
predictions = data_pipeline.output_transform_processor(running_stage)(predictions)
Expand Down
8 changes: 4 additions & 4 deletions flash/core/serve/flash_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def __init__(self, model):
self.model = model
self.model.eval()
self.data_pipeline = model.build_data_pipeline()
self.worker_input_transform_preprocessor = self.data_pipeline.worker_input_transform_preprocessor(
self.worker_input_transform_processor = self.data_pipeline.worker_input_transform_processor(
RunningStage.PREDICTING, is_serving=True
)
self.device_input_transform_preprocessor = self.data_pipeline.device_input_transform_preprocessor(
self.device_input_transform_processor = self.data_pipeline.device_input_transform_processor(
RunningStage.PREDICTING
)
self.output_transform_processor = self.data_pipeline.output_transform_processor(
Expand All @@ -74,12 +74,12 @@ def __init__(self, model):
)
def predict(self, inputs):
with torch.no_grad():
inputs = self.worker_input_transform_preprocessor(inputs)
inputs = self.worker_input_transform_processor(inputs)
if self.extra_arguments:
inputs = self.model.transfer_batch_to_device(inputs, self.device, 0)
else:
inputs = self.model.transfer_batch_to_device(inputs, self.device)
inputs = self.device_input_transform_preprocessor(inputs)
inputs = self.device_input_transform_processor(inputs)
preds = self.model.predict_step(inputs, 0)
preds = self.output_transform_processor(preds)
return preds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,16 @@
import pytest
import torch
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data._utils.collate import default_collate

from flash.core.data.data_module import DataModule
from flash import DataModule
from flash.core.data.data_source import DefaultDataSources
from flash.core.data.io.input_transform import DefaultInputTransform
from flash.core.data.io.input_transform import (
_InputTransformProcessor,
_InputTransformSequential,
DefaultInputTransform,
)
from flash.core.utilities.stages import RunningStage


class CustomInputTransform(DefaultInputTransform):
Expand All @@ -33,6 +39,46 @@ def __init__(self):
)


def test_input_transform_processor_str():
input_transform_processor = _InputTransformProcessor(
Mock(name="input_transform"),
default_collate,
torch.relu,
torch.softmax,
RunningStage.TRAINING,
False,
True,
)
assert str(input_transform_processor) == (
"_InputTransformProcessor:\n"
"\t(per_sample_transform): FuncModule(relu)\n"
"\t(collate_fn): FuncModule(default_collate)\n"
"\t(per_batch_transform): FuncModule(softmax)\n"
"\t(apply_per_sample_transform): False\n"
"\t(on_device): True\n"
"\t(stage): RunningStage.TRAINING"
)


def test_sequential_str():
sequential = _InputTransformSequential(
Mock(name="input_transform"),
torch.softmax,
torch.as_tensor,
torch.relu,
RunningStage.TRAINING,
True,
)
assert str(sequential) == (
"_InputTransformSequential:\n"
"\t(pre_tensor_transform): FuncModule(softmax)\n"
"\t(to_tensor_transform): FuncModule(as_tensor)\n"
"\t(post_tensor_transform): FuncModule(relu)\n"
"\t(assert_contains_tensor): True\n"
"\t(stage): RunningStage.TRAINING"
)


def test_data_source_of_name():
input_transform = CustomInputTransform()

Expand Down
Loading

0 comments on commit 63b4d3a

Please sign in to comment.