Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[2/N] Data Sources #264

Merged
merged 10 commits into from
May 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ python:
version: 3.7
install:
- requirements: requirements/docs.txt
#- requirements: requirements.txt
#- requirements: requirements.txt
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ clean:
rm -rf .pytest_cache
rm -rf ./docs/build
rm -rf ./docs/source/**/generated
rm -rf ./docs/source/api
rm -rf ./docs/source/api
2 changes: 1 addition & 1 deletion docs/source/_static/images/logo.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 2 additions & 4 deletions flash/vision/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]
),
"post_tensor_transform": ApplyToKeys(
DefaultDataKeys.INPUT,
# TODO (Edgar): replace with resize once kornia is fixed
K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)),
K.geometry.Resize(image_size),
K.augmentation.RandomHorizontalFlip(),
),
"per_batch_transform_on_device": ApplyToKeys(
Expand Down Expand Up @@ -70,8 +69,7 @@ def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
),
"post_tensor_transform": ApplyToKeys(
DefaultDataKeys.INPUT,
# TODO (Edgar): replace with resize once kornia is fixed
K.augmentation.RandomResizedCrop(image_size, scale=(1.0, 1.0), ratio=(1.0, 1.0)),
K.geometry.Resize(image_size),
),
"per_batch_transform_on_device": ApplyToKeys(
DefaultDataKeys.INPUT,
Expand Down
76 changes: 41 additions & 35 deletions flash_notebooks/custom_task_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,21 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Any, List, Tuple\n",
"from typing import Any, List, Tuple, Dict\n",
"\n",
"import numpy as np\n",
"import torch\n",
"from pytorch_lightning import seed_everything\n",
"from sklearn import datasets\n",
"from sklearn.model_selection import train_test_split\n",
"from torch import nn\n",
"from torch import nn, Tensor\n",
"\n",
"import flash\n",
"from flash.data.auto_dataset import AutoDataset\n",
"from flash.data.process import Postprocess, Preprocess"
"from flash.data.data_source import DataSource\n",
"from flash.data.process import Preprocess\n",
"\n",
"ND = np.ndarray"
]
},
{
Expand Down Expand Up @@ -152,24 +155,43 @@
"metadata": {},
"outputs": [],
"source": [
"class NumpyRegressionPreprocess(Preprocess):\n",
"class NumpyDataSource(DataSource):\n",
"\n",
" def load_data(self, data: Tuple[np.ndarray, np.ndarray], dataset: AutoDataset) -> List[Tuple[np.ndarray, float]]:\n",
" def load_data(self, data: Tuple[ND, ND], dataset: AutoDataset) -> List[Tuple[ND, float]]:\n",
" if self.training:\n",
" dataset.num_inputs = data[0].shape[1]\n",
" return [(x, y) for x, y in zip(*data)]\n",
"\n",
" def to_tensor_transform(self, sample: Any) -> Tuple[torch.Tensor, torch.Tensor]:\n",
" def predict_load_data(self, data: ND) -> ND:\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class NumpyPreprocess(Preprocess):\n",
"\n",
" def __init__(self):\n",
" super().__init__(data_sources={\"numpy\": NumpyDataSource()}, default_data_source=\"numpy\")\n",
"\n",
" def to_tensor_transform(self, sample: Any) -> Tuple[Tensor, Tensor]:\n",
" x, y = sample\n",
" x = torch.from_numpy(x).float()\n",
" y = torch.tensor(y, dtype=torch.float)\n",
" return x, y\n",
"\n",
" def predict_load_data(self, data: np.ndarray) -> np.ndarray:\n",
" return data\n",
" def predict_to_tensor_transform(self, sample: ND) -> ND:\n",
" return torch.from_numpy(sample).float()\n",
"\n",
" def get_state_dict(self) -> Dict[str, Any]:\n",
" return {}\n",
"\n",
" def predict_to_tensor_transform(self, sample: np.ndarray) -> np.ndarray:\n",
" return torch.from_numpy(sample).float()\n"
" @classmethod\n",
" def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool):\n",
" return cls()"
]
},
{
Expand All @@ -181,15 +203,16 @@
"class SklearnDataModule(flash.DataModule):\n",
"\n",
" @classmethod\n",
" def from_dataset(cls, x: np.ndarray, y: np.ndarray, batch_size: int = 64, num_workers: int = 0):\n",
" def from_dataset(cls, x: np.ndarray, y: np.ndarray, preprocess: Preprocess, batch_size: int = 64, num_workers: int = 0):\n",
"\n",
" preprocess = NumpyRegressionPreprocess()\n",
" preprocess = preprocess\n",
"\n",
" x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.20, random_state=0)\n",
"\n",
" dm = cls.from_load_data_inputs(\n",
" train_load_data_input=(x_train, y_train),\n",
" test_load_data_input=(x_test, y_test),\n",
" dm = cls.from_data_source(\n",
" \"numpy\",\n",
" train_data=(x_train, y_train),\n",
" test_data=(x_test, y_test),\n",
" preprocess=preprocess,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers\n",
Expand All @@ -204,7 +227,8 @@
"metadata": {},
"outputs": [],
"source": [
"datamodule = SklearnDataModule.from_dataset(*datasets.load_diabetes(return_X_y=True))"
"x, y = datasets.load_diabetes(return_X_y=True)\n",
"datamodule = SklearnDataModule.from_dataset(x, y, NumpyPreprocess())"
]
},
{
Expand Down Expand Up @@ -350,25 +374,7 @@
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"metadata": {},
"nbformat": 4,
"nbformat_minor": 4
}
Loading