diff --git a/CHANGELOG.md b/CHANGELOG.md index 989bfcc60e..b5cacbf704 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added a `TabularForecaster` task based on PyTorch Forecasting ([#647](https://github.com/PyTorchLightning/lightning-flash/pull/647)) + ### Changed ### Fixed diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index f690f0a26c..fd5ee8ef33 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -19,7 +19,6 @@ import torch from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from torch.utils.data import DataLoader, IterableDataset @@ -42,29 +41,19 @@ class DataPipelineState: def __init__(self): self._state: Dict[Type[ProcessState], ProcessState] = {} - self._initialized = False def set_state(self, state: ProcessState): """Add the given :class:`.ProcessState` to the :class:`.DataPipelineState`.""" - if not self._initialized: - self._state[type(state)] = state - else: - rank_zero_warn( - f"Attempted to add a state ({state}) after the data pipeline has already been initialized. This will" - " only have an effect when a new data pipeline is created.", - UserWarning, - ) + self._state[type(state)] = state def get_state(self, state_type: Type[ProcessState]) -> Optional[ProcessState]: """Get the :class:`.ProcessState` of the given type from the :class:`.DataPipelineState`.""" - if state_type in self._state: - return self._state[state_type] - return None + return self._state.get(state_type, None) def __str__(self) -> str: - return f"{self.__class__.__name__}(initialized={self._initialized}, state={self._state})" + return f"{self.__class__.__name__}(state={self._state})" class DataPipeline: @@ -113,13 +102,11 @@ def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) -> :class:`.Postprocess`, and :class:`.Serializer`. Once this has been called, any attempt to add new state will give a warning.""" data_pipeline_state = data_pipeline_state or DataPipelineState() - data_pipeline_state._initialized = False if self.data_source is not None: self.data_source.attach_data_pipeline_state(data_pipeline_state) self._preprocess_pipeline.attach_data_pipeline_state(data_pipeline_state) self._postprocess_pipeline.attach_data_pipeline_state(data_pipeline_state) self._serializer.attach_data_pipeline_state(data_pipeline_state) - data_pipeline_state._initialized = True # TODO: Not sure we need this return data_pipeline_state @property diff --git a/flash/core/integrations/pytorch_forecasting/__init__.py b/flash/core/integrations/pytorch_forecasting/__init__.py new file mode 100644 index 0000000000..33c4b2f483 --- /dev/null +++ b/flash/core/integrations/pytorch_forecasting/__init__.py @@ -0,0 +1 @@ +from flash.core.integrations.pytorch_forecasting.transforms import convert_predictions # noqa: F401 diff --git a/flash/core/integrations/pytorch_forecasting/adapter.py b/flash/core/integrations/pytorch_forecasting/adapter.py new file mode 100644 index 0000000000..473ecc38bf --- /dev/null +++ b/flash/core/integrations/pytorch_forecasting/adapter.py @@ -0,0 +1,119 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from copy import copy +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +import torchmetrics + +from flash import Task +from flash.core.adapter import Adapter +from flash.core.data.batch import default_uncollate +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.states import CollateFn +from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _PANDAS_AVAILABLE + +if _PANDAS_AVAILABLE: + from pandas import DataFrame +else: + DataFrame = object + +if _FORECASTING_AVAILABLE: + from pytorch_forecasting import TimeSeriesDataSet +else: + TimeSeriesDataSet = object + + +class PatchTimeSeriesDataSet(TimeSeriesDataSet): + """Hack to prevent index construction or data validation / conversion when instantiating model. + + This enables the ``TimeSeriesDataSet`` to be created from a single row of data. + """ + + def _construct_index(self, data: DataFrame, predict_mode: bool) -> DataFrame: + return DataFrame() + + def _data_to_tensors(self, data: DataFrame) -> Dict[str, torch.Tensor]: + return {} + + +class PyTorchForecastingAdapter(Adapter): + """The ``PyTorchForecastingAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with PyTorch + Forecasting.""" + + def __init__(self, backbone): + super().__init__() + + self.backbone = backbone + + @staticmethod + def _collate_fn(collate_fn, samples): + samples = [(sample[DefaultDataKeys.INPUT], sample[DefaultDataKeys.TARGET]) for sample in samples] + batch = collate_fn(samples) + return {DefaultDataKeys.INPUT: batch[0], DefaultDataKeys.TARGET: batch[1]} + + @classmethod + def from_task( + cls, + task: Task, + parameters: Dict[str, Any], + backbone: str, + loss_fn: Optional[Callable] = None, + metrics: Optional[Union[torchmetrics.Metric, List[torchmetrics.Metric]]] = None, + **backbone_kwargs, + ) -> Adapter: + parameters = copy(parameters) + # Remove the single row of data from the parameters to reconstruct the `time_series_dataset` + data = parameters.pop("data_sample") + time_series_dataset = PatchTimeSeriesDataSet.from_parameters(parameters, data) + + backbone_kwargs["loss"] = loss_fn + + if metrics is not None and not isinstance(metrics, list): + metrics = [metrics] + backbone_kwargs["logging_metrics"] = metrics + + backbone_kwargs = backbone_kwargs or {} + + adapter = cls(task.backbones.get(backbone)(time_series_dataset=time_series_dataset, **backbone_kwargs)) + + # Attach the required collate function + adapter.set_state(CollateFn(partial(PyTorchForecastingAdapter._collate_fn, time_series_dataset._collate_fn))) + + return adapter + + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return self.backbone.training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return self.backbone.validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> None: + raise NotImplementedError( + "Backbones provided by PyTorch Forecasting don't support testing. Use validation instead." + ) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + result = dict(self.backbone(batch[DefaultDataKeys.INPUT])) + result[DefaultDataKeys.INPUT] = default_uncollate(batch[DefaultDataKeys.INPUT]) + return default_uncollate(result) + + def training_epoch_end(self, outputs) -> None: + self.backbone.training_epoch_end(outputs) + + def validation_epoch_end(self, outputs) -> None: + self.backbone.validation_epoch_end(outputs) diff --git a/flash/core/integrations/pytorch_forecasting/backbones.py b/flash/core/integrations/pytorch_forecasting/backbones.py new file mode 100644 index 0000000000..baba87ac50 --- /dev/null +++ b/flash/core/integrations/pytorch_forecasting/backbones.py @@ -0,0 +1,49 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools + +from flash.core.integrations.pytorch_forecasting.adapter import PyTorchForecastingAdapter +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _FORECASTING_AVAILABLE +from flash.core.utilities.providers import _PYTORCH_FORECASTING + +if _FORECASTING_AVAILABLE: + from pytorch_forecasting import ( + DecoderMLP, + DeepAR, + NBeats, + RecurrentNetwork, + TemporalFusionTransformer, + TimeSeriesDataSet, + ) + + +PYTORCH_FORECASTING_BACKBONES = FlashRegistry("backbones") + + +if _FORECASTING_AVAILABLE: + + def load_torch_forecasting(model, time_series_dataset: TimeSeriesDataSet, **kwargs): + return model.from_dataset(time_series_dataset, **kwargs) + + for model, name in zip( + [TemporalFusionTransformer, NBeats, RecurrentNetwork, DeepAR, DecoderMLP], + ["temporal_fusion_transformer", "n_beats", "recurrent_network", "deep_ar", "decoder_mlp"], + ): + PYTORCH_FORECASTING_BACKBONES( + functools.partial(load_torch_forecasting, model), + name=name, + providers=_PYTORCH_FORECASTING, + adapter=PyTorchForecastingAdapter, + ) diff --git a/flash/core/integrations/pytorch_forecasting/transforms.py b/flash/core/integrations/pytorch_forecasting/transforms.py new file mode 100644 index 0000000000..ce193d0bcc --- /dev/null +++ b/flash/core/integrations/pytorch_forecasting/transforms.py @@ -0,0 +1,30 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Tuple + +from torch.utils.data._utils.collate import default_collate + +from flash.core.data.data_source import DefaultDataKeys + + +def convert_predictions(predictions: List[Dict[str, Any]]) -> Tuple[Dict[str, Any], List]: + # Flatten list if batches were used + if all(isinstance(fl, list) for fl in predictions): + unrolled_predictions = [] + for prediction_batch in predictions: + unrolled_predictions.extend(prediction_batch) + predictions = unrolled_predictions + result = default_collate(predictions) + inputs = result.pop(DefaultDataKeys.INPUT) + return result, inputs diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index e420117e35..53fdc66e26 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -70,6 +70,7 @@ def _compare_version(package: str, op, version) -> bool: _PANDAS_AVAILABLE = _module_available("pandas") _SKLEARN_AVAILABLE = _module_available("sklearn") _TABNET_AVAILABLE = _module_available("pytorch_tabnet") +_FORECASTING_AVAILABLE = _module_available("pytorch_forecasting") _KORNIA_AVAILABLE = _module_available("kornia") _COCO_AVAILABLE = _module_available("pycocotools") _TIMM_AVAILABLE = _module_available("timm") @@ -126,7 +127,7 @@ class Image: _DATASETS_AVAILABLE, ] ) -_TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE +_TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE and _FORECASTING_AVAILABLE _VIDEO_AVAILABLE = _TORCHVISION_AVAILABLE and _PIL_AVAILABLE and _PYTORCHVIDEO_AVAILABLE and _KORNIA_AVAILABLE _IMAGE_AVAILABLE = all( [ diff --git a/flash/core/utilities/providers.py b/flash/core/utilities/providers.py index a5bb749246..d536154d73 100644 --- a/flash/core/utilities/providers.py +++ b/flash/core/utilities/providers.py @@ -46,3 +46,4 @@ def __str__(self): _OPEN3D_ML = Provider("Intelligent Systems Lab Org/Open3D-ML", "https://github.com/isl-org/Open3D-ML") _PYTORCHVIDEO = Provider("Facebook Research/PyTorchVideo", "https://github.com/facebookresearch/pytorchvideo") _VISSL = Provider("Facebook Research/vissl", "https://github.com/facebookresearch/vissl") +_PYTORCH_FORECASTING = Provider("jdb78/PyTorch-Forecasting", "https://github.com/jdb78/pytorch-forecasting") diff --git a/flash/tabular/__init__.py b/flash/tabular/__init__.py index 22698efc99..466ca1fc0f 100644 --- a/flash/tabular/__init__.py +++ b/flash/tabular/__init__.py @@ -1,3 +1,8 @@ from flash.tabular.classification import TabularClassificationData, TabularClassifier # noqa: F401 from flash.tabular.data import TabularData # noqa: F401 +from flash.tabular.forecasting.data import ( # noqa: F401 + TabularForecastingData, + TabularForecastingDataFrameDataSource, + TabularForecastingPreprocess, +) from flash.tabular.regression import TabularRegressionData # noqa: F401 diff --git a/flash/tabular/classification/cli.py b/flash/tabular/classification/cli.py index 6787b1c8d6..fe3f8a9ae3 100644 --- a/flash/tabular/classification/cli.py +++ b/flash/tabular/classification/cli.py @@ -14,7 +14,8 @@ from flash.core.data.utils import download_data from flash.core.utilities.flash_cli import FlashCLI -from flash.tabular import TabularClassificationData, TabularClassifier +from flash.tabular.classification.data import TabularClassificationData +from flash.tabular.classification.model import TabularClassifier __all__ = ["tabular_classification"] diff --git a/flash/tabular/forecasting/__init__.py b/flash/tabular/forecasting/__init__.py new file mode 100644 index 0000000000..fb978e7a62 --- /dev/null +++ b/flash/tabular/forecasting/__init__.py @@ -0,0 +1,2 @@ +from flash.tabular.forecasting.data import TabularForecastingData # noqa: F401 +from flash.tabular.forecasting.model import TabularForecaster # noqa: F401 diff --git a/flash/tabular/forecasting/data.py b/flash/tabular/forecasting/data.py new file mode 100644 index 0000000000..07c49b6d5c --- /dev/null +++ b/flash/tabular/forecasting/data.py @@ -0,0 +1,220 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from copy import copy +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Mapping, Optional, Union + +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.process import Deserializer, Preprocess +from flash.core.data.properties import ProcessState +from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _PANDAS_AVAILABLE, requires + +if _PANDAS_AVAILABLE: + from pandas.core.frame import DataFrame +else: + DataFrame = object + +if _FORECASTING_AVAILABLE: + from pytorch_forecasting import TimeSeriesDataSet + + +@dataclass(unsafe_hash=True, frozen=True) +class TimeSeriesDataSetParametersState(ProcessState): + """A :class:`~flash.core.data.properties.ProcessState` containing ``labels``, a mapping from class index to + label.""" + + time_series_dataset_parameters: Optional[Dict[str, Any]] + + +class TabularForecastingDataFrameDataSource(DataSource[DataFrame]): + @requires("tabular") + def __init__( + self, + time_idx: str, + target: Union[str, List[str]], + group_ids: List[str], + parameters: Optional[Dict[str, Any]] = None, + **data_source_kwargs: Any, + ): + super().__init__() + self.time_idx = time_idx + self.target = target + self.group_ids = group_ids + self.data_source_kwargs = data_source_kwargs + + self.set_state(TimeSeriesDataSetParametersState(parameters)) + + def load_data(self, data: DataFrame, dataset: Optional[Any] = None): + if self.training: + time_series_dataset = TimeSeriesDataSet( + data, time_idx=self.time_idx, group_ids=self.group_ids, target=self.target, **self.data_source_kwargs + ) + parameters = time_series_dataset.get_parameters() + + # Add some sample data so that we can recreate the `TimeSeriesDataSet` later on + parameters["data_sample"] = data.iloc[[0]] + + self.set_state(TimeSeriesDataSetParametersState(parameters)) + dataset.parameters = parameters + else: + parameters = copy(self.get_state(TimeSeriesDataSetParametersState).time_series_dataset_parameters) + if parameters is None: + raise MisconfigurationException( + "Loading data for evaluation or inference requires parameters from the train data. Either " + "construct the train data at the same time as evaluation and inference or provide the train " + "`datamodule.parameters` to `from_data_frame` in the `parameters` argument." + ) + parameters.pop("data_sample") + time_series_dataset = TimeSeriesDataSet.from_parameters( + parameters, + data, + predict=True, + stop_randomization=True, + ) + dataset.time_series_dataset = time_series_dataset + return time_series_dataset + + def load_sample(self, sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any: + return {DefaultDataKeys.INPUT: sample[0], DefaultDataKeys.TARGET: sample[1]} + + +class TabularForecastingPreprocess(Preprocess): + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + deserializer: Optional[Deserializer] = None, + **data_source_kwargs: Any, + ): + self.data_source_kwargs = data_source_kwargs + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.DATAFRAME: TabularForecastingDataFrameDataSource(**data_source_kwargs), + }, + deserializer=deserializer, + default_data_source=DefaultDataSources.DATAFRAME, + ) + + def get_state_dict(self, strict: bool = False) -> Dict[str, Any]: + return {**self.transforms, **self.data_source_kwargs} + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = True) -> "Preprocess": + return cls(**state_dict) + + +class TabularForecastingData(DataModule): + """Data module for tabular tasks.""" + + preprocess_cls = TabularForecastingPreprocess + + @property + def parameters(self) -> Optional[Dict[str, Any]]: + return getattr(self.train_dataset, "parameters", None) + + @classmethod + def from_data_frame( + cls, + time_idx: Optional[str] = None, + target: Optional[Union[str, List[str]]] = None, + group_ids: Optional[List[str]] = None, + parameters: Optional[Dict[str, Any]] = None, + train_data_frame: Optional[DataFrame] = None, + val_data_frame: Optional[DataFrame] = None, + test_data_frame: Optional[DataFrame] = None, + predict_data_frame: Optional[DataFrame] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ): + """Creates a :class:`~flash.tabular.data.TabularClassificationData` object from the given data frames. + + Args: + group_ids: + target: + time_idx: + train_data_frame: The pandas ``DataFrame`` containing the training data. + val_data_frame: The pandas ``DataFrame`` containing the validation data. + test_data_frame: The pandas ``DataFrame`` containing the testing data. + predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = TabularClassificationData.from_data_frame( + "categorical_input", + "numerical_input", + "target", + train_data_frame=train_data, + ) + """ + + return cls.from_data_source( + time_idx=time_idx, + target=target, + group_ids=group_ids, + parameters=parameters, + data_source=DefaultDataSources.DATAFRAME, + train_data=train_data_frame, + val_data=val_data_frame, + test_data=test_data_frame, + predict_data=predict_data_frame, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) diff --git a/flash/tabular/forecasting/model.py b/flash/tabular/forecasting/model.py new file mode 100644 index 0000000000..eb46ef820d --- /dev/null +++ b/flash/tabular/forecasting/model.py @@ -0,0 +1,74 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, List, Optional, Union + +import torchmetrics +from pytorch_lightning import LightningModule + +from flash.core.adapter import AdapterTask +from flash.core.integrations.pytorch_forecasting.adapter import PyTorchForecastingAdapter +from flash.core.integrations.pytorch_forecasting.backbones import PYTORCH_FORECASTING_BACKBONES +from flash.core.registry import FlashRegistry +from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE + + +class TabularForecaster(AdapterTask): + backbones: FlashRegistry = FlashRegistry("backbones") + PYTORCH_FORECASTING_BACKBONES + + def __init__( + self, + parameters: Dict[str, Any], + backbone: str, + loss_fn: Optional[Callable] = None, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, + metrics: Union[torchmetrics.Metric, List[torchmetrics.Metric]] = None, + learning_rate: float = 4e-3, + **backbone_kwargs + ): + + self.save_hyperparameters() + + metadata = self.backbones.get(backbone, with_metadata=True) + adapter = metadata["metadata"]["adapter"].from_task( + self, + parameters=parameters, + backbone=backbone, + loss_fn=loss_fn, + metrics=metrics, + **backbone_kwargs, + ) + + super().__init__( + adapter, + learning_rate=learning_rate, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + ) + + @property + def pytorch_forecasting_model(self) -> LightningModule: + """This property provides access to the ``LightningModule`` object that is wrapped by Flash for backbones + provided by PyTorch Forecasting. + + This can be used with + :func:`~flash.core.integrations.pytorch_forecasting.transforms.convert_predictions` to access the visualization + features built in to PyTorch Forecasting. + """ + if not isinstance(self.adapter, PyTorchForecastingAdapter): + raise AttributeError( + "The `pytorch_forecasting_model` attribute can only be accessed for backbones provided by PyTorch " + "Forecasting." + ) + return self.adapter.backbone diff --git a/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py b/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py new file mode 100644 index 0000000000..ec62cb2643 --- /dev/null +++ b/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py @@ -0,0 +1,74 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import flash +from flash.core.integrations.pytorch_forecasting import convert_predictions +from flash.core.utilities.imports import example_requires +from flash.tabular.forecasting import TabularForecaster, TabularForecastingData + +example_requires(["tabular", "matplotlib"]) + +import matplotlib.pyplot as plt # noqa: E402 +import pandas as pd # noqa: E402 +from pytorch_forecasting.data import NaNLabelEncoder # noqa: E402 +from pytorch_forecasting.data.examples import generate_ar_data # noqa: E402 + +# Example based on this tutorial: https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/ar.html +# 1. Create the DataModule +data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100, seed=42) +data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") + +max_prediction_length = 20 + +training_cutoff = data["time_idx"].max() - max_prediction_length + +datamodule = TabularForecastingData.from_data_frame( + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + # only unknown variable is "value" - and N-Beats can also not take any additional variables + time_varying_unknown_reals=["value"], + max_encoder_length=60, + max_prediction_length=max_prediction_length, + train_data_frame=data[lambda x: x.time_idx <= training_cutoff], + val_data_frame=data, + batch_size=32, +) + +# 2. Build the task +model = TabularForecaster(datamodule.parameters, backbone="n_beats", widths=[32, 512], backcast_loss_ratio=0.1) + +# 3. Create the trainer and train the model +trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count(), gradient_clip_val=0.01) +trainer.fit(model, datamodule=datamodule) + +# 4. Generate predictions +predictions = model.predict(data) +print(predictions) + +# Convert predictions +predictions, inputs = convert_predictions(predictions) + +# Plot predictions +for idx in range(10): # plot 10 examples + model.pytorch_forecasting_model.plot_prediction(inputs, predictions, idx=idx, add_loss_to_title=True) + +# Plot interpretation +for idx in range(10): # plot 10 examples + model.pytorch_forecasting_model.plot_interpretation(inputs, predictions, idx=idx) + +# Show the plots! +plt.show() diff --git a/flash_examples/tabular_forecasting.py b/flash_examples/tabular_forecasting.py new file mode 100644 index 0000000000..836f01fe64 --- /dev/null +++ b/flash_examples/tabular_forecasting.py @@ -0,0 +1,61 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import flash +from flash.core.utilities.imports import example_requires +from flash.tabular.forecasting import TabularForecaster, TabularForecastingData + +example_requires("tabular") + +import pandas as pd # noqa: E402 +from pytorch_forecasting.data import NaNLabelEncoder # noqa: E402 +from pytorch_forecasting.data.examples import generate_ar_data # noqa: E402 + +# Example based on this tutorial: https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/ar.html +# 1. Create the DataModule +data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100, seed=42) +data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") + +max_prediction_length = 20 + +training_cutoff = data["time_idx"].max() - max_prediction_length + +datamodule = TabularForecastingData.from_data_frame( + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + # only unknown variable is "value" - and N-Beats can also not take any additional variables + time_varying_unknown_reals=["value"], + max_encoder_length=60, + max_prediction_length=max_prediction_length, + train_data_frame=data[lambda x: x.time_idx <= training_cutoff], + val_data_frame=data, + batch_size=32, +) + +# 2. Build the task +model = TabularForecaster(datamodule.parameters, backbone="n_beats", widths=[32, 512], backcast_loss_ratio=0.1) + +# 3. Create the trainer and train the model +trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count(), gradient_clip_val=0.01) +trainer.fit(model, datamodule=datamodule) + +# 4. Generate predictions +predictions = model.predict(data) +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("tabular_forecasting_model.pt") diff --git a/requirements/datatype_tabular.txt b/requirements/datatype_tabular.txt index 521a5930a3..bbd5720096 100644 --- a/requirements/datatype_tabular.txt +++ b/requirements/datatype_tabular.txt @@ -1,2 +1,3 @@ pytorch-tabnet==3.1 scikit-learn +pytorch-forecasting diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py index b7565903dc..d5853e19a9 100644 --- a/tests/core/data/test_data_pipeline.py +++ b/tests/core/data/test_data_pipeline.py @@ -59,18 +59,9 @@ def test_str(): state.set_state(ProcessState()) assert str(state) == ( - "DataPipelineState(initialized=False, " - "state={<class 'flash.core.data.properties.ProcessState'>: ProcessState()})" + "DataPipelineState(state={<class 'flash.core.data.properties.ProcessState'>: ProcessState()})" ) - @staticmethod - def test_warning(): - state = DataPipelineState() - state._initialized = True - - with pytest.warns(UserWarning, match="data pipeline has already been initialized"): - state.set_state(ProcessState()) - @staticmethod def test_get_state(): state = DataPipelineState() diff --git a/tests/tabular/classification/test_data.py b/tests/tabular/classification/test_data.py index b1e9ef3f25..4f1fc632d8 100644 --- a/tests/tabular/classification/test_data.py +++ b/tests/tabular/classification/test_data.py @@ -89,30 +89,6 @@ def test_embedding_sizes(): assert es == [(100_000, 17), (1_000_000, 31)] -@pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required") -def test_tabular_data(tmpdir): - train_data_frame = TEST_DF_1.copy() - val_data_frame = TEST_DF_2.copy() - test_data_frame = TEST_DF_2.copy() - dm = TabularClassificationData.from_data_frame( - categorical_fields=["category"], - numerical_fields=["scalar_a", "scalar_b"], - target_fields="label", - train_data_frame=train_data_frame, - val_data_frame=val_data_frame, - test_data_frame=test_data_frame, - num_workers=0, - batch_size=1, - ) - for dl in [dm.train_dataloader(), dm.val_dataloader(), dm.test_dataloader()]: - data = next(iter(dl)) - (cat, num) = data[DefaultDataKeys.INPUT] - target = data[DefaultDataKeys.TARGET] - assert cat.shape == (1, 1) - assert num.shape == (1, 2) - assert target.shape == (1,) - - @pytest.mark.skipif(not _PANDAS_AVAILABLE, reason="pandas is required") def test_categorical_target(tmpdir): train_data_frame = TEST_DF_1.copy() diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index e7ee5e9f5d..2efe7c316e 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -21,7 +21,8 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _TABULAR_AVAILABLE -from flash.tabular import TabularClassificationData, TabularClassifier +from flash.tabular.classification.data import TabularClassificationData +from flash.tabular.classification.model import TabularClassifier from tests.helpers.utils import _SERVE_TESTING, _TABULAR_TESTING # ======== Mock functions ======== diff --git a/tests/tabular/forecasting/__init__.py b/tests/tabular/forecasting/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tabular/forecasting/test_data.py b/tests/tabular/forecasting/test_data.py new file mode 100644 index 0000000000..1c10291d5c --- /dev/null +++ b/tests/tabular/forecasting/test_data.py @@ -0,0 +1,94 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import MagicMock, patch + +import pytest +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from flash.tabular.forecasting import TabularForecastingData +from tests.helpers.utils import _TABULAR_TESTING + + +@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.") +@patch("flash.tabular.forecasting.data.TimeSeriesDataSet") +def test_from_data_frame_time_series_data_set_single_call(patch_time_series_data_set): + """Tests that ``TabularForecastingData.from_data_frame`` calls ``TimeSeriesDataSet`` with the expected + parameters when called once with data for all stages.""" + patch_time_series_data_set.return_value.get_parameters.return_value = {"test": None} + + train_data = MagicMock() + val_data = MagicMock() + + TabularForecastingData.from_data_frame( + "time_idx", + "target", + ["series"], + train_data_frame=train_data, + val_data_frame=val_data, + additional_kwarg="test", + ) + + patch_time_series_data_set.assert_called_once_with( + train_data, time_idx="time_idx", group_ids=["series"], target="target", additional_kwarg="test" + ) + + patch_time_series_data_set.from_parameters.assert_called_once_with( + {"test": None}, val_data, predict=True, stop_randomization=True + ) + + +@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.") +@patch("flash.tabular.forecasting.data.TimeSeriesDataSet") +def test_from_data_frame_time_series_data_set_multi_call(patch_time_series_data_set): + """Tests that ``TabularForecastingData.from_data_frame`` calls ``TimeSeriesDataSet`` with the expected + parameters when called separately for each stage.""" + patch_time_series_data_set.return_value.get_parameters.return_value = {"test": None} + + train_data = MagicMock() + val_data = MagicMock() + + train_datamodule = TabularForecastingData.from_data_frame( + "time_idx", + "target", + ["series"], + train_data_frame=train_data, + additional_kwarg="test", + ) + + TabularForecastingData.from_data_frame( + val_data_frame=val_data, + parameters=train_datamodule.parameters, + ) + + patch_time_series_data_set.assert_called_once_with( + train_data, time_idx="time_idx", group_ids=["series"], target="target", additional_kwarg="test" + ) + + patch_time_series_data_set.from_parameters.assert_called_once_with( + {"test": None}, val_data, predict=True, stop_randomization=True + ) + + +@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.") +def test_from_data_frame_misconfiguration(): + """Tests that a ``MisconfigurationException`` is raised when ``TabularForecastingData`` is constructed without + parameters.""" + with pytest.raises(MisconfigurationException, match="evaluation or inference requires parameters"): + TabularForecastingData.from_data_frame( + "time_idx", + "target", + ["series"], + val_data_frame=MagicMock(), + additional_kwarg="test", + ) diff --git a/tests/tabular/forecasting/test_model.py b/tests/tabular/forecasting/test_model.py new file mode 100644 index 0000000000..bebf8477bc --- /dev/null +++ b/tests/tabular/forecasting/test_model.py @@ -0,0 +1,80 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +import flash +from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_AVAILABLE +from flash.tabular.forecasting import TabularForecaster, TabularForecastingData +from tests.helpers.utils import _TABULAR_TESTING + +if _TABULAR_AVAILABLE: + from pytorch_forecasting.data import NaNLabelEncoder + from pytorch_forecasting.data.examples import generate_ar_data + +if _PANDAS_AVAILABLE: + import pandas as pd + + +@pytest.fixture +def sample_data(): + data = generate_ar_data(seasonality=10.0, timesteps=100, n_series=2, seed=42) + data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") + max_prediction_length = 20 + training_cutoff = data["time_idx"].max() - max_prediction_length + return data, training_cutoff, max_prediction_length + + +@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.") +def test_fast_dev_run_smoke(sample_data): + """Test that fast dev run works with the NBeats example data.""" + data, training_cutoff, max_prediction_length = sample_data + datamodule = TabularForecastingData.from_data_frame( + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + time_varying_unknown_reals=["value"], + max_encoder_length=60, + max_prediction_length=max_prediction_length, + train_data_frame=data[lambda x: x.time_idx <= training_cutoff], + val_data_frame=data, + ) + + model = TabularForecaster(datamodule.parameters, backbone="n_beats", widths=[32, 512], backcast_loss_ratio=0.1) + + trainer = flash.Trainer(max_epochs=1, fast_dev_run=True, gradient_clip_val=0.01) + trainer.fit(model, datamodule=datamodule) + + +@pytest.mark.skipif(not _TABULAR_TESTING, reason="Tabular libraries aren't installed.") +def test_testing_raises(sample_data): + """Tests that ``NotImplementedError`` is raised when attempting to perform a test pass.""" + data, training_cutoff, max_prediction_length = sample_data + datamodule = TabularForecastingData.from_data_frame( + time_idx="time_idx", + target="value", + categorical_encoders={"series": NaNLabelEncoder().fit(data.series)}, + group_ids=["series"], + time_varying_unknown_reals=["value"], + max_encoder_length=60, + max_prediction_length=max_prediction_length, + train_data_frame=data[lambda x: x.time_idx <= training_cutoff], + test_data_frame=data, + ) + + model = TabularForecaster(datamodule.parameters, backbone="n_beats", widths=[32, 512], backcast_loss_ratio=0.1) + trainer = flash.Trainer(max_epochs=1, fast_dev_run=True, gradient_clip_val=0.01) + + with pytest.raises(NotImplementedError, match="Backbones provided by PyTorch Forecasting don't support testing."): + trainer.test(model, datamodule)