Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
DataPipeline PoC (#141)
Browse files Browse the repository at this point in the history
* add prototype of DataPipeline

* Add Prototype of PostProcessingPipeline

* isort + pep8

* update post_processing_pipeline

* update data pipline

* add new prediction part

* change loader name

* update

* uypdate new datapipeline

* update model with new pipeline

* update

* update gitignore

* add autodataset

* add batch processing

* update

* update

* update

* add process file

* make datapipeline attaching and detaching more robust

* resolve flake8

* update

* push curr state

* Update flash/data/batch.py

Co-authored-by: Kaushik B <[email protected]>

* resolve some bugs

* tests

* update

* make everything nn.Module and check serialization

* resolve kornia example

* add prototype of DataPipeline

* Add Prototype of PostProcessingPipeline

* isort + pep8

* update post_processing_pipeline

* update data pipline

* add new prediction part

* change loader name

* update

* update

* update

* uypdate new datapipeline

* update model with new pipeline

* update gitignore

* add autodataset

* add batch processing

* update

* update

* add process file

* make datapipeline attaching and detaching more robust

* resolve flake8

* update

* push curr state

* Update flash/data/batch.py

Co-authored-by: Kaushik B <[email protected]>

* resolve some bugs

* update

* tests

* resolve kornia example

* make everything nn.Module and check serialization

* rebase_fixes

* add more tests

* update tabular

* add new hooks

* update tabular

* update

* Move func to data module

* fix vision to current version

* transfer text classification to new API

* add more tests

* update

* resolve most bugs

* address most comments

* remove kornia example

* add support for summurization example

* work with ObjectDetection

* Update gitignore

* updates

* resolve bug

* resolve image embedder

* Update Image Classifer

* Renaming

* fix recursion

* resolve bug

* Fix DataPipeline function resolution

* put back properties instead of attributes

* fix import

* fix examples

* add checks for loading

* fix recursion

* fix seq2seq dataset

* fix dm init in tests

* fix data parts

* resolve tests and flake8

* update on comments

* update notebooks

* devel

* update

* update

* update

* resolve the doc

* update

* don't apply flake8 on notebook

* resolve tests

* comment a notebook

* update

* update ci

* add fixes

* updaet

* update with lightning

* add a test for flash_special_arguments

* add data_pipeline

* update ci

* delete generate .py file

* update bolts

* udpate ci

* update

* Update flash/data/auto_dataset.py

Co-authored-by: Kaushik B <[email protected]>

* update

* Update tests/data/test_data_pipeline.py

Co-authored-by: Kaushik B <[email protected]>

* update

* update

* add some docstring

* update docstring

* update on comments

* Fixes

* Docs

* Docs

* update ci

* update on comments

* Update flash/data/batch.py

Co-authored-by: Kaushik B <[email protected]>

* Update flash/data/data_module.py

* Update flash/data/process.py

* Apply suggestions from code review

* cleaning

* add pip install

* switch back to master

* update requierements

* try

* try

* try

* update

* prune legacy

* update

* update

* update to latest

* delete extra files

* updates to Task class

* Update Datamodule

* resolve comments

* update

* update

* update

* update

* try

* update

* udpate

* update

* update

* update

* formatting

* update on comments

* update on comments

* General changes

* General changes

* update

* update

* add _data_pipeline back

* update

Co-authored-by: justusschock <[email protected]>
Co-authored-by: Justus Schock <[email protected]>
Co-authored-by: Kaushik B <[email protected]>
Co-authored-by: Kaushik Bokka <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Carlos Mocholi <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
8 people authored Mar 29, 2021
1 parent 3b4c6b6 commit ba34bf4
Show file tree
Hide file tree
Showing 78 changed files with 1,979 additions and 1,877 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/ci-notebook.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ jobs:
# Look to see if there is a cache hit for the corresponding requirements file
key: flash-datasets_predict

#- name: Run Notebooks
# run: |
# jupyter nbconvert --to script flash_notebooks/image_classification.ipynb
# jupyter nbconvert --to script flash_notebooks/tabular_classification.ipynb
#
# ipython flash_notebooks/image_classification.py
# ipython flash_notebooks/tabular_classification.py
- name: Run Notebooks
run: |
# jupyter nbconvert --to script flash_notebooks/image_classification.ipynb
jupyter nbconvert --to script flash_notebooks/tabular_classification.ipynb
# ipython flash_notebooks/image_classification.py
ipython flash_notebooks/tabular_classification.py
2 changes: 1 addition & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ jobs:
run: |
python --version
pip --version
pip install -e . --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install --requirement requirements/devel.txt --pre --upgrade --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install -e . --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip list
shell: bash

Expand Down
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ flash_notebooks/*.py
flash_notebooks/data
MNIST*
titanic
coco128
hymenoptera_data
xsum
imdb
xsum
coco128
wmt_en_ro
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/')
datamodule = TabularData.from_csv(
"./data/titanic/titanic.csv",
test_csv="./data/titanic/test.csv",
categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
numerical_input=["Fare"],
cat_cols=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
num_cols=["Fare"],
target="Survived",
val_size=0.25,
)
Expand Down
27 changes: 0 additions & 27 deletions docs/source/custom_task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,33 +67,6 @@ for the prediction of diabetes disease progression. We can create this
``DataModule`` below, wrapping the scikit-learn `Diabetes
dataset <https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset>`__.

.. testcode::

class DiabetesPipeline(flash.core.data.TaskDataPipeline):
def after_uncollate(self, samples):
return [f"disease progression: {float(s):.2f}" for s in samples]

class DiabetesData(flash.DataModule):
def __init__(self, batch_size=64, num_workers=0):
x, y = datasets.load_diabetes(return_X_y=True)
x = torch.from_numpy(x).float()
y = torch.from_numpy(y).float().unsqueeze(1)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.20, random_state=0)

train_ds = TensorDataset(x_train, y_train)
test_ds = TensorDataset(x_test, y_test)

super().__init__(
train_ds=train_ds,
test_ds=test_ds,
batch_size=batch_size,
num_workers=num_workers
)
self.num_inputs = x.shape[1]

@staticmethod
def default_pipeline():
return DiabetesPipeline()

You’ll notice we added a ``DataPipeline``, which will be used when we
call ``.predict()`` on our model. In this case we want to nicely format
Expand Down
50 changes: 3 additions & 47 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,50 +7,6 @@ Data
DataPipeline
------------

To make tasks work for inference, one must create a ``DataPipeline``.
The ``flash.core.data.DataPipeline`` exposes 6 hooks to override:

.. code:: python
class DataPipeline:
"""
This class purpose is to facilitate the conversion of raw data to processed or batched data and back.
Several hooks are provided for maximum flexibility.
collate_fn:
- before_collate
- collate
- after_collate
uncollate_fn:
- before_uncollate
- uncollate
- after_uncollate
"""
def before_collate(self, samples: Any) -> Any:
"""Override to apply transformations to samples"""
return samples
def collate(self, samples: Any) -> Any:
"""Override to convert a set of samples to a batch"""
if not isinstance(samples, Tensor):
return default_collate(samples)
return samples
def after_collate(self, batch: Any) -> Any:
"""Override to apply transformations to the batch"""
return batch
def before_uncollate(self, batch: Any) -> Any:
"""Override to apply transformations to the batch"""
return batch
def uncollate(self, batch: Any) -> ny:
"""Override to convert a batch to a set of samples"""
samples = batch
return samples
def after_uncollate(self, samples: Any) -> Any:
"""Override to apply transformations to samples"""
return samplesA
To make tasks work for inference, one must create a ``Preprocess`` and ``PostProcess``.
The ``flash.data.process.Preprocess`` exposes 9 hooks to override which can specifialzed for each stage using
``train``, ``val``, ``test``, ``predict`` prefixes.
4 changes: 1 addition & 3 deletions docs/source/reference/image_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Use the :class:`~flash.vision.ImageClassifier` pretrained model for inference on
print(predictions)
# 3b. Or generate predictions with a whole folder!
datamodule = ImageClassificationData.from_folder(folder="data/hymenoptera_data/predict/")
datamodule = ImageClassificationData.from_folders(predict_folder="data/hymenoptera_data/predict/")
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
Expand Down Expand Up @@ -185,5 +185,3 @@ ImageClassificationData
.. automethod:: flash.vision.ImageClassificationData.from_filepaths

.. automethod:: flash.vision.ImageClassificationData.from_folders

.. automethod:: flash.vision.ImageClassificationData.from_folder
12 changes: 6 additions & 6 deletions docs/source/reference/image_embedder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ Use the :class:`~flash.vision.ImageEmbedder` pretrained model for inference on a

.. code-block:: python
from flash.vision import ImageEmbedder
from flash.vision import ImageEmbedder
# Load finetuned task
embedder = ImageEmbedder(backbone="resnet18")
# Load finetuned task
embedder = ImageEmbedder(backbone="resnet18")
# 2. Perform inference on an image file
embeddings = embedder.predict("path/to/image.png")
print(embeddings)
# 2. Perform inference on an image file
embeddings = embedder.predict("path/to/image.png")
print(embeddings)
Or on a random image tensor

Expand Down
12 changes: 6 additions & 6 deletions docs/source/reference/object_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ Use the :class:`~flash.vision.ObjectDetector` pretrained model for inference on

.. code-block:: python
from flash.vision import ObjectDetector
from flash.vision import ObjectDetector
# 1. Load the model
detector = ObjectDetector()
# 1. Load the model
detector = ObjectDetector()
# 2. Perform inference on an image file
predictions = detector.predict("path/to/image.png")
print(predictions)
# 2. Perform inference on an image file
predictions = detector.predict("path/to/image.png")
print(predictions)
Or on a random image tensor

Expand Down
13 changes: 6 additions & 7 deletions docs/source/reference/tabular_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ We can use the Flash Tabular classification task to predict the probability a pa
We can create :class:`~flash.tabular.TabularData` from csv files using the :func:`~flash.tabular.TabularData.from_csv` method. We will pass in:

* **train_csv**- csv file containing the training data converted to a Pandas DataFrame
* **categorical_input**- a list of the names of columns that contain categorical data (strings or integers)
* **numerical_input**- a list of the names of columns that contain numerical continuous data (floats)
* **cat_cols**- a list of the names of columns that contain categorical data (strings or integers)
* **num_cols**- a list of the names of columns that contain numerical continuous data (floats)
* **target**- the name of the column we want to predict


Expand All @@ -56,8 +56,8 @@ Next, we create the :class:`~flash.tabular.TabularClassifier` task, using the Da
datamodule = TabularData.from_csv(
"./data/titanic/titanic.csv",
test_csv="./data/titanic/test.csv",
categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
numerical_input=["Fare"],
cat_cols=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
num_cols=["Fare"],
target="Survived",
val_size=0.25,
)
Expand Down Expand Up @@ -120,8 +120,8 @@ Or you can finetune your own model and use that for prediction:
datamodule = TabularData.from_csv(
"my_data_file.csv",
test_csv="./data/titanic/test.csv",
categorical_input=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
numerical_input=["Fare"],
cat_cols=["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"],
num_cols=["Fare"],
target="Survived",
val_size=0.25,
)
Expand Down Expand Up @@ -166,4 +166,3 @@ TabularData
.. automethod:: flash.tabular.TabularData.from_csv

.. automethod:: flash.tabular.TabularData.from_df

7 changes: 2 additions & 5 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)

from flash import tabular, text, vision # noqa: E402
from flash.core import data, utils # noqa: E402
from flash.core.classification import ClassificationTask # noqa: E402
from flash.core.data import DataModule # noqa: E402
from flash.core.data.utils import download_data # noqa: E402
from flash.core.model import Task # noqa: E402
from flash.core.trainer import Trainer # noqa: E402
from flash.data.data_module import DataModule # noqa: E402
from flash.data.utils import download_data # noqa: E402

__all__ = [
"Task",
Expand All @@ -42,7 +41,5 @@
"vision",
"text",
"tabular",
"data",
"utils",
"download_data",
]
17 changes: 6 additions & 11 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,18 @@
import torch
from torch import Tensor

from flash.core.data import TaskDataPipeline
from flash.core.model import Task
from flash.data.process import Postprocess


class ClassificationDataPipeline(TaskDataPipeline):
class ClassificationPostprocess(Postprocess):

def before_uncollate(self, batch: Union[Tensor, tuple]) -> Tensor:
if isinstance(batch, tuple):
batch = batch[0]
return torch.softmax(batch, -1)

def after_uncollate(self, samples: Any) -> Any:
def per_sample_transform(self, samples: Any) -> Any:
return torch.argmax(samples, -1).tolist()


class ClassificationTask(Task):

@staticmethod
def default_pipeline() -> ClassificationDataPipeline:
return ClassificationDataPipeline()
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._postprocess = ClassificationPostprocess()
3 changes: 0 additions & 3 deletions flash/core/data/__init__.py

This file was deleted.

Loading

0 comments on commit ba34bf4

Please sign in to comment.