Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add TabularForecaster task based on PyTorch Forecasting (#647)
Browse files Browse the repository at this point in the history
* Revert "Added TabularRegressionData extending TabularData (#574)"

This reverts commit c318e4a

* added DataModule, PreProcess, DataSource for TabularForecasting

* added TABULAR_FORECASTING_BACKBONES

* [WIP] added model.py in tabular forecasting

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Updates

* Updates

* Try fix

* Updates

* Rename to TabularClassificationData

* Updates

* Fix embedding sizes

* Fixes and add example

* Updates

* Switch to an adapter

* Small fixes

* Add inference error

* Add inference and refactor

* Add interpertation example

* Fix broken tests

* Small fixes and add some tests

* Updates

* Update CHANGELOG.md

* Add provider

* Update flash/core/integrations/pytorch_forecasting/adapter.py

Co-authored-by: Jirka Borovec <[email protected]>

* Update flash/core/integrations/pytorch_forecasting/adapter.py

Co-authored-by: Jirka Borovec <[email protected]>

* Update on comments

Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
5 people authored Oct 29, 2021
1 parent ddd4c94 commit 9927853
Show file tree
Hide file tree
Showing 22 changed files with 823 additions and 53 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added a `TabularForecaster` task based on PyTorch Forecasting ([#647](https://github.com/PyTorchLightning/lightning-flash/pull/647))

### Changed

### Fixed
Expand Down
19 changes: 3 additions & 16 deletions flash/core/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import torch
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from torch.utils.data import DataLoader, IterableDataset
Expand All @@ -42,29 +41,19 @@ class DataPipelineState:

def __init__(self):
self._state: Dict[Type[ProcessState], ProcessState] = {}
self._initialized = False

def set_state(self, state: ProcessState):
"""Add the given :class:`.ProcessState` to the :class:`.DataPipelineState`."""

if not self._initialized:
self._state[type(state)] = state
else:
rank_zero_warn(
f"Attempted to add a state ({state}) after the data pipeline has already been initialized. This will"
" only have an effect when a new data pipeline is created.",
UserWarning,
)
self._state[type(state)] = state

def get_state(self, state_type: Type[ProcessState]) -> Optional[ProcessState]:
"""Get the :class:`.ProcessState` of the given type from the :class:`.DataPipelineState`."""

if state_type in self._state:
return self._state[state_type]
return None
return self._state.get(state_type, None)

def __str__(self) -> str:
return f"{self.__class__.__name__}(initialized={self._initialized}, state={self._state})"
return f"{self.__class__.__name__}(state={self._state})"


class DataPipeline:
Expand Down Expand Up @@ -113,13 +102,11 @@ def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) ->
:class:`.Postprocess`, and :class:`.Serializer`. Once this has been called, any attempt to add new state will
give a warning."""
data_pipeline_state = data_pipeline_state or DataPipelineState()
data_pipeline_state._initialized = False
if self.data_source is not None:
self.data_source.attach_data_pipeline_state(data_pipeline_state)
self._preprocess_pipeline.attach_data_pipeline_state(data_pipeline_state)
self._postprocess_pipeline.attach_data_pipeline_state(data_pipeline_state)
self._serializer.attach_data_pipeline_state(data_pipeline_state)
data_pipeline_state._initialized = True # TODO: Not sure we need this
return data_pipeline_state

@property
Expand Down
1 change: 1 addition & 0 deletions flash/core/integrations/pytorch_forecasting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from flash.core.integrations.pytorch_forecasting.transforms import convert_predictions # noqa: F401
119 changes: 119 additions & 0 deletions flash/core/integrations/pytorch_forecasting/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import copy
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torchmetrics

from flash import Task
from flash.core.adapter import Adapter
from flash.core.data.batch import default_uncollate
from flash.core.data.data_source import DefaultDataKeys
from flash.core.data.states import CollateFn
from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _PANDAS_AVAILABLE

if _PANDAS_AVAILABLE:
from pandas import DataFrame
else:
DataFrame = object

if _FORECASTING_AVAILABLE:
from pytorch_forecasting import TimeSeriesDataSet
else:
TimeSeriesDataSet = object


class PatchTimeSeriesDataSet(TimeSeriesDataSet):
"""Hack to prevent index construction or data validation / conversion when instantiating model.
This enables the ``TimeSeriesDataSet`` to be created from a single row of data.
"""

def _construct_index(self, data: DataFrame, predict_mode: bool) -> DataFrame:
return DataFrame()

def _data_to_tensors(self, data: DataFrame) -> Dict[str, torch.Tensor]:
return {}


class PyTorchForecastingAdapter(Adapter):
"""The ``PyTorchForecastingAdapter`` is an :class:`~flash.core.adapter.Adapter` for integrating with PyTorch
Forecasting."""

def __init__(self, backbone):
super().__init__()

self.backbone = backbone

@staticmethod
def _collate_fn(collate_fn, samples):
samples = [(sample[DefaultDataKeys.INPUT], sample[DefaultDataKeys.TARGET]) for sample in samples]
batch = collate_fn(samples)
return {DefaultDataKeys.INPUT: batch[0], DefaultDataKeys.TARGET: batch[1]}

@classmethod
def from_task(
cls,
task: Task,
parameters: Dict[str, Any],
backbone: str,
loss_fn: Optional[Callable] = None,
metrics: Optional[Union[torchmetrics.Metric, List[torchmetrics.Metric]]] = None,
**backbone_kwargs,
) -> Adapter:
parameters = copy(parameters)
# Remove the single row of data from the parameters to reconstruct the `time_series_dataset`
data = parameters.pop("data_sample")
time_series_dataset = PatchTimeSeriesDataSet.from_parameters(parameters, data)

backbone_kwargs["loss"] = loss_fn

if metrics is not None and not isinstance(metrics, list):
metrics = [metrics]
backbone_kwargs["logging_metrics"] = metrics

backbone_kwargs = backbone_kwargs or {}

adapter = cls(task.backbones.get(backbone)(time_series_dataset=time_series_dataset, **backbone_kwargs))

# Attach the required collate function
adapter.set_state(CollateFn(partial(PyTorchForecastingAdapter._collate_fn, time_series_dataset._collate_fn)))

return adapter

def training_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
return self.backbone.training_step(batch, batch_idx)

def validation_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
return self.backbone.validation_step(batch, batch_idx)

def test_step(self, batch: Any, batch_idx: int) -> None:
raise NotImplementedError(
"Backbones provided by PyTorch Forecasting don't support testing. Use validation instead."
)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
result = dict(self.backbone(batch[DefaultDataKeys.INPUT]))
result[DefaultDataKeys.INPUT] = default_uncollate(batch[DefaultDataKeys.INPUT])
return default_uncollate(result)

def training_epoch_end(self, outputs) -> None:
self.backbone.training_epoch_end(outputs)

def validation_epoch_end(self, outputs) -> None:
self.backbone.validation_epoch_end(outputs)
49 changes: 49 additions & 0 deletions flash/core/integrations/pytorch_forecasting/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools

from flash.core.integrations.pytorch_forecasting.adapter import PyTorchForecastingAdapter
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _FORECASTING_AVAILABLE
from flash.core.utilities.providers import _PYTORCH_FORECASTING

if _FORECASTING_AVAILABLE:
from pytorch_forecasting import (
DecoderMLP,
DeepAR,
NBeats,
RecurrentNetwork,
TemporalFusionTransformer,
TimeSeriesDataSet,
)


PYTORCH_FORECASTING_BACKBONES = FlashRegistry("backbones")


if _FORECASTING_AVAILABLE:

def load_torch_forecasting(model, time_series_dataset: TimeSeriesDataSet, **kwargs):
return model.from_dataset(time_series_dataset, **kwargs)

for model, name in zip(
[TemporalFusionTransformer, NBeats, RecurrentNetwork, DeepAR, DecoderMLP],
["temporal_fusion_transformer", "n_beats", "recurrent_network", "deep_ar", "decoder_mlp"],
):
PYTORCH_FORECASTING_BACKBONES(
functools.partial(load_torch_forecasting, model),
name=name,
providers=_PYTORCH_FORECASTING,
adapter=PyTorchForecastingAdapter,
)
30 changes: 30 additions & 0 deletions flash/core/integrations/pytorch_forecasting/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Tuple

from torch.utils.data._utils.collate import default_collate

from flash.core.data.data_source import DefaultDataKeys


def convert_predictions(predictions: List[Dict[str, Any]]) -> Tuple[Dict[str, Any], List]:
# Flatten list if batches were used
if all(isinstance(fl, list) for fl in predictions):
unrolled_predictions = []
for prediction_batch in predictions:
unrolled_predictions.extend(prediction_batch)
predictions = unrolled_predictions
result = default_collate(predictions)
inputs = result.pop(DefaultDataKeys.INPUT)
return result, inputs
3 changes: 2 additions & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def _compare_version(package: str, op, version) -> bool:
_PANDAS_AVAILABLE = _module_available("pandas")
_SKLEARN_AVAILABLE = _module_available("sklearn")
_TABNET_AVAILABLE = _module_available("pytorch_tabnet")
_FORECASTING_AVAILABLE = _module_available("pytorch_forecasting")
_KORNIA_AVAILABLE = _module_available("kornia")
_COCO_AVAILABLE = _module_available("pycocotools")
_TIMM_AVAILABLE = _module_available("timm")
Expand Down Expand Up @@ -126,7 +127,7 @@ class Image:
_DATASETS_AVAILABLE,
]
)
_TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE
_TABULAR_AVAILABLE = _TABNET_AVAILABLE and _PANDAS_AVAILABLE and _FORECASTING_AVAILABLE
_VIDEO_AVAILABLE = _TORCHVISION_AVAILABLE and _PIL_AVAILABLE and _PYTORCHVIDEO_AVAILABLE and _KORNIA_AVAILABLE
_IMAGE_AVAILABLE = all(
[
Expand Down
1 change: 1 addition & 0 deletions flash/core/utilities/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ def __str__(self):
_OPEN3D_ML = Provider("Intelligent Systems Lab Org/Open3D-ML", "https://github.com/isl-org/Open3D-ML")
_PYTORCHVIDEO = Provider("Facebook Research/PyTorchVideo", "https://github.com/facebookresearch/pytorchvideo")
_VISSL = Provider("Facebook Research/vissl", "https://github.com/facebookresearch/vissl")
_PYTORCH_FORECASTING = Provider("jdb78/PyTorch-Forecasting", "https://github.com/jdb78/pytorch-forecasting")
5 changes: 5 additions & 0 deletions flash/tabular/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from flash.tabular.classification import TabularClassificationData, TabularClassifier # noqa: F401
from flash.tabular.data import TabularData # noqa: F401
from flash.tabular.forecasting.data import ( # noqa: F401
TabularForecastingData,
TabularForecastingDataFrameDataSource,
TabularForecastingPreprocess,
)
from flash.tabular.regression import TabularRegressionData # noqa: F401
3 changes: 2 additions & 1 deletion flash/tabular/classification/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from flash.core.data.utils import download_data
from flash.core.utilities.flash_cli import FlashCLI
from flash.tabular import TabularClassificationData, TabularClassifier
from flash.tabular.classification.data import TabularClassificationData
from flash.tabular.classification.model import TabularClassifier

__all__ = ["tabular_classification"]

Expand Down
2 changes: 2 additions & 0 deletions flash/tabular/forecasting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from flash.tabular.forecasting.data import TabularForecastingData # noqa: F401
from flash.tabular.forecasting.model import TabularForecaster # noqa: F401
Loading

0 comments on commit 9927853

Please sign in to comment.