Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Pytorch video #216

Merged
merged 70 commits into from
Apr 30, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
ae190b9
update
tchaton Apr 15, 2021
a4cb3a3
update
tchaton Apr 15, 2021
cdba489
Update flash/vision/video/classification/data.py
kaushikb11 Apr 15, 2021
c6b1cd7
update
tchaton Apr 15, 2021
5f1aae9
Merge branch 'pytorch_video' of https://github.com/PyTorchLightning/l…
tchaton Apr 15, 2021
2c427f6
Update flash/vision/video/classification/model.py
tchaton Apr 15, 2021
0c4a092
update
tchaton Apr 15, 2021
db6df41
Merge branch 'pytorch_video' of https://github.com/PyTorchLightning/l…
tchaton Apr 15, 2021
19ea5f1
update
tchaton Apr 15, 2021
b21f152
typo
tchaton Apr 15, 2021
aea8214
update
tchaton Apr 15, 2021
fbc43c8
update
tchaton Apr 15, 2021
73e0191
resolve some internal bugs
tchaton Apr 16, 2021
a1ff7b6
update on comments
tchaton Apr 16, 2021
3227e77
move files
tchaton Apr 16, 2021
98b4e13
update
tchaton Apr 16, 2021
9e17b50
update
tchaton Apr 16, 2021
eb286de
update
tchaton Apr 16, 2021
b122059
filter for 3.6
tchaton Apr 16, 2021
ae8197d
update on comments
tchaton Apr 16, 2021
c4526f4
update
tchaton Apr 16, 2021
0c2f852
update
tchaton Apr 16, 2021
c949061
update
tchaton Apr 16, 2021
fa30ea5
clean auto dataset
tchaton Apr 16, 2021
2777b9e
typo
tchaton Apr 16, 2021
17bfe73
update
tchaton Apr 16, 2021
b9bae51
update on comments:
tchaton Apr 16, 2021
38c9610
add doc
tchaton Apr 16, 2021
8a04ceb
remove backbone section
tchaton Apr 16, 2021
383f939
update
tchaton Apr 16, 2021
ab21afa
update
tchaton Apr 16, 2021
11bdd62
update
tchaton Apr 16, 2021
3ac8437
update
tchaton Apr 16, 2021
5a9158b
map to None
tchaton Apr 16, 2021
8ad791b
update
tchaton Apr 16, 2021
4feef51
update
tchaton Apr 16, 2021
35bb690
update on comments
tchaton Apr 16, 2021
1b4d565
update script
tchaton Apr 16, 2021
912fce0
update on comments
tchaton Apr 16, 2021
c6919f4
Update docs/source/reference/video_classification.rst
carmocca Apr 16, 2021
aeb6fee
Merge branch 'master' into pytorch_video
tchaton Apr 18, 2021
25bea44
Merge branch 'master' into pytorch_video
tchaton Apr 19, 2021
ab4b6d4
Merge branch 'master' into pytorch_video
tchaton Apr 19, 2021
480aa18
Merge branch 'master' into pytorch_video
tchaton Apr 27, 2021
41bdc5b
update
tchaton Apr 27, 2021
3c660ef
Merge branch 'master' into pytorch_video
tchaton Apr 27, 2021
04382a5
update
tchaton Apr 27, 2021
cf3ef94
update
tchaton Apr 27, 2021
9e9b656
Merge branch 'pytorch_video' of https://github.com/PyTorchLightning/l…
tchaton Apr 27, 2021
754f43c
update
tchaton Apr 27, 2021
7a09783
Updates
ethanwharris Apr 27, 2021
6697e91
update
tchaton Apr 27, 2021
231171a
update
tchaton Apr 27, 2021
92aa151
update
tchaton Apr 27, 2021
939a251
update
tchaton Apr 27, 2021
63babc6
iupdate:
tchaton Apr 27, 2021
530367d
update
tchaton Apr 29, 2021
81733ed
update
tchaton Apr 29, 2021
fdd85a2
Merge branch 'master' into pytorch_video
tchaton Apr 29, 2021
ed043b3
resolve ci
tchaton Apr 29, 2021
1735b6f
update
tchaton Apr 29, 2021
5d80e45
update
tchaton Apr 30, 2021
f201a70
updates
tchaton Apr 30, 2021
aff7657
update
tchaton Apr 30, 2021
b18457a
update
tchaton Apr 30, 2021
78dbc3a
update
tchaton Apr 30, 2021
80f7e71
update
tchaton Apr 30, 2021
1999639
update
tchaton Apr 30, 2021
3b0bd8f
update
tchaton Apr 30, 2021
43e2bc3
update
tchaton Apr 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,10 @@ def configure_finetune_callback(self) -> List[Callback]:

@staticmethod
def _resolve(
old_preprocess: Optional[Preprocess],
old_postprocess: Optional[Postprocess],
new_preprocess: Optional[Preprocess],
new_postprocess: Optional[Postprocess],
old_preprocess: Optional[Preprocess],
old_postprocess: Optional[Postprocess],
new_preprocess: Optional[Preprocess],
new_postprocess: Optional[Postprocess],
) -> Tuple[Optional[Preprocess], Optional[Postprocess]]:
"""Resolves the correct :class:`.Preprocess` and :class:`.Postprocess` to use, choosing ``new_*`` if it is not
None or a base class (:class:`.Preprocess` or :class:`.Postprocess`) and ``old_*`` otherwise.
Expand Down Expand Up @@ -308,3 +308,10 @@ def available_backbones(cls) -> List[str]:
if registry is None:
return []
return registry.available_keys()

@classmethod
def available_models(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "models", None)
if registry is None:
return []
return registry.available_keys()
127 changes: 124 additions & 3 deletions flash/data/auto_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from inspect import signature
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING
from typing import Any, Callable, Iterable, Iterator, Optional, TYPE_CHECKING

import torch
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.warning_utils import rank_zero_warn
from torch.utils.data import Dataset
from torch.utils.data import Dataset, IterableDataset

from flash.data.callback import ControlFlow
from flash.data.process import Preprocess
Expand Down Expand Up @@ -141,3 +141,124 @@ def __len__(self) -> int:
if not self.load_sample and not self.load_data:
raise RuntimeError("`__len__` for `load_sample` and `load_data` could not be inferred.")
return len(self.preprocessed_data)


class IterableAutoDataset(IterableDataset):

DATASET_KEY = "dataset"
"""
This class is used to encapsulate a Preprocess Object ``load_data`` and ``load_sample`` functions.
``load_data`` will be called within the ``__init__`` function of the AutoDataset if ``running_stage``
is provided and ``load_sample`` within ``__getitem__`` function.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
data: Any,
load_data: Optional[Callable] = None,
load_sample: Optional[Callable] = None,
data_pipeline: Optional['DataPipeline'] = None,
running_stage: Optional[RunningStage] = None
) -> None:
super().__init__()

if load_data or load_sample:
if data_pipeline:
rank_zero_warn(
"``datapipeline`` is specified but load_sample and/or load_data are also specified. "
"Won't use datapipeline"
)
# initial states
self._load_data_called = False
self._running_stage = None

self.data = data
self.data_pipeline = data_pipeline
self.load_data = load_data
self.load_sample = load_sample
self.dataset: Optional[IterableDataset] = None
self.dataset_iter: Optional[Iterator] = None

# trigger the setup only if `running_stage` is provided
self.running_stage = running_stage

@property
def running_stage(self) -> Optional[RunningStage]:
return self._running_stage

@running_stage.setter
def running_stage(self, running_stage: RunningStage) -> None:
if self._running_stage != running_stage or (not self._running_stage):
self._running_stage = running_stage
self._load_data_context = CurrentRunningStageFuncContext(self._running_stage, "load_data", self.preprocess)
self._load_sample_context = CurrentRunningStageFuncContext(
self._running_stage, "load_sample", self.preprocess
)
self._setup(running_stage)

@property
def preprocess(self) -> Optional[Preprocess]:
if self.data_pipeline is not None:
return self.data_pipeline._preprocess_pipeline

@property
def control_flow_callback(self) -> Optional[ControlFlow]:
preprocess = self.preprocess
if preprocess is not None:
return ControlFlow(preprocess.callbacks)

def _call_load_data(self, data: Any) -> Iterable:
parameters = signature(self.load_data).parameters
if len(parameters) > 1 and self.DATASET_KEY in parameters:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return self.load_data(data, self)
else:
return self.load_data(data)

def _call_load_sample(self, sample: Any) -> Any:
parameters = signature(self.load_sample).parameters
if len(parameters) > 1 and self.DATASET_KEY in parameters:
return self.load_sample(sample, self)
else:
return self.load_sample(sample)

def _setup(self, stage: Optional[RunningStage]) -> None:
assert not stage or _STAGES_PREFIX[stage] in _STAGES_PREFIX_VALUES
previous_load_data = self.load_data.__code__ if self.load_data else None

if self._running_stage and self.data_pipeline and (not self.load_data or not self.load_sample) and stage:
self.load_data = getattr(
self.preprocess,
self.data_pipeline._resolve_function_hierarchy('load_data', self.preprocess, stage, Preprocess)
)
self.load_sample = getattr(
self.preprocess,
self.data_pipeline._resolve_function_hierarchy('load_sample', self.preprocess, stage, Preprocess)
)
if self.load_data and (previous_load_data != self.load_data.__code__ or not self._load_data_called):
if previous_load_data:
rank_zero_warn(
"The load_data function of the Autogenerated Dataset changed. "
"This is not expected! Preloading Data again to ensure compatibility. This may take some time."
)
with self._load_data_context:
self.dataset = self._call_load_data(self.data)
self.dataset_iter = None
self._load_data_called = True

def __iter__(self):
self.dataset_iter = iter(self.dataset)
return self

def __next__(self) -> Any:
if not self.load_sample and not self.load_data:
raise RuntimeError("`__getitem__` for `load_sample` and `load_data` could not be inferred.")

data = next(self.dataset_iter)

if self.load_sample:
with self._load_sample_context:
data: Any = self._call_load_sample(data)
if self.control_flow_callback:
self.control_flow_callback.on_load_sample(data, self.running_stage)
return data
return data
51 changes: 40 additions & 11 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import Subset
from torch.utils.data.dataset import IterableDataset, Subset

from flash.data.auto_dataset import AutoDataset
from flash.data.auto_dataset import AutoDataset, IterableAutoDataset
from flash.data.base_viz import BaseVisualization
from flash.data.callback import BaseDataFetcher
from flash.data.data_pipeline import DataPipeline, Postprocess, Preprocess
Expand Down Expand Up @@ -215,7 +215,8 @@ def _train_dataloader(self) -> DataLoader:
return DataLoader(
train_ds,
batch_size=self.batch_size,
shuffle=True,
shuffle=False if isinstance(train_ds, (IterableDataset,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
IterableAutoDataset)) else True, # IterableDataset can't be shuffled
num_workers=self.num_workers,
pin_memory=True,
drop_last=True,
Expand Down Expand Up @@ -287,7 +288,8 @@ def autogenerate_dataset(
whole_data_load_fn: Optional[Callable] = None,
per_sample_load_fn: Optional[Callable] = None,
data_pipeline: Optional[DataPipeline] = None,
) -> AutoDataset:
use_iterable_auto_dataset: bool = False,
) -> Union[AutoDataset, IterableAutoDataset]:
"""
This function is used to generate an ``AutoDataset`` from a ``DataPipeline`` if provided
or from the provided ``whole_data_load_fn``, ``per_sample_load_fn`` functions directly
Expand All @@ -304,6 +306,10 @@ def autogenerate_dataset(
cls.preprocess_cls,
DataPipeline._resolve_function_hierarchy('load_sample', cls.preprocess_cls, running_stage, Preprocess)
)
if use_iterable_auto_dataset:
return IterableAutoDataset(
data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage
)
return AutoDataset(data, whole_data_load_fn, per_sample_load_fn, data_pipeline, running_stage=running_stage)

@staticmethod
Expand Down Expand Up @@ -374,15 +380,25 @@ def _generate_dataset_if_possible(
running_stage: RunningStage,
whole_data_load_fn: Optional[Callable] = None,
per_sample_load_fn: Optional[Callable] = None,
data_pipeline: Optional[DataPipeline] = None
data_pipeline: Optional[DataPipeline] = None,
use_iterable_auto_dataset: bool = False,
) -> Optional[AutoDataset]:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if data is None:
return

if data_pipeline:
return data_pipeline._generate_auto_dataset(data, running_stage=running_stage)
return data_pipeline._generate_auto_dataset(
data, running_stage=running_stage, use_iterable_auto_dataset=use_iterable_auto_dataset
)

return cls.autogenerate_dataset(data, running_stage, whole_data_load_fn, per_sample_load_fn, data_pipeline)
return cls.autogenerate_dataset(
data,
running_stage,
whole_data_load_fn,
per_sample_load_fn,
data_pipeline,
use_iterable_auto_dataset=use_iterable_auto_dataset
)

@classmethod
def from_load_data_inputs(
Expand All @@ -393,6 +409,7 @@ def from_load_data_inputs(
predict_load_data_input: Optional[Any] = None,
preprocess: Optional[Preprocess] = None,
postprocess: Optional[Postprocess] = None,
use_iterable_auto_dataset: bool = False,
**kwargs,
) -> 'DataModule':
"""
Expand Down Expand Up @@ -424,16 +441,28 @@ def from_load_data_inputs(
data_fetcher.attach_to_preprocess(data_pipeline._preprocess_pipeline)

train_dataset = cls._generate_dataset_if_possible(
train_load_data_input, running_stage=RunningStage.TRAINING, data_pipeline=data_pipeline
train_load_data_input,
running_stage=RunningStage.TRAINING,
data_pipeline=data_pipeline,
use_iterable_auto_dataset=use_iterable_auto_dataset
tchaton marked this conversation as resolved.
Show resolved Hide resolved
)
val_dataset = cls._generate_dataset_if_possible(
val_load_data_input, running_stage=RunningStage.VALIDATING, data_pipeline=data_pipeline
val_load_data_input,
running_stage=RunningStage.VALIDATING,
data_pipeline=data_pipeline,
use_iterable_auto_dataset=use_iterable_auto_dataset
tchaton marked this conversation as resolved.
Show resolved Hide resolved
)
test_dataset = cls._generate_dataset_if_possible(
test_load_data_input, running_stage=RunningStage.TESTING, data_pipeline=data_pipeline
test_load_data_input,
running_stage=RunningStage.TESTING,
data_pipeline=data_pipeline,
use_iterable_auto_dataset=use_iterable_auto_dataset
tchaton marked this conversation as resolved.
Show resolved Hide resolved
)
predict_dataset = cls._generate_dataset_if_possible(
predict_load_data_input, running_stage=RunningStage.PREDICTING, data_pipeline=data_pipeline
predict_load_data_input,
running_stage=RunningStage.PREDICTING,
data_pipeline=data_pipeline,
use_iterable_auto_dataset=use_iterable_auto_dataset
tchaton marked this conversation as resolved.
Show resolved Hide resolved
)
datamodule = cls(
train_dataset=train_dataset,
Expand Down
13 changes: 10 additions & 3 deletions flash/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.utils.data._utils.collate import default_collate, default_convert
from torch.utils.data.dataloader import DataLoader

from flash.data.auto_dataset import AutoDataset
from flash.data.auto_dataset import AutoDataset, IterableAutoDataset
from flash.data.batch import _PostProcessor, _PreProcessor, _Sequential
from flash.data.process import Postprocess, Preprocess
from flash.data.utils import _POSTPROCESS_FUNCS, _PREPROCESS_FUNCS, _STAGES_PREFIX
Expand Down Expand Up @@ -297,7 +297,7 @@ def _attach_preprocess_to_model(
if isinstance(dataloader, (_PatchDataLoader, Callable)):
dataloader = dataloader()

if not dataloader:
if dataloader is not None:
continue

if isinstance(dataloader, Sequence):
Expand Down Expand Up @@ -458,7 +458,14 @@ def fn():

return fn

def _generate_auto_dataset(self, data: Union[Iterable, Any], running_stage: RunningStage = None) -> AutoDataset:
def _generate_auto_dataset(
self,
data: Union[Iterable, Any],
running_stage: RunningStage = None,
use_iterable_auto_dataset: bool = False
tchaton marked this conversation as resolved.
Show resolved Hide resolved
) -> Union[AutoDataset, IterableAutoDataset]:
if use_iterable_auto_dataset:
return IterableAutoDataset(data, data_pipeline=self, running_stage=running_stage)
return AutoDataset(data=data, data_pipeline=self, running_stage=running_stage)

def to_dataloader(
Expand Down
1 change: 1 addition & 0 deletions flash/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
_COCO_AVAILABLE = _module_available("pycocotools")
_TIMM_AVAILABLE = _module_available("timm")
_TORCHVISION_AVAILABLE = _module_available("torchvision")
_PYTORCH_VIDEO_AVAILABLE = _module_available("pytorchvideo")
2 changes: 2 additions & 0 deletions flash/vision/video/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from flash.vision.video.classification.data import VideoClassificationData
from flash.vision.video.classification.model import VideoClassifier
Empty file.
Loading