Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[1/N] Data Sources (#256)
Browse files Browse the repository at this point in the history
* Initial commit

* POC Initial commit

* Remove unused code

* Some fixes

* Simplify data source

* Expand preprocess

* Fixes

* Fixes

* Cleaning

* Fixes

* Remove un-needed code

* Remove sequence data source

* Simplify data source

* Fix FilesDataSource

* Minor fix

* Add numpy and tesnor data sources

* Fixes

* Onboard object detection

* update

* Add text classification

* Small update

* Add tabular

* Fixes

* Fixes

* Add summarization example

* Add translation

* assert empty data_source in datapipeline creation

* add more assertions for test_classification_task_predict_folder_path

* Add video

* add smoke tests for autodataset

* improve autodataset test

* Fix some tests

* Fix a test

* Fixes

* add tests for base and iterable

* add todo with detected error in callbacks test

* fix test_data_pipeline_init_and_assignement

* fix test_data_pipeline_is_overriden_and_resolve_function_hierarchy

* Fix some tests

* Fix some tests

* Fix some tests

* Fix a test

* Fixes

* Fixes

* Fixes

* Fixes

* Fixes

* deprecate csv test for image classification

* Fix video

* fix test_from_filepaths_splits

* Fixes

* Fixes

* Fixes

* Fixes

* Fixes

* Fixes

* Fixes

* Fixes

* Fix docs build

* Fix docs build

* Fix examples

* Fixes

* Fixes

* Bump huggingface minimal

* debugging

* debugging

* Fixes

* Fixes

* Respond to comments

* feedback

* Updates

* Fixes

* Fixes

* revert

* Updates

* Fixes

* Fixes

* Fixes

Co-authored-by: tchaton <[email protected]>
Co-authored-by: Edgar Riba <[email protected]>
  • Loading branch information
3 people authored May 7, 2021
1 parent b3c049c commit ce63fd7
Show file tree
Hide file tree
Showing 74 changed files with 2,422 additions and 2,720 deletions.
14 changes: 13 additions & 1 deletion docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ Here are common terms you need to be familiar with:
- The :class:`~flash.data.data_module.DataModule` contains the dataset, transforms and dataloaders.
* - :class:`~flash.data.data_pipeline.DataPipeline`
- The :class:`~flash.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` objects.
* - :class:`~flash.data.data_source.DataSource`
- The :class:`~flash.data.data_source.DataSource` provides a hook-based API for creating data sets.
* - :class:`~flash.data.process.Preprocess`
- The :class:`~flash.data.process.Preprocess` provides a simple hook-based API to encapsulate your pre-processing logic.
The :class:`~flash.data.process.Preprocess` provides multiple hooks such as :meth:`~flash.data.process.Preprocess.load_data`
Expand Down Expand Up @@ -275,6 +277,17 @@ Example::
API reference
*************

.. _data_source:

DataSource
__________

.. autoclass:: flash.data.data_source.DataSource
:members:


----------

.. _preprocess:

Preprocess
Expand Down Expand Up @@ -325,7 +338,6 @@ __________

.. autoclass:: flash.data.data_module.DataModule
:members:
from_load_data_inputs,
train_dataset,
val_dataset,
test_dataset,
Expand Down
4 changes: 0 additions & 4 deletions docs/source/reference/image_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,4 @@ ImageClassificationData

.. autoclass:: flash.vision.ImageClassificationData

.. automethod:: flash.vision.ImageClassificationData.from_filepaths

.. automethod:: flash.vision.ImageClassificationData.from_folders

.. autoclass:: flash.vision.ImageClassificationPreprocess
2 changes: 1 addition & 1 deletion docs/source/reference/tabular_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,4 @@ TabularData

.. automethod:: flash.tabular.TabularData.from_csv

.. automethod:: flash.tabular.TabularData.from_df
.. automethod:: flash.tabular.TabularData.from_data_frame
2 changes: 0 additions & 2 deletions docs/source/reference/video_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,5 +152,3 @@ VideoClassificationData
-----------------------

.. autoclass:: flash.video.VideoClassificationData

.. automethod:: flash.video.VideoClassificationData.from_paths
21 changes: 8 additions & 13 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Callable, List, Mapping, Optional, Sequence, Union

import torch
Expand All @@ -20,20 +19,15 @@
from pytorch_lightning.utilities import rank_zero_warn

from flash.core.model import Task
from flash.data.process import ProcessState, Serializer
from flash.data.data_source import LabelsState
from flash.data.process import Serializer


def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Calls BCE with logits and cast the target one_hot (y) encoding to floating point precision."""
return F.binary_cross_entropy_with_logits(x, y.float())


@dataclass(unsafe_hash=True, frozen=True)
class ClassificationState(ProcessState):

labels: Optional[List[str]]


class ClassificationTask(Task):

def __init__(
Expand Down Expand Up @@ -130,7 +124,7 @@ class Labels(Classes):
Args:
labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not
provided, will attempt to get them from the :class:`.ClassificationState`.
provided, will attempt to get them from the :class:`.LabelsState`.
multi_label: If true, treats outputs as multi label logits.
Expand All @@ -141,13 +135,16 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False
super().__init__(multi_label=multi_label, threshold=threshold)
self._labels = labels

if labels is not None:
self.set_state(LabelsState(labels))

def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
labels = None

if self._labels is not None:
labels = self._labels
else:
state = self.get_state(ClassificationState)
state = self.get_state(LabelsState)
if state is not None:
labels = state.labels

Expand All @@ -158,7 +155,5 @@ def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
return [labels[cls] for cls in classes]
return labels[classes]
else:
rank_zero_warn(
"No ClassificationState was found, this serializer will act as a Classes serializer.", UserWarning
)
rank_zero_warn("No LabelsState was found, this serializer will act as a Classes serializer.", UserWarning)
return classes
45 changes: 36 additions & 9 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
from importlib import import_module
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
import inspect

import torch
import torchmetrics
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.optim.lr_scheduler import _LRScheduler
Expand All @@ -29,7 +29,8 @@
from flash.core.registry import FlashRegistry
from flash.core.schedulers import _SCHEDULERS_REGISTRY
from flash.core.utils import get_callable_dict
from flash.data.data_pipeline import DataPipeline
from flash.data.data_pipeline import DataPipeline, DataPipelineState
from flash.data.data_source import DataSource, DefaultDataSources
from flash.data.process import Postprocess, Preprocess, Serializer, SerializerMapping


Expand Down Expand Up @@ -103,6 +104,9 @@ def __init__(
self._postprocess: Optional[Postprocess] = postprocess
self._serializer: Optional[Serializer] = None

# TODO: create enum values to define what are the exact states
self._data_pipeline_state: Optional[DataPipelineState] = None

# Explicitly set the serializer to call the setter
self.serializer = serializer

Expand Down Expand Up @@ -154,6 +158,7 @@ def test_step(self, batch: Any, batch_idx: int) -> None:
def predict(
self,
x: Any,
data_source: Optional[str] = None,
data_pipeline: Optional[DataPipeline] = None,
) -> Any:
"""
Expand All @@ -169,9 +174,9 @@ def predict(
"""
running_stage = RunningStage.PREDICTING

data_pipeline = self.build_data_pipeline(data_pipeline)
data_pipeline = self.build_data_pipeline(data_source or "default", data_pipeline)

x = [x for x in data_pipeline._generate_auto_dataset(x, running_stage)]
x = [x for x in data_pipeline.data_source.generate_dataset(x, running_stage)]
x = data_pipeline.worker_preprocessor(running_stage)(x)
# switch to self.device when #7188 merge in Lightning
x = self.transfer_batch_to_device(x, next(self.parameters()).device)
Expand Down Expand Up @@ -252,7 +257,11 @@ def serializer(self, serializer: Union[Serializer, Mapping[str, Serializer]]):
serializer = SerializerMapping(serializer)
self._serializer = serializer

def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> Optional[DataPipeline]:
def build_data_pipeline(
self,
data_source: Optional[str] = None,
data_pipeline: Optional[DataPipeline] = None,
) -> Optional[DataPipeline]:
"""Build a :class:`.DataPipeline` incorporating available
:class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess`
objects. These will be overridden in the following resolution order (lowest priority first):
Expand All @@ -269,20 +278,26 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O
Returns:
The fully resolved :class:`.DataPipeline`.
"""
preprocess, postprocess, serializer = None, None, None
old_data_source, preprocess, postprocess, serializer = None, None, None, None

# Datamodule
if self.datamodule is not None and getattr(self.datamodule, 'data_pipeline', None) is not None:
old_data_source = getattr(self.datamodule.data_pipeline, 'data_source', None)
preprocess = getattr(self.datamodule.data_pipeline, '_preprocess_pipeline', None)
postprocess = getattr(self.datamodule.data_pipeline, '_postprocess_pipeline', None)
serializer = getattr(self.datamodule.data_pipeline, '_serializer', None)

elif self.trainer is not None and hasattr(
self.trainer, 'datamodule'
) and getattr(self.trainer.datamodule, 'data_pipeline', None) is not None:
old_data_source = getattr(self.trainer.datamodule.data_pipeline, 'data_source', None)
preprocess = getattr(self.trainer.datamodule.data_pipeline, '_preprocess_pipeline', None)
postprocess = getattr(self.trainer.datamodule.data_pipeline, '_postprocess_pipeline', None)
serializer = getattr(self.trainer.datamodule.data_pipeline, '_serializer', None)
else:
# TODO: we should log with low severity level that we use defaults to create
# `preprocess`, `postprocess` and `serializer`.
pass

# Defaults / task attributes
preprocess, postprocess, serializer = Task._resolve(
Expand All @@ -305,8 +320,16 @@ def build_data_pipeline(self, data_pipeline: Optional[DataPipeline] = None) -> O
getattr(data_pipeline, '_serializer', None),
)

data_pipeline = DataPipeline(preprocess, postprocess, serializer)
data_pipeline.initialize()
data_source = data_source or old_data_source

if isinstance(data_source, str):
if preprocess is None:
data_source = DataSource() # TODO: warn the user that we are not using the specified data source
else:
data_source = preprocess.data_source_of_name(data_source)

data_pipeline = DataPipeline(data_source, preprocess, postprocess, serializer)
self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state)
return data_pipeline

@property
Expand Down Expand Up @@ -376,12 +399,16 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
# https://pytorch.org/docs/stable/notes/serialization.html
if self.data_pipeline is not None and 'data_pipeline' not in checkpoint:
checkpoint['data_pipeline'] = self.data_pipeline
if self._data_pipeline_state is not None and '_data_pipeline_state' not in checkpoint:
checkpoint['_data_pipeline_state'] = self._data_pipeline_state
super().on_save_checkpoint(checkpoint)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
super().on_load_checkpoint(checkpoint)
if 'data_pipeline' in checkpoint:
self.data_pipeline = checkpoint['data_pipeline']
if '_data_pipeline_state' in checkpoint:
self._data_pipeline_state = checkpoint['_data_pipeline_state']

@classmethod
def available_backbones(cls) -> List[str]:
Expand Down
Loading

0 comments on commit ce63fd7

Please sign in to comment.