Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
VISSL datapipeline fix (#880)
Browse files Browse the repository at this point in the history
  • Loading branch information
ananyahjha93 authored Oct 21, 2021
1 parent b41722a commit 1dcdbbe
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 10 deletions.
69 changes: 61 additions & 8 deletions flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,16 @@
from flash.core.data.batch import default_uncollate
from flash.core.data.callback import FlashCallback
from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataKeys, DefaultDataSources
from flash.core.data.properties import Properties
from flash.core.data.states import CollateFn
from flash.core.data.properties import ProcessState, Properties
from flash.core.data.states import (
CollateFn,
PerBatchTransform,
PerBatchTransformOnDevice,
PerSampleTransformOnDevice,
PostTensorTransform,
PreTensorTransform,
ToTensorTransform,
)
from flash.core.data.transforms import ApplyToKeys
from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules, CurrentRunningStageFuncContext
from flash.core.utilities.stages import RunningStage
Expand Down Expand Up @@ -346,17 +354,62 @@ def _apply_sample_transform(self, sample: Any) -> Any:
return [self.current_transform(s) for s in sample]
return self.current_transform(sample)

def _apply_batch_transform(self, batch: Any):
return self.current_transform(batch)

def _apply_transform_on_sample(self, sample: Any, transform: Callable):
if isinstance(sample, list):
return [transform(s) for s in sample]

return transform(sample)

def _apply_transform_on_batch(self, batch: Any, transform: Callable):
return transform(batch)

def _apply_process_state_transform(
self,
process_state: ProcessState,
sample: Optional[Any] = None,
batch: Optional[Any] = None,
):
# assert both sample and batch are not None
if sample is None:
assert batch is not None, "sample not provided, batch should not be None"
mode = "batch"
else:
assert batch is None, "sample provided, batch should be None"
mode = "sample"

process_state_transform = self.get_state(process_state)

if process_state_transform is not None:
if process_state_transform.transform is not None:
if mode == "sample":
return self._apply_transform_on_sample(sample, process_state_transform.transform)
else:
return self._apply_transform_on_batch(batch, process_state_transform.transform)
else:
if mode == "sample":
return sample
else:
return batch
else:
if mode == "sample":
return self._apply_sample_transform(sample)
else:
return self._apply_batch_transform(batch)

def pre_tensor_transform(self, sample: Any) -> Any:
"""Transforms to apply on a single object."""
return self._apply_sample_transform(sample)
return self._apply_process_state_transform(PreTensorTransform, sample=sample)

def to_tensor_transform(self, sample: Any) -> Tensor:
"""Transforms to convert single object to a tensor."""
return self._apply_sample_transform(sample)
return self._apply_process_state_transform(ToTensorTransform, sample=sample)

def post_tensor_transform(self, sample: Tensor) -> Tensor:
"""Transforms to apply on a tensor."""
return self._apply_sample_transform(sample)
return self._apply_process_state_transform(PostTensorTransform, sample=sample)

def per_batch_transform(self, batch: Any) -> Any:
"""Transforms to apply to a whole batch (if possible use this for efficiency).
Expand All @@ -366,7 +419,7 @@ def per_batch_transform(self, batch: Any) -> Any:
This option is mutually exclusive with :meth:`per_sample_transform_on_device`,
since if both are specified, uncollation has to be applied.
"""
return self.current_transform(batch)
return self._apply_process_state_transform(PerBatchTransform, batch=batch)

def collate(self, samples: Sequence, metadata=None) -> Any:
"""Transform to convert a sequence of samples to a collated batch."""
Expand Down Expand Up @@ -400,7 +453,7 @@ def per_sample_transform_on_device(self, sample: Any) -> Any:
This function won't be called within the dataloader workers, since to make that happen
each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU).
"""
return self.current_transform(sample)
return self._apply_process_state_transform(PerSampleTransformOnDevice, sample=sample)

def per_batch_transform_on_device(self, batch: Any) -> Any:
"""Transforms to apply to a whole batch (if possible use this for efficiency).
Expand All @@ -410,7 +463,7 @@ def per_batch_transform_on_device(self, batch: Any) -> Any:
This function won't be called within the dataloader workers, since to make that happen
each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU).
"""
return self.current_transform(batch)
return self._apply_process_state_transform(PerBatchTransformOnDevice, batch=batch)

def available_data_sources(self) -> Sequence[str]:
"""Get the list of available data source names for use with this
Expand Down
18 changes: 18 additions & 0 deletions flash/core/data/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@ class PostTensorTransform(ProcessState):
transform: Optional[Callable] = None


@dataclass(unsafe_hash=True, frozen=True)
class PerSampleTransformOnDevice(ProcessState):

transform: Optional[Callable] = None


@dataclass(unsafe_hash=True, frozen=True)
class PerBatchTransform(ProcessState):

transform: Optional[Callable] = None


@dataclass(unsafe_hash=True, frozen=True)
class PerBatchTransformOnDevice(ProcessState):

transform: Optional[Callable] = None


@dataclass(unsafe_hash=True, frozen=True)
class CollateFn(ProcessState):

Expand Down
16 changes: 15 additions & 1 deletion flash/image/embedding/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@

from flash.core.adapter import AdapterTask
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.states import CollateFn, PostTensorTransform, PreTensorTransform, ToTensorTransform
from flash.core.data.states import (
CollateFn,
PerBatchTransform,
PerBatchTransformOnDevice,
PerSampleTransformOnDevice,
PostTensorTransform,
PreTensorTransform,
ToTensorTransform,
)
from flash.core.data.transforms import ApplyToKeys
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _VISSL_AVAILABLE
Expand Down Expand Up @@ -88,6 +96,9 @@ def __init__(
if training_strategy_kwargs is None:
training_strategy_kwargs = {}

if pretraining_transform_kwargs is None:
pretraining_transform_kwargs = {}

backbone, _ = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs)

metadata = self.training_strategies.get(training_strategy, with_metadata=True)
Expand Down Expand Up @@ -118,6 +129,9 @@ def __init__(
self.adapter.set_state(ToTensorTransform(to_tensor_transform))
self.adapter.set_state(PostTensorTransform(None))
self.adapter.set_state(PreTensorTransform(None))
self.adapter.set_state(PerSampleTransformOnDevice(None))
self.adapter.set_state(PerBatchTransform(None))
self.adapter.set_state(PerBatchTransformOnDevice(None))

warnings.warn(
"Warning: VISSL ImageEmbedder overrides any user provided transforms"
Expand Down
79 changes: 79 additions & 0 deletions tests/core/data/test_data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from flash.core.data.data_source import DataSource
from flash.core.data.process import DefaultPreprocess, Deserializer, Postprocess, Preprocess, Serializer
from flash.core.data.properties import ProcessState
from flash.core.data.states import PerBatchTransformOnDevice, ToTensorTransform
from flash.core.model import Task
from flash.core.utilities.imports import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.stages import RunningStage
Expand Down Expand Up @@ -733,6 +734,84 @@ def test_datapipeline_transformations(tmpdir):
assert data_source.predict_load_data_called


@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_datapipeline_transformations_overridden_by_task():
# define preprocess transforms
class ImageDataSource(DataSource):
def load_data(self, folder: str):
# from folder -> return files paths
return ["a.jpg", "b.jpg"]

def load_sample(self, path: str) -> Image.Image:
# from a file path, load the associated image
return np.random.uniform(0, 1, (64, 64, 3))

class ImageClassificationPreprocess(DefaultPreprocess):
def __init__(
self,
train_transform=None,
val_transform=None,
test_transform=None,
predict_transform=None,
):
super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={"default": ImageDataSource()},
)

def default_transforms(self):
return {
"to_tensor_transform": T.Compose([T.ToTensor()]),
"per_batch_transform_on_device": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
}

# define task which overrides transforms using set_state
class CustomModel(Task):
def __init__(self):
super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss())

# override default transform to resize images
self.set_state(ToTensorTransform(T.Compose([T.ToTensor(), T.Resize(128)])))

# remove normalization, => image still in [0, 1] range
self.set_state(PerBatchTransformOnDevice(None))

def training_step(self, batch, batch_idx):
assert batch.shape == torch.Size([2, 3, 128, 128])
assert torch.max(batch) <= 1.0
assert torch.min(batch) >= 0.0

def validation_step(self, batch, batch_idx):
assert batch.shape == torch.Size([2, 3, 128, 128])
assert torch.max(batch) <= 1.0
assert torch.min(batch) >= 0.0

class CustomDataModule(DataModule):

preprocess_cls = ImageClassificationPreprocess

datamodule = CustomDataModule.from_data_source(
"default",
"train_folder",
"val_folder",
None,
batch_size=2,
)

# call trainer
model = CustomModel()
trainer = Trainer(
max_epochs=1,
limit_train_batches=2,
limit_val_batches=1,
num_sanity_val_steps=1,
)
trainer.fit(model, datamodule=datamodule)


def test_is_overriden_recursive(tmpdir):
class TestPreprocess(DefaultPreprocess):
def collate(self, *_):
Expand Down
3 changes: 2 additions & 1 deletion tests/text/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def test_load_from_checkpoint_dependency_error():
"cli_args",
(
["flash", "text_classification", "--trainer.fast_dev_run", "True"],
["flash", "text_classification", "--trainer.fast_dev_run", "True", "from_toxic"],
# TODO: update this to work with Pietro's new text data loading (separate PR)
# ["flash", "text_classification", "--trainer.fast_dev_run", "True", "from_toxic"],
),
)
def test_cli(cli_args):
Expand Down

0 comments on commit 1dcdbbe

Please sign in to comment.