Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[2/N] Data Sources #264

Merged
merged 10 commits into from
May 8, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ python:
version: 3.7
install:
- requirements: requirements/docs.txt
#- requirements: requirements.txt
#- requirements: requirements.txt
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ clean:
rm -rf .pytest_cache
rm -rf ./docs/build
rm -rf ./docs/source/**/generated
rm -rf ./docs/source/api
rm -rf ./docs/source/api
2 changes: 1 addition & 1 deletion docs/source/_static/images/logo.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 2 additions & 4 deletions flash/vision/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]
),
"post_tensor_transform": ApplyToKeys(
DefaultDataKeys.INPUT,
# TODO (Edgar): replace with resize once kornia is fixed
K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)),
K.geometry.Resize(image_size),
K.augmentation.RandomHorizontalFlip(),
),
"per_batch_transform_on_device": ApplyToKeys(
Expand Down Expand Up @@ -70,8 +69,7 @@ def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
),
"post_tensor_transform": ApplyToKeys(
DefaultDataKeys.INPUT,
# TODO (Edgar): replace with resize once kornia is fixed
K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)),
K.geometry.Resize(image_size),
),
"per_batch_transform_on_device": ApplyToKeys(
DefaultDataKeys.INPUT,
Expand Down
76 changes: 41 additions & 35 deletions flash_notebooks/custom_task_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,21 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, List, Tuple\n",
"from typing import Any, List, Tuple, Dict\n",
"\n",
"import numpy as np\n",
"import torch\n",
"from pytorch_lightning import seed_everything\n",
"from sklearn import datasets\n",
"from sklearn.model_selection import train_test_split\n",
"from torch import nn\n",
"from torch import nn, Tensor\n",
"\n",
"import flash\n",
"from flash.data.auto_dataset import AutoDataset\n",
"from flash.data.process import Postprocess, Preprocess"
"from flash.data.data_source import DataSource\n",
"from flash.data.process import Preprocess\n",
"\n",
"ND = np.ndarray"
]
},
{
Expand Down Expand Up @@ -152,24 +155,43 @@
"metadata": {},
"outputs": [],
"source": [
"class NumpyRegressionPreprocess(Preprocess):\n",
"class NumpyDataSource(DataSource):\n",
"\n",
" def load_data(self, data: Tuple[np.ndarray, np.ndarray], dataset: AutoDataset) -> List[Tuple[np.ndarray, float]]:\n",
" def load_data(self, data: Tuple[ND, ND], dataset: AutoDataset) -> List[Tuple[ND, float]]:\n",
" if self.training:\n",
" dataset.num_inputs = data[0].shape[1]\n",
" return [(x, y) for x, y in zip(*data)]\n",
"\n",
" def to_tensor_transform(self, sample: Any) -> Tuple[torch.Tensor, torch.Tensor]:\n",
" def predict_load_data(self, data: ND) -> ND:\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class NumpyPreprocess(Preprocess):\n",
"\n",
" def __init__(self):\n",
" super().__init__(data_sources={\"numpy\": NumpyDataSource()}, default_data_source=\"numpy\")\n",
"\n",
" def to_tensor_transform(self, sample: Any) -> Tuple[Tensor, Tensor]:\n",
" x, y = sample\n",
" x = torch.from_numpy(x).float()\n",
" y = torch.tensor(y, dtype=torch.float)\n",
" return x, y\n",
"\n",
" def predict_load_data(self, data: np.ndarray) -> np.ndarray:\n",
" return data\n",
" def predict_to_tensor_transform(self, sample: ND) -> ND:\n",
" return torch.from_numpy(sample).float()\n",
"\n",
" def get_state_dict(self) -> Dict[str, Any]:\n",
" return {}\n",
"\n",
" def predict_to_tensor_transform(self, sample: np.ndarray) -> np.ndarray:\n",
" return torch.from_numpy(sample).float()\n"
" @classmethod\n",
" def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool):\n",
" return cls()"
]
},
{
Expand All @@ -181,15 +203,16 @@
"class SklearnDataModule(flash.DataModule):\n",
"\n",
" @classmethod\n",
" def from_dataset(cls, x: np.ndarray, y: np.ndarray, batch_size: int = 64, num_workers: int = 0):\n",
" def from_dataset(cls, x: np.ndarray, y: np.ndarray, preprocess: Preprocess, batch_size: int = 64, num_workers: int = 0):\n",
"\n",
" preprocess = NumpyRegressionPreprocess()\n",
" preprocess = preprocess\n",
"\n",
" x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.20, random_state=0)\n",
"\n",
" dm = cls.from_load_data_inputs(\n",
" train_load_data_input=(x_train, y_train),\n",
" test_load_data_input=(x_test, y_test),\n",
" dm = cls.from_data_source(\n",
" \"numpy\",\n",
" train_data=(x_train, y_train),\n",
" test_data=(x_test, y_test),\n",
" preprocess=preprocess,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers\n",
Expand All @@ -204,7 +227,8 @@
"metadata": {},
"outputs": [],
"source": [
"datamodule = SklearnDataModule.from_dataset(*datasets.load_diabetes(return_X_y=True))"
"x, y = datasets.load_diabetes(return_X_y=True)\n",
"datamodule = SklearnDataModule.from_dataset(x, y, NumpyPreprocess())"
]
},
{
Expand Down Expand Up @@ -350,25 +374,7 @@
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"metadata": {},
"nbformat": 4,
"nbformat_minor": 4
}
Loading