From 27163475d062c96de158262a5cc585c7dff278c2 Mon Sep 17 00:00:00 2001 From: Suman Michael Date: Thu, 15 Jul 2021 22:33:07 +0530 Subject: [PATCH 01/27] Revert "Added TabularRegressionData extending TabularData (#574)" This reverts commit c318e4ad --- README.md | 4 +- flash/tabular/data.py | 510 ------------------ flash/tabular/regression/__init__.py | 1 - flash/tabular/regression/data.py | 18 - flash_examples/tabular_classification.py | 4 +- tests/tabular/classification/test_data.py | 18 +- .../test_data_model_integration.py | 4 +- tests/tabular/classification/test_model.py | 5 +- 8 files changed, 18 insertions(+), 546 deletions(-) delete mode 100644 flash/tabular/data.py delete mode 100644 flash/tabular/regression/__init__.py delete mode 100644 flash/tabular/regression/data.py diff --git a/README.md b/README.md index b5d9a59187..88cccbcbe5 100644 --- a/README.md +++ b/README.md @@ -260,13 +260,13 @@ To illustrate, say we want to build a model to predict if a passenger survived o from torchmetrics.classification import Accuracy, Precision, Recall import flash from flash.core.data.utils import download_data -from flash.tabular import TabularClassifier, TabularClassificationData +from flash.tabular import TabularClassifier, TabularData # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", 'data/') # 2. Load the data -datamodule = TabularClassificationData.from_csv( +datamodule = TabularData.from_csv( ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], "Fare", target_fields="Survived", diff --git a/flash/tabular/data.py b/flash/tabular/data.py deleted file mode 100644 index f6a9d717e5..0000000000 --- a/flash/tabular/data.py +++ /dev/null @@ -1,510 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from io import StringIO -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -from pytorch_lightning.utilities.exceptions import MisconfigurationException - -from flash.core.classification import LabelsState -from flash.core.data.callback import BaseDataFetcher -from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources -from flash.core.data.process import Deserializer, Postprocess, Preprocess -from flash.core.utilities.imports import _PANDAS_AVAILABLE -from flash.tabular.classification.utils import ( - _compute_normalization, - _generate_codes, - _pre_transform, - _to_cat_vars_numpy, - _to_num_vars_numpy, -) - -if _PANDAS_AVAILABLE: - import pandas as pd - from pandas.core.frame import DataFrame -else: - DataFrame = object - - -class TabularDataFrameDataSource(DataSource[DataFrame]): - - def __init__( - self, - cat_cols: Optional[List[str]] = None, - num_cols: Optional[List[str]] = None, - target_col: Optional[str] = None, - mean: Optional[DataFrame] = None, - std: Optional[DataFrame] = None, - codes: Optional[Dict[str, Any]] = None, - target_codes: Optional[Dict[str, Any]] = None, - classes: Optional[List[str]] = None, - is_regression: bool = True, - ): - super().__init__() - - self.cat_cols = cat_cols - self.num_cols = num_cols - self.target_col = target_col - self.mean = mean - self.std = std - self.codes = codes - self.target_codes = target_codes - self.is_regression = is_regression - - self.set_state(LabelsState(classes)) - self.num_classes = len(classes) - - def common_load_data( - self, - df: DataFrame, - dataset: Optional[Any] = None, - ): - # impute_data - # compute train dataset stats - dfs = _pre_transform([df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, - self.target_codes) - - df = dfs[0] - - if dataset is not None: - dataset.num_samples = len(df) - - cat_vars = _to_cat_vars_numpy(df, self.cat_cols) - num_vars = _to_num_vars_numpy(df, self.num_cols) - - cat_vars = np.stack(cat_vars, 1) # if len(cat_vars) else np.zeros((len(self), 0)) - num_vars = np.stack(num_vars, 1) # if len(num_vars) else np.zeros((len(self), 0)) - return df, cat_vars, num_vars - - def load_data(self, data: DataFrame, dataset: Optional[Any] = None): - df, cat_vars, num_vars = self.common_load_data(data, dataset=dataset) - target = df[self.target_col].to_numpy().astype(np.float32 if self.is_regression else np.int64) - return [{ - DefaultDataKeys.INPUT: (c, n), - DefaultDataKeys.TARGET: t - } for c, n, t in zip(cat_vars, num_vars, target)] - - def predict_load_data(self, data: DataFrame, dataset: Optional[Any] = None): - _, cat_vars, num_vars = self.common_load_data(data, dataset=dataset) - return [{DefaultDataKeys.INPUT: (c, n)} for c, n in zip(cat_vars, num_vars)] - - -class TabularCSVDataSource(TabularDataFrameDataSource): - - def load_data(self, data: str, dataset: Optional[Any] = None): - return super().load_data(pd.read_csv(data), dataset=dataset) - - def predict_load_data(self, data: str, dataset: Optional[Any] = None): - return super().predict_load_data(pd.read_csv(data), dataset=dataset) - - -class TabularDeserializer(Deserializer): - - def __init__( - self, - cat_cols: Optional[List[str]] = None, - num_cols: Optional[List[str]] = None, - target_col: Optional[str] = None, - mean: Optional[DataFrame] = None, - std: Optional[DataFrame] = None, - codes: Optional[Dict[str, Any]] = None, - target_codes: Optional[Dict[str, Any]] = None, - classes: Optional[List[str]] = None, - is_regression: bool = True - ): - super().__init__() - self.cat_cols = cat_cols - self.num_cols = num_cols - self.target_col = target_col - self.mean = mean - self.std = std - self.codes = codes - self.target_codes = target_codes - self.classes = classes - self.is_regression = is_regression - - def deserialize(self, data: str) -> Any: - df = pd.read_csv(StringIO(data)) - df = _pre_transform([df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, - self.target_codes)[0] - - cat_vars = _to_cat_vars_numpy(df, self.cat_cols) - num_vars = _to_num_vars_numpy(df, self.num_cols) - - cat_vars = np.stack(cat_vars, 1) - num_vars = np.stack(num_vars, 1) - - return [{DefaultDataKeys.INPUT: [c, n]} for c, n in zip(cat_vars, num_vars)] - - @property - def example_input(self) -> str: - row = {} - for cat_col in self.cat_cols: - row[cat_col] = ["test"] - for num_col in self.num_cols: - row[num_col] = [0] - return str(DataFrame.from_dict(row).to_csv()) - - -class TabularPreprocess(Preprocess): - - def __init__( - self, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, - cat_cols: Optional[List[str]] = None, - num_cols: Optional[List[str]] = None, - target_col: Optional[str] = None, - mean: Optional[DataFrame] = None, - std: Optional[DataFrame] = None, - codes: Optional[Dict[str, Any]] = None, - target_codes: Optional[Dict[str, Any]] = None, - classes: Optional[List[str]] = None, - is_regression: bool = True, - deserializer: Optional[Deserializer] = None - ): - self.cat_cols = cat_cols - self.num_cols = num_cols - self.target_col = target_col - self.mean = mean - self.std = std - self.codes = codes - self.target_codes = target_codes - self.classes = classes - self.is_regression = is_regression - - super().__init__( - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - predict_transform=predict_transform, - data_sources={ - DefaultDataSources.CSV: TabularCSVDataSource( - cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression - ), - "data_frame": TabularDataFrameDataSource( - cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression - ), - }, - default_data_source=DefaultDataSources.CSV, - deserializer=deserializer or TabularDeserializer( - cat_cols=cat_cols, - num_cols=num_cols, - target_col=target_col, - mean=mean, - std=std, - codes=codes, - target_codes=target_codes, - classes=classes, - is_regression=is_regression - ) - ) - - def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: - return { - **self.transforms, - "cat_cols": self.cat_cols, - "num_cols": self.num_cols, - "target_col": self.target_col, - "mean": self.mean, - "std": self.std, - "codes": self.codes, - "target_codes": self.target_codes, - "classes": self.classes, - "is_regression": self.is_regression, - } - - @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> 'Preprocess': - return cls(**state_dict) - - -class TabularPostprocess(Postprocess): - - def uncollate(self, batch: Any) -> Any: - return batch - - -class TabularData(DataModule): - """Data module for tabular tasks""" - - preprocess_cls = TabularPreprocess - postprocess_cls = TabularPostprocess - - is_regression: bool = False - - @property - def codes(self) -> Dict[str, str]: - return self._data_source.codes - - @property - def num_classes(self) -> int: - return self._data_source.num_classes - - @property - def cat_cols(self) -> Optional[List[str]]: - return self._data_source.cat_cols - - @property - def num_cols(self) -> Optional[List[str]]: - return self._data_source.num_cols - - @property - def num_features(self) -> int: - return len(self.cat_cols) + len(self.num_cols) - - @property - def emb_sizes(self) -> list: - """Recommended embedding sizes.""" - - # https://developers.googleblog.com/2017/11/introducing-tensorflow-feature-columns.html - # The following "formula" provides a general rule of thumb about the number of embedding dimensions: - # embedding_dimensions = number_of_categories**0.25 - num_classes = [len(self.codes[cat]) for cat in self.cat_cols] - emb_dims = [max(int(n**0.25), 16) for n in num_classes] - return list(zip(num_classes, emb_dims)) - - @staticmethod - def _sanetize_cols(cat_cols: Optional[Union[str, List[str]]], num_cols: Optional[Union[str, List[str]]]): - if cat_cols is None and num_cols is None: - raise RuntimeError('Both `cat_cols` and `num_cols` are None!') - - return cat_cols or [], num_cols or [] - - @classmethod - def compute_state( - cls, - train_data_frame: DataFrame, - val_data_frame: Optional[DataFrame], - test_data_frame: Optional[DataFrame], - predict_data_frame: Optional[DataFrame], - target_fields: str, - numerical_fields: List[str], - categorical_fields: List[str], - ) -> Tuple[float, float, List[str], Dict[str, Any], Dict[str, Any]]: - - if train_data_frame is None: - raise MisconfigurationException( - "train_data_frame is required to instantiate the TabularDataFrameDataSource" - ) - - data_frames = [train_data_frame] - - if val_data_frame is not None: - data_frames += [val_data_frame] - - if test_data_frame is not None: - data_frames += [test_data_frame] - - if predict_data_frame is not None: - data_frames += [predict_data_frame] - - mean, std = _compute_normalization(data_frames[0], numerical_fields) - - classes = list(data_frames[0][target_fields].unique()) - - if data_frames[0][target_fields].dtype == object: - # if the target_fields is a category, not an int - target_codes = _generate_codes(data_frames, [target_fields]) - else: - target_codes = None - codes = _generate_codes(data_frames, categorical_fields) - - return mean, std, classes, codes, target_codes - - @classmethod - def from_data_frame( - cls, - categorical_fields: Optional[Union[str, List[str]]], - numerical_fields: Optional[Union[str, List[str]]], - target_fields: Optional[str] = None, - train_data_frame: Optional[DataFrame] = None, - val_data_frame: Optional[DataFrame] = None, - test_data_frame: Optional[DataFrame] = None, - predict_data_frame: Optional[DataFrame] = None, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, - data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, - val_split: Optional[float] = None, - batch_size: int = 4, - num_workers: Optional[int] = None, - **preprocess_kwargs: Any, - ): - """Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames. - - Args: - categorical_fields: The field or fields (columns) in the CSV file containing categorical inputs. - numerical_fields: The field or fields (columns) in the CSV file containing numerical inputs. - target_fields: The field or fields (columns) in the CSV file to use for the target. - train_data_frame: The pandas ``DataFrame`` containing the training data. - val_data_frame: The pandas ``DataFrame`` containing the validation data. - test_data_frame: The pandas ``DataFrame`` containing the testing data. - predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting. - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. - data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the - :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` - will be constructed and used. - val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. - - Returns: - The constructed data module. - - Examples:: - - data_module = TabularData.from_data_frame( - "categorical_input", - "numerical_input", - "target", - train_data_frame=train_data, - ) - """ - categorical_fields, numerical_fields = cls._sanetize_cols(categorical_fields, numerical_fields) - - if not isinstance(categorical_fields, list): - categorical_fields = [categorical_fields] - - if not isinstance(numerical_fields, list): - numerical_fields = [numerical_fields] - - mean, std, classes, codes, target_codes = cls.compute_state( - train_data_frame=train_data_frame, - val_data_frame=val_data_frame, - test_data_frame=test_data_frame, - predict_data_frame=predict_data_frame, - target_fields=target_fields, - numerical_fields=numerical_fields, - categorical_fields=categorical_fields, - ) - - return cls.from_data_source( - "data_frame", - train_data_frame, - val_data_frame, - test_data_frame, - predict_data_frame, - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - predict_transform=predict_transform, - data_fetcher=data_fetcher, - preprocess=preprocess, - val_split=val_split, - batch_size=batch_size, - num_workers=num_workers, - cat_cols=categorical_fields, - num_cols=numerical_fields, - target_col=target_fields, - mean=mean, - std=std, - codes=codes, - target_codes=target_codes, - classes=classes, - is_regression=cls.is_regression, - **preprocess_kwargs, - ) - - @classmethod - def from_csv( - cls, - categorical_fields: Optional[Union[str, List[str]]], - numerical_fields: Optional[Union[str, List[str]]], - target_fields: Optional[str] = None, - train_file: Optional[str] = None, - val_file: Optional[str] = None, - test_file: Optional[str] = None, - predict_file: Optional[str] = None, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, - data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, - val_split: Optional[float] = None, - batch_size: int = 4, - num_workers: Optional[int] = None, - **preprocess_kwargs: Any, - ) -> 'DataModule': - """Creates a :class:`~flash.tabular.data.TabularData` object from the given CSV files. - - Args: - categorical_fields: The field or fields (columns) in the CSV file containing categorical inputs. - numerical_fields: The field or fields (columns) in the CSV file containing numerical inputs. - target_fields: The field or fields (columns) in the CSV file to use for the target. - train_file: The CSV file containing the training data. - val_file: The CSV file containing the validation data. - test_file: The CSV file containing the testing data. - predict_file: The CSV file containing the data to use when predicting. - train_transform: The dictionary of transforms to use during training which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. - val_transform: The dictionary of transforms to use during validation which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. - test_transform: The dictionary of transforms to use during testing which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. - predict_transform: The dictionary of transforms to use during predicting which maps - :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. - data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the - :class:`~flash.core.data.data_module.DataModule`. - preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the - :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` - will be constructed and used. - val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. - preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used - if ``preprocess = None``. - - Returns: - The constructed data module. - - Examples:: - - data_module = TabularData.from_csv( - "categorical_input", - "numerical_input", - "target", - train_file="train_data.csv", - ) - """ - return cls.from_data_frame( - categorical_fields=categorical_fields, - numerical_fields=numerical_fields, - target_fields=target_fields, - train_data_frame=pd.read_csv(train_file) if train_file is not None else None, - val_data_frame=pd.read_csv(val_file) if val_file is not None else None, - test_data_frame=pd.read_csv(test_file) if test_file is not None else None, - predict_data_frame=pd.read_csv(predict_file) if predict_file is not None else None, - preprocess=preprocess, - val_split=val_split, - batch_size=batch_size, - num_workers=num_workers, - ) diff --git a/flash/tabular/regression/__init__.py b/flash/tabular/regression/__init__.py deleted file mode 100644 index a93e599ff0..0000000000 --- a/flash/tabular/regression/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from flash.tabular.regression.data import TabularRegressionData # noqa: F401 diff --git a/flash/tabular/regression/data.py b/flash/tabular/regression/data.py deleted file mode 100644 index 04dd8cd3b4..0000000000 --- a/flash/tabular/regression/data.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from flash.tabular.data import TabularData - - -class TabularRegressionData(TabularData): - is_regression = True diff --git a/flash_examples/tabular_classification.py b/flash_examples/tabular_classification.py index 9e6b0ab049..fa3a2cc23e 100644 --- a/flash_examples/tabular_classification.py +++ b/flash_examples/tabular_classification.py @@ -13,12 +13,12 @@ # limitations under the License. import flash from flash.core.data.utils import download_data -from flash.tabular import TabularClassificationData, TabularClassifier +from flash.tabular import TabularClassifier, TabularData # 1. Create the DataModule download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data") -datamodule = TabularClassificationData.from_csv( +datamodule = TabularData.from_csv( ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], "Fare", target_fields="Survived", diff --git a/tests/tabular/classification/test_data.py b/tests/tabular/classification/test_data.py index 6bf2cae4fb..baa87b3451 100644 --- a/tests/tabular/classification/test_data.py +++ b/tests/tabular/classification/test_data.py @@ -23,7 +23,7 @@ if _PANDAS_AVAILABLE: import pandas as pd - from flash.tabular import TabularClassificationData + from flash.tabular import TabularData from flash.tabular.classification.utils import _categorize, _normalize TEST_DF_1 = pd.DataFrame( @@ -73,19 +73,19 @@ def test_emb_sizes(): self.codes = {"category": [None, "a", "b", "c"]} self.cat_cols = ["category"] # use __get__ to test property with mocked self - es = TabularClassificationData.emb_sizes.__get__(self) # pylint: disable=E1101 + es = TabularData.emb_sizes.__get__(self) # pylint: disable=E1101 assert es == [(4, 16)] self.codes = {} self.cat_cols = [] # use __get__ to test property with mocked self - es = TabularClassificationData.emb_sizes.__get__(self) # pylint: disable=E1101 + es = TabularData.emb_sizes.__get__(self) # pylint: disable=E1101 assert es == [] self.codes = {"large": ["a"] * 100_000, "larger": ["b"] * 1_000_000} self.cat_cols = ["large", "larger"] # use __get__ to test property with mocked self - es = TabularClassificationData.emb_sizes.__get__(self) # pylint: disable=E1101 + es = TabularData.emb_sizes.__get__(self) # pylint: disable=E1101 assert es == [(100_000, 17), (1_000_000, 31)] @@ -94,7 +94,7 @@ def test_tabular_data(tmpdir): train_data_frame = TEST_DF_1.copy() val_data_frame = TEST_DF_2.copy() test_data_frame = TEST_DF_2.copy() - dm = TabularClassificationData.from_data_frame( + dm = TabularData.from_data_frame( categorical_fields=["category"], numerical_fields=["scalar_a", "scalar_b"], target_fields="label", @@ -122,7 +122,7 @@ def test_categorical_target(tmpdir): # change int label to string df["label"] = df["label"].astype(str) - dm = TabularClassificationData.from_data_frame( + dm = TabularData.from_data_frame( categorical_fields=["category"], numerical_fields=["scalar_a", "scalar_b"], target_fields="label", @@ -146,7 +146,7 @@ def test_from_data_frame(tmpdir): train_data_frame = TEST_DF_1.copy() val_data_frame = TEST_DF_2.copy() test_data_frame = TEST_DF_2.copy() - dm = TabularClassificationData.from_data_frame( + dm = TabularData.from_data_frame( categorical_fields=["category"], numerical_fields=["scalar_a", "scalar_b"], target_fields="label", @@ -173,7 +173,7 @@ def test_from_csv(tmpdir): TEST_DF_2.to_csv(val_csv) TEST_DF_2.to_csv(test_csv) - dm = TabularClassificationData.from_csv( + dm = TabularData.from_csv( categorical_fields=["category"], numerical_fields=["scalar_a", "scalar_b"], target_fields="label", @@ -196,7 +196,7 @@ def test_from_csv(tmpdir): def test_empty_inputs(): train_data_frame = TEST_DF_1.copy() with pytest.raises(RuntimeError): - TabularClassificationData.from_data_frame( + TabularData.from_data_frame( numerical_fields=None, categorical_fields=None, target_fields="label", diff --git a/tests/tabular/classification/test_data_model_integration.py b/tests/tabular/classification/test_data_model_integration.py index e30cac67c8..349aeeaaba 100644 --- a/tests/tabular/classification/test_data_model_integration.py +++ b/tests/tabular/classification/test_data_model_integration.py @@ -15,7 +15,7 @@ import pytorch_lightning as pl from flash.core.utilities.imports import _TABULAR_AVAILABLE -from flash.tabular import TabularClassificationData, TabularClassifier +from flash.tabular import TabularClassifier, TabularData from tests.helpers.utils import _TABULAR_TESTING if _TABULAR_AVAILABLE: @@ -37,7 +37,7 @@ def test_classification(tmpdir): train_data_frame = TEST_DF_1.copy() val_data_frame = TEST_DF_1.copy() test_data_frame = TEST_DF_1.copy() - data = TabularClassificationData.from_data_frame( + data = TabularData.from_data_frame( categorical_fields=["category"], numerical_fields=["scalar_a", "scalar_b"], target_fields="label", diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index a64c2d090d..d3cc3db332 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -21,7 +21,8 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _TABULAR_AVAILABLE -from flash.tabular import TabularClassificationData, TabularClassifier +from flash.tabular import TabularClassifier +from flash.tabular.classification.data import TabularData from tests.helpers.utils import _SERVE_TESTING, _TABULAR_TESTING # ======== Mock functions ======== @@ -99,7 +100,7 @@ def test_jit(tmpdir): @mock.patch("flash._IS_TESTING", True) def test_serve(): train_data = {"num_col": [1.4, 2.5], "cat_col": ["positive", "negative"], "target": [1, 2]} - datamodule = TabularClassificationData.from_data_frame( + datamodule = TabularData.from_data_frame( "cat_col", "num_col", "target", From a34be7d1eb93efad13f16ba6c976a1136411adbf Mon Sep 17 00:00:00 2001 From: Suman Michael Date: Thu, 15 Jul 2021 23:18:38 +0530 Subject: [PATCH 02/27] added DataModule, PreProcess, DataSource for TabularForecasting --- flash/core/data/data_source.py | 1 + flash/core/utilities/imports.py | 3 +- flash/tabular/__init__.py | 9 +- flash/tabular/classification/__init__.py | 2 +- flash/tabular/classification/data.py | 503 ++++++++++++++++++++++- flash/tabular/forecasting/__init__.py | 5 + flash/tabular/forecasting/data.py | 210 ++++++++++ requirements/datatype_tabular.txt | 1 + 8 files changed, 726 insertions(+), 8 deletions(-) create mode 100644 flash/tabular/forecasting/__init__.py create mode 100644 flash/tabular/forecasting/data.py diff --git a/flash/core/data/data_source.py b/flash/core/data/data_source.py index d3c7c611ef..7aff883fdb 100644 --- a/flash/core/data/data_source.py +++ b/flash/core/data/data_source.py @@ -156,6 +156,7 @@ class DefaultDataSources(LightningEnum): JSON = "json" DATASET = "dataset" FIFTYONE = "fiftyone" + DATAFRAME = "dataframe" # TODO: Create a FlashEnum class??? def __hash__(self) -> int: diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 9922f49eba..213fc455cd 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -68,6 +68,7 @@ def _compare_version(package: str, op, version) -> bool: _PANDAS_AVAILABLE = _module_available("pandas") _SKLEARN_AVAILABLE = _module_available("sklearn") _TABNET_AVAILABLE = _module_available("pytorch_tabnet") +_FORECASTING_AVAILABLE = _module_available("pytorch_forecasting") _KORNIA_AVAILABLE = _module_available("kornia") _COCO_AVAILABLE = _module_available("pycocotools") _TIMM_AVAILABLE = _module_available("timm") @@ -94,7 +95,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0") _TEXT_AVAILABLE = _TRANSFORMERS_AVAILABLE -_TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE +_TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE and _FORECASTING_AVAILABLE _VIDEO_AVAILABLE = _PYTORCHVIDEO_AVAILABLE _IMAGE_AVAILABLE = all([ _TORCHVISION_AVAILABLE, diff --git a/flash/tabular/__init__.py b/flash/tabular/__init__.py index 22698efc99..7127e23641 100644 --- a/flash/tabular/__init__.py +++ b/flash/tabular/__init__.py @@ -1,3 +1,6 @@ -from flash.tabular.classification import TabularClassificationData, TabularClassifier # noqa: F401 -from flash.tabular.data import TabularData # noqa: F401 -from flash.tabular.regression import TabularRegressionData # noqa: F401 +from flash.tabular.classification import TabularClassifier, TabularData # noqa: F401 +from flash.tabular.forecasting import ( + TabularForecastingData, + TabularForecastingPreprocess, + TabularForecastingDataFrameDataSource +) \ No newline at end of file diff --git a/flash/tabular/classification/__init__.py b/flash/tabular/classification/__init__.py index 6134277abf..45724db27b 100644 --- a/flash/tabular/classification/__init__.py +++ b/flash/tabular/classification/__init__.py @@ -1,2 +1,2 @@ -from flash.tabular.classification.data import TabularClassificationData # noqa: F401 +from flash.tabular.classification.data import TabularData # noqa: F401 from flash.tabular.classification.model import TabularClassifier # noqa: F401 diff --git a/flash/tabular/classification/data.py b/flash/tabular/classification/data.py index 63cdda9ea2..c2a60e24da 100644 --- a/flash/tabular/classification/data.py +++ b/flash/tabular/classification/data.py @@ -11,8 +11,505 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from flash.tabular.data import TabularData +from io import StringIO +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np +from pytorch_lightning.utilities.exceptions import MisconfigurationException -class TabularClassificationData(TabularData): - is_regression = False +from flash.core.classification import LabelsState +from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.process import Deserializer, Postprocess, Preprocess +from flash.core.utilities.imports import _PANDAS_AVAILABLE +from flash.tabular.classification.utils import ( + _compute_normalization, + _generate_codes, + _pre_transform, + _to_cat_vars_numpy, + _to_num_vars_numpy, +) + +if _PANDAS_AVAILABLE: + import pandas as pd + from pandas.core.frame import DataFrame +else: + DataFrame = object + + +class TabularDataFrameDataSource(DataSource[DataFrame]): + + def __init__( + self, + cat_cols: Optional[List[str]] = None, + num_cols: Optional[List[str]] = None, + target_col: Optional[str] = None, + mean: Optional[DataFrame] = None, + std: Optional[DataFrame] = None, + codes: Optional[Dict[str, Any]] = None, + target_codes: Optional[Dict[str, Any]] = None, + classes: Optional[List[str]] = None, + is_regression: bool = True, + ): + super().__init__() + + self.cat_cols = cat_cols + self.num_cols = num_cols + self.target_col = target_col + self.mean = mean + self.std = std + self.codes = codes + self.target_codes = target_codes + self.is_regression = is_regression + + self.set_state(LabelsState(classes)) + self.num_classes = len(classes) + + def common_load_data( + self, + df: DataFrame, + dataset: Optional[Any] = None, + ): + # impute_data + # compute train dataset stats + dfs = _pre_transform([df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, + self.target_codes) + + df = dfs[0] + + if dataset is not None: + dataset.num_samples = len(df) + + cat_vars = _to_cat_vars_numpy(df, self.cat_cols) + num_vars = _to_num_vars_numpy(df, self.num_cols) + + cat_vars = np.stack(cat_vars, 1) # if len(cat_vars) else np.zeros((len(self), 0)) + num_vars = np.stack(num_vars, 1) # if len(num_vars) else np.zeros((len(self), 0)) + return df, cat_vars, num_vars + + def load_data(self, data: DataFrame, dataset: Optional[Any] = None): + df, cat_vars, num_vars = self.common_load_data(data, dataset=dataset) + target = df[self.target_col].to_numpy().astype(np.float32 if self.is_regression else np.int64) + return [{ + DefaultDataKeys.INPUT: (c, n), + DefaultDataKeys.TARGET: t + } for c, n, t in zip(cat_vars, num_vars, target)] + + def predict_load_data(self, data: DataFrame, dataset: Optional[Any] = None): + _, cat_vars, num_vars = self.common_load_data(data, dataset=dataset) + return [{DefaultDataKeys.INPUT: (c, n)} for c, n in zip(cat_vars, num_vars)] + + +class TabularCSVDataSource(TabularDataFrameDataSource): + + def load_data(self, data: str, dataset: Optional[Any] = None): + return super().load_data(pd.read_csv(data), dataset=dataset) + + def predict_load_data(self, data: str, dataset: Optional[Any] = None): + return super().predict_load_data(pd.read_csv(data), dataset=dataset) + + +class TabularDeserializer(Deserializer): + + def __init__( + self, + cat_cols: Optional[List[str]] = None, + num_cols: Optional[List[str]] = None, + target_col: Optional[str] = None, + mean: Optional[DataFrame] = None, + std: Optional[DataFrame] = None, + codes: Optional[Dict[str, Any]] = None, + target_codes: Optional[Dict[str, Any]] = None, + classes: Optional[List[str]] = None, + is_regression: bool = True + ): + super().__init__() + self.cat_cols = cat_cols + self.num_cols = num_cols + self.target_col = target_col + self.mean = mean + self.std = std + self.codes = codes + self.target_codes = target_codes + self.classes = classes + self.is_regression = is_regression + + def deserialize(self, data: str) -> Any: + df = pd.read_csv(StringIO(data)) + df = _pre_transform([df], self.num_cols, self.cat_cols, self.codes, self.mean, self.std, self.target_col, + self.target_codes)[0] + + cat_vars = _to_cat_vars_numpy(df, self.cat_cols) + num_vars = _to_num_vars_numpy(df, self.num_cols) + + cat_vars = np.stack(cat_vars, 1) + num_vars = np.stack(num_vars, 1) + + return [{DefaultDataKeys.INPUT: [c, n]} for c, n in zip(cat_vars, num_vars)] + + @property + def example_input(self) -> str: + row = {} + for cat_col in self.cat_cols: + row[cat_col] = ["test"] + for num_col in self.num_cols: + row[num_col] = [0] + return str(DataFrame.from_dict(row).to_csv()) + + +class TabularPreprocess(Preprocess): + + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + cat_cols: Optional[List[str]] = None, + num_cols: Optional[List[str]] = None, + target_col: Optional[str] = None, + mean: Optional[DataFrame] = None, + std: Optional[DataFrame] = None, + codes: Optional[Dict[str, Any]] = None, + target_codes: Optional[Dict[str, Any]] = None, + classes: Optional[List[str]] = None, + is_regression: bool = True, + deserializer: Optional[Deserializer] = None + ): + self.cat_cols = cat_cols + self.num_cols = num_cols + self.target_col = target_col + self.mean = mean + self.std = std + self.codes = codes + self.target_codes = target_codes + self.classes = classes + self.is_regression = is_regression + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.CSV: TabularCSVDataSource( + cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression + ), + "data_frame": TabularDataFrameDataSource( + cat_cols, num_cols, target_col, mean, std, codes, target_codes, classes, is_regression + ), + }, + default_data_source=DefaultDataSources.CSV, + deserializer=deserializer or TabularDeserializer( + cat_cols=cat_cols, + num_cols=num_cols, + target_col=target_col, + mean=mean, + std=std, + codes=codes, + target_codes=target_codes, + classes=classes, + is_regression=is_regression + ) + ) + + def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: + return { + **self.transforms, + "cat_cols": self.cat_cols, + "num_cols": self.num_cols, + "target_col": self.target_col, + "mean": self.mean, + "std": self.std, + "codes": self.codes, + "target_codes": self.target_codes, + "classes": self.classes, + "is_regression": self.is_regression, + } + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> 'Preprocess': + return cls(**state_dict) + + +class TabularPostprocess(Postprocess): + + def uncollate(self, batch: Any) -> Any: + return batch + + +class TabularData(DataModule): + """Data module for tabular tasks""" + + preprocess_cls = TabularPreprocess + postprocess_cls = TabularPostprocess + + @property + def codes(self) -> Dict[str, str]: + return self._data_source.codes + + @property + def num_classes(self) -> int: + return self._data_source.num_classes + + @property + def cat_cols(self) -> Optional[List[str]]: + return self._data_source.cat_cols + + @property + def num_cols(self) -> Optional[List[str]]: + return self._data_source.num_cols + + @property + def num_features(self) -> int: + return len(self.cat_cols) + len(self.num_cols) + + @property + def emb_sizes(self) -> list: + """Recommended embedding sizes.""" + + # https://developers.googleblog.com/2017/11/introducing-tensorflow-feature-columns.html + # The following "formula" provides a general rule of thumb about the number of embedding dimensions: + # embedding_dimensions = number_of_categories**0.25 + num_classes = [len(self.codes[cat]) for cat in self.cat_cols] + emb_dims = [max(int(n**0.25), 16) for n in num_classes] + return list(zip(num_classes, emb_dims)) + + @staticmethod + def _sanetize_cols(cat_cols: Optional[Union[str, List[str]]], num_cols: Optional[Union[str, List[str]]]): + if cat_cols is None and num_cols is None: + raise RuntimeError('Both `cat_cols` and `num_cols` are None!') + + return cat_cols or [], num_cols or [] + + @classmethod + def compute_state( + cls, + train_data_frame: DataFrame, + val_data_frame: Optional[DataFrame], + test_data_frame: Optional[DataFrame], + predict_data_frame: Optional[DataFrame], + target_fields: str, + numerical_fields: List[str], + categorical_fields: List[str], + ) -> Tuple[float, float, List[str], Dict[str, Any], Dict[str, Any]]: + + if train_data_frame is None: + raise MisconfigurationException( + "train_data_frame is required to instantiate the TabularDataFrameDataSource" + ) + + data_frames = [train_data_frame] + + if val_data_frame is not None: + data_frames += [val_data_frame] + + if test_data_frame is not None: + data_frames += [test_data_frame] + + if predict_data_frame is not None: + data_frames += [predict_data_frame] + + mean, std = _compute_normalization(data_frames[0], numerical_fields) + + classes = list(data_frames[0][target_fields].unique()) + + if data_frames[0][target_fields].dtype == object: + # if the target_fields is a category, not an int + target_codes = _generate_codes(data_frames, [target_fields]) + else: + target_codes = None + codes = _generate_codes(data_frames, categorical_fields) + + return mean, std, classes, codes, target_codes + + @classmethod + def from_data_frame( + cls, + categorical_fields: Optional[Union[str, List[str]]], + numerical_fields: Optional[Union[str, List[str]]], + target_fields: Optional[str] = None, + train_data_frame: Optional[DataFrame] = None, + val_data_frame: Optional[DataFrame] = None, + test_data_frame: Optional[DataFrame] = None, + predict_data_frame: Optional[DataFrame] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + is_regression: bool = False, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames. + + Args: + categorical_fields: The field or fields (columns) in the CSV file containing categorical inputs. + numerical_fields: The field or fields (columns) in the CSV file containing numerical inputs. + target_fields: The field or fields (columns) in the CSV file to use for the target. + train_data_frame: The pandas ``DataFrame`` containing the training data. + val_data_frame: The pandas ``DataFrame`` containing the validation data. + test_data_frame: The pandas ``DataFrame`` containing the testing data. + predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + is_regression: If ``True``, targets will be formatted as floating point. If ``False``, targets will be + formatted as integers. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = TabularData.from_data_frame( + "categorical_input", + "numerical_input", + "target", + train_data_frame=train_data, + ) + """ + categorical_fields, numerical_fields = cls._sanetize_cols(categorical_fields, numerical_fields) + + if not isinstance(categorical_fields, list): + categorical_fields = [categorical_fields] + + if not isinstance(numerical_fields, list): + numerical_fields = [numerical_fields] + + mean, std, classes, codes, target_codes = cls.compute_state( + train_data_frame=train_data_frame, + val_data_frame=val_data_frame, + test_data_frame=test_data_frame, + predict_data_frame=predict_data_frame, + target_fields=target_fields, + numerical_fields=numerical_fields, + categorical_fields=categorical_fields, + ) + + return cls.from_data_source( + "data_frame", + train_data_frame, + val_data_frame, + test_data_frame, + predict_data_frame, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + cat_cols=categorical_fields, + num_cols=numerical_fields, + target_col=target_fields, + mean=mean, + std=std, + codes=codes, + target_codes=target_codes, + classes=classes, + is_regression=is_regression, + **preprocess_kwargs, + ) + + @classmethod + def from_csv( + cls, + categorical_fields: Optional[Union[str, List[str]]], + numerical_fields: Optional[Union[str, List[str]]], + target_fields: Optional[str] = None, + train_file: Optional[str] = None, + val_file: Optional[str] = None, + test_file: Optional[str] = None, + predict_file: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + is_regression: bool = False, + **preprocess_kwargs: Any, + ) -> 'DataModule': + """Creates a :class:`~flash.tabular.data.TabularData` object from the given CSV files. + + Args: + categorical_fields: The field or fields (columns) in the CSV file containing categorical inputs. + numerical_fields: The field or fields (columns) in the CSV file containing numerical inputs. + target_fields: The field or fields (columns) in the CSV file to use for the target. + train_file: The CSV file containing the training data. + val_file: The CSV file containing the validation data. + test_file: The CSV file containing the testing data. + predict_file: The CSV file containing the data to use when predicting. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + is_regression: If ``True``, targets will be formatted as floating point. If ``False``, targets will be + formatted as integers. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = TabularData.from_csv( + "categorical_input", + "numerical_input", + "target", + train_file="train_data.csv", + ) + """ + return cls.from_data_frame( + categorical_fields=categorical_fields, + numerical_fields=numerical_fields, + target_fields=target_fields, + train_data_frame=pd.read_csv(train_file) if train_file is not None else None, + val_data_frame=pd.read_csv(val_file) if val_file is not None else None, + test_data_frame=pd.read_csv(test_file) if test_file is not None else None, + predict_data_frame=pd.read_csv(predict_file) if predict_file is not None else None, + is_regression=is_regression, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + ) diff --git a/flash/tabular/forecasting/__init__.py b/flash/tabular/forecasting/__init__.py new file mode 100644 index 0000000000..d1036c9c35 --- /dev/null +++ b/flash/tabular/forecasting/__init__.py @@ -0,0 +1,5 @@ +from flash.tabular.forecasting.data import ( + TabularForecastingData, + TabularForecastingPreprocess, + TabularForecastingDataFrameDataSource +) \ No newline at end of file diff --git a/flash/tabular/forecasting/data.py b/flash/tabular/forecasting/data.py new file mode 100644 index 0000000000..068c2665af --- /dev/null +++ b/flash/tabular/forecasting/data.py @@ -0,0 +1,210 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence, Mapping + +from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.process import Deserializer, Postprocess, Preprocess +from flash.core.data.properties import ProcessState +from flash.core.utilities.imports import _PANDAS_AVAILABLE, _FORECASTING_AVAILABLE, requires_extras + +if _PANDAS_AVAILABLE: + from pandas.core.frame import DataFrame +else: + DataFrame = object + +if _FORECASTING_AVAILABLE: + from pytorch_forecasting import TimeSeriesDataSet + + +@dataclass(unsafe_hash=True, frozen=True) +class TimeSeriesDataSetState(ProcessState): + """ + A :class:`~flash.core.data.properties.ProcessState` containing ``labels``, + a mapping from class index to label. + """ + + time_series_dataset: Optional[TimeSeriesDataSet] + + +class TabularForecastingDataFrameDataSource(DataSource[DataFrame]): + + @requires_extras("tabular") + def __init__( + self, + time_idx: str, + target: Union[str, List[str]], + group_ids: List[str], + **data_source_kwargs: Any + ): + self.time_idx = time_idx + self.target = target + self.group_ids = group_ids + self.data_source_kwargs = data_source_kwargs + super().__init__() + + self.dataset = None + + def load_data(self, data: DataFrame, dataset: Optional[Any] = None): + if self.training: + dataset.time_series_dataset = TimeSeriesDataSet( + data, time_idx=self.time_idx, group_ids=self.group_ids, target=self.target, **self.data_source_kwargs + ) + self.set_state(TimeSeriesDataSetState(dataset.time_series_dataset)) + return dataset.time_series_dataset + else: + train_time_series_dataset = self.get_state(TimeSeriesDataSetState).time_series_dataset + eval_time_series_dataset = TimeSeriesDataSet.from_dataset( + train_time_series_dataset, data, + min_prediction_idx=train_time_series_dataset.index.time.max() + 1, + stop_randomization=True + ) + return eval_time_series_dataset + + @staticmethod + def load_sample(sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: + return {DefaultDataKeys.INPUT: sample[0], DefaultDataKeys.TARGET: sample[1]} + + +class TabularForecastingPreprocess(Preprocess): + + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + deserializer: Optional[Deserializer] = None, + **data_source_kwargs: Any + ): + self.data_source_kwargs = data_source_kwargs + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.DATAFRAME: TabularForecastingDataFrameDataSource( + **data_source_kwargs + ), + }, + deserializer=deserializer + ) + + def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: + return { + **self.transforms, + **self.data_source_kwargs + } + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> 'Preprocess': + return cls(**state_dict) + + +class TabularForecastingData(DataModule): + """Data module for tabular tasks""" + + preprocess_cls = TabularForecastingPreprocess + + @property + def data_source(self) -> DataSource: + return self._data_source + + @classmethod + def from_data_frame( + cls, + time_idx: str, + target: Union[str, List[str]], + group_ids: List[str], + train_data_frame: Optional[DataFrame] = None, + val_data_frame: Optional[DataFrame] = None, + test_data_frame: Optional[DataFrame] = None, + predict_data_frame: Optional[DataFrame] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames. + + Args: + group_ids: + target: + time_idx: + train_data_frame: The pandas ``DataFrame`` containing the training data. + val_data_frame: The pandas ``DataFrame`` containing the validation data. + test_data_frame: The pandas ``DataFrame`` containing the testing data. + predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = TabularData.from_data_frame( + "categorical_input", + "numerical_input", + "target", + train_data_frame=train_data, + ) + + """ + + return cls.from_data_source( + time_idx=time_idx, + target=target, + group_ids=group_ids, + data_source=DefaultDataSources.DATAFRAME, + train_data=train_data_frame, + val_data=val_data_frame, + test_data=test_data_frame, + predict_data=predict_data_frame, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) diff --git a/requirements/datatype_tabular.txt b/requirements/datatype_tabular.txt index 521a5930a3..bbd5720096 100644 --- a/requirements/datatype_tabular.txt +++ b/requirements/datatype_tabular.txt @@ -1,2 +1,3 @@ pytorch-tabnet==3.1 scikit-learn +pytorch-forecasting From 42aa6ce1306c163961d19dc14be7b333d4b398a7 Mon Sep 17 00:00:00 2001 From: Suman Michael Date: Fri, 16 Jul 2021 19:45:24 +0530 Subject: [PATCH 03/27] added TABULAR_FORECASTING_BACKBONES --- flash/tabular/forecasting/__init__.py | 3 ++- flash/tabular/forecasting/backbone.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 flash/tabular/forecasting/backbone.py diff --git a/flash/tabular/forecasting/__init__.py b/flash/tabular/forecasting/__init__.py index d1036c9c35..22806630b9 100644 --- a/flash/tabular/forecasting/__init__.py +++ b/flash/tabular/forecasting/__init__.py @@ -2,4 +2,5 @@ TabularForecastingData, TabularForecastingPreprocess, TabularForecastingDataFrameDataSource -) \ No newline at end of file +) +from flash.tabular.forecasting.backbone import TABULAR_FORECASTING_BACKBONES diff --git a/flash/tabular/forecasting/backbone.py b/flash/tabular/forecasting/backbone.py new file mode 100644 index 0000000000..93a3c31718 --- /dev/null +++ b/flash/tabular/forecasting/backbone.py @@ -0,0 +1,14 @@ +from pytorch_forecasting import TemporalFusionTransformer + +from flash.core.registry import FlashRegistry +from flash.tabular.forecasting import TabularForecastingData + +TABULAR_FORECASTING_BACKBONES = FlashRegistry("backbones") + + +@TABULAR_FORECASTING_BACKBONES(name="temporal_fusion_transformer", namespace="tabular/forecasting") +def temporal_fusion_transformer(tabular_forecasting_data: TabularForecastingData, **kwargs): + return TemporalFusionTransformer.from_dataset( + tabular_forecasting_data.train_dataset.time_series_dataset, + **kwargs + ) From 00b43aac8f279a765738355b7bb65bc72f25838e Mon Sep 17 00:00:00 2001 From: Suman Michael Date: Tue, 10 Aug 2021 11:19:47 +0530 Subject: [PATCH 04/27] [WIP] added model.py in tabular forecasting --- flash/tabular/forecasting/model.py | 91 ++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 flash/tabular/forecasting/model.py diff --git a/flash/tabular/forecasting/model.py b/flash/tabular/forecasting/model.py new file mode 100644 index 0000000000..bd015f0642 --- /dev/null +++ b/flash/tabular/forecasting/model.py @@ -0,0 +1,91 @@ +from typing import Union, Optional, Tuple, Dict, Callable, Type, Any, Mapping, Sequence, List + +from pytorch_forecasting import BaseModel, QuantileLoss, SMAPE +from torch.optim import Optimizer + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.process import Deserializer, Postprocess +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _TORCH_AVAILABLE + +if _TORCH_AVAILABLE: + import torch + import torchmetrics + from torch import nn + from torch.optim.lr_scheduler import _LRScheduler +else: + _LRScheduler = object + +from flash import Task, Serializer, Preprocess +from flash.tabular.forecasting import ( + TabularForecastingData, + TABULAR_FORECASTING_BACKBONES +) + + +class TabularForecaster(Task): + backbones: FlashRegistry = TABULAR_FORECASTING_BACKBONES + + def __init__( + self, + tabular_forecasting_data: TabularForecastingData, + backbone: Union[str, Tuple[nn.Module, int]] = "temporal_fusion_transformer", + backbone_kwargs: Optional[Dict] = None, + loss_fn: Optional[Callable] = None, + optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer, str] = torch.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, + metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + learning_rate: float = 1e-2, + **task_kwargs + ): + + super().__init__( + model=None, + loss_fn=None, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, + metrics=metrics, + learning_rate=learning_rate, + **task_kwargs + ) + + self.save_hyperparameters() + + if not backbone_kwargs: + backbone_kwargs = {} + + if isinstance(backbone, tuple): + self.backbone = backbone + else: + self.backbone = self.backbones.get(backbone)( + tabular_forecasting_data=tabular_forecasting_data, + **backbone_kwargs + ) + self.model = self.backbone + + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return self.model.training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return self.model.validation_step(batch, batch_idx) + + def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: + return self.model.configure_optimizers() + + + + # More hooks to map + + @classmethod + def from_data(cls, tabular_forecasting_data: TabularForecastingData, **kwargs): + return cls( + tabular_forecasting_data=tabular_forecasting_data, + backbone_kwargs={"loss": SMAPE()}, + **kwargs + ) From c3c4282a9481ee1fd4f9ecf14979061db085493b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Aug 2021 09:59:12 +0000 Subject: [PATCH 05/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/tabular/__init__.py | 4 +- flash/tabular/forecasting/__init__.py | 4 +- flash/tabular/forecasting/backbone.py | 5 +- flash/tabular/forecasting/data.py | 93 +++++++++++---------------- flash/tabular/forecasting/model.py | 44 +++++-------- 5 files changed, 61 insertions(+), 89 deletions(-) diff --git a/flash/tabular/__init__.py b/flash/tabular/__init__.py index 7127e23641..a3597f8b0f 100644 --- a/flash/tabular/__init__.py +++ b/flash/tabular/__init__.py @@ -1,6 +1,6 @@ from flash.tabular.classification import TabularClassifier, TabularData # noqa: F401 from flash.tabular.forecasting import ( TabularForecastingData, + TabularForecastingDataFrameDataSource, TabularForecastingPreprocess, - TabularForecastingDataFrameDataSource -) \ No newline at end of file +) diff --git a/flash/tabular/forecasting/__init__.py b/flash/tabular/forecasting/__init__.py index 22806630b9..4bc22bb33e 100644 --- a/flash/tabular/forecasting/__init__.py +++ b/flash/tabular/forecasting/__init__.py @@ -1,6 +1,6 @@ +from flash.tabular.forecasting.backbone import TABULAR_FORECASTING_BACKBONES from flash.tabular.forecasting.data import ( TabularForecastingData, + TabularForecastingDataFrameDataSource, TabularForecastingPreprocess, - TabularForecastingDataFrameDataSource ) -from flash.tabular.forecasting.backbone import TABULAR_FORECASTING_BACKBONES diff --git a/flash/tabular/forecasting/backbone.py b/flash/tabular/forecasting/backbone.py index 93a3c31718..5e9407bcec 100644 --- a/flash/tabular/forecasting/backbone.py +++ b/flash/tabular/forecasting/backbone.py @@ -8,7 +8,4 @@ @TABULAR_FORECASTING_BACKBONES(name="temporal_fusion_transformer", namespace="tabular/forecasting") def temporal_fusion_transformer(tabular_forecasting_data: TabularForecastingData, **kwargs): - return TemporalFusionTransformer.from_dataset( - tabular_forecasting_data.train_dataset.time_series_dataset, - **kwargs - ) + return TemporalFusionTransformer.from_dataset(tabular_forecasting_data.train_dataset.time_series_dataset, **kwargs) diff --git a/flash/tabular/forecasting/data.py b/flash/tabular/forecasting/data.py index 068c2665af..ae0f729a73 100644 --- a/flash/tabular/forecasting/data.py +++ b/flash/tabular/forecasting/data.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence, Mapping +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources from flash.core.data.process import Deserializer, Postprocess, Preprocess from flash.core.data.properties import ProcessState -from flash.core.utilities.imports import _PANDAS_AVAILABLE, _FORECASTING_AVAILABLE, requires_extras +from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _PANDAS_AVAILABLE, requires_extras if _PANDAS_AVAILABLE: from pandas.core.frame import DataFrame @@ -32,24 +32,15 @@ @dataclass(unsafe_hash=True, frozen=True) class TimeSeriesDataSetState(ProcessState): - """ - A :class:`~flash.core.data.properties.ProcessState` containing ``labels``, - a mapping from class index to label. - """ + """A :class:`~flash.core.data.properties.ProcessState` containing ``labels``, a mapping from class index to + label.""" time_series_dataset: Optional[TimeSeriesDataSet] class TabularForecastingDataFrameDataSource(DataSource[DataFrame]): - @requires_extras("tabular") - def __init__( - self, - time_idx: str, - target: Union[str, List[str]], - group_ids: List[str], - **data_source_kwargs: Any - ): + def __init__(self, time_idx: str, target: Union[str, List[str]], group_ids: List[str], **data_source_kwargs: Any): self.time_idx = time_idx self.target = target self.group_ids = group_ids @@ -68,9 +59,10 @@ def load_data(self, data: DataFrame, dataset: Optional[Any] = None): else: train_time_series_dataset = self.get_state(TimeSeriesDataSetState).time_series_dataset eval_time_series_dataset = TimeSeriesDataSet.from_dataset( - train_time_series_dataset, data, - min_prediction_idx=train_time_series_dataset.index.time.max() + 1, - stop_randomization=True + train_time_series_dataset, + data, + min_prediction_idx=train_time_series_dataset.index.time.max() + 1, + stop_randomization=True, ) return eval_time_series_dataset @@ -80,15 +72,14 @@ def load_sample(sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any class TabularForecastingPreprocess(Preprocess): - def __init__( - self, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, - deserializer: Optional[Deserializer] = None, - **data_source_kwargs: Any + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + deserializer: Optional[Deserializer] = None, + **data_source_kwargs: Any, ): self.data_source_kwargs = data_source_kwargs super().__init__( @@ -97,26 +88,21 @@ def __init__( test_transform=test_transform, predict_transform=predict_transform, data_sources={ - DefaultDataSources.DATAFRAME: TabularForecastingDataFrameDataSource( - **data_source_kwargs - ), + DefaultDataSources.DATAFRAME: TabularForecastingDataFrameDataSource(**data_source_kwargs), }, - deserializer=deserializer + deserializer=deserializer, ) def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: - return { - **self.transforms, - **self.data_source_kwargs - } + return {**self.transforms, **self.data_source_kwargs} @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> 'Preprocess': + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> "Preprocess": return cls(**state_dict) class TabularForecastingData(DataModule): - """Data module for tabular tasks""" + """Data module for tabular tasks.""" preprocess_cls = TabularForecastingPreprocess @@ -126,24 +112,24 @@ def data_source(self) -> DataSource: @classmethod def from_data_frame( - cls, - time_idx: str, - target: Union[str, List[str]], - group_ids: List[str], - train_data_frame: Optional[DataFrame] = None, - val_data_frame: Optional[DataFrame] = None, - test_data_frame: Optional[DataFrame] = None, - predict_data_frame: Optional[DataFrame] = None, - train_transform: Optional[Dict[str, Callable]] = None, - val_transform: Optional[Dict[str, Callable]] = None, - test_transform: Optional[Dict[str, Callable]] = None, - predict_transform: Optional[Dict[str, Callable]] = None, - data_fetcher: Optional[BaseDataFetcher] = None, - preprocess: Optional[Preprocess] = None, - val_split: Optional[float] = None, - batch_size: int = 4, - num_workers: Optional[int] = None, - **preprocess_kwargs: Any, + cls, + time_idx: str, + target: Union[str, List[str]], + group_ids: List[str], + train_data_frame: Optional[DataFrame] = None, + val_data_frame: Optional[DataFrame] = None, + test_data_frame: Optional[DataFrame] = None, + predict_data_frame: Optional[DataFrame] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, ): """Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames. @@ -185,7 +171,6 @@ def from_data_frame( "target", train_data_frame=train_data, ) - """ return cls.from_data_source( diff --git a/flash/tabular/forecasting/model.py b/flash/tabular/forecasting/model.py index bd015f0642..2ab8540484 100644 --- a/flash/tabular/forecasting/model.py +++ b/flash/tabular/forecasting/model.py @@ -1,4 +1,4 @@ -from typing import Union, Optional, Tuple, Dict, Callable, Type, Any, Mapping, Sequence, List +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union from pytorch_forecasting import BaseModel, QuantileLoss, SMAPE from torch.optim import Optimizer @@ -16,29 +16,26 @@ else: _LRScheduler = object -from flash import Task, Serializer, Preprocess -from flash.tabular.forecasting import ( - TabularForecastingData, - TABULAR_FORECASTING_BACKBONES -) +from flash import Preprocess, Serializer, Task +from flash.tabular.forecasting import TABULAR_FORECASTING_BACKBONES, TabularForecastingData class TabularForecaster(Task): backbones: FlashRegistry = TABULAR_FORECASTING_BACKBONES def __init__( - self, - tabular_forecasting_data: TabularForecastingData, - backbone: Union[str, Tuple[nn.Module, int]] = "temporal_fusion_transformer", - backbone_kwargs: Optional[Dict] = None, - loss_fn: Optional[Callable] = None, - optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer, str] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, - metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, - learning_rate: float = 1e-2, - **task_kwargs + self, + tabular_forecasting_data: TabularForecastingData, + backbone: Union[str, Tuple[nn.Module, int]] = "temporal_fusion_transformer", + backbone_kwargs: Optional[Dict] = None, + loss_fn: Optional[Callable] = None, + optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer, str] = torch.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, + metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + learning_rate: float = 1e-2, + **task_kwargs ): super().__init__( @@ -62,8 +59,7 @@ def __init__( self.backbone = backbone else: self.backbone = self.backbones.get(backbone)( - tabular_forecasting_data=tabular_forecasting_data, - **backbone_kwargs + tabular_forecasting_data=tabular_forecasting_data, **backbone_kwargs ) self.model = self.backbone @@ -78,14 +74,8 @@ def validation_step(self, batch: Any, batch_idx: int) -> Any: def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: return self.model.configure_optimizers() - - # More hooks to map @classmethod def from_data(cls, tabular_forecasting_data: TabularForecastingData, **kwargs): - return cls( - tabular_forecasting_data=tabular_forecasting_data, - backbone_kwargs={"loss": SMAPE()}, - **kwargs - ) + return cls(tabular_forecasting_data=tabular_forecasting_data, backbone_kwargs={"loss": SMAPE()}, **kwargs) From 75cc620fcdd2c73454db37cb91c38899dbdfe3d7 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 10 Aug 2021 11:07:01 +0100 Subject: [PATCH 06/27] Updates --- flash/tabular/forecasting/backbone.py | 20 ++++++----- flash/tabular/forecasting/model.py | 52 +++++++++++---------------- 2 files changed, 32 insertions(+), 40 deletions(-) diff --git a/flash/tabular/forecasting/backbone.py b/flash/tabular/forecasting/backbone.py index 93a3c31718..7b8c64739a 100644 --- a/flash/tabular/forecasting/backbone.py +++ b/flash/tabular/forecasting/backbone.py @@ -1,14 +1,18 @@ -from pytorch_forecasting import TemporalFusionTransformer - from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _FORECASTING_AVAILABLE from flash.tabular.forecasting import TabularForecastingData +if _FORECASTING_AVAILABLE: + from pytorch_forecasting import TemporalFusionTransformer + + TABULAR_FORECASTING_BACKBONES = FlashRegistry("backbones") -@TABULAR_FORECASTING_BACKBONES(name="temporal_fusion_transformer", namespace="tabular/forecasting") -def temporal_fusion_transformer(tabular_forecasting_data: TabularForecastingData, **kwargs): - return TemporalFusionTransformer.from_dataset( - tabular_forecasting_data.train_dataset.time_series_dataset, - **kwargs - ) +if _FORECASTING_AVAILABLE: + + @TABULAR_FORECASTING_BACKBONES(name="temporal_fusion_transformer", namespace="tabular/forecasting") + def temporal_fusion_transformer(tabular_forecasting_data: TabularForecastingData, **kwargs): + return TemporalFusionTransformer.from_dataset( + tabular_forecasting_data.train_dataset.time_series_dataset, **kwargs + ) diff --git a/flash/tabular/forecasting/model.py b/flash/tabular/forecasting/model.py index bd015f0642..7c5f7ed5aa 100644 --- a/flash/tabular/forecasting/model.py +++ b/flash/tabular/forecasting/model.py @@ -1,12 +1,12 @@ -from typing import Union, Optional, Tuple, Dict, Callable, Type, Any, Mapping, Sequence, List +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union -from pytorch_forecasting import BaseModel, QuantileLoss, SMAPE from torch.optim import Optimizer +from flash import Task from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.process import Deserializer, Postprocess from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _TORCH_AVAILABLE +from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _TORCH_AVAILABLE +from flash.tabular.forecasting import TABULAR_FORECASTING_BACKBONES, TabularForecastingData if _TORCH_AVAILABLE: import torch @@ -16,29 +16,26 @@ else: _LRScheduler = object -from flash import Task, Serializer, Preprocess -from flash.tabular.forecasting import ( - TabularForecastingData, - TABULAR_FORECASTING_BACKBONES -) +if _FORECASTING_AVAILABLE: + from pytorch_forecasting import SMAPE class TabularForecaster(Task): backbones: FlashRegistry = TABULAR_FORECASTING_BACKBONES def __init__( - self, - tabular_forecasting_data: TabularForecastingData, - backbone: Union[str, Tuple[nn.Module, int]] = "temporal_fusion_transformer", - backbone_kwargs: Optional[Dict] = None, - loss_fn: Optional[Callable] = None, - optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer, str] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, - metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, - learning_rate: float = 1e-2, - **task_kwargs + self, + tabular_forecasting_data: TabularForecastingData, + backbone: Union[str, Tuple[nn.Module, int]] = "temporal_fusion_transformer", + backbone_kwargs: Optional[Dict] = None, + loss_fn: Optional[Callable] = None, + optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer, str] = torch.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, + metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + learning_rate: float = 1e-2, + **task_kwargs ): super().__init__( @@ -62,8 +59,7 @@ def __init__( self.backbone = backbone else: self.backbone = self.backbones.get(backbone)( - tabular_forecasting_data=tabular_forecasting_data, - **backbone_kwargs + tabular_forecasting_data=tabular_forecasting_data, **backbone_kwargs ) self.model = self.backbone @@ -78,14 +74,6 @@ def validation_step(self, batch: Any, batch_idx: int) -> Any: def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: return self.model.configure_optimizers() - - - # More hooks to map - @classmethod def from_data(cls, tabular_forecasting_data: TabularForecastingData, **kwargs): - return cls( - tabular_forecasting_data=tabular_forecasting_data, - backbone_kwargs={"loss": SMAPE()}, - **kwargs - ) + return cls(tabular_forecasting_data=tabular_forecasting_data, backbone_kwargs={"loss": SMAPE()}, **kwargs) From 5c554d42548185762a3ab28b14cc84c526c35144 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 10 Aug 2021 11:18:12 +0100 Subject: [PATCH 07/27] Updates --- flash/tabular/__init__.py | 2 +- flash/tabular/forecasting/backbone.py | 2 +- flash/tabular/forecasting/model.py | 3 ++- tests/tabular/classification/test_model.py | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/flash/tabular/__init__.py b/flash/tabular/__init__.py index a3597f8b0f..35d604f636 100644 --- a/flash/tabular/__init__.py +++ b/flash/tabular/__init__.py @@ -1,5 +1,5 @@ from flash.tabular.classification import TabularClassifier, TabularData # noqa: F401 -from flash.tabular.forecasting import ( +from flash.tabular.forecasting.data import ( # noqa: F401 TabularForecastingData, TabularForecastingDataFrameDataSource, TabularForecastingPreprocess, diff --git a/flash/tabular/forecasting/backbone.py b/flash/tabular/forecasting/backbone.py index 7b8c64739a..32858f392d 100644 --- a/flash/tabular/forecasting/backbone.py +++ b/flash/tabular/forecasting/backbone.py @@ -1,6 +1,6 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _FORECASTING_AVAILABLE -from flash.tabular.forecasting import TabularForecastingData +from flash.tabular.forecasting.data import TabularForecastingData if _FORECASTING_AVAILABLE: from pytorch_forecasting import TemporalFusionTransformer diff --git a/flash/tabular/forecasting/model.py b/flash/tabular/forecasting/model.py index 7c5f7ed5aa..3a5e5132e1 100644 --- a/flash/tabular/forecasting/model.py +++ b/flash/tabular/forecasting/model.py @@ -6,7 +6,8 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _TORCH_AVAILABLE -from flash.tabular.forecasting import TABULAR_FORECASTING_BACKBONES, TabularForecastingData +from flash.tabular.forecasting.backbone import TABULAR_FORECASTING_BACKBONES +from flash.tabular.forecasting.data import TabularForecastingData if _TORCH_AVAILABLE: import torch diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index b3a0c803f9..c11ab8b065 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -21,8 +21,8 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _TABULAR_AVAILABLE -from flash.tabular import TabularClassifier from flash.tabular.classification.data import TabularData +from flash.tabular.classification.model import TabularClassifier from tests.helpers.utils import _SERVE_TESTING, _TABULAR_TESTING # ======== Mock functions ======== From 3f0225221a6b753a5d6465fd6b3c403513dbfe09 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 10 Aug 2021 11:23:37 +0100 Subject: [PATCH 08/27] Try fix --- flash/tabular/forecasting/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash/tabular/forecasting/data.py b/flash/tabular/forecasting/data.py index ae0f729a73..4f7d1a8aa3 100644 --- a/flash/tabular/forecasting/data.py +++ b/flash/tabular/forecasting/data.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Union from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources -from flash.core.data.process import Deserializer, Postprocess, Preprocess +from flash.core.data.process import Deserializer, Preprocess from flash.core.data.properties import ProcessState from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _PANDAS_AVAILABLE, requires_extras @@ -35,7 +35,7 @@ class TimeSeriesDataSetState(ProcessState): """A :class:`~flash.core.data.properties.ProcessState` containing ``labels``, a mapping from class index to label.""" - time_series_dataset: Optional[TimeSeriesDataSet] + time_series_dataset: Optional["TimeSeriesDataSet"] class TabularForecastingDataFrameDataSource(DataSource[DataFrame]): From f6ac5282f61065f193c58905d9c51b082ac6e7ad Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 10 Aug 2021 11:31:10 +0100 Subject: [PATCH 09/27] Updates --- flash/tabular/classification/cli.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/flash/tabular/classification/cli.py b/flash/tabular/classification/cli.py index 63eff2458f..f3d405bc39 100644 --- a/flash/tabular/classification/cli.py +++ b/flash/tabular/classification/cli.py @@ -15,7 +15,8 @@ from flash.core.data.utils import download_data from flash.core.utilities.flash_cli import FlashCLI -from flash.tabular import TabularClassificationData, TabularClassifier +from flash.tabular.classification.data import TabularData +from flash.tabular.classification.model import TabularClassifier __all__ = ["tabular_classification"] @@ -24,10 +25,10 @@ def from_titanic( batch_size: int = 4, num_workers: Optional[int] = None, **preprocess_kwargs, -) -> TabularClassificationData: +) -> TabularData: """Downloads and loads the Titanic data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data") - return TabularClassificationData.from_csv( + return TabularData.from_csv( ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], "Fare", target_fields="Survived", @@ -43,7 +44,7 @@ def tabular_classification(): """Classify tabular data.""" cli = FlashCLI( TabularClassifier, - TabularClassificationData, + TabularData, default_datamodule_builder=from_titanic, default_arguments={ "trainer.max_epochs": 3, From 3db1966db8e082d733f5c23b04a08811d11fc01b Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 10 Aug 2021 11:32:47 +0100 Subject: [PATCH 10/27] Rename to TabularClassificationData --- README.md | 4 ++-- docs/source/api/tabular.rst | 2 +- .../reference/tabular_classification.rst | 4 ++-- flash/tabular/__init__.py | 2 +- flash/tabular/classification/__init__.py | 2 +- flash/tabular/classification/cli.py | 8 ++++---- flash/tabular/classification/data.py | 10 +++++----- flash/tabular/forecasting/data.py | 4 ++-- flash_examples/tabular_classification.py | 4 ++-- flash_notebooks/tabular_classification.ipynb | 6 +++--- tests/tabular/classification/test_data.py | 18 +++++++++--------- .../test_data_model_integration.py | 4 ++-- tests/tabular/classification/test_model.py | 4 ++-- 13 files changed, 36 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 32bc103eb1..c822a7b716 100644 --- a/README.md +++ b/README.md @@ -272,13 +272,13 @@ To illustrate, say we want to build a model to predict if a passenger survived o from torchmetrics.classification import Accuracy, Precision, Recall import flash from flash.core.data.utils import download_data -from flash.tabular import TabularClassifier, TabularData +from flash.tabular import TabularClassifier, TabularClassificationData # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/") # 2. Load the data -datamodule = TabularData.from_csv( +datamodule = TabularClassificationData.from_csv( ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], "Fare", target_fields="Survived", diff --git a/docs/source/api/tabular.rst b/docs/source/api/tabular.rst index 0752a5ca52..0edf57e59f 100644 --- a/docs/source/api/tabular.rst +++ b/docs/source/api/tabular.rst @@ -38,7 +38,7 @@ __________________ :nosignatures: :template: classtemplate.rst - ~data.TabularData + ~data.TabularClassificationData ~data.TabularDataFrameDataSource ~data.TabularCSVDataSource ~data.TabularDeserializer diff --git a/docs/source/reference/tabular_classification.rst b/docs/source/reference/tabular_classification.rst index 6bb68ba585..c1356adc56 100644 --- a/docs/source/reference/tabular_classification.rst +++ b/docs/source/reference/tabular_classification.rst @@ -29,8 +29,8 @@ The data is provided in CSV files that look like this: 6,0,3,"Moran, Mr. James",male,,0,0,330877,8.4583,,Q ... -Once we've downloaded the data using :func:`~flash.core.data.download_data`, we can create the :class:`~flash.tabular.classification.data.TabularData` from our CSV files using the :func:`~flash.tabular.classification.data.TabularData.from_csv` method. -From :meth:`the API reference `, we need to provide: +Once we've downloaded the data using :func:`~flash.core.data.download_data`, we can create the :class:`~flash.tabular.classification.data.TabularClassificationData` from our CSV files using the :func:`~flash.tabular.classification.data.TabularClassificationData.from_csv` method. +From :meth:`the API reference `, we need to provide: * **cat_cols**- A list of the names of columns that contain categorical data (strings or integers). * **num_cols**- A list of the names of columns that contain numerical continuous data (floats). diff --git a/flash/tabular/__init__.py b/flash/tabular/__init__.py index 35d604f636..9b452478db 100644 --- a/flash/tabular/__init__.py +++ b/flash/tabular/__init__.py @@ -1,4 +1,4 @@ -from flash.tabular.classification import TabularClassifier, TabularData # noqa: F401 +from flash.tabular.classification import TabularClassificationData, TabularClassifier # noqa: F401 from flash.tabular.forecasting.data import ( # noqa: F401 TabularForecastingData, TabularForecastingDataFrameDataSource, diff --git a/flash/tabular/classification/__init__.py b/flash/tabular/classification/__init__.py index 45724db27b..6134277abf 100644 --- a/flash/tabular/classification/__init__.py +++ b/flash/tabular/classification/__init__.py @@ -1,2 +1,2 @@ -from flash.tabular.classification.data import TabularData # noqa: F401 +from flash.tabular.classification.data import TabularClassificationData # noqa: F401 from flash.tabular.classification.model import TabularClassifier # noqa: F401 diff --git a/flash/tabular/classification/cli.py b/flash/tabular/classification/cli.py index f3d405bc39..4abe38e6cc 100644 --- a/flash/tabular/classification/cli.py +++ b/flash/tabular/classification/cli.py @@ -15,7 +15,7 @@ from flash.core.data.utils import download_data from flash.core.utilities.flash_cli import FlashCLI -from flash.tabular.classification.data import TabularData +from flash.tabular.classification.data import TabularClassificationData from flash.tabular.classification.model import TabularClassifier __all__ = ["tabular_classification"] @@ -25,10 +25,10 @@ def from_titanic( batch_size: int = 4, num_workers: Optional[int] = None, **preprocess_kwargs, -) -> TabularData: +) -> TabularClassificationData: """Downloads and loads the Titanic data set.""" download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data") - return TabularData.from_csv( + return TabularClassificationData.from_csv( ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], "Fare", target_fields="Survived", @@ -44,7 +44,7 @@ def tabular_classification(): """Classify tabular data.""" cli = FlashCLI( TabularClassifier, - TabularData, + TabularClassificationData, default_datamodule_builder=from_titanic, default_arguments={ "trainer.max_epochs": 3, diff --git a/flash/tabular/classification/data.py b/flash/tabular/classification/data.py index db0a50ed46..c52b4635a1 100644 --- a/flash/tabular/classification/data.py +++ b/flash/tabular/classification/data.py @@ -236,7 +236,7 @@ def uncollate(self, batch: Any) -> Any: return batch -class TabularData(DataModule): +class TabularClassificationData(DataModule): """Data module for tabular tasks.""" preprocess_cls = TabularPreprocess @@ -343,7 +343,7 @@ def from_data_frame( is_regression: bool = False, **preprocess_kwargs: Any, ): - """Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames. + """Creates a :class:`~flash.tabular.data.TabularClassificationData` object from the given data frames. Args: categorical_fields: The field or fields (columns) in the CSV file containing categorical inputs. @@ -379,7 +379,7 @@ def from_data_frame( Examples:: - data_module = TabularData.from_data_frame( + data_module = TabularClassificationData.from_data_frame( "categorical_input", "numerical_input", "target", @@ -453,7 +453,7 @@ def from_csv( is_regression: bool = False, **preprocess_kwargs: Any, ) -> "DataModule": - """Creates a :class:`~flash.tabular.data.TabularData` object from the given CSV files. + """Creates a :class:`~flash.tabular.data.TabularClassificationData` object from the given CSV files. Args: categorical_fields: The field or fields (columns) in the CSV file containing categorical inputs. @@ -489,7 +489,7 @@ def from_csv( Examples:: - data_module = TabularData.from_csv( + data_module = TabularClassificationData.from_csv( "categorical_input", "numerical_input", "target", diff --git a/flash/tabular/forecasting/data.py b/flash/tabular/forecasting/data.py index 4f7d1a8aa3..afd7d68e5f 100644 --- a/flash/tabular/forecasting/data.py +++ b/flash/tabular/forecasting/data.py @@ -131,7 +131,7 @@ def from_data_frame( num_workers: Optional[int] = None, **preprocess_kwargs: Any, ): - """Creates a :class:`~flash.tabular.data.TabularData` object from the given data frames. + """Creates a :class:`~flash.tabular.data.TabularClassificationData` object from the given data frames. Args: group_ids: @@ -165,7 +165,7 @@ def from_data_frame( Examples:: - data_module = TabularData.from_data_frame( + data_module = TabularClassificationData.from_data_frame( "categorical_input", "numerical_input", "target", diff --git a/flash_examples/tabular_classification.py b/flash_examples/tabular_classification.py index fa3a2cc23e..9e6b0ab049 100644 --- a/flash_examples/tabular_classification.py +++ b/flash_examples/tabular_classification.py @@ -13,12 +13,12 @@ # limitations under the License. import flash from flash.core.data.utils import download_data -from flash.tabular import TabularClassifier, TabularData +from flash.tabular import TabularClassificationData, TabularClassifier # 1. Create the DataModule download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data") -datamodule = TabularData.from_csv( +datamodule = TabularClassificationData.from_csv( ["Sex", "Age", "SibSp", "Parch", "Ticket", "Cabin", "Embarked"], "Fare", target_fields="Survived", diff --git a/flash_notebooks/tabular_classification.ipynb b/flash_notebooks/tabular_classification.ipynb index 3369d03b45..3cab1959f8 100644 --- a/flash_notebooks/tabular_classification.ipynb +++ b/flash_notebooks/tabular_classification.ipynb @@ -54,7 +54,7 @@ "\n", "import flash\n", "from flash.core.data.utils import download_data\n", - "from flash.tabular import TabularClassifier, TabularData" + "from flash.tabular import TabularClassifier, TabularClassificationData" ] }, { @@ -84,7 +84,7 @@ "### 2. Load the data\n", "Flash Tasks have built-in DataModules that you can use to organize your data. Pass in a train, validation and test folders and Flash will take care of the rest.\n", "\n", - "Creates a TabularData relies on [Pandas DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html). " + "Creates a TabularClassificationData relies on [Pandas DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html). " ] }, { @@ -94,7 +94,7 @@ "metadata": {}, "outputs": [], "source": [ - "datamodule = TabularData.from_csv(\n", + "datamodule = TabularClassificationData.from_csv(\n", " [\"Sex\", \"Age\", \"SibSp\", \"Parch\", \"Ticket\", \"Cabin\", \"Embarked\"],\n", " [\"Fare\"],\n", " target_fields=\"Survived\",\n", diff --git a/tests/tabular/classification/test_data.py b/tests/tabular/classification/test_data.py index c1d210b4e8..9948fc6eab 100644 --- a/tests/tabular/classification/test_data.py +++ b/tests/tabular/classification/test_data.py @@ -23,7 +23,7 @@ if _PANDAS_AVAILABLE: import pandas as pd - from flash.tabular import TabularData + from flash.tabular import TabularClassificationData from flash.tabular.classification.utils import _categorize, _normalize TEST_DF_1 = pd.DataFrame( @@ -73,19 +73,19 @@ def test_embedding_sizes(): self.codes = {"category": [None, "a", "b", "c"]} self.cat_cols = ["category"] # use __get__ to test property with mocked self - es = TabularData.emb_sizes.__get__(self) # pylint: disable=E1101 + es = TabularClassificationData.emb_sizes.__get__(self) # pylint: disable=E1101 assert es == [(4, 16)] self.codes = {} self.cat_cols = [] # use __get__ to test property with mocked self - es = TabularData.emb_sizes.__get__(self) # pylint: disable=E1101 + es = TabularClassificationData.emb_sizes.__get__(self) # pylint: disable=E1101 assert es == [] self.codes = {"large": ["a"] * 100_000, "larger": ["b"] * 1_000_000} self.cat_cols = ["large", "larger"] # use __get__ to test property with mocked self - es = TabularData.emb_sizes.__get__(self) # pylint: disable=E1101 + es = TabularClassificationData.emb_sizes.__get__(self) # pylint: disable=E1101 assert es == [(100_000, 17), (1_000_000, 31)] @@ -94,7 +94,7 @@ def test_tabular_data(tmpdir): train_data_frame = TEST_DF_1.copy() val_data_frame = TEST_DF_2.copy() test_data_frame = TEST_DF_2.copy() - dm = TabularData.from_data_frame( + dm = TabularClassificationData.from_data_frame( categorical_fields=["category"], numerical_fields=["scalar_a", "scalar_b"], target_fields="label", @@ -122,7 +122,7 @@ def test_categorical_target(tmpdir): # change int label to string df["label"] = df["label"].astype(str) - dm = TabularData.from_data_frame( + dm = TabularClassificationData.from_data_frame( categorical_fields=["category"], numerical_fields=["scalar_a", "scalar_b"], target_fields="label", @@ -146,7 +146,7 @@ def test_from_data_frame(tmpdir): train_data_frame = TEST_DF_1.copy() val_data_frame = TEST_DF_2.copy() test_data_frame = TEST_DF_2.copy() - dm = TabularData.from_data_frame( + dm = TabularClassificationData.from_data_frame( categorical_fields=["category"], numerical_fields=["scalar_a", "scalar_b"], target_fields="label", @@ -173,7 +173,7 @@ def test_from_csv(tmpdir): TEST_DF_2.to_csv(val_csv) TEST_DF_2.to_csv(test_csv) - dm = TabularData.from_csv( + dm = TabularClassificationData.from_csv( categorical_fields=["category"], numerical_fields=["scalar_a", "scalar_b"], target_fields="label", @@ -196,7 +196,7 @@ def test_from_csv(tmpdir): def test_empty_inputs(): train_data_frame = TEST_DF_1.copy() with pytest.raises(RuntimeError): - TabularData.from_data_frame( + TabularClassificationData.from_data_frame( numerical_fields=None, categorical_fields=None, target_fields="label", diff --git a/tests/tabular/classification/test_data_model_integration.py b/tests/tabular/classification/test_data_model_integration.py index 75f44a70ea..3d4875f1dd 100644 --- a/tests/tabular/classification/test_data_model_integration.py +++ b/tests/tabular/classification/test_data_model_integration.py @@ -15,7 +15,7 @@ import pytorch_lightning as pl from flash.core.utilities.imports import _TABULAR_AVAILABLE -from flash.tabular import TabularClassifier, TabularData +from flash.tabular import TabularClassificationData, TabularClassifier from tests.helpers.utils import _TABULAR_TESTING if _TABULAR_AVAILABLE: @@ -37,7 +37,7 @@ def test_classification(tmpdir): train_data_frame = TEST_DF_1.copy() val_data_frame = TEST_DF_1.copy() test_data_frame = TEST_DF_1.copy() - data = TabularData.from_data_frame( + data = TabularClassificationData.from_data_frame( categorical_fields=["category"], numerical_fields=["scalar_a", "scalar_b"], target_fields="label", diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index c11ab8b065..2efe7c316e 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -21,7 +21,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _TABULAR_AVAILABLE -from flash.tabular.classification.data import TabularData +from flash.tabular.classification.data import TabularClassificationData from flash.tabular.classification.model import TabularClassifier from tests.helpers.utils import _SERVE_TESTING, _TABULAR_TESTING @@ -99,7 +99,7 @@ def test_jit(tmpdir): @mock.patch("flash._IS_TESTING", True) def test_serve(): train_data = {"num_col": [1.4, 2.5], "cat_col": ["positive", "negative"], "target": [1, 2]} - datamodule = TabularData.from_data_frame( + datamodule = TabularClassificationData.from_data_frame( "cat_col", "num_col", "target", From f2a8cc16ea720deb96b8f814dfe48d9793220fc9 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 10 Aug 2021 11:36:39 +0100 Subject: [PATCH 11/27] Updates --- flash/tabular/forecasting/__init__.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/flash/tabular/forecasting/__init__.py b/flash/tabular/forecasting/__init__.py index 4bc22bb33e..fb978e7a62 100644 --- a/flash/tabular/forecasting/__init__.py +++ b/flash/tabular/forecasting/__init__.py @@ -1,6 +1,2 @@ -from flash.tabular.forecasting.backbone import TABULAR_FORECASTING_BACKBONES -from flash.tabular.forecasting.data import ( - TabularForecastingData, - TabularForecastingDataFrameDataSource, - TabularForecastingPreprocess, -) +from flash.tabular.forecasting.data import TabularForecastingData # noqa: F401 +from flash.tabular.forecasting.model import TabularForecaster # noqa: F401 From e72d44126d1539a4b17aba56e3209a7bb812ccff Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 10 Aug 2021 11:40:52 +0100 Subject: [PATCH 12/27] Fix embedding sizes --- tests/tabular/classification/test_data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/tabular/classification/test_data.py b/tests/tabular/classification/test_data.py index 9948fc6eab..b1e9ef3f25 100644 --- a/tests/tabular/classification/test_data.py +++ b/tests/tabular/classification/test_data.py @@ -73,19 +73,19 @@ def test_embedding_sizes(): self.codes = {"category": [None, "a", "b", "c"]} self.cat_cols = ["category"] # use __get__ to test property with mocked self - es = TabularClassificationData.emb_sizes.__get__(self) # pylint: disable=E1101 + es = TabularClassificationData.embedding_sizes.__get__(self) # pylint: disable=E1101 assert es == [(4, 16)] self.codes = {} self.cat_cols = [] # use __get__ to test property with mocked self - es = TabularClassificationData.emb_sizes.__get__(self) # pylint: disable=E1101 + es = TabularClassificationData.embedding_sizes.__get__(self) # pylint: disable=E1101 assert es == [] self.codes = {"large": ["a"] * 100_000, "larger": ["b"] * 1_000_000} self.cat_cols = ["large", "larger"] # use __get__ to test property with mocked self - es = TabularClassificationData.emb_sizes.__get__(self) # pylint: disable=E1101 + es = TabularClassificationData.embedding_sizes.__get__(self) # pylint: disable=E1101 assert es == [(100_000, 17), (1_000_000, 31)] From 739d9a8ef89b4bece0d9369de2f5ea20748d8334 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 10 Aug 2021 18:11:19 +0100 Subject: [PATCH 13/27] Fixes and add example --- flash/tabular/forecasting/backbone.py | 18 ----- flash/tabular/forecasting/backbones.py | 27 +++++++ flash/tabular/forecasting/data.py | 70 ++++++++++++++---- flash/tabular/forecasting/model.py | 58 ++++++--------- flash_examples/tabular_forecasting.py | 99 ++++++++++++++++++++++++++ 5 files changed, 206 insertions(+), 66 deletions(-) delete mode 100644 flash/tabular/forecasting/backbone.py create mode 100644 flash/tabular/forecasting/backbones.py create mode 100644 flash_examples/tabular_forecasting.py diff --git a/flash/tabular/forecasting/backbone.py b/flash/tabular/forecasting/backbone.py deleted file mode 100644 index 32858f392d..0000000000 --- a/flash/tabular/forecasting/backbone.py +++ /dev/null @@ -1,18 +0,0 @@ -from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _FORECASTING_AVAILABLE -from flash.tabular.forecasting.data import TabularForecastingData - -if _FORECASTING_AVAILABLE: - from pytorch_forecasting import TemporalFusionTransformer - - -TABULAR_FORECASTING_BACKBONES = FlashRegistry("backbones") - - -if _FORECASTING_AVAILABLE: - - @TABULAR_FORECASTING_BACKBONES(name="temporal_fusion_transformer", namespace="tabular/forecasting") - def temporal_fusion_transformer(tabular_forecasting_data: TabularForecastingData, **kwargs): - return TemporalFusionTransformer.from_dataset( - tabular_forecasting_data.train_dataset.time_series_dataset, **kwargs - ) diff --git a/flash/tabular/forecasting/backbones.py b/flash/tabular/forecasting/backbones.py new file mode 100644 index 0000000000..d50cea3568 --- /dev/null +++ b/flash/tabular/forecasting/backbones.py @@ -0,0 +1,27 @@ +import functools + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _FORECASTING_AVAILABLE +from flash.tabular.forecasting.data import TabularForecastingData + +if _FORECASTING_AVAILABLE: + from pytorch_forecasting import DecoderMLP, DeepAR, NBeats, RecurrentNetwork, TemporalFusionTransformer + + +TABULAR_FORECASTING_BACKBONES = FlashRegistry("backbones") + + +if _FORECASTING_AVAILABLE: + + def load_torch_forecasting(model, tabular_forecasting_data: TabularForecastingData, **kwargs): + return model.from_dataset(tabular_forecasting_data.train_dataset.time_series_dataset, **kwargs) + + for model, name in zip( + [TemporalFusionTransformer, NBeats, RecurrentNetwork, DeepAR, DecoderMLP], + ["temporal_fusion_transformer", "n_beats", "recurrent_network", "deep_ar", "decoder_mlp"], + ): + TABULAR_FORECASTING_BACKBONES( + functools.partial(load_torch_forecasting, model), + name=name, + namespace="tabular/forecasting", + ) diff --git a/flash/tabular/forecasting/data.py b/flash/tabular/forecasting/data.py index afd7d68e5f..32e39f6d36 100644 --- a/flash/tabular/forecasting/data.py +++ b/flash/tabular/forecasting/data.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools from dataclasses import dataclass from typing import Any, Callable, Dict, List, Mapping, Optional, Union +from torch.utils.data import DataLoader, Dataset + from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources @@ -41,13 +44,11 @@ class TimeSeriesDataSetState(ProcessState): class TabularForecastingDataFrameDataSource(DataSource[DataFrame]): @requires_extras("tabular") def __init__(self, time_idx: str, target: Union[str, List[str]], group_ids: List[str], **data_source_kwargs: Any): + super().__init__() self.time_idx = time_idx self.target = target self.group_ids = group_ids self.data_source_kwargs = data_source_kwargs - super().__init__() - - self.dataset = None def load_data(self, data: DataFrame, dataset: Optional[Any] = None): if self.training: @@ -55,19 +56,17 @@ def load_data(self, data: DataFrame, dataset: Optional[Any] = None): data, time_idx=self.time_idx, group_ids=self.group_ids, target=self.target, **self.data_source_kwargs ) self.set_state(TimeSeriesDataSetState(dataset.time_series_dataset)) - return dataset.time_series_dataset else: train_time_series_dataset = self.get_state(TimeSeriesDataSetState).time_series_dataset - eval_time_series_dataset = TimeSeriesDataSet.from_dataset( + dataset.time_series_dataset = TimeSeriesDataSet.from_dataset( train_time_series_dataset, data, - min_prediction_idx=train_time_series_dataset.index.time.max() + 1, + predict=True, stop_randomization=True, ) - return eval_time_series_dataset + return dataset.time_series_dataset - @staticmethod - def load_sample(sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: + def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: return {DefaultDataKeys.INPUT: sample[0], DefaultDataKeys.TARGET: sample[1]} @@ -106,9 +105,56 @@ class TabularForecastingData(DataModule): preprocess_cls = TabularForecastingPreprocess - @property - def data_source(self) -> DataSource: - return self._data_source + @staticmethod + def _collate_fn(collate_fn, samples): + samples = [(sample[DefaultDataKeys.INPUT], sample[DefaultDataKeys.TARGET]) for sample in samples] + batch = collate_fn(samples) + return {DefaultDataKeys.INPUT: batch[0], DefaultDataKeys.TARGET: batch[1]} + + def _train_dataloader(self) -> DataLoader: + train_ds: Dataset = self._train_ds() if isinstance(self._train_ds, Callable) else self._train_ds + time_series_dataset: TimeSeriesDataSet = train_ds.time_series_dataset + result = time_series_dataset.to_dataloader( + train=True, + batch_size=self.batch_size, + ) + collate_fn = functools.partial(self._collate_fn, result.collate_fn) + batch_sampler = result.batch_sampler + return DataLoader( + train_ds, + collate_fn=collate_fn, + batch_sampler=batch_sampler, + num_workers=self.num_workers, + pin_memory=True, + ) + + def _eval_dataloader(self, dataset: Dataset) -> DataLoader: + time_series_dataset: TimeSeriesDataSet = dataset.time_series_dataset + result = time_series_dataset.to_dataloader( + train=False, + batch_size=self.batch_size, + ) + collate_fn = functools.partial(self._collate_fn, result.collate_fn) + batch_sampler = result.batch_sampler + return DataLoader( + dataset, + collate_fn=collate_fn, + batch_sampler=batch_sampler, + num_workers=self.num_workers, + pin_memory=True, + ) + + def _val_dataloader(self) -> DataLoader: + val_ds: Dataset = self._val_ds() if isinstance(self._val_ds, Callable) else self._val_ds + return self._eval_dataloader(val_ds) + + def _test_dataloader(self) -> DataLoader: + test_ds: Dataset = self._test_ds() if isinstance(self._test_ds, Callable) else self._test_ds + return self._eval_dataloader(test_ds) + + def _predict_dataloader(self) -> DataLoader: + predict_ds: Dataset = self._predict_ds() if isinstance(self._predict_ds, Callable) else self._predict_ds + return self._eval_dataloader(predict_ds) @classmethod def from_data_frame( diff --git a/flash/tabular/forecasting/model.py b/flash/tabular/forecasting/model.py index 3a5e5132e1..2c5c14fe21 100644 --- a/flash/tabular/forecasting/model.py +++ b/flash/tabular/forecasting/model.py @@ -1,25 +1,15 @@ -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, Union -from torch.optim import Optimizer +import torch +import torchmetrics +from torch.optim.lr_scheduler import _LRScheduler from flash import Task from flash.core.data.data_source import DefaultDataKeys from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _TORCH_AVAILABLE -from flash.tabular.forecasting.backbone import TABULAR_FORECASTING_BACKBONES +from flash.tabular.forecasting.backbones import TABULAR_FORECASTING_BACKBONES from flash.tabular.forecasting.data import TabularForecastingData -if _TORCH_AVAILABLE: - import torch - import torchmetrics - from torch import nn - from torch.optim.lr_scheduler import _LRScheduler -else: - _LRScheduler = object - -if _FORECASTING_AVAILABLE: - from pytorch_forecasting import SMAPE - class TabularForecaster(Task): backbones: FlashRegistry = TABULAR_FORECASTING_BACKBONES @@ -27,18 +17,16 @@ class TabularForecaster(Task): def __init__( self, tabular_forecasting_data: TabularForecastingData, - backbone: Union[str, Tuple[nn.Module, int]] = "temporal_fusion_transformer", - backbone_kwargs: Optional[Dict] = None, + backbone: str = "temporal_fusion_transformer", loss_fn: Optional[Callable] = None, optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer, str] = torch.optim.Adam, optimizer_kwargs: Optional[Dict[str, Any]] = None, scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, scheduler_kwargs: Optional[Dict[str, Any]] = None, - metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, - learning_rate: float = 1e-2, - **task_kwargs + metrics: Union[torchmetrics.Metric, List[torchmetrics.Metric]] = None, + learning_rate: float = 3e-2, + **backbone_kwargs ): - super().__init__( model=None, loss_fn=None, @@ -46,35 +34,33 @@ def __init__( optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, - metrics=metrics, + metrics=None, learning_rate=learning_rate, - **task_kwargs ) self.save_hyperparameters() + backbone_kwargs["loss"] = loss_fn + + if metrics is not None and not isinstance(metrics, list): + metrics = [metrics] + backbone_kwargs["logging_metrics"] = metrics + if not backbone_kwargs: backbone_kwargs = {} - if isinstance(backbone, tuple): - self.backbone = backbone - else: - self.backbone = self.backbones.get(backbone)( - tabular_forecasting_data=tabular_forecasting_data, **backbone_kwargs - ) - self.model = self.backbone + self.backbone = self.backbones.get(backbone)( + tabular_forecasting_data=tabular_forecasting_data, **backbone_kwargs + ) def training_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) - return self.model.training_step(batch, batch_idx) + return self.backbone.training_step(batch, batch_idx) def validation_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) - return self.model.validation_step(batch, batch_idx) - - def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: - return self.model.configure_optimizers() + return self.backbone.validation_step(batch, batch_idx) @classmethod def from_data(cls, tabular_forecasting_data: TabularForecastingData, **kwargs): - return cls(tabular_forecasting_data=tabular_forecasting_data, backbone_kwargs={"loss": SMAPE()}, **kwargs) + return cls(tabular_forecasting_data=tabular_forecasting_data, **kwargs) diff --git a/flash_examples/tabular_forecasting.py b/flash_examples/tabular_forecasting.py new file mode 100644 index 0000000000..1f5fe80af6 --- /dev/null +++ b/flash_examples/tabular_forecasting.py @@ -0,0 +1,99 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +from pytorch_forecasting.data import GroupNormalizer +from pytorch_forecasting.data.examples import get_stallion_data + +import flash +from flash.tabular.forecasting import TabularForecaster, TabularForecastingData + +data = get_stallion_data() + +# add time index +data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month +data["time_idx"] -= data["time_idx"].min() + +# add additional features +data["month"] = data.date.dt.month.astype(str).astype("category") # categories have be strings +data["log_volume"] = np.log(data.volume + 1e-8) +data["avg_volume_by_sku"] = data.groupby(["time_idx", "sku"], observed=True).volume.transform("mean") +data["avg_volume_by_agency"] = data.groupby(["time_idx", "agency"], observed=True).volume.transform("mean") + +# we want to encode special days as one variable and thus need to first reverse one-hot encoding +special_days = [ + "easter_day", + "good_friday", + "new_year", + "christmas", + "labor_day", + "independence_day", + "revolution_day_memorial", + "regional_games", + "fifa_u_17_world_cup", + "football_gold_cup", + "beer_capital", + "music_fest", +] +data[special_days] = data[special_days].apply(lambda x: x.map({0: "-", 1: x.name})).astype("category") +data.sample(10, random_state=521) + +max_prediction_length = 6 +max_encoder_length = 24 +training_cutoff = data["time_idx"].max() - max_prediction_length + +datamodule = TabularForecastingData.from_data_frame( + time_idx="time_idx", + target="volume", + group_ids=["agency", "sku"], + min_encoder_length=max_encoder_length // 2, # keep encoder length long (as it is in the validation set) + max_encoder_length=max_encoder_length, + min_prediction_length=1, + max_prediction_length=max_prediction_length, + static_categoricals=["agency", "sku"], + static_reals=["avg_population_2017", "avg_yearly_household_income_2017"], + time_varying_known_categoricals=["special_days", "month"], + variable_groups={"special_days": special_days}, # group of categorical variables can be treated as one variable + time_varying_known_reals=["time_idx", "price_regular", "discount_in_percent"], + time_varying_unknown_categoricals=[], + time_varying_unknown_reals=[ + "volume", + "log_volume", + "industry_volume", + "soda_volume", + "avg_max_temp", + "avg_volume_by_agency", + "avg_volume_by_sku", + ], + target_normalizer=GroupNormalizer( + groups=["agency", "sku"], transformation="softplus" + ), # use softplus and normalize by group + add_relative_time_idx=True, + add_target_scales=True, + add_encoder_length=True, + train_data_frame=data[lambda x: x.time_idx <= training_cutoff], + val_data_frame=data, + batch_size=64, +) + +model = TabularForecaster.from_data( + datamodule, + hidden_size=16, + attention_head_size=1, + dropout=0.1, + hidden_continuous_size=8, + output_size=7, +) + +trainer = flash.Trainer(max_epochs=3, limit_train_batches=30) +trainer.fit(model, datamodule=datamodule) From a3aafd00bc0da5393f2931c5e022b76fdc312880 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 20 Sep 2021 22:15:40 +0100 Subject: [PATCH 14/27] Updates --- flash/tabular/forecasting/backbones.py | 14 +++-- flash/tabular/forecasting/data.py | 84 +++++++------------------- flash/tabular/forecasting/model.py | 35 +++++++++-- flash_examples/tabular_forecasting.py | 6 +- 4 files changed, 65 insertions(+), 74 deletions(-) diff --git a/flash/tabular/forecasting/backbones.py b/flash/tabular/forecasting/backbones.py index d50cea3568..a582ec271b 100644 --- a/flash/tabular/forecasting/backbones.py +++ b/flash/tabular/forecasting/backbones.py @@ -2,10 +2,16 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _FORECASTING_AVAILABLE -from flash.tabular.forecasting.data import TabularForecastingData if _FORECASTING_AVAILABLE: - from pytorch_forecasting import DecoderMLP, DeepAR, NBeats, RecurrentNetwork, TemporalFusionTransformer + from pytorch_forecasting import ( + DecoderMLP, + DeepAR, + NBeats, + RecurrentNetwork, + TemporalFusionTransformer, + TimeSeriesDataSet, + ) TABULAR_FORECASTING_BACKBONES = FlashRegistry("backbones") @@ -13,8 +19,8 @@ if _FORECASTING_AVAILABLE: - def load_torch_forecasting(model, tabular_forecasting_data: TabularForecastingData, **kwargs): - return model.from_dataset(tabular_forecasting_data.train_dataset.time_series_dataset, **kwargs) + def load_torch_forecasting(model, time_series_dataset: TimeSeriesDataSet, **kwargs): + return model.from_dataset(time_series_dataset, **kwargs) for model, name in zip( [TemporalFusionTransformer, NBeats, RecurrentNetwork, DeepAR, DecoderMLP], diff --git a/flash/tabular/forecasting/data.py b/flash/tabular/forecasting/data.py index 32e39f6d36..eb9f312892 100644 --- a/flash/tabular/forecasting/data.py +++ b/flash/tabular/forecasting/data.py @@ -11,18 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools +from copy import copy from dataclasses import dataclass from typing import Any, Callable, Dict, List, Mapping, Optional, Union -from torch.utils.data import DataLoader, Dataset - from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources from flash.core.data.process import Deserializer, Preprocess from flash.core.data.properties import ProcessState -from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _PANDAS_AVAILABLE, requires_extras +from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _PANDAS_AVAILABLE, requires if _PANDAS_AVAILABLE: from pandas.core.frame import DataFrame @@ -34,15 +32,15 @@ @dataclass(unsafe_hash=True, frozen=True) -class TimeSeriesDataSetState(ProcessState): +class TimeSeriesDataSetParametersState(ProcessState): """A :class:`~flash.core.data.properties.ProcessState` containing ``labels``, a mapping from class index to label.""" - time_series_dataset: Optional["TimeSeriesDataSet"] + time_series_dataset_parameters: Optional[Dict[str, Any]] class TabularForecastingDataFrameDataSource(DataSource[DataFrame]): - @requires_extras("tabular") + @requires("tabular") def __init__(self, time_idx: str, target: Union[str, List[str]], group_ids: List[str], **data_source_kwargs: Any): super().__init__() self.time_idx = time_idx @@ -52,19 +50,26 @@ def __init__(self, time_idx: str, target: Union[str, List[str]], group_ids: List def load_data(self, data: DataFrame, dataset: Optional[Any] = None): if self.training: - dataset.time_series_dataset = TimeSeriesDataSet( + time_series_dataset = TimeSeriesDataSet( data, time_idx=self.time_idx, group_ids=self.group_ids, target=self.target, **self.data_source_kwargs ) - self.set_state(TimeSeriesDataSetState(dataset.time_series_dataset)) + parameters = time_series_dataset.get_parameters() + self.set_state(TimeSeriesDataSetParametersState(parameters)) + + # Add some sample data so that we can recreate the `TimeSeriesDataSet` later on + parameters = copy(parameters) + parameters["data_sample"] = data.iloc[[0]] + dataset.parameters = parameters else: - train_time_series_dataset = self.get_state(TimeSeriesDataSetState).time_series_dataset - dataset.time_series_dataset = TimeSeriesDataSet.from_dataset( - train_time_series_dataset, + parameters = self.get_state(TimeSeriesDataSetParametersState).time_series_dataset_parameters + time_series_dataset = TimeSeriesDataSet.from_parameters( + parameters, data, predict=True, stop_randomization=True, ) - return dataset.time_series_dataset + dataset.time_series_dataset = time_series_dataset + return time_series_dataset def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: return {DefaultDataKeys.INPUT: sample[0], DefaultDataKeys.TARGET: sample[1]} @@ -105,56 +110,9 @@ class TabularForecastingData(DataModule): preprocess_cls = TabularForecastingPreprocess - @staticmethod - def _collate_fn(collate_fn, samples): - samples = [(sample[DefaultDataKeys.INPUT], sample[DefaultDataKeys.TARGET]) for sample in samples] - batch = collate_fn(samples) - return {DefaultDataKeys.INPUT: batch[0], DefaultDataKeys.TARGET: batch[1]} - - def _train_dataloader(self) -> DataLoader: - train_ds: Dataset = self._train_ds() if isinstance(self._train_ds, Callable) else self._train_ds - time_series_dataset: TimeSeriesDataSet = train_ds.time_series_dataset - result = time_series_dataset.to_dataloader( - train=True, - batch_size=self.batch_size, - ) - collate_fn = functools.partial(self._collate_fn, result.collate_fn) - batch_sampler = result.batch_sampler - return DataLoader( - train_ds, - collate_fn=collate_fn, - batch_sampler=batch_sampler, - num_workers=self.num_workers, - pin_memory=True, - ) - - def _eval_dataloader(self, dataset: Dataset) -> DataLoader: - time_series_dataset: TimeSeriesDataSet = dataset.time_series_dataset - result = time_series_dataset.to_dataloader( - train=False, - batch_size=self.batch_size, - ) - collate_fn = functools.partial(self._collate_fn, result.collate_fn) - batch_sampler = result.batch_sampler - return DataLoader( - dataset, - collate_fn=collate_fn, - batch_sampler=batch_sampler, - num_workers=self.num_workers, - pin_memory=True, - ) - - def _val_dataloader(self) -> DataLoader: - val_ds: Dataset = self._val_ds() if isinstance(self._val_ds, Callable) else self._val_ds - return self._eval_dataloader(val_ds) - - def _test_dataloader(self) -> DataLoader: - test_ds: Dataset = self._test_ds() if isinstance(self._test_ds, Callable) else self._test_ds - return self._eval_dataloader(test_ds) - - def _predict_dataloader(self) -> DataLoader: - predict_ds: Dataset = self._predict_ds() if isinstance(self._predict_ds, Callable) else self._predict_ds - return self._eval_dataloader(predict_ds) + @property + def parameters(self) -> Dict[str, Any]: + return self.train_dataset.parameters @classmethod def from_data_frame( diff --git a/flash/tabular/forecasting/model.py b/flash/tabular/forecasting/model.py index 2c5c14fe21..f52007bc98 100644 --- a/flash/tabular/forecasting/model.py +++ b/flash/tabular/forecasting/model.py @@ -1,3 +1,5 @@ +from copy import copy +from functools import partial from typing import Any, Callable, Dict, List, Optional, Type, Union import torch @@ -6,17 +8,32 @@ from flash import Task from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.states import CollateFn from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _PANDAS_AVAILABLE from flash.tabular.forecasting.backbones import TABULAR_FORECASTING_BACKBONES from flash.tabular.forecasting.data import TabularForecastingData +if _PANDAS_AVAILABLE: + from pandas.core.frame import DataFrame + +if _FORECASTING_AVAILABLE: + from pytorch_forecasting import TimeSeriesDataSet + + +class PatchTimeSeriesDataSet(TimeSeriesDataSet): + """Hack to prevent index construction when instantiating model.""" + + def _construct_index(self, data: DataFrame, predict_mode: bool) -> DataFrame: + return DataFrame() + class TabularForecaster(Task): backbones: FlashRegistry = TABULAR_FORECASTING_BACKBONES def __init__( self, - tabular_forecasting_data: TabularForecastingData, + parameters: Dict[str, Any], backbone: str = "temporal_fusion_transformer", loss_fn: Optional[Callable] = None, optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer, str] = torch.optim.Adam, @@ -40,6 +57,10 @@ def __init__( self.save_hyperparameters() + parameters = copy(parameters) + data = parameters.pop("data_sample") + time_series_dataset = PatchTimeSeriesDataSet.from_parameters(parameters, data) + backbone_kwargs["loss"] = loss_fn if metrics is not None and not isinstance(metrics, list): @@ -49,9 +70,9 @@ def __init__( if not backbone_kwargs: backbone_kwargs = {} - self.backbone = self.backbones.get(backbone)( - tabular_forecasting_data=tabular_forecasting_data, **backbone_kwargs - ) + self.backbone = self.backbones.get(backbone)(time_series_dataset=time_series_dataset, **backbone_kwargs) + + self.set_state(CollateFn(partial(TabularForecaster._collate_fn, time_series_dataset._collate_fn))) def training_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) @@ -64,3 +85,9 @@ def validation_step(self, batch: Any, batch_idx: int) -> Any: @classmethod def from_data(cls, tabular_forecasting_data: TabularForecastingData, **kwargs): return cls(tabular_forecasting_data=tabular_forecasting_data, **kwargs) + + @staticmethod + def _collate_fn(collate_fn, samples): + samples = [(sample[DefaultDataKeys.INPUT], sample[DefaultDataKeys.TARGET]) for sample in samples] + batch = collate_fn(samples) + return {DefaultDataKeys.INPUT: batch[0], DefaultDataKeys.TARGET: batch[1]} diff --git a/flash_examples/tabular_forecasting.py b/flash_examples/tabular_forecasting.py index 1f5fe80af6..1beb85361c 100644 --- a/flash_examples/tabular_forecasting.py +++ b/flash_examples/tabular_forecasting.py @@ -86,8 +86,8 @@ batch_size=64, ) -model = TabularForecaster.from_data( - datamodule, +model = TabularForecaster( + datamodule.parameters, hidden_size=16, attention_head_size=1, dropout=0.1, @@ -95,5 +95,5 @@ output_size=7, ) -trainer = flash.Trainer(max_epochs=3, limit_train_batches=30) +trainer = flash.Trainer(max_epochs=30, gradient_clip_val=0.1) trainer.fit(model, datamodule=datamodule) From 157ef3fd87d9e766b0c3b8ae611e3da87e657628 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 22 Sep 2021 22:35:53 +0100 Subject: [PATCH 15/27] Switch to an adapter --- flash/tabular/forecasting/adapters.py | 106 +++++++++++++++++++++++++ flash/tabular/forecasting/backbones.py | 15 ++++ flash/tabular/forecasting/model.py | 92 +++++++-------------- 3 files changed, 151 insertions(+), 62 deletions(-) create mode 100644 flash/tabular/forecasting/adapters.py diff --git a/flash/tabular/forecasting/adapters.py b/flash/tabular/forecasting/adapters.py new file mode 100644 index 0000000000..fcf2efe7c3 --- /dev/null +++ b/flash/tabular/forecasting/adapters.py @@ -0,0 +1,106 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from copy import copy +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Union + +import torchmetrics + +from flash.core.adapter import Adapter +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.states import CollateFn +from flash.core.model import Task +from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _PANDAS_AVAILABLE + +if _PANDAS_AVAILABLE: + from pandas.core.frame import DataFrame + +if _FORECASTING_AVAILABLE: + from pytorch_forecasting import TimeSeriesDataSet + + +class PatchTimeSeriesDataSet(TimeSeriesDataSet): + """Hack to prevent index construction when instantiating model.""" + + def _construct_index(self, data: DataFrame, predict_mode: bool) -> DataFrame: + return DataFrame() + + +class PyTorchForecastingAdapter(Adapter): + """The ``PyTorchForecastingAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with PyTorch + Forecasting.""" + + @staticmethod + def _collate_fn(collate_fn, samples): + samples = [(sample[DefaultDataKeys.INPUT], sample[DefaultDataKeys.TARGET]) for sample in samples] + batch = collate_fn(samples) + return {DefaultDataKeys.INPUT: batch[0], DefaultDataKeys.TARGET: batch[1]} + + def __init__(self, backbone, collate_fn): + super().__init__() + + self.backbone = backbone + + self.set_state(CollateFn(partial(PyTorchForecastingAdapter._collate_fn, collate_fn))) + + @classmethod + def from_task( + cls, + task: Task, + parameters: Dict[str, Any], + backbone: str, + loss_fn: Optional[Callable] = None, + metrics: Optional[Union[torchmetrics.Metric, List[torchmetrics.Metric]]] = None, + **backbone_kwargs, + ) -> Adapter: + parameters = copy(parameters) + data = parameters.pop("data_sample") + time_series_dataset = PatchTimeSeriesDataSet.from_parameters(parameters, data) + + backbone_kwargs["loss"] = loss_fn + + if metrics is not None and not isinstance(metrics, list): + metrics = [metrics] + backbone_kwargs["logging_metrics"] = metrics + + if not backbone_kwargs: + backbone_kwargs = {} + + return cls( + task.backbones.get(backbone)(time_series_dataset=time_series_dataset, **backbone_kwargs), + time_series_dataset._collate_fn, + ) + + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return self.backbone.training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return self.backbone.validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> None: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + # PyTorch Forecasting models don't have a `test_step`, so re-use `validation_step` + return self.backbone.validation_step(batch, batch_idx) + + def training_epoch_end(self, outputs) -> None: + self.backbone.training_epoch_end(outputs) + + def validation_epoch_end(self, outputs) -> None: + self.backbone.validation_epoch_end(outputs) + + def test_epoch_end(self, outputs) -> None: + # PyTorch Forecasting models don't have a `test_epoch_end`, so re-use `validation_epoch_end` + self.backbone.validation_epoch_end(outputs) diff --git a/flash/tabular/forecasting/backbones.py b/flash/tabular/forecasting/backbones.py index a582ec271b..e278786932 100644 --- a/flash/tabular/forecasting/backbones.py +++ b/flash/tabular/forecasting/backbones.py @@ -1,7 +1,21 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import functools from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _FORECASTING_AVAILABLE +from flash.tabular.forecasting.adapters import PyTorchForecastingAdapter if _FORECASTING_AVAILABLE: from pytorch_forecasting import ( @@ -30,4 +44,5 @@ def load_torch_forecasting(model, time_series_dataset: TimeSeriesDataSet, **kwar functools.partial(load_torch_forecasting, model), name=name, namespace="tabular/forecasting", + adapter=PyTorchForecastingAdapter, ) diff --git a/flash/tabular/forecasting/model.py b/flash/tabular/forecasting/model.py index f52007bc98..0b87faf450 100644 --- a/flash/tabular/forecasting/model.py +++ b/flash/tabular/forecasting/model.py @@ -1,34 +1,28 @@ -from copy import copy -from functools import partial +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any, Callable, Dict, List, Optional, Type, Union import torch import torchmetrics from torch.optim.lr_scheduler import _LRScheduler -from flash import Task -from flash.core.data.data_source import DefaultDataKeys -from flash.core.data.states import CollateFn +from flash.core.adapter import AdapterTask from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _PANDAS_AVAILABLE from flash.tabular.forecasting.backbones import TABULAR_FORECASTING_BACKBONES -from flash.tabular.forecasting.data import TabularForecastingData -if _PANDAS_AVAILABLE: - from pandas.core.frame import DataFrame -if _FORECASTING_AVAILABLE: - from pytorch_forecasting import TimeSeriesDataSet - - -class PatchTimeSeriesDataSet(TimeSeriesDataSet): - """Hack to prevent index construction when instantiating model.""" - - def _construct_index(self, data: DataFrame, predict_mode: bool) -> DataFrame: - return DataFrame() - - -class TabularForecaster(Task): +class TabularForecaster(AdapterTask): backbones: FlashRegistry = TABULAR_FORECASTING_BACKBONES def __init__( @@ -44,50 +38,24 @@ def __init__( learning_rate: float = 3e-2, **backbone_kwargs ): + + self.save_hyperparameters() + + metadata = self.backbones.get(backbone, with_metadata=True) + adapter = metadata["metadata"]["adapter"].from_task( + self, + parameters=parameters, + backbone=backbone, + loss_fn=loss_fn, + metrics=metrics, + **backbone_kwargs, + ) + super().__init__( - model=None, - loss_fn=None, + adapter, + learning_rate=learning_rate, optimizer=optimizer, optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, - metrics=None, - learning_rate=learning_rate, ) - - self.save_hyperparameters() - - parameters = copy(parameters) - data = parameters.pop("data_sample") - time_series_dataset = PatchTimeSeriesDataSet.from_parameters(parameters, data) - - backbone_kwargs["loss"] = loss_fn - - if metrics is not None and not isinstance(metrics, list): - metrics = [metrics] - backbone_kwargs["logging_metrics"] = metrics - - if not backbone_kwargs: - backbone_kwargs = {} - - self.backbone = self.backbones.get(backbone)(time_series_dataset=time_series_dataset, **backbone_kwargs) - - self.set_state(CollateFn(partial(TabularForecaster._collate_fn, time_series_dataset._collate_fn))) - - def training_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) - return self.backbone.training_step(batch, batch_idx) - - def validation_step(self, batch: Any, batch_idx: int) -> Any: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) - return self.backbone.validation_step(batch, batch_idx) - - @classmethod - def from_data(cls, tabular_forecasting_data: TabularForecastingData, **kwargs): - return cls(tabular_forecasting_data=tabular_forecasting_data, **kwargs) - - @staticmethod - def _collate_fn(collate_fn, samples): - samples = [(sample[DefaultDataKeys.INPUT], sample[DefaultDataKeys.TARGET]) for sample in samples] - batch = collate_fn(samples) - return {DefaultDataKeys.INPUT: batch[0], DefaultDataKeys.TARGET: batch[1]} From e7bca8e1ea063d49f134c0fb64d5df8f2ec2bb24 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 23 Sep 2021 12:20:07 +0100 Subject: [PATCH 16/27] Small fixes --- flash/tabular/forecasting/adapters.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/flash/tabular/forecasting/adapters.py b/flash/tabular/forecasting/adapters.py index fcf2efe7c3..f2891e6722 100644 --- a/flash/tabular/forecasting/adapters.py +++ b/flash/tabular/forecasting/adapters.py @@ -28,10 +28,15 @@ if _FORECASTING_AVAILABLE: from pytorch_forecasting import TimeSeriesDataSet +else: + TimeSeriesDataSet = object class PatchTimeSeriesDataSet(TimeSeriesDataSet): - """Hack to prevent index construction when instantiating model.""" + """Hack to prevent index construction when instantiating model. + + This enables the ``TimeSeriesDataSet`` to be created from a single row of data. + """ def _construct_index(self, data: DataFrame, predict_mode: bool) -> DataFrame: return DataFrame() @@ -41,12 +46,6 @@ class PyTorchForecastingAdapter(Adapter): """The ``PyTorchForecastingAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with PyTorch Forecasting.""" - @staticmethod - def _collate_fn(collate_fn, samples): - samples = [(sample[DefaultDataKeys.INPUT], sample[DefaultDataKeys.TARGET]) for sample in samples] - batch = collate_fn(samples) - return {DefaultDataKeys.INPUT: batch[0], DefaultDataKeys.TARGET: batch[1]} - def __init__(self, backbone, collate_fn): super().__init__() @@ -54,6 +53,12 @@ def __init__(self, backbone, collate_fn): self.set_state(CollateFn(partial(PyTorchForecastingAdapter._collate_fn, collate_fn))) + @staticmethod + def _collate_fn(collate_fn, samples): + samples = [(sample[DefaultDataKeys.INPUT], sample[DefaultDataKeys.TARGET]) for sample in samples] + batch = collate_fn(samples) + return {DefaultDataKeys.INPUT: batch[0], DefaultDataKeys.TARGET: batch[1]} + @classmethod def from_task( cls, From 86f3bf92def4e7ce988ac8b5be4a313c1e26aa60 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 28 Oct 2021 13:50:30 +0100 Subject: [PATCH 17/27] Add inference error --- flash/tabular/forecasting/adapters.py | 24 +++++++++++++++++------- flash/tabular/forecasting/model.py | 15 +++++---------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/flash/tabular/forecasting/adapters.py b/flash/tabular/forecasting/adapters.py index f2891e6722..b31be8bb84 100644 --- a/flash/tabular/forecasting/adapters.py +++ b/flash/tabular/forecasting/adapters.py @@ -46,13 +46,11 @@ class PyTorchForecastingAdapter(Adapter): """The ``PyTorchForecastingAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with PyTorch Forecasting.""" - def __init__(self, backbone, collate_fn): + def __init__(self, backbone): super().__init__() self.backbone = backbone - self.set_state(CollateFn(partial(PyTorchForecastingAdapter._collate_fn, collate_fn))) - @staticmethod def _collate_fn(collate_fn, samples): samples = [(sample[DefaultDataKeys.INPUT], sample[DefaultDataKeys.TARGET]) for sample in samples] @@ -82,10 +80,15 @@ def from_task( if not backbone_kwargs: backbone_kwargs = {} - return cls( - task.backbones.get(backbone)(time_series_dataset=time_series_dataset, **backbone_kwargs), - time_series_dataset._collate_fn, - ) + forecasting_model = task.backbones.get(backbone)(time_series_dataset=time_series_dataset, **backbone_kwargs) + + # Attach the required collate function + task.set_state(CollateFn(partial(PyTorchForecastingAdapter._collate_fn, time_series_dataset._collate_fn))) + + # Attach the `forecasting_model` attribute to expose the built-in inference methods from PyTorch Forecasting + task.forecasting_model = forecasting_model + + return cls(forecasting_model) def training_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) @@ -100,6 +103,13 @@ def test_step(self, batch: Any, batch_idx: int) -> None: # PyTorch Forecasting models don't have a `test_step`, so re-use `validation_step` return self.backbone.validation_step(batch, batch_idx) + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + raise NotImplementedError( + "Flash's inference is not currently supported with backbones provided by PyTorch Forecasting. You can " + "access the PyTorch Forecasting LightningModule directly with the `forecasting_model` attribute of the " + "`TabularForecaster`." + ) + def training_epoch_end(self, outputs) -> None: self.backbone.training_epoch_end(outputs) diff --git a/flash/tabular/forecasting/model.py b/flash/tabular/forecasting/model.py index 0b87faf450..3de2d92f63 100644 --- a/flash/tabular/forecasting/model.py +++ b/flash/tabular/forecasting/model.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Union -import torch import torchmetrics -from torch.optim.lr_scheduler import _LRScheduler from flash.core.adapter import AdapterTask from flash.core.registry import FlashRegistry +from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE from flash.tabular.forecasting.backbones import TABULAR_FORECASTING_BACKBONES @@ -30,10 +29,8 @@ def __init__( parameters: Dict[str, Any], backbone: str = "temporal_fusion_transformer", loss_fn: Optional[Callable] = None, - optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer, str] = torch.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, - scheduler_kwargs: Optional[Dict[str, Any]] = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: Union[torchmetrics.Metric, List[torchmetrics.Metric]] = None, learning_rate: float = 3e-2, **backbone_kwargs @@ -55,7 +52,5 @@ def __init__( adapter, learning_rate=learning_rate, optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + lr_scheduler=lr_scheduler, ) From c7967ca5c722b20c9579158693e4062e0ab1163f Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 28 Oct 2021 19:28:31 +0100 Subject: [PATCH 18/27] Add inference and refactor --- flash/core/data/data_pipeline.py | 19 +--- .../pytorch_forecasting/__init__.py | 1 + .../pytorch_forecasting/adapter.py} | 39 ++++---- .../pytorch_forecasting}/backbones.py | 6 +- .../pytorch_forecasting/transforms.py | 30 ++++++ flash/tabular/forecasting/data.py | 1 + flash/tabular/forecasting/model.py | 26 ++++- flash_examples/tabular_forecasting.py | 98 ++++++------------- 8 files changed, 111 insertions(+), 109 deletions(-) create mode 100644 flash/core/integrations/pytorch_forecasting/__init__.py rename flash/{tabular/forecasting/adapters.py => core/integrations/pytorch_forecasting/adapter.py} (74%) rename flash/{tabular/forecasting => core/integrations/pytorch_forecasting}/backbones.py (88%) create mode 100644 flash/core/integrations/pytorch_forecasting/transforms.py diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index f690f0a26c..fd5ee8ef33 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -19,7 +19,6 @@ import torch from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from torch.utils.data import DataLoader, IterableDataset @@ -42,29 +41,19 @@ class DataPipelineState: def __init__(self): self._state: Dict[Type[ProcessState], ProcessState] = {} - self._initialized = False def set_state(self, state: ProcessState): """Add the given :class:`.ProcessState` to the :class:`.DataPipelineState`.""" - if not self._initialized: - self._state[type(state)] = state - else: - rank_zero_warn( - f"Attempted to add a state ({state}) after the data pipeline has already been initialized. This will" - " only have an effect when a new data pipeline is created.", - UserWarning, - ) + self._state[type(state)] = state def get_state(self, state_type: Type[ProcessState]) -> Optional[ProcessState]: """Get the :class:`.ProcessState` of the given type from the :class:`.DataPipelineState`.""" - if state_type in self._state: - return self._state[state_type] - return None + return self._state.get(state_type, None) def __str__(self) -> str: - return f"{self.__class__.__name__}(initialized={self._initialized}, state={self._state})" + return f"{self.__class__.__name__}(state={self._state})" class DataPipeline: @@ -113,13 +102,11 @@ def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) -> :class:`.Postprocess`, and :class:`.Serializer`. Once this has been called, any attempt to add new state will give a warning.""" data_pipeline_state = data_pipeline_state or DataPipelineState() - data_pipeline_state._initialized = False if self.data_source is not None: self.data_source.attach_data_pipeline_state(data_pipeline_state) self._preprocess_pipeline.attach_data_pipeline_state(data_pipeline_state) self._postprocess_pipeline.attach_data_pipeline_state(data_pipeline_state) self._serializer.attach_data_pipeline_state(data_pipeline_state) - data_pipeline_state._initialized = True # TODO: Not sure we need this return data_pipeline_state @property diff --git a/flash/core/integrations/pytorch_forecasting/__init__.py b/flash/core/integrations/pytorch_forecasting/__init__.py new file mode 100644 index 0000000000..33c4b2f483 --- /dev/null +++ b/flash/core/integrations/pytorch_forecasting/__init__.py @@ -0,0 +1 @@ +from flash.core.integrations.pytorch_forecasting.transforms import convert_predictions # noqa: F401 diff --git a/flash/tabular/forecasting/adapters.py b/flash/core/integrations/pytorch_forecasting/adapter.py similarity index 74% rename from flash/tabular/forecasting/adapters.py rename to flash/core/integrations/pytorch_forecasting/adapter.py index b31be8bb84..ab43495494 100644 --- a/flash/tabular/forecasting/adapters.py +++ b/flash/core/integrations/pytorch_forecasting/adapter.py @@ -15,16 +15,18 @@ from functools import partial from typing import Any, Callable, Dict, List, Optional, Union +import torch import torchmetrics +from flash import Task from flash.core.adapter import Adapter +from flash.core.data.batch import default_uncollate from flash.core.data.data_source import DefaultDataKeys from flash.core.data.states import CollateFn -from flash.core.model import Task from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _PANDAS_AVAILABLE if _PANDAS_AVAILABLE: - from pandas.core.frame import DataFrame + from pandas import DataFrame if _FORECASTING_AVAILABLE: from pytorch_forecasting import TimeSeriesDataSet @@ -33,7 +35,7 @@ class PatchTimeSeriesDataSet(TimeSeriesDataSet): - """Hack to prevent index construction when instantiating model. + """Hack to prevent index construction or data validation / conversion when instantiating model. This enables the ``TimeSeriesDataSet`` to be created from a single row of data. """ @@ -41,6 +43,9 @@ class PatchTimeSeriesDataSet(TimeSeriesDataSet): def _construct_index(self, data: DataFrame, predict_mode: bool) -> DataFrame: return DataFrame() + def _data_to_tensors(self, data: DataFrame) -> Dict[str, torch.Tensor]: + return {} + class PyTorchForecastingAdapter(Adapter): """The ``PyTorchForecastingAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with PyTorch @@ -80,15 +85,12 @@ def from_task( if not backbone_kwargs: backbone_kwargs = {} - forecasting_model = task.backbones.get(backbone)(time_series_dataset=time_series_dataset, **backbone_kwargs) + adapter = cls(task.backbones.get(backbone)(time_series_dataset=time_series_dataset, **backbone_kwargs)) # Attach the required collate function - task.set_state(CollateFn(partial(PyTorchForecastingAdapter._collate_fn, time_series_dataset._collate_fn))) - - # Attach the `forecasting_model` attribute to expose the built-in inference methods from PyTorch Forecasting - task.forecasting_model = forecasting_model + adapter.set_state(CollateFn(partial(PyTorchForecastingAdapter._collate_fn, time_series_dataset._collate_fn))) - return cls(forecasting_model) + return adapter def training_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) @@ -99,17 +101,15 @@ def validation_step(self, batch: Any, batch_idx: int) -> Any: return self.backbone.validation_step(batch, batch_idx) def test_step(self, batch: Any, batch_idx: int) -> None: - batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) - # PyTorch Forecasting models don't have a `test_step`, so re-use `validation_step` - return self.backbone.validation_step(batch, batch_idx) - - def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: raise NotImplementedError( - "Flash's inference is not currently supported with backbones provided by PyTorch Forecasting. You can " - "access the PyTorch Forecasting LightningModule directly with the `forecasting_model` attribute of the " - "`TabularForecaster`." + "Backbones provided by PyTorch Forecasting don't support testing. Use validation instead." ) + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + result = dict(self.backbone(batch[DefaultDataKeys.INPUT])) + result[DefaultDataKeys.INPUT] = default_uncollate(batch[DefaultDataKeys.INPUT]) + return default_uncollate(result) + def training_epoch_end(self, outputs) -> None: self.backbone.training_epoch_end(outputs) @@ -117,5 +117,6 @@ def validation_epoch_end(self, outputs) -> None: self.backbone.validation_epoch_end(outputs) def test_epoch_end(self, outputs) -> None: - # PyTorch Forecasting models don't have a `test_epoch_end`, so re-use `validation_epoch_end` - self.backbone.validation_epoch_end(outputs) + raise NotImplementedError( + "Backbones provided by PyTorch Forecasting don't support testing. Use validation instead." + ) diff --git a/flash/tabular/forecasting/backbones.py b/flash/core/integrations/pytorch_forecasting/backbones.py similarity index 88% rename from flash/tabular/forecasting/backbones.py rename to flash/core/integrations/pytorch_forecasting/backbones.py index e278786932..f0ffc8621f 100644 --- a/flash/tabular/forecasting/backbones.py +++ b/flash/core/integrations/pytorch_forecasting/backbones.py @@ -13,9 +13,9 @@ # limitations under the License. import functools +from flash.core.integrations.pytorch_forecasting.adapter import PyTorchForecastingAdapter from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _FORECASTING_AVAILABLE -from flash.tabular.forecasting.adapters import PyTorchForecastingAdapter if _FORECASTING_AVAILABLE: from pytorch_forecasting import ( @@ -28,7 +28,7 @@ ) -TABULAR_FORECASTING_BACKBONES = FlashRegistry("backbones") +PYTORCH_FORECASTING_BACKBONES = FlashRegistry("backbones") if _FORECASTING_AVAILABLE: @@ -40,7 +40,7 @@ def load_torch_forecasting(model, time_series_dataset: TimeSeriesDataSet, **kwar [TemporalFusionTransformer, NBeats, RecurrentNetwork, DeepAR, DecoderMLP], ["temporal_fusion_transformer", "n_beats", "recurrent_network", "deep_ar", "decoder_mlp"], ): - TABULAR_FORECASTING_BACKBONES( + PYTORCH_FORECASTING_BACKBONES( functools.partial(load_torch_forecasting, model), name=name, namespace="tabular/forecasting", diff --git a/flash/core/integrations/pytorch_forecasting/transforms.py b/flash/core/integrations/pytorch_forecasting/transforms.py new file mode 100644 index 0000000000..ce193d0bcc --- /dev/null +++ b/flash/core/integrations/pytorch_forecasting/transforms.py @@ -0,0 +1,30 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Tuple + +from torch.utils.data._utils.collate import default_collate + +from flash.core.data.data_source import DefaultDataKeys + + +def convert_predictions(predictions: List[Dict[str, Any]]) -> Tuple[Dict[str, Any], List]: + # Flatten list if batches were used + if all(isinstance(fl, list) for fl in predictions): + unrolled_predictions = [] + for prediction_batch in predictions: + unrolled_predictions.extend(prediction_batch) + predictions = unrolled_predictions + result = default_collate(predictions) + inputs = result.pop(DefaultDataKeys.INPUT) + return result, inputs diff --git a/flash/tabular/forecasting/data.py b/flash/tabular/forecasting/data.py index eb9f312892..15d87f596f 100644 --- a/flash/tabular/forecasting/data.py +++ b/flash/tabular/forecasting/data.py @@ -95,6 +95,7 @@ def __init__( DefaultDataSources.DATAFRAME: TabularForecastingDataFrameDataSource(**data_source_kwargs), }, deserializer=deserializer, + default_data_source=DefaultDataSources.DATAFRAME, ) def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: diff --git a/flash/tabular/forecasting/model.py b/flash/tabular/forecasting/model.py index 3de2d92f63..eb46ef820d 100644 --- a/flash/tabular/forecasting/model.py +++ b/flash/tabular/forecasting/model.py @@ -14,25 +14,27 @@ from typing import Any, Callable, Dict, List, Optional, Union import torchmetrics +from pytorch_lightning import LightningModule from flash.core.adapter import AdapterTask +from flash.core.integrations.pytorch_forecasting.adapter import PyTorchForecastingAdapter +from flash.core.integrations.pytorch_forecasting.backbones import PYTORCH_FORECASTING_BACKBONES from flash.core.registry import FlashRegistry from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE -from flash.tabular.forecasting.backbones import TABULAR_FORECASTING_BACKBONES class TabularForecaster(AdapterTask): - backbones: FlashRegistry = TABULAR_FORECASTING_BACKBONES + backbones: FlashRegistry = FlashRegistry("backbones") + PYTORCH_FORECASTING_BACKBONES def __init__( self, parameters: Dict[str, Any], - backbone: str = "temporal_fusion_transformer", + backbone: str, loss_fn: Optional[Callable] = None, optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: Union[torchmetrics.Metric, List[torchmetrics.Metric]] = None, - learning_rate: float = 3e-2, + learning_rate: float = 4e-3, **backbone_kwargs ): @@ -54,3 +56,19 @@ def __init__( optimizer=optimizer, lr_scheduler=lr_scheduler, ) + + @property + def pytorch_forecasting_model(self) -> LightningModule: + """This property provides access to the ``LightningModule`` object that is wrapped by Flash for backbones + provided by PyTorch Forecasting. + + This can be used with + :func:`~flash.core.integrations.pytorch_forecasting.transforms.convert_predictions` to access the visualization + features built in to PyTorch Forecasting. + """ + if not isinstance(self.adapter, PyTorchForecastingAdapter): + raise AttributeError( + "The `pytorch_forecasting_model` attribute can only be accessed for backbones provided by PyTorch " + "Forecasting." + ) + return self.adapter.backbone diff --git a/flash_examples/tabular_forecasting.py b/flash_examples/tabular_forecasting.py index 1beb85361c..e42aed8851 100644 --- a/flash_examples/tabular_forecasting.py +++ b/flash_examples/tabular_forecasting.py @@ -11,89 +11,53 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np -from pytorch_forecasting.data import GroupNormalizer -from pytorch_forecasting.data.examples import get_stallion_data +import torch import flash +from flash.core.utilities.imports import example_requires from flash.tabular.forecasting import TabularForecaster, TabularForecastingData -data = get_stallion_data() +example_requires("tabular") -# add time index -data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month -data["time_idx"] -= data["time_idx"].min() +import pandas as pd # noqa: E402 +from pytorch_forecasting.data import NaNLabelEncoder # noqa: E402 +from pytorch_forecasting.data.examples import generate_ar_data # noqa: E402 -# add additional features -data["month"] = data.date.dt.month.astype(str).astype("category") # categories have be strings -data["log_volume"] = np.log(data.volume + 1e-8) -data["avg_volume_by_sku"] = data.groupby(["time_idx", "sku"], observed=True).volume.transform("mean") -data["avg_volume_by_agency"] = data.groupby(["time_idx", "agency"], observed=True).volume.transform("mean") +# Example based on this tutorial: https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/ar.html +# 1. Create the DataModule -# we want to encode special days as one variable and thus need to first reverse one-hot encoding -special_days = [ - "easter_day", - "good_friday", - "new_year", - "christmas", - "labor_day", - "independence_day", - "revolution_day_memorial", - "regional_games", - "fifa_u_17_world_cup", - "football_gold_cup", - "beer_capital", - "music_fest", -] -data[special_days] = data[special_days].apply(lambda x: x.map({0: "-", 1: x.name})).astype("category") -data.sample(10, random_state=521) +data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100, seed=42) +data["static"] = 2 +data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") + +max_prediction_length = 20 -max_prediction_length = 6 -max_encoder_length = 24 training_cutoff = data["time_idx"].max() - max_prediction_length datamodule = TabularForecastingData.from_data_frame( time_idx="time_idx", - target="volume", - group_ids=["agency", "sku"], - min_encoder_length=max_encoder_length // 2, # keep encoder length long (as it is in the validation set) - max_encoder_length=max_encoder_length, - min_prediction_length=1, + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + # only unknown variable is "value" - and N-Beats can also not take any additional variables + time_varying_unknown_reals=["value"], + max_encoder_length=60, max_prediction_length=max_prediction_length, - static_categoricals=["agency", "sku"], - static_reals=["avg_population_2017", "avg_yearly_household_income_2017"], - time_varying_known_categoricals=["special_days", "month"], - variable_groups={"special_days": special_days}, # group of categorical variables can be treated as one variable - time_varying_known_reals=["time_idx", "price_regular", "discount_in_percent"], - time_varying_unknown_categoricals=[], - time_varying_unknown_reals=[ - "volume", - "log_volume", - "industry_volume", - "soda_volume", - "avg_max_temp", - "avg_volume_by_agency", - "avg_volume_by_sku", - ], - target_normalizer=GroupNormalizer( - groups=["agency", "sku"], transformation="softplus" - ), # use softplus and normalize by group - add_relative_time_idx=True, - add_target_scales=True, - add_encoder_length=True, train_data_frame=data[lambda x: x.time_idx <= training_cutoff], val_data_frame=data, - batch_size=64, + batch_size=32, ) -model = TabularForecaster( - datamodule.parameters, - hidden_size=16, - attention_head_size=1, - dropout=0.1, - hidden_continuous_size=8, - output_size=7, -) +# 2. Build the task +model = TabularForecaster(datamodule.parameters, backbone="n_beats", widths=[32, 512], backcast_loss_ratio=0.1) -trainer = flash.Trainer(max_epochs=30, gradient_clip_val=0.1) +# 3. Create the trainer and train the model +trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count(), gradient_clip_val=0.01, limit_train_batches=10) trainer.fit(model, datamodule=datamodule) + +# 4. Generate predictions +predictions = model.predict(data) +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("tabular_forecasting_model.pt") From 7fb852fb4778dd7dd380c07a9e3804cd244f45d4 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 28 Oct 2021 19:46:42 +0100 Subject: [PATCH 19/27] Add interpertation example --- .../tabular_forecasting_interpretable.py | 76 +++++++++++++++++++ flash_examples/tabular_forecasting.py | 2 +- 2 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py diff --git a/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py b/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py new file mode 100644 index 0000000000..5839a0f776 --- /dev/null +++ b/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py @@ -0,0 +1,76 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import flash +from flash.core.integrations.pytorch_forecasting import convert_predictions +from flash.core.utilities.imports import example_requires +from flash.tabular.forecasting import TabularForecaster, TabularForecastingData + +example_requires(["tabular", "matplotlib"]) + +import matplotlib.pyplot as plt # noqa: E402 +import pandas as pd # noqa: E402 +from pytorch_forecasting.data import NaNLabelEncoder # noqa: E402 +from pytorch_forecasting.data.examples import generate_ar_data # noqa: E402 + +# Example based on this tutorial: https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/ar.html +# 1. Create the DataModule + +data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100, seed=42) +data["static"] = 2 +data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") + +max_prediction_length = 20 + +training_cutoff = data["time_idx"].max() - max_prediction_length + +datamodule = TabularForecastingData.from_data_frame( + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + # only unknown variable is "value" - and N-Beats can also not take any additional variables + time_varying_unknown_reals=["value"], + max_encoder_length=60, + max_prediction_length=max_prediction_length, + train_data_frame=data[lambda x: x.time_idx <= training_cutoff], + val_data_frame=data, + batch_size=32, +) + +# 2. Build the task +model = TabularForecaster(datamodule.parameters, backbone="n_beats", widths=[32, 512], backcast_loss_ratio=0.1) + +# 3. Create the trainer and train the model +trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count(), gradient_clip_val=0.01) +trainer.fit(model, datamodule=datamodule) + +# 4. Generate predictions +predictions = model.predict(data) +print(predictions) + +# Convert predictions +predictions, inputs = convert_predictions(predictions) + +# Plot predictions +for idx in range(10): # plot 10 examples + model.pytorch_forecasting_model.plot_prediction(inputs, predictions, idx=idx, add_loss_to_title=True) + +# Plot interpretation +for idx in range(10): # plot 10 examples + model.pytorch_forecasting_model.plot_interpretation(inputs, predictions, idx=idx) + +# Show the plots! +plt.show() diff --git a/flash_examples/tabular_forecasting.py b/flash_examples/tabular_forecasting.py index e42aed8851..381c8fbe11 100644 --- a/flash_examples/tabular_forecasting.py +++ b/flash_examples/tabular_forecasting.py @@ -52,7 +52,7 @@ model = TabularForecaster(datamodule.parameters, backbone="n_beats", widths=[32, 512], backcast_loss_ratio=0.1) # 3. Create the trainer and train the model -trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count(), gradient_clip_val=0.01, limit_train_batches=10) +trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count(), gradient_clip_val=0.01) trainer.fit(model, datamodule=datamodule) # 4. Generate predictions From 3a9c9ab48985767b890654c7dc5e0e624c8a5acf Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 28 Oct 2021 20:10:32 +0100 Subject: [PATCH 20/27] Fix broken tests --- tests/core/data/test_data_pipeline.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py index b7565903dc..d5853e19a9 100644 --- a/tests/core/data/test_data_pipeline.py +++ b/tests/core/data/test_data_pipeline.py @@ -59,18 +59,9 @@ def test_str(): state.set_state(ProcessState()) assert str(state) == ( - "DataPipelineState(initialized=False, " - "state={: ProcessState()})" + "DataPipelineState(state={: ProcessState()})" ) - @staticmethod - def test_warning(): - state = DataPipelineState() - state._initialized = True - - with pytest.warns(UserWarning, match="data pipeline has already been initialized"): - state.set_state(ProcessState()) - @staticmethod def test_get_state(): state = DataPipelineState() From b7846a3337415fa22f65c57f662117195f791df1 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 29 Oct 2021 13:53:10 +0100 Subject: [PATCH 21/27] Small fixes and add some tests --- .../pytorch_forecasting/adapter.py | 5 - flash/tabular/forecasting/data.py | 38 ++++++-- .../tabular_forecasting_interpretable.py | 1 - flash_examples/tabular_forecasting.py | 1 - tests/tabular/classification/test_data.py | 24 ----- tests/tabular/forecasting/__init__.py | 0 tests/tabular/forecasting/test_data.py | 94 +++++++++++++++++++ tests/tabular/forecasting/test_model.py | 78 +++++++++++++++ 8 files changed, 201 insertions(+), 40 deletions(-) create mode 100644 tests/tabular/forecasting/__init__.py create mode 100644 tests/tabular/forecasting/test_data.py create mode 100644 tests/tabular/forecasting/test_model.py diff --git a/flash/core/integrations/pytorch_forecasting/adapter.py b/flash/core/integrations/pytorch_forecasting/adapter.py index ab43495494..0e48dc6d21 100644 --- a/flash/core/integrations/pytorch_forecasting/adapter.py +++ b/flash/core/integrations/pytorch_forecasting/adapter.py @@ -115,8 +115,3 @@ def training_epoch_end(self, outputs) -> None: def validation_epoch_end(self, outputs) -> None: self.backbone.validation_epoch_end(outputs) - - def test_epoch_end(self, outputs) -> None: - raise NotImplementedError( - "Backbones provided by PyTorch Forecasting don't support testing. Use validation instead." - ) diff --git a/flash/tabular/forecasting/data.py b/flash/tabular/forecasting/data.py index 15d87f596f..07c49b6d5c 100644 --- a/flash/tabular/forecasting/data.py +++ b/flash/tabular/forecasting/data.py @@ -15,6 +15,8 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Mapping, Optional, Union +from pytorch_lightning.utilities.exceptions import MisconfigurationException + from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources @@ -41,27 +43,43 @@ class TimeSeriesDataSetParametersState(ProcessState): class TabularForecastingDataFrameDataSource(DataSource[DataFrame]): @requires("tabular") - def __init__(self, time_idx: str, target: Union[str, List[str]], group_ids: List[str], **data_source_kwargs: Any): + def __init__( + self, + time_idx: str, + target: Union[str, List[str]], + group_ids: List[str], + parameters: Optional[Dict[str, Any]] = None, + **data_source_kwargs: Any, + ): super().__init__() self.time_idx = time_idx self.target = target self.group_ids = group_ids self.data_source_kwargs = data_source_kwargs + self.set_state(TimeSeriesDataSetParametersState(parameters)) + def load_data(self, data: DataFrame, dataset: Optional[Any] = None): if self.training: time_series_dataset = TimeSeriesDataSet( data, time_idx=self.time_idx, group_ids=self.group_ids, target=self.target, **self.data_source_kwargs ) parameters = time_series_dataset.get_parameters() - self.set_state(TimeSeriesDataSetParametersState(parameters)) # Add some sample data so that we can recreate the `TimeSeriesDataSet` later on - parameters = copy(parameters) parameters["data_sample"] = data.iloc[[0]] + + self.set_state(TimeSeriesDataSetParametersState(parameters)) dataset.parameters = parameters else: - parameters = self.get_state(TimeSeriesDataSetParametersState).time_series_dataset_parameters + parameters = copy(self.get_state(TimeSeriesDataSetParametersState).time_series_dataset_parameters) + if parameters is None: + raise MisconfigurationException( + "Loading data for evaluation or inference requires parameters from the train data. Either " + "construct the train data at the same time as evaluation and inference or provide the train " + "`datamodule.parameters` to `from_data_frame` in the `parameters` argument." + ) + parameters.pop("data_sample") time_series_dataset = TimeSeriesDataSet.from_parameters( parameters, data, @@ -112,15 +130,16 @@ class TabularForecastingData(DataModule): preprocess_cls = TabularForecastingPreprocess @property - def parameters(self) -> Dict[str, Any]: - return self.train_dataset.parameters + def parameters(self) -> Optional[Dict[str, Any]]: + return getattr(self.train_dataset, "parameters", None) @classmethod def from_data_frame( cls, - time_idx: str, - target: Union[str, List[str]], - group_ids: List[str], + time_idx: Optional[str] = None, + target: Optional[Union[str, List[str]]] = None, + group_ids: Optional[List[str]] = None, + parameters: Optional[Dict[str, Any]] = None, train_data_frame: Optional[DataFrame] = None, val_data_frame: Optional[DataFrame] = None, test_data_frame: Optional[DataFrame] = None, @@ -182,6 +201,7 @@ def from_data_frame( time_idx=time_idx, target=target, group_ids=group_ids, + parameters=parameters, data_source=DefaultDataSources.DATAFRAME, train_data=train_data_frame, val_data=val_data_frame, diff --git a/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py b/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py index 5839a0f776..2c5c9ea60a 100644 --- a/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py +++ b/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py @@ -27,7 +27,6 @@ # Example based on this tutorial: https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/ar.html # 1. Create the DataModule - data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100, seed=42) data["static"] = 2 data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") diff --git a/flash_examples/tabular_forecasting.py b/flash_examples/tabular_forecasting.py index 381c8fbe11..718cc0aa72 100644 --- a/flash_examples/tabular_forecasting.py +++ b/flash_examples/tabular_forecasting.py @@ -25,7 +25,6 @@ # Example based on this tutorial: https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/ar.html # 1. Create the DataModule - data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100, seed=42) data["static"] = 2 data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") diff --git a/tests/tabular/classification/test_data.py b/tests/tabular/classification/test_data.py index b1e9ef3f25..4f1fc632d8 100644 --- a/tests/tabular/classification/test_data.py +++ b/tests/tabular/classification/test_data.py @@ -89,30 +89,6 @@ def test_embedding_sizes(): assert es == [(100_000, 17), (1_000_000, 31)] -@pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required") -def test_tabular_data(tmpdir): - train_data_frame = TEST_DF_1.copy() - val_data_frame = TEST_DF_2.copy() - test_data_frame = TEST_DF_2.copy() - dm = TabularClassificationData.from_data_frame( - categorical_fields=["category"], - numerical_fields=["scalar_a", "scalar_b"], - target_fields="label", - train_data_frame=train_data_frame, - val_data_frame=val_data_frame, - test_data_frame=test_data_frame, - num_workers=0, - batch_size=1, - ) - for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: - data = next(iter(dl)) - (cat, num) = data[DefaultDataKeys.INPUT] - target = data[DefaultDataKeys.TARGET] - assert cat.shape == (1, 1) - assert num.shape == (1, 2) - assert target.shape == (1,) - - @pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required") def test_categorical_target(tmpdir): train_data_frame = TEST_DF_1.copy() diff --git a/tests/tabular/forecasting/__init__.py b/tests/tabular/forecasting/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tabular/forecasting/test_data.py b/tests/tabular/forecasting/test_data.py new file mode 100644 index 0000000000..1c10291d5c --- /dev/null +++ b/tests/tabular/forecasting/test_data.py @@ -0,0 +1,94 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import MagicMock, patch + +import pytest +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from flash.tabular.forecasting import TabularForecastingData +from tests.helpers.utils import _TABULAR_TESTING + + +@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.") +@patch("flash.tabular.forecasting.data.TimeSeriesDataSet") +def test_from_data_frame_time_series_data_set_single_call(patch_time_series_data_set): + """Tests that ``TabularForecastingData.from_data_frame`` calls ``TimeSeriesDataSet`` with the expected + parameters when called once with data for all stages.""" + patch_time_series_data_set.return_value.get_parameters.return_value = {"test": None} + + train_data = MagicMock() + val_data = MagicMock() + + TabularForecastingData.from_data_frame( + "time_idx", + "target", + ["series"], + train_data_frame=train_data, + val_data_frame=val_data, + additional_kwarg="test", + ) + + patch_time_series_data_set.assert_called_once_with( + train_data, time_idx="time_idx", group_ids=["series"], target="target", additional_kwarg="test" + ) + + patch_time_series_data_set.from_parameters.assert_called_once_with( + {"test": None}, val_data, predict=True, stop_randomization=True + ) + + +@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.") +@patch("flash.tabular.forecasting.data.TimeSeriesDataSet") +def test_from_data_frame_time_series_data_set_multi_call(patch_time_series_data_set): + """Tests that ``TabularForecastingData.from_data_frame`` calls ``TimeSeriesDataSet`` with the expected + parameters when called separately for each stage.""" + patch_time_series_data_set.return_value.get_parameters.return_value = {"test": None} + + train_data = MagicMock() + val_data = MagicMock() + + train_datamodule = TabularForecastingData.from_data_frame( + "time_idx", + "target", + ["series"], + train_data_frame=train_data, + additional_kwarg="test", + ) + + TabularForecastingData.from_data_frame( + val_data_frame=val_data, + parameters=train_datamodule.parameters, + ) + + patch_time_series_data_set.assert_called_once_with( + train_data, time_idx="time_idx", group_ids=["series"], target="target", additional_kwarg="test" + ) + + patch_time_series_data_set.from_parameters.assert_called_once_with( + {"test": None}, val_data, predict=True, stop_randomization=True + ) + + +@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.") +def test_from_data_frame_misconfiguration(): + """Tests that a ``MisconfigurationException`` is raised when ``TabularForecastingData`` is constructed without + parameters.""" + with pytest.raises(MisconfigurationException, match="evaluation or inference requires parameters"): + TabularForecastingData.from_data_frame( + "time_idx", + "target", + ["series"], + val_data_frame=MagicMock(), + additional_kwarg="test", + ) diff --git a/tests/tabular/forecasting/test_model.py b/tests/tabular/forecasting/test_model.py new file mode 100644 index 0000000000..073970a3b0 --- /dev/null +++ b/tests/tabular/forecasting/test_model.py @@ -0,0 +1,78 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +import flash +from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_AVAILABLE +from flash.tabular.forecasting import TabularForecaster, TabularForecastingData + +if _TABULAR_AVAILABLE: + from pytorch_forecasting.data import NaNLabelEncoder + from pytorch_forecasting.data.examples import generate_ar_data + +if _PANDAS_AVAILABLE: + import pandas as pd + + +@pytest.fixture +def sample_data(): + data = generate_ar_data(seasonality=10.0, timesteps=100, n_series=2, seed=42) + data["static"] = 2 + data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") + max_prediction_length = 20 + training_cutoff = data["time_idx"].max() - max_prediction_length + return data, training_cutoff, max_prediction_length + + +def test_fast_dev_run_smoke(sample_data): + """Test that fast dev run works with the NBeats example data.""" + data, training_cutoff, max_prediction_length = sample_data + datamodule = TabularForecastingData.from_data_frame( + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + time_varying_unknown_reals=["value"], + max_encoder_length=60, + max_prediction_length=max_prediction_length, + train_data_frame=data[lambda x: x.time_idx <= training_cutoff], + val_data_frame=data, + ) + + model = TabularForecaster(datamodule.parameters, backbone="n_beats", widths=[32, 512], backcast_loss_ratio=0.1) + + trainer = flash.Trainer(max_epochs=1, fast_dev_run=True, gradient_clip_val=0.01) + trainer.fit(model, datamodule=datamodule) + + +def test_testing_raises(sample_data): + """Tests that ``NotImplementedError`` is raised when attempting to perform a test pass.""" + data, training_cutoff, max_prediction_length = sample_data + datamodule = TabularForecastingData.from_data_frame( + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + time_varying_unknown_reals=["value"], + max_encoder_length=60, + max_prediction_length=max_prediction_length, + train_data_frame=data[lambda x: x.time_idx <= training_cutoff], + test_data_frame=data, + ) + + model = TabularForecaster(datamodule.parameters, backbone="n_beats", widths=[32, 512], backcast_loss_ratio=0.1) + trainer = flash.Trainer(max_epochs=1, fast_dev_run=True, gradient_clip_val=0.01) + + with pytest.raises(NotImplementedError, match="Backbones provided by PyTorch Forecasting don't support testing."): + trainer.test(model, datamodule) From b7756b003ae7becb02e9ff3593ea8fc71a9a08e8 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 29 Oct 2021 13:57:36 +0100 Subject: [PATCH 22/27] Updates --- tests/tabular/forecasting/test_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/tabular/forecasting/test_model.py b/tests/tabular/forecasting/test_model.py index 073970a3b0..1e0df945e2 100644 --- a/tests/tabular/forecasting/test_model.py +++ b/tests/tabular/forecasting/test_model.py @@ -16,6 +16,7 @@ import flash from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_AVAILABLE from flash.tabular.forecasting import TabularForecaster, TabularForecastingData +from tests.helpers.utils import _TABULAR_TESTING if _TABULAR_AVAILABLE: from pytorch_forecasting.data import NaNLabelEncoder @@ -35,6 +36,7 @@ def sample_data(): return data, training_cutoff, max_prediction_length +@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.") def test_fast_dev_run_smoke(sample_data): """Test that fast dev run works with the NBeats example data.""" data, training_cutoff, max_prediction_length = sample_data @@ -56,6 +58,7 @@ def test_fast_dev_run_smoke(sample_data): trainer.fit(model, datamodule=datamodule) +@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.") def test_testing_raises(sample_data): """Tests that ``NotImplementedError`` is raised when attempting to perform a test pass.""" data, training_cutoff, max_prediction_length = sample_data From 6313ffe1ca5436248a115bdc1ae28b777a066331 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 29 Oct 2021 13:58:55 +0100 Subject: [PATCH 23/27] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 58c5b77345..a1d2556d10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added a `TabularForecaster` task based on PyTorch Forecasting ([#647](https://github.com/PyTorchLightning/lightning-flash/pull/647)) + ### Changed ### Fixed From 8976a9011dca3bb47a809a49d0e66a4b169a1af2 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 29 Oct 2021 14:01:01 +0100 Subject: [PATCH 24/27] Add provider --- flash/core/integrations/pytorch_forecasting/backbones.py | 3 ++- flash/core/utilities/providers.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/flash/core/integrations/pytorch_forecasting/backbones.py b/flash/core/integrations/pytorch_forecasting/backbones.py index f0ffc8621f..baba87ac50 100644 --- a/flash/core/integrations/pytorch_forecasting/backbones.py +++ b/flash/core/integrations/pytorch_forecasting/backbones.py @@ -16,6 +16,7 @@ from flash.core.integrations.pytorch_forecasting.adapter import PyTorchForecastingAdapter from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _FORECASTING_AVAILABLE +from flash.core.utilities.providers import _PYTORCH_FORECASTING if _FORECASTING_AVAILABLE: from pytorch_forecasting import ( @@ -43,6 +44,6 @@ def load_torch_forecasting(model, time_series_dataset: TimeSeriesDataSet, **kwar PYTORCH_FORECASTING_BACKBONES( functools.partial(load_torch_forecasting, model), name=name, - namespace="tabular/forecasting", + providers=_PYTORCH_FORECASTING, adapter=PyTorchForecastingAdapter, ) diff --git a/flash/core/utilities/providers.py b/flash/core/utilities/providers.py index a5bb749246..d536154d73 100644 --- a/flash/core/utilities/providers.py +++ b/flash/core/utilities/providers.py @@ -46,3 +46,4 @@ def __str__(self): _OPEN3D_ML = Provider("Intelligent Systems Lab Org/Open3D-ML", "https://github.com/isl-org/Open3D-ML") _PYTORCHVIDEO = Provider("Facebook Research/PyTorchVideo", "https://github.com/facebookresearch/pytorchvideo") _VISSL = Provider("Facebook Research/vissl", "https://github.com/facebookresearch/vissl") +_PYTORCH_FORECASTING = Provider("jdb78/PyTorch-Forecasting", "https://github.com/jdb78/pytorch-forecasting") From fb4a5982c05ca9ffa01b2d0bd5b0230d0e7bce66 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 29 Oct 2021 15:46:55 +0100 Subject: [PATCH 25/27] Update flash/core/integrations/pytorch_forecasting/adapter.py Co-authored-by: Jirka Borovec --- flash/core/integrations/pytorch_forecasting/adapter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/flash/core/integrations/pytorch_forecasting/adapter.py b/flash/core/integrations/pytorch_forecasting/adapter.py index 0e48dc6d21..7bf6dceae3 100644 --- a/flash/core/integrations/pytorch_forecasting/adapter.py +++ b/flash/core/integrations/pytorch_forecasting/adapter.py @@ -27,6 +27,8 @@ if _PANDAS_AVAILABLE: from pandas import DataFrame +else: + DataFrame = object if _FORECASTING_AVAILABLE: from pytorch_forecasting import TimeSeriesDataSet From 9c213d6cd2a6201e5e97337b38bed03b711d586c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 29 Oct 2021 15:47:26 +0100 Subject: [PATCH 26/27] Update flash/core/integrations/pytorch_forecasting/adapter.py Co-authored-by: Jirka Borovec --- flash/core/integrations/pytorch_forecasting/adapter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/flash/core/integrations/pytorch_forecasting/adapter.py b/flash/core/integrations/pytorch_forecasting/adapter.py index 7bf6dceae3..f77d0f8e56 100644 --- a/flash/core/integrations/pytorch_forecasting/adapter.py +++ b/flash/core/integrations/pytorch_forecasting/adapter.py @@ -84,8 +84,7 @@ def from_task( metrics = [metrics] backbone_kwargs["logging_metrics"] = metrics - if not backbone_kwargs: - backbone_kwargs = {} + backbone_kwargs = backbone_kwargs or {} adapter = cls(task.backbones.get(backbone)(time_series_dataset=time_series_dataset, **backbone_kwargs)) From 3cbd13e089bc3fd005cf6a76099bd1f9bc5426dd Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 29 Oct 2021 15:58:03 +0100 Subject: [PATCH 27/27] Update on comments --- flash/core/integrations/pytorch_forecasting/adapter.py | 1 + .../pytorch_forecasting/tabular_forecasting_interpretable.py | 1 - flash_examples/tabular_forecasting.py | 1 - tests/tabular/forecasting/test_model.py | 1 - 4 files changed, 1 insertion(+), 3 deletions(-) diff --git a/flash/core/integrations/pytorch_forecasting/adapter.py b/flash/core/integrations/pytorch_forecasting/adapter.py index f77d0f8e56..473ecc38bf 100644 --- a/flash/core/integrations/pytorch_forecasting/adapter.py +++ b/flash/core/integrations/pytorch_forecasting/adapter.py @@ -75,6 +75,7 @@ def from_task( **backbone_kwargs, ) -> Adapter: parameters = copy(parameters) + # Remove the single row of data from the parameters to reconstruct the `time_series_dataset` data = parameters.pop("data_sample") time_series_dataset = PatchTimeSeriesDataSet.from_parameters(parameters, data) diff --git a/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py b/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py index 2c5c9ea60a..ec62cb2643 100644 --- a/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py +++ b/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py @@ -28,7 +28,6 @@ # Example based on this tutorial: https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/ar.html # 1. Create the DataModule data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100, seed=42) -data["static"] = 2 data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") max_prediction_length = 20 diff --git a/flash_examples/tabular_forecasting.py b/flash_examples/tabular_forecasting.py index 718cc0aa72..836f01fe64 100644 --- a/flash_examples/tabular_forecasting.py +++ b/flash_examples/tabular_forecasting.py @@ -26,7 +26,6 @@ # Example based on this tutorial: https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/ar.html # 1. Create the DataModule data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100, seed=42) -data["static"] = 2 data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") max_prediction_length = 20 diff --git a/tests/tabular/forecasting/test_model.py b/tests/tabular/forecasting/test_model.py index 1e0df945e2..bebf8477bc 100644 --- a/tests/tabular/forecasting/test_model.py +++ b/tests/tabular/forecasting/test_model.py @@ -29,7 +29,6 @@ @pytest.fixture def sample_data(): data = generate_ar_data(seasonality=10.0, timesteps=100, n_series=2, seed=42) - data["static"] = 2 data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") max_prediction_length = 20 training_cutoff = data["time_idx"].max() - max_prediction_length