Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Tabular regression task and example (#892)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
4 people authored Nov 5, 2021
1 parent d0adc61 commit ba38014
Show file tree
Hide file tree
Showing 11 changed files with 242 additions and 11 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a `TabularForecaster` task based on PyTorch Forecasting ([#647](https://github.com/PyTorchLightning/lightning-flash/pull/647))

- Added a `TabularRegressor` task ([#892](https://github.com/PyTorchLightning/lightning-flash/pull/892))

### Changed

### Fixed
Expand Down
1 change: 1 addition & 0 deletions docs/source/api/tabular.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ __________
:nosignatures:
:template: classtemplate.rst

~regression.model.TabularRegressor
~regression.data.TabularRegressionData

Forecasting
Expand Down
8 changes: 4 additions & 4 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.


class ClassificationMixin:
@staticmethod
def _build(
self,
num_classes: Optional[int] = None,
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
**kwargs,
) -> None:

metrics, loss_fn = ClassificationMixin._build(self, num_classes, loss_fn, metrics, multi_label)
metrics, loss_fn = ClassificationMixin._build(num_classes, loss_fn, metrics, multi_label)

super().__init__(
*args,
Expand All @@ -95,7 +95,7 @@ def __init__(
**kwargs,
) -> None:

metrics, loss_fn = ClassificationMixin._build(self, num_classes, loss_fn, metrics, multi_label)
metrics, loss_fn = ClassificationMixin._build(num_classes, loss_fn, metrics, multi_label)

super().__init__(
*args,
Expand Down Expand Up @@ -257,7 +257,7 @@ def __init__(
def serialize(
self,
sample: Any,
) -> Union[Classification, Classifications, Dict[str, Any], Dict[str, Any]]:
) -> Union[Classification, Classifications, Dict[str, Any]]:
pred = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
pred = torch.tensor(pred)

Expand Down
6 changes: 5 additions & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from flash.core.registry import FlashRegistry
from flash.core.serve.composition import Composition
from flash.core.utilities.apply_func import get_callable_dict
from flash.core.utilities.imports import requires
from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, requires
from flash.core.utilities.providers import _HUGGINGFACE
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import (
Expand Down Expand Up @@ -399,6 +399,10 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
for name, metric in metrics.items():
if isinstance(metric, torchmetrics.metric.Metric):
metric(y_hat, y)
# PL 1.4.0 -> 1.4.9 tries to deepcopy the metric.
# Sometimes _forward_cache is not a leaf, so we convert it to one.
if not metric._forward_cache.is_leaf and not _PL_GREATER_EQUAL_1_5_0:
metric._forward_cache = metric._forward_cache.clone().detach()
logs[name] = metric # log the metric itself if it is of type Metric
else:
logs[name] = metric(y_hat, y)
Expand Down
57 changes: 57 additions & 0 deletions flash/core/regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Mapping, Optional, Sequence, Union

import torch
import torch.nn.functional as F
import torchmetrics

from flash.core.data.process import Serializer
from flash.core.model import Task


class RegressionMixin:
@staticmethod
def _build(
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
):
metrics = metrics or torchmetrics.MeanSquaredError()
loss_fn = loss_fn or F.mse_loss

return metrics, loss_fn

def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
return x


class RegressionTask(Task, RegressionMixin):
def __init__(
self,
*args,
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
**kwargs,
) -> None:

metrics, loss_fn = RegressionMixin._build(loss_fn, metrics)

super().__init__(
*args,
loss_fn=loss_fn,
metrics=metrics,
serializer=serializer,
**kwargs,
)
2 changes: 1 addition & 1 deletion flash/tabular/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
TabularForecastingDataFrameDataSource,
TabularForecastingPreprocess,
)
from flash.tabular.regression import TabularRegressionData # noqa: F401
from flash.tabular.regression import TabularRegressionData, TabularRegressor # noqa: F401
7 changes: 2 additions & 5 deletions flash/tabular/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TabularClassifier(ClassificationTask):
package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict
containing a combination of the aforementioned. In all cases, each metric needs to have the signature
`metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training, defaults to `1e-3`
learning_rate: Learning rate to use for training.
multi_label: Whether the targets are multi-label or not.
serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
**tabnet_kwargs: Optional additional arguments for the TabNet model, see
Expand Down Expand Up @@ -90,10 +90,7 @@ def __init__(

def forward(self, x_in) -> torch.Tensor:
# TabNet takes single input, x_in is composed of (categorical, numerical)
xs = []
for x in x_in:
if x.numel():
xs.append(x)
xs = [x for x in x_in if x.numel()]
x = torch.cat(xs, dim=1)
return self.model(x)[0]

Expand Down
1 change: 1 addition & 0 deletions flash/tabular/regression/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from flash.tabular.regression.data import TabularRegressionData # noqa: F401
from flash.tabular.regression.model import TabularRegressor # noqa: F401
111 changes: 111 additions & 0 deletions flash/tabular/regression/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Tuple

import torch
from torch.nn import functional as F

from flash.core.data.data_source import DefaultDataKeys
from flash.core.regression import RegressionTask
from flash.core.utilities.imports import _TABULAR_AVAILABLE
from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE

if _TABULAR_AVAILABLE:
from pytorch_tabnet.tab_network import TabNet


class TabularRegressor(RegressionTask):
"""The ``TabularRegressor`` is a :class:`~flash.Task` for regression tabular data.
Args:
num_features: Number of columns in table (not including target column).
embedding_sizes: List of (num_classes, emb_dim) to form categorical embeddings.
loss_fn: Loss function for training, defaults to cross entropy.
optimizer: Optimizer to use for training.
lr_scheduler: The LR scheduler to use during training.
metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics`
package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict
containing a combination of the aforementioned. In all cases, each metric needs to have the signature
`metric(preds,target)` and return a single scalar tensor.
learning_rate: Learning rate to use for training.
multi_label: Whether the targets are multi-label or not.
serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
**tabnet_kwargs: Optional additional arguments for the TabNet model, see
`pytorch_tabnet <https://dreamquark-ai.github.io/tabnet/_modules/pytorch_tabnet/tab_network.html#TabNet>`_.
"""

required_extras: str = "tabular"

def __init__(
self,
num_features: int,
embedding_sizes: List[Tuple[int, int]] = None,
loss_fn: Callable = F.mse_loss,
optimizer: OPTIMIZER_TYPE = "Adam",
lr_scheduler: LR_SCHEDULER_TYPE = None,
metrics: METRICS_TYPE = None,
learning_rate: float = 1e-2,
serializer: SERIALIZER_TYPE = None,
**tabnet_kwargs,
):
self.save_hyperparameters()

cat_dims, cat_emb_dim = zip(*embedding_sizes) if embedding_sizes else ([], [])
model = TabNet(
input_dim=num_features,
output_dim=1,
cat_idxs=list(range(len(embedding_sizes))),
cat_dims=list(cat_dims),
cat_emb_dim=list(cat_emb_dim),
**tabnet_kwargs,
)

super().__init__(
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
metrics=metrics,
learning_rate=learning_rate,
serializer=serializer,
)

self.save_hyperparameters()

def forward(self, x_in) -> torch.Tensor:
# TabNet takes single input, x_in is composed of (categorical, numerical)
xs = [x for x in x_in if x.numel()]
x = torch.cat(xs, dim=1)
return self.model(x)[0].flatten()

def training_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
return super().training_step(batch, batch_idx)

def validation_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
return super().validation_step(batch, batch_idx)

def test_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
return super().test_step(batch, batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
batch = batch[DefaultDataKeys.INPUT]
return self(batch)

@classmethod
def from_data(cls, datamodule, **kwargs) -> "TabularRegressor":
model = cls(datamodule.num_features, datamodule.embedding_sizes, **kwargs)
return model
54 changes: 54 additions & 0 deletions flash_examples/tabular_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.core.data.utils import download_data
from flash.tabular import TabularRegressionData, TabularRegressor

# 1. Create the DataModule
download_data("https://archive.ics.uci.edu/ml/machine-learning-databases/00560/SeoulBikeData.csv", "./data")

datamodule = TabularRegressionData.from_csv(
categorical_fields=["Seasons", "Holiday", "Functioning Day"],
numerical_fields=[
"Hour",
"Temperature(�C)",
"Humidity(%)",
"Wind speed (m/s)",
"Wind speed (m/s)",
"Visibility (10m)",
"Dew point temperature(�C)",
"Solar Radiation (MJ/m2)",
"Rainfall(mm)",
"Snowfall (cm)",
],
target_fields="Rented Bike Count",
train_file="data/SeoulBikeData.csv",
val_split=0.1,
)

# 2. Build the task
model = TabularRegressor.from_data(datamodule, learning_rate=0.1)

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Generate predictions from a CSV
predictions = model.predict("data/SeoulBikeData.csv")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("tabular_regression_model.pt")
4 changes: 4 additions & 0 deletions tests/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@
"tabular_classification.py",
marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed"),
),
pytest.param(
"tabular_regression.py",
marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed"),
),
pytest.param("template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")),
pytest.param(
"text_classification.py",
Expand Down

0 comments on commit ba38014

Please sign in to comment.