Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Tabular regression task and example #892

Merged
merged 19 commits into from
Nov 5, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.


class ClassificationMixin:
@staticmethod
def _build(
self,
num_classes: Optional[int] = None,
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
**kwargs,
) -> None:

metrics, loss_fn = ClassificationMixin._build(self, num_classes, loss_fn, metrics, multi_label)
metrics, loss_fn = ClassificationMixin._build(num_classes, loss_fn, metrics, multi_label)

super().__init__(
*args,
Expand All @@ -95,7 +95,7 @@ def __init__(
**kwargs,
) -> None:

metrics, loss_fn = ClassificationMixin._build(self, num_classes, loss_fn, metrics, multi_label)
metrics, loss_fn = ClassificationMixin._build(num_classes, loss_fn, metrics, multi_label)

super().__init__(
*args,
Expand Down Expand Up @@ -257,7 +257,7 @@ def __init__(
def serialize(
self,
sample: Any,
) -> Union[Classification, Classifications, Dict[str, Any], Dict[str, Any]]:
) -> Union[Classification, Classifications, Dict[str, Any]]:
pred = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample
pred = torch.tensor(pred)

Expand Down
59 changes: 59 additions & 0 deletions flash/core/regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Mapping, Optional, Sequence, Union

import torch
import torch.nn.functional as F
import torchmetrics

from flash.core.data.process import Serializer
from flash.core.model import Task


class RegressionMixin:
@staticmethod
def _build(
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
):
metrics = metrics or torchmetrics.MeanSquaredError()
loss_fn = loss_fn or F.mse_loss

return metrics, loss_fn

def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
if getattr(self.hparams, "multi_label", False):
return torch.sigmoid(x)
return torch.softmax(x, dim=1)
Borda marked this conversation as resolved.
Show resolved Hide resolved


class RegressionTask(Task, RegressionMixin):
def __init__(
self,
*args,
loss_fn: Optional[Callable] = None,
metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
**kwargs,
) -> None:

metrics, loss_fn = RegressionMixin._build(loss_fn, metrics)

super().__init__(
*args,
loss_fn=loss_fn,
metrics=metrics,
serializer=serializer,
**kwargs,
)
2 changes: 1 addition & 1 deletion flash/tabular/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
TabularForecastingDataFrameDataSource,
TabularForecastingPreprocess,
)
from flash.tabular.regression import TabularRegressionData # noqa: F401
from flash.tabular.regression import TabularRegressionData, TabularRegressor # noqa: F401
2 changes: 1 addition & 1 deletion flash/tabular/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TabularClassifier(ClassificationTask):
package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict
containing a combination of the aforementioned. In all cases, each metric needs to have the signature
`metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`.
learning_rate: Learning rate to use for training, defaults to `1e-3`
learning_rate: Learning rate to use for training.
multi_label: Whether the targets are multi-label or not.
serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
**tabnet_kwargs: Optional additional arguments for the TabNet model, see
Expand Down
1 change: 1 addition & 0 deletions flash/tabular/regression/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from flash.tabular.regression.data import TabularRegressionData # noqa: F401
from flash.tabular.regression.model import TabularRegressor # noqa: F401
114 changes: 114 additions & 0 deletions flash/tabular/regression/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List, Tuple

import torch
from torch.nn import functional as F

from flash.core.data.data_source import DefaultDataKeys
from flash.core.regression import RegressionTask
from flash.core.utilities.imports import _TABULAR_AVAILABLE
from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE

if _TABULAR_AVAILABLE:
from pytorch_tabnet.tab_network import TabNet


class TabularRegressor(RegressionTask):
"""The ``TabularRegressor`` is a :class:`~flash.Task` for regression tabular data.

Args:
num_features: Number of columns in table (not including target column).
embedding_sizes: List of (num_classes, emb_dim) to form categorical embeddings.
loss_fn: Loss function for training, defaults to cross entropy.
optimizer: Optimizer to use for training.
lr_scheduler: The LR scheduler to use during training.
metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics`
package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict
containing a combination of the aforementioned. In all cases, each metric needs to have the signature
`metric(preds,target)` and return a single scalar tensor.
learning_rate: Learning rate to use for training.
multi_label: Whether the targets are multi-label or not.
serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
**tabnet_kwargs: Optional additional arguments for the TabNet model, see
`pytorch_tabnet <https://dreamquark-ai.github.io/tabnet/_modules/pytorch_tabnet/tab_network.html#TabNet>`_.
"""

required_extras: str = "tabular"

def __init__(
self,
num_features: int,
embedding_sizes: List[Tuple[int, int]] = None,
loss_fn: Callable = F.cross_entropy,
optimizer: OPTIMIZER_TYPE = "Adam",
lr_scheduler: LR_SCHEDULER_TYPE = None,
metrics: METRICS_TYPE = None,
learning_rate: float = 1e-2,
serializer: SERIALIZER_TYPE = None,
**tabnet_kwargs,
):
self.save_hyperparameters()

cat_dims, cat_emb_dim = zip(*embedding_sizes) if embedding_sizes else ([], [])
model = TabNet(
Borda marked this conversation as resolved.
Show resolved Hide resolved
input_dim=num_features,
output_dim=1,
cat_idxs=list(range(len(embedding_sizes))),
cat_dims=list(cat_dims),
cat_emb_dim=list(cat_emb_dim),
**tabnet_kwargs,
)

super().__init__(
model=model,
loss_fn=loss_fn,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
metrics=metrics,
learning_rate=learning_rate,
serializer=serializer,
)

self.save_hyperparameters()

def forward(self, x_in) -> torch.Tensor:
# TabNet takes single input, x_in is composed of (categorical, numerical)
xs = []
for x in x_in:
if x.numel():
xs.append(x)
x = torch.cat(xs, dim=1)
return self.model(x)[0]

def training_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
return super().training_step(batch, batch_idx)

def validation_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
return super().validation_step(batch, batch_idx)

def test_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
return super().test_step(batch, batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
batch = batch[DefaultDataKeys.INPUT]
return self(batch)

@classmethod
def from_data(cls, datamodule, **kwargs) -> "TabularRegressor":
model = cls(datamodule.num_features, datamodule.num_classes, datamodule.embedding_sizes, **kwargs)
return model
54 changes: 54 additions & 0 deletions flash_examples/tabular_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

import flash
from flash.core.data.utils import download_data
from flash.tabular import TabularRegressionData, TabularRegressor

# 1. Create the DataModule
download_data("https://archive.ics.uci.edu/ml/machine-learning-databases/00560/SeoulBikeData.csv", "./data")

datamodule = TabularRegressionData.from_csv(
categorical_fields=["Seasons", "Holiday", "Functioning Day"],
numerical_fields=[
"Hour",
"Temperature(�C)",
"Humidity(%)",
"Wind speed (m/s)",
"Wind speed (m/s)",
"Visibility (10m)",
"Dew point temperature(�C)",
"Solar Radiation (MJ/m2)",
"Rainfall(mm)",
"Snowfall (cm)",
],
target_fields="Rented Bike Count",
train_file="data/SeoulBikeData.csv",
val_split=0.1,
)

# 2. Build the task
model = TabularRegressor.from_data(datamodule)

# 3. Create the trainer and train the model
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)

# 4. Generate predictions from a CSV
predictions = model.predict("data/SeoulBikeData.csv")
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("tabular_regression_model.pt")