Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[refactor] DataPipeline 1/n (#188)
Browse files Browse the repository at this point in the history
* add data_pipeline

* update ci

* delete generate .py file

* update bolts

* udpate ci

* update

* Update flash/data/auto_dataset.py

Co-authored-by: Kaushik B <[email protected]>

* update

* Update tests/data/test_data_pipeline.py

Co-authored-by: Kaushik B <[email protected]>

* update

* update

* add some docstring

* update docstring

* update on comments

* Fixes

* Docs

* Docs

* update ci

* update on comments

* Update flash/data/batch.py

Co-authored-by: Kaushik B <[email protected]>

* Update flash/data/data_module.py

* Update flash/data/process.py

* Apply suggestions from code review

* cleaning

* add pip install

* update requierements

* try

* try

* try

* Update flash/data/auto_dataset.py

Co-authored-by: Jirka Borovec <[email protected]>

* upate on comments

* last comments

* update

Co-authored-by: Kaushik B <[email protected]>
Co-authored-by: Carlos Mocholi <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
5 people authored Mar 25, 2021
1 parent 024b3be commit e1a2d5d
Show file tree
Hide file tree
Showing 37 changed files with 2,783 additions and 121 deletions.
20 changes: 9 additions & 11 deletions .github/workflows/ci-notebook.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,8 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -U pip wheel
#pip install treon
pip install . --pre --upgrade --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install --requirement requirements/test.txt --quiet --find-links https://download.pytorch.org/whl/torch_stable.html
pip install --requirement requirements/notebooks.txt --quiet --find-links https://download.pytorch.org/whl/torch_stable.html
pip install --requirement requirements/devel.txt --pre --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install --requirement requirements/notebooks.txt --quiet --upgrade-strategy only-if-needed
- name: Cache datasets
uses: actions/cache@v2
Expand All @@ -59,10 +57,10 @@ jobs:
# Look to see if there is a cache hit for the corresponding requirements file
key: flash-datasets_predict

- name: Run Notebooks
run: |
jupyter nbconvert --to script flash_notebooks/image_classification.ipynb
jupyter nbconvert --to script flash_notebooks/tabular_classification.ipynb
ipython flash_notebooks/image_classification.py
ipython flash_notebooks/tabular_classification.py
#- name: Run Notebooks
# run: |
# jupyter nbconvert --to script flash_notebooks/image_classification.ipynb
# jupyter nbconvert --to script flash_notebooks/tabular_classification.ipynb
#
# ipython flash_notebooks/image_classification.py
# ipython flash_notebooks/tabular_classification.py
10 changes: 4 additions & 6 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,11 @@ jobs:
- name: Install dependencies
run: |
# python -m pip install --upgrade --user pip
python -m pip install . --pre --upgrade --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
python -m pip install --requirement requirements/test.txt --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
# pip install tox coverage
python --version
python -m pip --version
python -m pip list
pip --version
pip install -e . --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
python -m pip install --requirement requirements/devel.txt --pre --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip list
shell: bash

- name: Cache datasets
Expand Down
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,11 @@ titanic.csv
data_folder
*.pt
*.zip
flash_notebooks/*.py
flash_notebooks/data
MNIST*
titanic
coco128
hymenoptera_data
xsum
imdb
16 changes: 1 addition & 15 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Data
DataPipeline
------------

To make tasks work for inference, one must create a ``DataPipeline``.
To make tasks work for inference, one must create a ``DataPipeline``.
The ``flash.core.data.DataPipeline`` exposes 6 hooks to override:

.. code:: python
Expand Down Expand Up @@ -54,17 +54,3 @@ The ``flash.core.data.DataPipeline`` exposes 6 hooks to override:
def after_uncollate(self, samples: Any) -> Any:
"""Override to apply transformations to samples"""
return samplesA
Use these utilities to download data.

-----

download_data
-------------

.. autofunction:: flash.core.data.utils.download_data
26 changes: 13 additions & 13 deletions flash/core/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from typing import List, Union

import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import BaseFinetuning
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
Expand All @@ -22,12 +22,12 @@

class NoFreeze(BaseFinetuning):

def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
def freeze_before_training(self, pl_module: LightningModule) -> None:
pass

def finetune_function(
self,
pl_module: pl.LightningModule,
pl_module: LightningModule,
epoch: int,
optimizer: Optimizer,
opt_idx: int,
Expand All @@ -54,25 +54,25 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo
self.attr_names = [attr_names] if isinstance(attr_names, str) else attr_names
self.train_bn = train_bn

def freeze_before_training(self, pl_module: pl.LightningModule) -> None:
def freeze_before_training(self, pl_module: LightningModule) -> None:
self.freeze_using_attr_names(pl_module, self.attr_names, train_bn=self.train_bn)

def freeze_using_attr_names(self, pl_module, attr_names: List[str], train_bn: bool = True):
for attr_name in attr_names:
attr = getattr(pl_module, attr_name, None)
if attr is None or not isinstance(attr, nn.Module):
if not attr or not isinstance(attr, nn.Module):
MisconfigurationException(f"Your model must have a {attr} attribute")
self.freeze(modules=attr, train_bn=train_bn)

def finetune_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
pass


class Freeze(FlashBaseFinetuning):

def finetune_function(
self,
pl_module: pl.LightningModule,
pl_module: LightningModule,
epoch: int,
optimizer: Optimizer,
opt_idx: int,
Expand All @@ -88,7 +88,7 @@ def __init__(self, attr_names: Union[str, List[str]] = "backbone", train_bn: boo

def finetune_function(
self,
pl_module: pl.LightningModule,
pl_module: LightningModule,
epoch: int,
optimizer: Optimizer,
opt_idx: int,
Expand All @@ -97,7 +97,7 @@ def finetune_function(
return
modules = [getattr(pl_module, attr_name) for attr_name in self.attr_names]
self.unfreeze_and_add_param_group(
module=modules,
modules=modules,
optimizer=optimizer,
train_bn=self.train_bn,
)
Expand All @@ -119,7 +119,7 @@ def __init__(

def finetune_function(
self,
pl_module: pl.LightningModule,
pl_module: LightningModule,
epoch: int,
optimizer: Optimizer,
opt_idx: int,
Expand All @@ -128,15 +128,15 @@ def finetune_function(
if epoch == self.unfreeze_milestones[0]:
# unfreeze num_layers last layers
self.unfreeze_and_add_param_group(
module=backbone_modules[-self.num_layers:],
modules=backbone_modules[-self.num_layers:],
optimizer=optimizer,
train_bn=self.train_bn,
)

elif epoch == self.unfreeze_milestones[1]:
# unfreeze remaining layers
self.unfreeze_and_add_param_group(
module=backbone_modules[:-self.num_layers],
modules=backbone_modules[:-self.num_layers],
optimizer=optimizer,
train_bn=self.train_bn,
)
Expand All @@ -151,7 +151,7 @@ def finetune_function(


def instantiate_default_finetuning_callbacks(strategy):
if strategy is None or strategy not in _DEFAULTS_FINETUNE_STRATEGIES:
if not strategy or strategy not in _DEFAULTS_FINETUNE_STRATEGIES:
raise MisconfigurationException(
f"a strategy should be provided. Use {list(_DEFAULTS_FINETUNE_STRATEGIES)} or provide a callback"
" instance of `flash.core.finetuning.FlashBaseFinetuning`. Found {strategy} "
Expand Down
12 changes: 7 additions & 5 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
import os
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union

import pytorch_lightning as pl
import torch
import torchmetrics
from pytorch_lightning import LightningModule
from torch import nn

from flash.core.data import DataModule, DataPipeline
Expand All @@ -43,7 +44,7 @@ def wrapper(self, *args, **kwargs) -> Any:
return wrapper


class Task(pl.LightningModule):
class Task(LightningModule):
"""A general Task.
Args:
Expand All @@ -59,7 +60,7 @@ def __init__(
model: Optional[nn.Module] = None,
loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[pl.metrics.Metric, Mapping, Sequence, None] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
):
super().__init__()
Expand All @@ -83,7 +84,8 @@ def step(self, batch: Any, batch_idx: int) -> Any:
losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()}
logs = {}
for name, metric in self.metrics.items():
if isinstance(metric, pl.metrics.Metric):
if isinstance(metric, torchmetrics.metric.Metric):
output["y_hat"] = self.data_pipeline.before_uncollate(output["y_hat"])
metric(output["y_hat"], y)
logs[name] = metric # log the metric itself if it is of type Metric
else:
Expand Down Expand Up @@ -152,7 +154,7 @@ def predict(
data_pipeline = data_pipeline or self.data_pipeline
batch = x if skip_collate_fn else data_pipeline.collate_fn(x)
batch_x, batch_y = batch if len(batch) == 2 and isinstance(batch, (list, tuple)) else (batch, None)
predictions = self.forward(batch_x)
predictions = self.predict_step(batch_x, 0)
output = data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x
return output

Expand Down
11 changes: 5 additions & 6 deletions flash/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,23 @@
import warnings
from typing import List, Optional, Union

import pytorch_lightning as pl
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import BaseFinetuning
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import DataLoader

from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, instantiate_default_finetuning_callbacks


class Trainer(pl.Trainer):
class Trainer(Trainer):

def fit(
self,
model: pl.LightningModule,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[pl.LightningDataModule] = None,
datamodule: Optional[LightningDataModule] = None,
):
r"""
Runs the full optimization routine. Same as pytorch_lightning.Trainer().fit()
Expand All @@ -57,7 +56,7 @@ def finetune(
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[pl.LightningDataModule] = None,
datamodule: Optional[LightningDataModule] = None,
strategy: Optional[Union[str, BaseFinetuning]] = None,
):
r"""
Expand Down
Loading

0 comments on commit e1a2d5d

Please sign in to comment.