Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Tabular regression task and example #892

Merged
merged 19 commits into from
Nov 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a `TabularForecaster` task based on PyTorch Forecasting ([#647](https://github.com/PyTorchLightning/lightning-flash/pull/647))

- Added a `TabularRegressor` task ([#892](https://github.com/PyTorchLightning/lightning-flash/pull/892))

### Changed

### Fixed
Expand Down
1 change: 1 addition & 0 deletions docs/source/api/tabular.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ __________
:nosignatures:
:template: classtemplate.rst

~regression.model.TabularRegressor
~regression.data.TabularRegressionData

Forecasting
Expand Down
8 changes: 4 additions & 4 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.


class ClassificationMixin:
@staticmethod
def _build(
self,
num_classes: Optional[int] = None,
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
**kwargs,
) -> None:

metrics, loss_fn = ClassificationMixin._build(self, num_classes, loss_fn, metrics, multi_label)
metrics, loss_fn = ClassificationMixin._build(num_classes, loss_fn, metrics, multi_label)

super().__init__(
*args,
Expand All @@ -95,7 +95,7 @@ def __init__(
**kwargs,
) -> None:

metrics, loss_fn = ClassificationMixin._build(self, num_classes, loss_fn, metrics, multi_label)
metrics, loss_fn = ClassificationMixin._build(num_classes, loss_fn, metrics, multi_label)

super().__init__(
*args,
Expand Down Expand Up @@ -257,7 +257,7 @@ def __init__(
def serialize(
self,
sample: Any,
) -> Union[Classification, Classifications, Dict[str, Any], Dict[str, Any]]:
) -> Union[Classification, Classifications, Dict[str, Any]]:
pred = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
pred = torch.tensor(pred)

Expand Down
6 changes: 5 additions & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from flash.core.registry import FlashRegistry
from flash.core.serve.composition import Composition
from flash.core.utilities.apply_func import get_callable_dict
from flash.core.utilities.imports import requires
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, requires
from flash.core.utilities.providers import _HUGGINGFACE
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import (
Expand Down Expand Up @@ -399,6 +399,10 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
for name, metric in metrics.items():
if isinstance(metric, torchmetrics.metric.Metric):
metric(y_hat, y)
# PL 1.4.0 -> 1.4.9 tries to deepcopy the metric.
# Sometimes _forward_cache is not a leaf, so we convert it to one.
if not metric._forward_cache.is_leaf and not _PL_GREATER_EQUAL_1_5_0:
metric._forward_cache = metric._forward_cache.clone().detach()
logs[name] = metric # log the metric itself if it is of type Metric
else:
logs[name] = metric(y_hat, y)
Expand Down
57 changes: 57 additions & 0 deletions flash/core/regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Mapping, Optional, Sequence, Union

import torch
import torch.nn.functional as F
import torchmetrics

from flash.core.data.process import Serializer
from flash.core.model import Task


class RegressionMixin:
@staticmethod
def _build(
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
):
metrics = metrics or torchmetrics.MeanSquaredError()
loss_fn = loss_fn or F.mse_loss

return metrics, loss_fn

def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
return x


class RegressionTask(Task, RegressionMixin):
def __init__(
self,
*args,
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
**kwargs,
) -> None:

metrics, loss_fn = RegressionMixin._build(loss_fn, metrics)

super().__init__(
*args,
loss_fn=loss_fn,
metrics=metrics,
serializer=serializer,
**kwargs,
)
2 changes: 1 addition & 1 deletion flash/tabular/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
TabularForecastingDataFrameDataSource,
TabularForecastingPreprocess,
)
from flash.tabular.regression import TabularRegressionData # noqa: F401
from flash.tabular.regression import TabularRegressionData, TabularRegressor # noqa: F401
7 changes: 2 additions & 5 deletions flash/tabular/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TabularClassifier(ClassificationTask):
package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict
containing a combination of the aforementioned. In all cases, each metric needs to have the signature
`metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training, defaults to `1e-3`
learning_rate: Learning rate to use for training.
multi_label: Whether the targets are multi-label or not.
serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
**tabnet_kwargs: Optional additional arguments for the TabNet model, see
Expand Down Expand Up @@ -90,10 +90,7 @@ def __init__(

def forward(self, x_in) -> torch.Tensor:
# TabNet takes single input, x_in is composed of (categorical, numerical)
xs = []
for x in x_in:
if x.numel():
xs.append(x)
xs = [x for x in x_in if x.numel()]
x = torch.cat(xs, dim=1)
return self.model(x)[0]

Expand Down
1 change: 1 addition & 0 deletions flash/tabular/regression/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from flash.tabular.regression.data import TabularRegressionData # noqa: F401
from flash.tabular.regression.model import TabularRegressor # noqa: F401
111 changes: 111 additions & 0 deletions flash/tabular/regression/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Tuple

import torch
from torch.nn import functional as F

from flash.core.data.data_source import DefaultDataKeys
from flash.core.regression import RegressionTask
from flash.core.utilities.imports import _TABULAR_AVAILABLE
from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE

if _TABULAR_AVAILABLE:
from pytorch_tabnet.tab_network import TabNet


class TabularRegressor(RegressionTask):
"""The ``TabularRegressor`` is a :class:`~flash.Task` for regression tabular data.

Args:
num_features: Number of columns in table (not including target column).
embedding_sizes: List of (num_classes, emb_dim) to form categorical embeddings.
loss_fn: Loss function for training, defaults to cross entropy.
optimizer: Optimizer to use for training.
lr_scheduler: The LR scheduler to use during training.
metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics`
package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict
containing a combination of the aforementioned. In all cases, each metric needs to have the signature
`metric(preds,target)` and return a single scalar tensor.
learning_rate: Learning rate to use for training.
multi_label: Whether the targets are multi-label or not.
serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
**tabnet_kwargs: Optional additional arguments for the TabNet model, see
`pytorch_tabnet <https://dreamquark-ai.github.io/tabnet/_modules/pytorch_tabnet/tab_network.html#TabNet>`_.
"""

required_extras: str = "tabular"

def __init__(
self,
num_features: int,
embedding_sizes: List[Tuple[int, int]] = None,
loss_fn: Callable = F.mse_loss,
optimizer: OPTIMIZER_TYPE = "Adam",
lr_scheduler: LR_SCHEDULER_TYPE = None,
metrics: METRICS_TYPE = None,
learning_rate: float = 1e-2,
serializer: SERIALIZER_TYPE = None,
**tabnet_kwargs,
):
self.save_hyperparameters()

cat_dims, cat_emb_dim = zip(*embedding_sizes) if embedding_sizes else ([], [])
model = TabNet(
Borda marked this conversation as resolved.
Show resolved Hide resolved
input_dim=num_features,
output_dim=1,
cat_idxs=list(range(len(embedding_sizes))),
cat_dims=list(cat_dims),
cat_emb_dim=list(cat_emb_dim),
**tabnet_kwargs,
)

super().__init__(
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
metrics=metrics,
learning_rate=learning_rate,
serializer=serializer,
)

self.save_hyperparameters()

def forward(self, x_in) -> torch.Tensor:
# TabNet takes single input, x_in is composed of (categorical, numerical)
xs = [x for x in x_in if x.numel()]
x = torch.cat(xs, dim=1)
return self.model(x)[0].flatten()

def training_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
return super().training_step(batch, batch_idx)

def validation_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
return super().validation_step(batch, batch_idx)

def test_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
return super().test_step(batch, batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
batch = batch[DefaultDataKeys.INPUT]
return self(batch)

@classmethod
def from_data(cls, datamodule, **kwargs) -> "TabularRegressor":
model = cls(datamodule.num_features, datamodule.embedding_sizes, **kwargs)
return model
54 changes: 54 additions & 0 deletions flash_examples/tabular_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.core.data.utils import download_data
from flash.tabular import TabularRegressionData, TabularRegressor

# 1. Create the DataModule
download_data("https://archive.ics.uci.edu/ml/machine-learning-databases/00560/SeoulBikeData.csv", "./data")

datamodule = TabularRegressionData.from_csv(
categorical_fields=["Seasons", "Holiday", "Functioning Day"],
numerical_fields=[
"Hour",
"Temperature(�C)",
"Humidity(%)",
"Wind speed (m/s)",
"Wind speed (m/s)",
"Visibility (10m)",
"Dew point temperature(�C)",
"Solar Radiation (MJ/m2)",
"Rainfall(mm)",
"Snowfall (cm)",
],
target_fields="Rented Bike Count",
train_file="data/SeoulBikeData.csv",
val_split=0.1,
)

# 2. Build the task
model = TabularRegressor.from_data(datamodule, learning_rate=0.1)

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Generate predictions from a CSV
predictions = model.predict("data/SeoulBikeData.csv")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("tabular_regression_model.pt")
4 changes: 4 additions & 0 deletions tests/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@
"tabular_classification.py",
marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed"),
),
pytest.param(
"tabular_regression.py",
marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed"),
),
pytest.param("template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")),
pytest.param(
"text_classification.py",
Expand Down