Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[1/N] Data Sources #256

Merged
merged 84 commits into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
735740e
Initial commit
ethanwharris Apr 28, 2021
be01397
POC Initial commit
ethanwharris Apr 29, 2021
214df85
Remove unused code
ethanwharris Apr 29, 2021
8f93bfb
Some fixes
ethanwharris Apr 30, 2021
e8ee4c0
Simplify data source
ethanwharris Apr 30, 2021
653057d
Expand preprocess
ethanwharris Apr 30, 2021
0184332
Fixes
ethanwharris Apr 30, 2021
5172a06
Fixes
ethanwharris Apr 30, 2021
5c3f597
Cleaning
ethanwharris Apr 30, 2021
44d70e1
Fixes
ethanwharris Apr 30, 2021
08657ea
Remove un-needed code
ethanwharris May 4, 2021
73be792
Remove sequence data source
ethanwharris May 4, 2021
3381840
Simplify data source
ethanwharris May 4, 2021
e01987d
Fix FilesDataSource
ethanwharris May 4, 2021
e385dfa
Minor fix
ethanwharris May 4, 2021
dc90754
Add numpy and tesnor data sources
ethanwharris May 4, 2021
c437043
Fixes
ethanwharris May 4, 2021
b32ee34
Onboard object detection
ethanwharris May 5, 2021
bfd320d
update
tchaton May 5, 2021
7e050be
Add text classification
ethanwharris May 5, 2021
34c41d4
Merge branch 'feature/data_sources' of https://github.com/PyTorchLigh…
ethanwharris May 5, 2021
6e0f69d
Small update
ethanwharris May 5, 2021
a2082bc
Add tabular
ethanwharris May 5, 2021
fd07644
Fixes
ethanwharris May 5, 2021
19e966d
Fixes
ethanwharris May 5, 2021
3b7ab0e
Add summarization example
ethanwharris May 6, 2021
d9c00c5
Add translation
ethanwharris May 6, 2021
2da3339
Merge branch 'master' into feature/data_sources
ethanwharris May 6, 2021
d5b8c4a
assert empty data_source in datapipeline creation
edgarriba May 6, 2021
dd35da6
add more assertions for test_classification_task_predict_folder_path
edgarriba May 6, 2021
f2c3f20
Add video
ethanwharris May 6, 2021
e186926
Merge branch 'feature/data_sources' of https://github.com/PyTorchLigh…
ethanwharris May 6, 2021
83024bb
add smoke tests for autodataset
edgarriba May 6, 2021
8309080
improve autodataset test
edgarriba May 6, 2021
f1c44a1
Fix some tests
ethanwharris May 6, 2021
47e8f3f
Fix a test
ethanwharris May 6, 2021
f3a238e
Fixes
ethanwharris May 6, 2021
eb5cfdd
add tests for base and iterable
edgarriba May 6, 2021
a997b9d
add todo with detected error in callbacks test
edgarriba May 6, 2021
b18f0fd
fix test_data_pipeline_init_and_assignement
edgarriba May 6, 2021
bda0a12
fix test_data_pipeline_is_overriden_and_resolve_function_hierarchy
edgarriba May 6, 2021
e4a4f8a
Fix some tests
ethanwharris May 6, 2021
f5f000f
Merge branch 'feature/data_sources' of https://github.com/PyTorchLigh…
ethanwharris May 6, 2021
464fffe
Fix some tests
ethanwharris May 6, 2021
e7d6b66
Fix some tests
ethanwharris May 6, 2021
cc57f86
Fix a test
ethanwharris May 6, 2021
3a63083
Fixes
ethanwharris May 6, 2021
3489953
Fixes
ethanwharris May 6, 2021
1ccf7ab
Fixes
ethanwharris May 6, 2021
64aff9e
Fixes
ethanwharris May 6, 2021
33506b3
Fixes
ethanwharris May 6, 2021
1f50432
deprecate csv test for image classification
edgarriba May 6, 2021
6b587fe
Fix video
ethanwharris May 6, 2021
2794a98
Merge branch 'feature/data_sources' of https://github.com/PyTorchLigh…
ethanwharris May 6, 2021
4215a47
fix test_from_filepaths_splits
edgarriba May 6, 2021
0256c04
Fixes
ethanwharris May 6, 2021
9806b85
Merge branch 'feature/data_sources' of https://github.com/PyTorchLigh…
ethanwharris May 6, 2021
1d5c41b
Fixes
ethanwharris May 6, 2021
8064c65
Fixes
ethanwharris May 6, 2021
a32560c
Fixes
ethanwharris May 6, 2021
4d34d94
Fixes
ethanwharris May 6, 2021
c85a8db
Fixes
ethanwharris May 6, 2021
c93a649
Fixes
ethanwharris May 6, 2021
02fd77b
Fixes
ethanwharris May 6, 2021
3d780fa
Fix docs build
ethanwharris May 6, 2021
704f558
Fix docs build
ethanwharris May 7, 2021
edfc38e
Fix examples
ethanwharris May 7, 2021
4679cb5
Fixes
ethanwharris May 7, 2021
05a1e98
Fixes
ethanwharris May 7, 2021
46b6a4f
Bump huggingface minimal
ethanwharris May 7, 2021
5b2013e
debugging
ethanwharris May 7, 2021
75f3469
debugging
ethanwharris May 7, 2021
950b13f
Fixes
ethanwharris May 7, 2021
f47208c
Fixes
ethanwharris May 7, 2021
db0c991
Respond to comments
ethanwharris May 7, 2021
db1cdf1
feedback
ethanwharris May 7, 2021
88cbc65
Updates
ethanwharris May 7, 2021
4ee1dd4
Fixes
ethanwharris May 7, 2021
ce3fcf2
Fixes
ethanwharris May 7, 2021
ed22b10
revert
ethanwharris May 7, 2021
f453d03
Updates
ethanwharris May 7, 2021
1ae8c56
Fixes
ethanwharris May 7, 2021
1088022
Fixes
ethanwharris May 7, 2021
9032be4
Fixes
ethanwharris May 7, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,15 @@
from pytorch_lightning.utilities import rank_zero_warn

from flash.core.model import Task
from flash.data.process import ProcessState, Serializer
from flash.data.data_source import LabelsState
from flash.data.process import Serializer


def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Calls BCE with logits and cast the target one_hot (y) encoding to floating point precision."""
return F.binary_cross_entropy_with_logits(x, y.float())


@dataclass(unsafe_hash=True, frozen=True)
class ClassificationState(ProcessState):

labels: Optional[List[str]]


class ClassificationTask(Task):

def __init__(
Expand Down Expand Up @@ -140,15 +135,16 @@ class Labels(Classes):
def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False, threshold: float = 0.5):
super().__init__(multi_label=multi_label, threshold=threshold)
self._labels = labels
self.set_state(ClassificationState(labels))
if labels is not None:
self.set_state(LabelsState(labels))

def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
labels = None

if self._labels is not None:
labels = self._labels
else:
state = self.get_state(ClassificationState)
state = self.get_state(LabelsState)
if state is not None:
labels = state.labels

Expand Down
40 changes: 31 additions & 9 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.optim.lr_scheduler import _LRScheduler
Expand All @@ -29,7 +28,8 @@
from flash.core.registry import FlashRegistry
from flash.core.schedulers import _SCHEDULERS_REGISTRY
from flash.core.utils import get_callable_dict
from flash.data.data_pipeline import DataPipeline
from flash.data.data_pipeline import DataPipeline, DataPipelineState
from flash.data.data_source import DataSource, DefaultDataSource
from flash.data.process import Postprocess, Preprocess, Serializer, SerializerMapping


Expand Down Expand Up @@ -103,14 +103,17 @@ def __init__(
self._postprocess: Optional[Postprocess] = postprocess
self._serializer: Optional[Serializer] = None

self._data_pipeline_state: Optional[DataPipelineState] = None

# Explicitly set the serializer to call the setter
self.serializer = serializer

def step(self, batch: Any, batch_idx: int) -> Any:
"""
The training/validation/test step. Override for custom behavior.
"""
x, y = batch
x, y = batch['input'], batch['target']
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
# x, y = batch
y_hat = self(x)
output = {"y_hat": y_hat}
losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()}
Expand Down Expand Up @@ -154,6 +157,7 @@ def test_step(self, batch: Any, batch_idx: int) -> None:
def predict(
self,
x: Any,
data_source: Union[str, DefaultDataSource, DataSource] = DefaultDataSource.FILES,
data_pipeline: Optional[DataPipeline] = None,
) -> Any:
"""
Expand All @@ -169,9 +173,9 @@ def predict(
"""
running_stage = RunningStage.PREDICTING

data_pipeline = self.build_data_pipeline(data_pipeline)
data_pipeline = self.build_data_pipeline(data_source, data_pipeline)

x = [x for x in data_pipeline._generate_auto_dataset(x, running_stage)]
x = [x for x in data_pipeline._data_source.generate_dataset(x, running_stage)]
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
x = data_pipeline.worker_preprocessor(running_stage)(x)
# switch to self.device when #7188 merge in Lightning
x = self.transfer_batch_to_device(x, next(self.parameters()).device)
Expand All @@ -181,6 +185,7 @@ def predict(
return predictions

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
batch = batch['input']
if isinstance(batch, tuple):
batch = batch[0]
elif isinstance(batch, list):
Expand Down Expand Up @@ -252,7 +257,11 @@ def serializer(self, serializer: Union[Serializer, Mapping[str, Serializer]]):
serializer = SerializerMapping(serializer)
self._serializer = serializer

def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> Optional[DataPipeline]:
def build_data_pipeline(
self,
data_source: Optional[Union[str, DefaultDataSource, DataSource]] = None,
data_pipeline: Optional[DataPipeline] = None,
) -> Optional[DataPipeline]:
"""Build a :class:`.DataPipeline` incorporating available
:class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`
objects. These will be overridden in the following resolution order (lowest priority first):
Expand All @@ -269,17 +278,19 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O
Returns:
The fully resolved :class:`.DataPipeline`.
"""
preprocess, postprocess, serializer = None, None, None
old_data_source, preprocess, postprocess, serializer = None, None, None, None

# Datamodule
if self.datamodule is not None and getattr(self.datamodule, 'data_pipeline', None) is not None:
old_data_source = getattr(self.datamodule.data_pipeline, '_data_source', None)
preprocess = getattr(self.datamodule.data_pipeline, '_preprocess_pipeline', None)
postprocess = getattr(self.datamodule.data_pipeline, '_postprocess_pipeline', None)
serializer = getattr(self.datamodule.data_pipeline, '_serializer', None)

elif self.trainer is not None and hasattr(
self.trainer, 'datamodule'
) and getattr(self.trainer.datamodule, 'data_pipeline', None) is not None:
old_data_source = getattr(self.trainer.datamodule.data_pipeline, '_data_source', None)
preprocess = getattr(self.trainer.datamodule.data_pipeline, '_preprocess_pipeline', None)
postprocess = getattr(self.trainer.datamodule.data_pipeline, '_postprocess_pipeline', None)
serializer = getattr(self.trainer.datamodule.data_pipeline, '_serializer', None)
Expand All @@ -305,8 +316,19 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O
getattr(data_pipeline, '_serializer', None),
)

data_pipeline = DataPipeline(preprocess, postprocess, serializer)
data_pipeline.initialize()
data_source = data_source or old_data_source

if str(data_source) == data_source:
data_source = DefaultDataSource(data_source)

if not isinstance(data_source, DataSource):
data_source = preprocess.data_source_of_type(data_source.as_type())()

if old_data_source is not None:
data_source._state.update(old_data_source._state) # TODO: This is a hack

data_pipeline = DataPipeline(data_source, preprocess, postprocess, serializer)
self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state)
return data_pipeline

@property
Expand Down
154 changes: 39 additions & 115 deletions flash/data/auto_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from inspect import signature
from typing import Any, Callable, Iterable, Iterator, Optional, TYPE_CHECKING
from typing import Any, Generic, Iterable, Sequence, TYPE_CHECKING, TypeVar

import torch
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.warning_utils import rank_zero_warn
from torch.utils.data import Dataset, IterableDataset

from flash.data.callback import ControlFlow
from flash.data.process import Preprocess
from flash.data.utils import _STAGES_PREFIX, _STAGES_PREFIX_VALUES, CurrentRunningStageFuncContext
from flash.data.utils import CurrentRunningStageFuncContext

if TYPE_CHECKING:
from flash.data.data_pipeline import DataPipeline
from flash.data.data_source import DataSource
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, but I think sphinx will have issues with forward declarations like this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs build is working for now, I'm not sure we can avoid a circular import here but could maybe just import the module and type as data_source.DataSource.


DATA_TYPE = TypeVar('DATA_TYPE')

class BaseAutoDataset:

class BaseAutoDataset(Generic[DATA_TYPE]):

DATASET_KEY = "dataset"
"""
Expand All @@ -38,141 +37,66 @@ class BaseAutoDataset:

def __init__(
self,
data: Any,
load_data: Optional[Callable] = None,
load_sample: Optional[Callable] = None,
data_pipeline: Optional['DataPipeline'] = None,
running_stage: Optional[RunningStage] = None
data: DATA_TYPE,
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
data_source: 'DataSource',
running_stage: RunningStage,
) -> None:
super().__init__()

if load_data or load_sample:
if data_pipeline:
rank_zero_warn(
"``datapipeline`` is specified but load_sample and/or load_data are also specified. "
"Won't use datapipeline"
)
# initial states
self._load_data_called = False
self._running_stage = None

self.data = data
self.data_pipeline = data_pipeline
self.load_data = load_data
self.load_sample = load_sample
self.data_source = data_source

# trigger the setup only if `running_stage` is provided
self._running_stage = None
self.running_stage = running_stage

@property
def running_stage(self) -> Optional[RunningStage]:
def running_stage(self) -> RunningStage:
return self._running_stage

@running_stage.setter
def running_stage(self, running_stage: RunningStage) -> None:
if self._running_stage != running_stage or (not self._running_stage):
self._running_stage = running_stage
self._load_data_context = CurrentRunningStageFuncContext(self._running_stage, "load_data", self.preprocess)
self._load_sample_context = CurrentRunningStageFuncContext(
self._running_stage, "load_sample", self.preprocess
)
self._setup(running_stage)
from flash.data.data_pipeline import DataPipeline
from flash.data.data_source import DataSource # Hack to avoid circular import TODO: something better than this

@property
def preprocess(self) -> Optional[Preprocess]:
if self.data_pipeline is not None:
return self.data_pipeline._preprocess_pipeline
self._running_stage = running_stage

@property
def control_flow_callback(self) -> Optional[ControlFlow]:
preprocess = self.preprocess
if preprocess is not None:
return ControlFlow(preprocess.callbacks)

def _call_load_data(self, data: Any) -> Iterable:
parameters = signature(self.load_data).parameters
if len(parameters) > 1 and self.DATASET_KEY in parameters:
return self.load_data(data, self)
else:
return self.load_data(data)
self._load_sample_context = CurrentRunningStageFuncContext(self.running_stage, "load_sample", self.data_source)

def _call_load_sample(self, sample: Any) -> Any:
parameters = signature(self.load_sample).parameters
if len(parameters) > 1 and self.DATASET_KEY in parameters:
return self.load_sample(sample, self)
else:
return self.load_sample(sample)

def _setup(self, stage: Optional[RunningStage]) -> None:
assert not stage or _STAGES_PREFIX[stage] in _STAGES_PREFIX_VALUES
previous_load_data = self.load_data.__code__ if self.load_data else None

if self._running_stage and self.data_pipeline and (not self.load_data or not self.load_sample) and stage:
self.load_data = getattr(
self.preprocess,
self.data_pipeline._resolve_function_hierarchy('load_data', self.preprocess, stage, Preprocess)
self.load_sample = getattr(
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
self.data_source,
DataPipeline._resolve_function_hierarchy(
'load_sample',
self.data_source,
self.running_stage,
DataSource,
)
self.load_sample = getattr(
self.preprocess,
self.data_pipeline._resolve_function_hierarchy('load_sample', self.preprocess, stage, Preprocess)
)
if self.load_data and (previous_load_data != self.load_data.__code__ or not self._load_data_called):
if previous_load_data:
rank_zero_warn(
"The load_data function of the Autogenerated Dataset changed. "
"This is not expected! Preloading Data again to ensure compatibility. This may take some time."
)
self.setup()
self._load_data_called = True

def setup(self):
raise NotImplementedError
)

def _call_load_sample(self, sample: Any) -> Any:
if self.load_sample:
with self._load_sample_context:
parameters = signature(self.load_sample).parameters
if len(parameters) > 1 and self.DATASET_KEY in parameters:
sample = self.load_sample(sample, self)
else:
sample = self.load_sample(sample)
return sample

class AutoDataset(BaseAutoDataset, Dataset):

def setup(self):
with self._load_data_context:
self.preprocessed_data = self._call_load_data(self.data)
class AutoDataset(BaseAutoDataset[Sequence[Any]], Dataset):
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

def __getitem__(self, index: int) -> Any:
if not self.load_sample and not self.load_data:
raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.")
if self.load_sample:
with self._load_sample_context:
data: Any = self._call_load_sample(self.preprocessed_data[index])
if self.control_flow_callback:
self.control_flow_callback.on_load_sample(data, self.running_stage)
return data
return self.preprocessed_data[index]
return self._call_load_sample(self.data[index])

def __len__(self) -> int:
if not self.load_sample and not self.load_data:
raise RuntimeError("`__len__` for `load_sample` and `load_data` could not be inferred.")
return len(self.preprocessed_data)
return len(self.data)


class IterableAutoDataset(BaseAutoDataset, IterableDataset):

def setup(self):
with self._load_data_context:
self.dataset = self._call_load_data(self.data)
self.dataset_iter = None
class IterableAutoDataset(BaseAutoDataset[Iterable[Any]], IterableDataset):
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

def __iter__(self):
self.dataset_iter = iter(self.dataset)
self.data_iter = iter(self.data)
return self

def __next__(self) -> Any:
if not self.load_sample and not self.load_data:
raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.")

data = next(self.dataset_iter)

if self.load_sample:
with self._load_sample_context:
data: Any = self._call_load_sample(data)
if self.control_flow_callback:
self.control_flow_callback.on_load_sample(data, self.running_stage)
return data
return data
return self._call_load_sample(next(self.data_iter))
2 changes: 2 additions & 0 deletions flash/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def __init__(
self._post_tensor_transform_context = CurrentFuncContext("post_tensor_transform", preprocess)

def forward(self, sample: Any) -> Any:
self.callback.on_load_sample(sample, self.stage)

with self._current_stage_context:
with self._pre_tensor_transform_context:
sample = self.pre_tensor_transform(sample)
Expand Down
3 changes: 0 additions & 3 deletions flash/data/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,6 @@ def enable(self):
yield
self.enabled = False

def attach_to_datamodule(self, datamodule) -> None:
datamodule.data_fetcher = self

def attach_to_preprocess(self, preprocess: 'flash.data.process.Preprocess') -> None:
preprocess.add_callbacks([self])
self._preprocess = preprocess
Expand Down
Loading