Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Remove old input transform and data module #1058

Merged
merged 15 commits into from
Dec 13, 2021
7 changes: 1 addition & 6 deletions docs/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,10 @@ _______________________
:nosignatures:
:template: classtemplate.rst

~flash.core.data.io.input_transform.BaseInputTransform
~flash.core.data.io.input_transform.DefaultInputTransform
~flash.core.data.process.DeserializerMapping
~flash.core.data.process.Deserializer
~flash.core.data.io.output_transform.OutputTransform
~flash.core.data.io.input_transform.InputTransform
~flash.core.data.input_transform.InputTransform

flash.core.data.properties
__________________________
Expand Down Expand Up @@ -144,9 +142,6 @@ _____________________
:nosignatures:
:template: classtemplate.rst

~flash.core.data.utils.CurrentFuncContext
~flash.core.data.utils.CurrentRunningStageContext
~flash.core.data.utils.CurrentRunningStageFuncContext
~flash.core.data.utils.FuncModule

.. autosummary::
Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/flash.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ flash
~flash.core.data.callback.FlashCallback
~flash.core.data.io.output_transform.OutputTransform
~flash.core.data.io.output.Output
~flash.core.data.io.input_transform.InputTransform
~flash.core.data.input_transform.InputTransform
~flash.core.model.Task
~flash.core.trainer.Trainer
2 changes: 1 addition & 1 deletion flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from flash.core.data.callback import FlashCallback
from flash.core.data.data_module import DataModule
from flash.core.data.input_transform import InputTransform
from flash.core.data.io.input import DataKeys, Input
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.io.output import Output
from flash.core.data.io.output_transform import OutputTransform
from flash.core.data.process import Serializer
Expand Down
2 changes: 1 addition & 1 deletion flash/audio/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
)
from flash.audio.classification.input_transform import AudioClassificationInputTransform
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE
from flash.core.data.io.input import Input
from flash.core.data.new_data_module import DataModule
from flash.core.data.utilities.paths import PATH_TYPE
from flash.core.registry import FlashRegistry
from flash.core.utilities.stages import RunningStage
Expand Down
2 changes: 1 addition & 1 deletion flash/audio/speech_recognition/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
SpeechRecognitionPathsInput,
)
from flash.audio.speech_recognition.output_transform import SpeechRecognitionOutputTransform
from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.data.io.input import Input
from flash.core.data.new_data_module import DataModule
from flash.core.registry import FlashRegistry
from flash.core.utilities.stages import RunningStage

Expand Down
33 changes: 2 additions & 31 deletions flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,14 @@
from torch import Tensor

from flash.core.data.callback import ControlFlow, FlashCallback
from flash.core.data.utils import convert_to_modules, CurrentFuncContext, CurrentRunningStageContext
from flash.core.data.utils import convert_to_modules
from flash.core.utilities.stages import RunningStage

if TYPE_CHECKING:
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.input_transform import InputTransform
from flash.core.data.process import Deserializer


class _DeserializeProcessor(torch.nn.Module):
def __init__(
self,
deserializer: "Deserializer",
input_transform: "InputTransform",
per_sample_transform: Callable,
callbacks: Optional[List[FlashCallback]] = None,
):
super().__init__()
self.input_transform = input_transform
self.callback = ControlFlow(callbacks or [])
self.deserializer = convert_to_modules(deserializer)
self.per_sample_transform = convert_to_modules(per_sample_transform)

self._current_stage_context = CurrentRunningStageContext(RunningStage.PREDICTING, input_transform, reset=False)
self._per_sample_transform_context = CurrentFuncContext("per_sample_transform", input_transform)

def forward(self, sample: str):

sample = self.deserializer(sample)

with self._current_stage_context:
with self._per_sample_transform_context:
sample = self.per_sample_transform(sample)
self.callback.on_per_sample_transform(sample, RunningStage.PREDICTING)

return sample


class _DeserializeProcessorV2(torch.nn.Module):
def __init__(
self,
Expand Down
Loading