Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[RFC] Fix capitalisation of Pre/PostProcessor (#366)
Browse files Browse the repository at this point in the history
* Replace Pre/PostProcess with Pre/Postprocess

* Replace Pre/PostProcess with Pre/Postprocess

* Replace Pre/PostProcess with Pre/Postprocess

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
akihironitta and mergify[bot] authored Jun 8, 2021
1 parent 849dd81 commit ec08f63
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 30 deletions.
6 changes: 3 additions & 3 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ Flash takes care of calling the right hooks for each stage.

Example::

# This will be wrapped into a :class:`~flash.core.data.batch._PreProcessor`.
# This will be wrapped into a :class:`~flash.core.data.batch._Preprocessor`.
def collate_fn(samples: Sequence[Any]) -> Any:

# This will be wrapped into a :class:`~flash.core.data.batch._Sequential`
Expand Down Expand Up @@ -423,7 +423,7 @@ Flash takes care of calling the right hooks for each stage.

Example::

# This will be wrapped into a :class:`~flash.core.data.batch._PreProcessor`
# This will be wrapped into a :class:`~flash.core.data.batch._Preprocessor`
def collate_fn(samples: Sequence[Any]) -> Any:

# if ``per_batch_transform`` hook is overridden, those functions below will be no-ops
Expand Down Expand Up @@ -459,7 +459,7 @@ Here is the pseudo-code:

Example::

# This will be wrapped into a :class:`~flash.core.data.batch._PreProcessor`
# This will be wrapped into a :class:`~flash.core.data.batch._Preprocessor`
def uncollate_fn(batch: Any) -> Any:

batch = per_batch_transform(batch)
Expand Down
8 changes: 4 additions & 4 deletions flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __str__(self) -> str:
)


class _PreProcessor(torch.nn.Module):
class _Preprocessor(torch.nn.Module):
"""
This class is used to encapsultate the following functions of a Preprocess Object:
Inside a worker:
Expand Down Expand Up @@ -190,7 +190,7 @@ def forward(self, samples: Sequence[Any]) -> Any:
def __str__(self) -> str:
# todo: define repr function which would take object and string attributes to be shown
return (
"_PreProcessor:\n"
"_Preprocessor:\n"
f"\t(per_sample_transform): {str(self.per_sample_transform)}\n"
f"\t(collate_fn): {str(self.collate_fn)}\n"
f"\t(per_batch_transform): {str(self.per_batch_transform)}\n"
Expand All @@ -200,7 +200,7 @@ def __str__(self) -> str:
)


class _PostProcessor(torch.nn.Module):
class _Postprocessor(torch.nn.Module):
"""
This class is used to encapsultate the following functions of a Postprocess Object:
Inside main process:
Expand Down Expand Up @@ -245,7 +245,7 @@ def forward(self, batch: Sequence[Any]):

def __str__(self) -> str:
return (
"_PostProcessor:\n"
"_Postprocessor:\n"
f"\t(per_batch_transform): {str(self.per_batch_transform)}\n"
f"\t(uncollate_fn): {str(self.uncollate_fn)}\n"
f"\t(per_sample_transform): {str(self.per_sample_transform)}\n"
Expand Down
26 changes: 13 additions & 13 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.utils.data._utils.collate import default_collate

from flash.core.data.auto_dataset import IterableAutoDataset
from flash.core.data.batch import _PostProcessor, _PreProcessor, _Sequential
from flash.core.data.batch import _Postprocessor, _Preprocessor, _Sequential
from flash.core.data.data_source import DataSource
from flash.core.data.process import DefaultPreprocess, Postprocess, Preprocess, Serializer
from flash.core.data.properties import ProcessState
Expand Down Expand Up @@ -163,13 +163,13 @@ def _is_overriden_recursive(
def _identity(samples: Sequence[Any]) -> Sequence[Any]:
return samples

def worker_preprocessor(self, running_stage: RunningStage) -> _PreProcessor:
def worker_preprocessor(self, running_stage: RunningStage) -> _Preprocessor:
return self._create_collate_preprocessors(running_stage)[0]

def device_preprocessor(self, running_stage: RunningStage) -> _PreProcessor:
def device_preprocessor(self, running_stage: RunningStage) -> _Preprocessor:
return self._create_collate_preprocessors(running_stage)[1]

def postprocessor(self, running_stage: RunningStage) -> _PostProcessor:
def postprocessor(self, running_stage: RunningStage) -> _Postprocessor:
return self._create_uncollate_postprocessors(running_stage)

@classmethod
Expand Down Expand Up @@ -208,7 +208,7 @@ def _create_collate_preprocessors(
self,
stage: RunningStage,
collate_fn: Optional[Callable] = None,
) -> Tuple[_PreProcessor, _PreProcessor]:
) -> Tuple[_Preprocessor, _Preprocessor]:

original_collate_fn = collate_fn

Expand Down Expand Up @@ -254,14 +254,14 @@ def _create_collate_preprocessors(
)

worker_collate_fn = worker_collate_fn.collate_fn if isinstance(
worker_collate_fn, _PreProcessor
worker_collate_fn, _Preprocessor
) else worker_collate_fn

assert_contains_tensor = self._is_overriden_recursive(
"to_tensor_transform", preprocess, Preprocess, prefix=_STAGES_PREFIX[stage]
)

worker_preprocessor = _PreProcessor(
worker_preprocessor = _Preprocessor(
preprocess, worker_collate_fn,
_Sequential(
preprocess,
Expand All @@ -273,7 +273,7 @@ def _create_collate_preprocessors(
), getattr(preprocess, func_names['per_batch_transform']), stage
)
worker_preprocessor._original_collate_fn = original_collate_fn
device_preprocessor = _PreProcessor(
device_preprocessor = _Preprocessor(
preprocess,
device_collate_fn,
getattr(preprocess, func_names['per_sample_transform_on_device']),
Expand All @@ -286,7 +286,7 @@ def _create_collate_preprocessors(

@staticmethod
def _model_transfer_to_device_wrapper(
func: Callable, preprocessor: _PreProcessor, model: 'Task', stage: RunningStage
func: Callable, preprocessor: _Preprocessor, model: 'Task', stage: RunningStage
) -> Callable:

if not isinstance(func, _StageOrchestrator):
Expand All @@ -296,7 +296,7 @@ def _model_transfer_to_device_wrapper(
return func

@staticmethod
def _model_predict_step_wrapper(func: Callable, postprocessor: _PostProcessor, model: 'Task') -> Callable:
def _model_predict_step_wrapper(func: Callable, postprocessor: _Postprocessor, model: 'Task') -> Callable:

if not isinstance(func, _StageOrchestrator):
_original = func
Expand Down Expand Up @@ -400,7 +400,7 @@ def _attach_preprocess_to_model(
self._model_transfer_to_device_wrapper(model.transfer_batch_to_device, device_collate_fn, model, stage)
)

def _create_uncollate_postprocessors(self, stage: RunningStage) -> _PostProcessor:
def _create_uncollate_postprocessors(self, stage: RunningStage) -> _Postprocessor:
save_per_sample = None
save_fn = None

Expand All @@ -422,7 +422,7 @@ def _create_uncollate_postprocessors(self, stage: RunningStage) -> _PostProcesso
else:
save_fn: Callable = getattr(postprocess, func_names["save_data"])

return _PostProcessor(
return _Postprocessor(
getattr(postprocess, func_names["uncollate"]),
getattr(postprocess, func_names["per_batch_transform"]),
getattr(postprocess, func_names["per_sample_transform"]),
Expand Down Expand Up @@ -491,7 +491,7 @@ def _detach_preprocessing_from_model(self, model: 'Task', stage: Optional[Runnin
if isinstance(loader, DataLoader):
dl_args = {k: v for k, v in vars(loader).items() if not k.startswith("_")}

if isinstance(dl_args['collate_fn'], _PreProcessor):
if isinstance(dl_args['collate_fn'], _Preprocessor):
dl_args["collate_fn"] = dl_args["collate_fn"]._original_collate_fn

if isinstance(dl_args["dataset"], IterableAutoDataset):
Expand Down
10 changes: 5 additions & 5 deletions tests/core/data/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.testing import assert_allclose
from torch.utils.data._utils.collate import default_collate

from flash.core.data.batch import _PostProcessor, _PreProcessor, _Sequential, default_uncollate
from flash.core.data.batch import _Postprocessor, _Preprocessor, _Sequential, default_uncollate


def test_sequential_str():
Expand All @@ -42,7 +42,7 @@ def test_sequential_str():


def test_preprocessor_str():
preprocessor = _PreProcessor(
preprocessor = _Preprocessor(
Mock(name="preprocess"),
default_collate,
torch.relu,
Expand All @@ -52,7 +52,7 @@ def test_preprocessor_str():
True,
)
assert str(preprocessor) == (
"_PreProcessor:\n"
"_Preprocessor:\n"
"\t(per_sample_transform): FuncModule(relu)\n"
"\t(collate_fn): FuncModule(default_collate)\n"
"\t(per_batch_transform): FuncModule(softmax)\n"
Expand All @@ -63,14 +63,14 @@ def test_preprocessor_str():


def test_postprocessor_str():
postprocessor = _PostProcessor(
postprocessor = _Postprocessor(
default_uncollate,
torch.relu,
torch.softmax,
None,
)
assert str(postprocessor) == (
"_PostProcessor:\n"
"_Postprocessor:\n"
"\t(per_batch_transform): FuncModule(relu)\n"
"\t(uncollate_fn): FuncModule(default_uncollate)\n"
"\t(per_sample_transform): FuncModule(softmax)\n"
Expand Down
10 changes: 5 additions & 5 deletions tests/core/data/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torch.utils.data._utils.collate import default_collate

from flash.core.data.auto_dataset import IterableAutoDataset
from flash.core.data.batch import _PostProcessor, _PreProcessor
from flash.core.data.batch import _Postprocessor, _Preprocessor
from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import _StageOrchestrator, DataPipeline, DataPipelineState
from flash.core.data.data_source import DataSource
Expand Down Expand Up @@ -318,7 +318,7 @@ def train_dataloader(self) -> Any:
assert model.train_dataloader().collate_fn == default_collate
assert model.transfer_batch_to_device.__self__ == model
model.on_train_dataloader()
assert isinstance(model.train_dataloader().collate_fn, _PreProcessor)
assert isinstance(model.train_dataloader().collate_fn, _Preprocessor)
assert isinstance(model.transfer_batch_to_device, _StageOrchestrator)
model.on_fit_end()
assert model.transfer_batch_to_device.__self__ == model
Expand Down Expand Up @@ -412,7 +412,7 @@ def _compare_pre_processor(self, p1, p2):
assert p1.per_batch_transform.func == p2.per_batch_transform.func

def _assert_stage_orchestrator_state(
self, stage_mapping: Dict, current_running_stage: RunningStage, cls=_PreProcessor
self, stage_mapping: Dict, current_running_stage: RunningStage, cls=_Preprocessor
):
assert isinstance(stage_mapping[current_running_stage], cls)
assert stage_mapping[current_running_stage]
Expand Down Expand Up @@ -471,7 +471,7 @@ def on_predict_dataloader(self) -> None:
assert isinstance(self.predict_step, _StageOrchestrator)
self._assert_stage_orchestrator_state(self.transfer_batch_to_device._stage_mapping, current_running_stage)
self._assert_stage_orchestrator_state(
self.predict_step._stage_mapping, current_running_stage, cls=_PostProcessor
self.predict_step._stage_mapping, current_running_stage, cls=_Postprocessor
)

def on_fit_end(self) -> None:
Expand Down Expand Up @@ -505,7 +505,7 @@ def test_stage_orchestrator_state_attach_detach(tmpdir):

class CustomDataPipeline(DataPipeline):

def _attach_postprocess_to_model(self, model: 'Task', _postprocesssor: _PostProcessor) -> 'Task':
def _attach_postprocess_to_model(self, model: 'Task', _postprocesssor: _Postprocessor) -> 'Task':
model.predict_step = self._model_predict_step_wrapper(model.predict_step, _postprocesssor, model)
return model

Expand Down

0 comments on commit ec08f63

Please sign in to comment.