Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
support list & callable transforms (#693)
Browse files Browse the repository at this point in the history
Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: Akihiro Nitta <[email protected]>
  • Loading branch information
3 people authored Sep 2, 2021
1 parent bc82b6b commit 4ebc45d
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 21 deletions.
7 changes: 4 additions & 3 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,9 @@ def from_data_source(
val_data: Any = None,
test_data: Any = None,
predict_data: Any = None,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[Preprocess] = None,
Expand Down Expand Up @@ -522,6 +522,7 @@ def from_data_source(
},
)
"""

preprocess = preprocess or cls.preprocess_cls(
train_transform,
val_transform,
Expand Down
24 changes: 15 additions & 9 deletions flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import inspect
import os
from abc import ABC, abstractclassmethod, abstractmethod
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union

import torch
from pytorch_lightning.trainer.states import RunningStage
Expand All @@ -28,6 +28,7 @@
from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataKeys, DefaultDataSources
from flash.core.data.properties import Properties
from flash.core.data.states import CollateFn
from flash.core.data.transforms import ApplyToKeys
from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules, CurrentRunningStageFuncContext


Expand Down Expand Up @@ -177,10 +178,10 @@ def pre_tensor_transform(self, sample: PIL.Image) -> PIL.Image:

def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
predict_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
data_sources: Optional[Dict[str, "DataSource"]] = None,
deserializer: Optional["Deserializer"] = None,
default_data_source: Optional[str] = None,
Expand Down Expand Up @@ -252,6 +253,11 @@ def _check_transforms(
if transform is None:
return transform

if isinstance(transform, list):
transform = {"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, torch.nn.Sequential(*transform))}
elif callable(transform):
transform = {"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, transform)}

if not isinstance(transform, Dict):
raise MisconfigurationException(
"Transform should be a dict. " f"Here are the available keys for your transforms: {_PREPROCESS_FUNCS}."
Expand Down Expand Up @@ -439,10 +445,10 @@ def data_source_of_name(self, data_source_name: str) -> DataSource:
class DefaultPreprocess(Preprocess):
def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
predict_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
data_sources: Optional[Dict[str, "DataSource"]] = None,
default_data_source: Optional[str] = None,
):
Expand Down
12 changes: 6 additions & 6 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ def from_data_frame(
predict_data_frame: Optional[pd.DataFrame] = None,
predict_images_root: Optional[str] = None,
predict_resolver: Optional[Callable[[str, str], str]] = None,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[Preprocess] = None,
Expand Down Expand Up @@ -217,9 +217,9 @@ def from_csv(
predict_file: Optional[str] = None,
predict_images_root: Optional[str] = None,
predict_resolver: Optional[Callable[[str, str], str]] = None,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[Preprocess] = None,
Expand Down
9 changes: 6 additions & 3 deletions tests/core/data/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ class Serializer2State(ProcessState):


def test_saving_with_serializers(tmpdir):

checkpoint_file = os.path.join(tmpdir, "tmp.ckpt")

class CustomModel(Task):
Expand Down Expand Up @@ -122,7 +121,6 @@ def __init__(self):


def test_data_source_of_name():

preprocess = CustomPreprocess()

assert preprocess.data_source_of_name("test")() == "test"
Expand All @@ -135,7 +133,6 @@ def test_data_source_of_name():


def test_available_data_sources():

preprocess = CustomPreprocess()

assert DefaultDataSources.TENSORS in preprocess.available_data_sources()
Expand All @@ -147,3 +144,9 @@ def test_available_data_sources():
assert DefaultDataSources.TENSORS in data_module.available_data_sources()
assert "test" in data_module.available_data_sources()
assert len(data_module.available_data_sources()) == 3


def test_check_transforms():
transform = torch.nn.Identity()
DefaultPreprocess(train_transform=transform)
DefaultPreprocess(train_transform=[transform])

0 comments on commit 4ebc45d

Please sign in to comment.