Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[refactor] DataPipeline 1/n #188

Merged
merged 41 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
2f381ef
add data_pipeline
tchaton Mar 22, 2021
465522d
update ci
tchaton Mar 22, 2021
819c018
delete generate .py file
tchaton Mar 22, 2021
2b4756d
update bolts
tchaton Mar 22, 2021
d291f12
udpate ci
tchaton Mar 22, 2021
ffdd258
update
tchaton Mar 22, 2021
f183f19
Merge branch 'master' into data_pipeline_1_n
tchaton Mar 22, 2021
2e7bc4b
Update flash/data/auto_dataset.py
tchaton Mar 22, 2021
2c1e412
update
tchaton Mar 22, 2021
b8d2abc
Merge branch 'data_pipeline_1_n' of https://github.com/PyTorchLightni…
tchaton Mar 22, 2021
d278382
Update tests/data/test_data_pipeline.py
tchaton Mar 22, 2021
0e32fa1
update
tchaton Mar 22, 2021
eba35f6
Merge branch 'data_pipeline_1_n' of https://github.com/PyTorchLightni…
tchaton Mar 22, 2021
8bea3dd
update
tchaton Mar 22, 2021
2990b0b
add some docstring
tchaton Mar 23, 2021
276cf40
update docstring
tchaton Mar 23, 2021
06e5a09
update on comments
tchaton Mar 23, 2021
913bb45
Fixes
carmocca Mar 24, 2021
98aa56d
Docs
carmocca Mar 24, 2021
58c147f
Docs
carmocca Mar 24, 2021
98f75c4
Merge branch 'master' into data_pipeline_1_n
carmocca Mar 24, 2021
84ce3b1
update ci
tchaton Mar 24, 2021
86669c6
update on comments
tchaton Mar 25, 2021
54d0fc3
Update flash/data/batch.py
tchaton Mar 25, 2021
637ff25
Update flash/data/data_module.py
kaushikb11 Mar 25, 2021
dd3dfdb
Update flash/data/process.py
kaushikb11 Mar 25, 2021
4c487a9
Apply suggestions from code review
Borda Mar 25, 2021
ab96ac7
cleaning
Borda Mar 25, 2021
51ea5d9
add pip install
tchaton Mar 25, 2021
23aaebf
update requierements
tchaton Mar 25, 2021
41dd86c
try
Borda Mar 25, 2021
7d8c955
try
Borda Mar 25, 2021
8451011
try
Borda Mar 25, 2021
0a96800
Update flash/data/auto_dataset.py
tchaton Mar 25, 2021
a86f3d5
upate on comments
tchaton Mar 25, 2021
be8ffad
last comments
tchaton Mar 25, 2021
68c1002
update
tchaton Mar 25, 2021
e7f9bb0
update on comments
tchaton Mar 25, 2021
44325f1
update
tchaton Mar 25, 2021
a36c595
smaller
tchaton Mar 25, 2021
a38a18a
faster
tchaton Mar 25, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/workflows/ci-notebook.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -U pip wheel
#pip install treon
pip install . --pre --upgrade --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install --requirement requirements/test.txt --quiet --find-links https://download.pytorch.org/whl/torch_stable.html
python -m pip install --requirement requirements/devel.txt --pre --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install --requirement requirements/notebooks.txt --quiet --find-links https://download.pytorch.org/whl/torch_stable.html

- name: Cache datasets
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ jobs:
- name: Install dependencies
run: |
# python -m pip install --upgrade --user pip
python -m pip install . --pre --upgrade --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
python -m pip install --requirement requirements/test.txt --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
python -m pip install --requirement requirements/devel.txt --pre --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
python -m pip install -e .
# pip install tox coverage
python --version
python -m pip --version
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
with:
# git is required to clone the docs theme
# before custom requirement are resolved https://github.com/ammaraskar/sphinx-action/issues/16
pre-build-command: "apt-get update -y && apt-get install -y gcc git pandoc && pip install -e . && pip install -r requirements/docs.txt"
pre-build-command: "apt-get update -y && apt-get install -y gcc git pandoc && pip install -e . && pip install -r requirements/docs.txt" && python -m pip install --requirement requirements/devel.txt --pre --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
docs-folder: "docs/"
repo-token: "${{ secrets.GITHUB_TOKEN }}"
- uses: actions/upload-artifact@v2
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/docs-deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ jobs:
- name: Install dependencies
run: |
pip install . -U -f https://download.pytorch.org/whl/cpu/torch_stable.html -q --use-feature=2020-resolver
python -m pip install --requirement requirements/devel.txt --pre --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install -r requirements/docs.txt --use-feature=2020-resolver
python -m pip install -e .
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux
sudo apt-get update
sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,5 @@ titanic.csv
data_folder
*.pt
*.zip
data
flash_notebooks/*.py
16 changes: 1 addition & 15 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Data
DataPipeline
------------

To make tasks work for inference, one must create a ``DataPipeline``.
To make tasks work for inference, one must create a ``DataPipeline``.
The ``flash.core.data.DataPipeline`` exposes 6 hooks to override:

.. code:: python
Expand Down Expand Up @@ -54,17 +54,3 @@ The ``flash.core.data.DataPipeline`` exposes 6 hooks to override:
def after_uncollate(self, samples: Any) -> Any:
"""Override to apply transformations to samples"""
return samplesA






Use these utilities to download data.

-----

download_data
-------------

.. autofunction:: flash.core.data.utils.download_data
20 changes: 10 additions & 10 deletions flash/core/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class NoFreeze(BaseFinetuning):
def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
pass

def finetunning_function(
def finetune_function(
self,
pl_module: pl.LightningModule,
epoch: int,
Expand All @@ -42,7 +42,7 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo

FlashBaseFinetuning can be used to create a custom Flash Finetuning Callback.

Override ``finetunning_function`` to put your unfreeze logic.
Override ``finetune_function`` to put your unfreeze logic.

Args:
attr_names: Name(s) of the module attributes of the model to be frozen.
Expand All @@ -62,15 +62,15 @@ def freeze_using_attr_names(self, pl_module, attr_names: List[str], train_bn: bo
attr = getattr(pl_module, attr_name, None)
if attr is None or not isinstance(attr, nn.Module):
MisconfigurationException(f"Your model must have a {attr} attribute")
self.freeze(module=attr, train_bn=train_bn)
self.freeze(modules=attr, train_bn=train_bn)

def finetunning_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
pass


class Freeze(FlashBaseFinetuning):

def finetunning_function(
def finetune_function(
self,
pl_module: pl.LightningModule,
epoch: int,
Expand All @@ -86,7 +86,7 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo
super().__init__(attr_names, train_bn)
self.unfreeze_epoch = unfreeze_epoch

def finetunning_function(
def finetune_function(
self,
pl_module: pl.LightningModule,
epoch: int,
Expand All @@ -97,7 +97,7 @@ def finetunning_function(
return
modules = [getattr(pl_module, attr_name) for attr_name in self.attr_names]
self.unfreeze_and_add_param_group(
module=modules,
modules=modules,
optimizer=optimizer,
train_bn=self.train_bn,
)
Expand All @@ -117,7 +117,7 @@ def __init__(

super().__init__(attr_names, train_bn)

def finetunning_function(
def finetune_function(
self,
pl_module: pl.LightningModule,
epoch: int,
Expand All @@ -128,15 +128,15 @@ def finetunning_function(
if epoch == self.unfreeze_milestones[0]:
# unfreeze num_layers last layers
self.unfreeze_and_add_param_group(
module=backbone_modules[-self.num_layers:],
modules=backbone_modules[-self.num_layers:],
optimizer=optimizer,
train_bn=self.train_bn,
)

elif epoch == self.unfreeze_milestones[1]:
# unfreeze remaining layers
self.unfreeze_and_add_param_group(
module=backbone_modules[:-self.num_layers],
modules=backbone_modules[:-self.num_layers],
optimizer=optimizer,
train_bn=self.train_bn,
)
Expand Down
6 changes: 4 additions & 2 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pytorch_lightning as pl
import torch
import torchmetrics
from torch import nn

from flash.core.data import DataModule, DataPipeline
Expand Down Expand Up @@ -83,7 +84,8 @@ def step(self, batch: Any, batch_idx: int) -> Any:
losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()}
logs = {}
for name, metric in self.metrics.items():
if isinstance(metric, pl.metrics.Metric):
if isinstance(metric, torchmetrics.metric.Metric):
output["y_hat"] = self.data_pipeline.before_uncollate(output["y_hat"])
metric(output["y_hat"], y)
logs[name] = metric # log the metric itself if it is of type Metric
else:
Expand Down Expand Up @@ -152,7 +154,7 @@ def predict(
data_pipeline = data_pipeline or self.data_pipeline
batch = x if skip_collate_fn else data_pipeline.collate_fn(x)
batch_x, batch_y = batch if len(batch) == 2 and isinstance(batch, (list, tuple)) else (batch, None)
predictions = self.forward(batch_x)
predictions = self.predict_step(batch_x, 0)
output = data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x
return output

Expand Down
137 changes: 137 additions & 0 deletions flash/data/auto_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from contextlib import contextmanager
tchaton marked this conversation as resolved.
Show resolved Hide resolved
from copy import deepcopy
from inspect import signature
from typing import Any, Callable, Optional, TYPE_CHECKING

import torch
from pytorch_lightning.core.decorators import parameter_validation
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.warning_utils import rank_zero_warn

from flash.data.process import Preprocess
from flash.data.utils import _STAGES_PREFIX

if TYPE_CHECKING:
from flash.data.data_pipeline import DataPipeline


class AutoDataset(torch.utils.data.Dataset):
tchaton marked this conversation as resolved.
Show resolved Hide resolved

FITTING_STAGES = ("train", "val")
STAGES = ("train", "test", "val", "predict")
DATASET_KEY = "dataset"
"""
This class is used to encapsulate a Preprocess Object ``load_data`` and ``load_sample`` functions.
``load_data`` will be called within the ``__init__`` function of the AutoDataset if ``running_stage``
is provided and ``load_sample`` within ``__getitem__`` function.
"""

def __init__(
self,
data: Any,
load_data: Optional[Callable] = None,
load_sample: Optional[Callable] = None,
data_pipeline: Optional['DataPipeline'] = None,
running_stage: Optional[RunningStage] = None
) -> None:
super().__init__()

if load_data is not None or load_sample is not None:
if data_pipeline is not None:
rank_zero_warn(
"``datapipeline`` is specified but load_sample and/or load_data are also specified. "
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"Won't use datapipeline"
tchaton marked this conversation as resolved.
Show resolved Hide resolved
)
# initial states
self._load_data_called = False
self._running_stage = None

self.data = data
self.data_pipeline = data_pipeline
self.load_data = load_data
self.load_sample = load_sample

# trigger the setup only if `running_stage` is provided
self.running_stage = running_stage

@property
def running_stage(self) -> Optional[RunningStage]:
return self._running_stage

@running_stage.setter
def running_stage(self, running_stage):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if self._running_stage != running_stage or (self._running_stage is None):
self._running_stage = running_stage
self._setup(running_stage)

def _call_load_data(self, data):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
parameters = signature(self.load_data).parameters
if len(parameters) > 1 and self.DATASET_KEY in parameters:
return self.load_data(data, self)
else:
return self.load_data(data)

def _call_load_sample(self, sample):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
parameters = signature(self.load_sample).parameters
if len(parameters) > 1 and self.DATASET_KEY in parameters:
return self.load_sample(sample, self)
else:
return self.load_sample(sample)

def _setup(self, stage: RunningStage):
assert stage is None or _STAGES_PREFIX[stage] in self.STAGES
previous_load_data = self.load_data.__code__ if self.load_data is not None else None

if (
self._running_stage is not None and self.data_pipeline is not None
and (self.load_data is None or self.load_sample is None) and stage is not None
tchaton marked this conversation as resolved.
Show resolved Hide resolved
):
self.load_data = getattr(
self.data_pipeline._preprocess_pipeline,
self.data_pipeline._resolve_function_hierarchy(
'load_data', self.data_pipeline._preprocess_pipeline, stage, Preprocess
)
)
self.load_sample = getattr(
self.data_pipeline._preprocess_pipeline,
self.data_pipeline._resolve_function_hierarchy(
'load_sample', self.data_pipeline._preprocess_pipeline, stage, Preprocess
)
)
if self.load_data is not None and (previous_load_data != self.load_data.__code__ or not self._load_data_called):
if previous_load_data is not None:
rank_zero_warn(
"The load_data function of the Autogenerated Dataset changed. "
"This is not expected! Preloading Data again to ensure compatibility. This may take some time."
)
with self._set_running_stage(stage):
self._preprocessed_data = self._call_load_data(self.data)
self._load_data_called = True

@contextmanager
def _set_running_stage(self, stage: RunningStage):
if self.load_data is not None:
if self.data_pipeline is not None and self.data_pipeline._preprocess_pipeline is not None:
self.data_pipeline._preprocess_pipeline._running_stage = stage
yield
if self.load_data is not None:
if self.data_pipeline is not None and self.data_pipeline._preprocess_pipeline is not None:
self.data_pipeline._preprocess_pipeline._running_stage = None

def __getitem__(self, index: int) -> Any:
if self.load_sample is None and self.load_data is None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(
"Names for LoadSample and LoadData could not be inferred."
" Consider setting the RunningStage"
)
if self.load_sample is not None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return self._call_load_sample(self._preprocessed_data[index])
return self._preprocessed_data[index]

def __len__(self) -> int:
if self.load_sample is None and self.load_data is None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError(
"Names for LoadSample and LoadData could not be inferred."
" Consider setting the RunningStage"
)
return len(self._preprocessed_data)
Loading