diff --git a/CHANGELOG.md b/CHANGELOG.md index db2fcb4df2..e5b92c08bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `TabularForecaster` task based on PyTorch Forecasting ([#647](https://github.com/PyTorchLightning/lightning-flash/pull/647)) +- Added a `TabularRegressor` task ([#892](https://github.com/PyTorchLightning/lightning-flash/pull/892)) + ### Changed ### Fixed diff --git a/docs/source/api/tabular.rst b/docs/source/api/tabular.rst index 60b4c18f89..45fbca2fb3 100644 --- a/docs/source/api/tabular.rst +++ b/docs/source/api/tabular.rst @@ -28,6 +28,7 @@ __________ :nosignatures: :template: classtemplate.rst + ~regression.model.TabularRegressor ~regression.data.TabularRegressionData Forecasting diff --git a/flash/core/classification.py b/flash/core/classification.py index 5dacef2bb8..884dffbdfb 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -39,8 +39,8 @@ def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch. class ClassificationMixin: + @staticmethod def _build( - self, num_classes: Optional[int] = None, loss_fn: Optional[Callable] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, @@ -72,7 +72,7 @@ def __init__( **kwargs, ) -> None: - metrics, loss_fn = ClassificationMixin._build(self, num_classes, loss_fn, metrics, multi_label) + metrics, loss_fn = ClassificationMixin._build(num_classes, loss_fn, metrics, multi_label) super().__init__( *args, @@ -95,7 +95,7 @@ def __init__( **kwargs, ) -> None: - metrics, loss_fn = ClassificationMixin._build(self, num_classes, loss_fn, metrics, multi_label) + metrics, loss_fn = ClassificationMixin._build(num_classes, loss_fn, metrics, multi_label) super().__init__( *args, @@ -257,7 +257,7 @@ def __init__( def serialize( self, sample: Any, - ) -> Union[Classification, Classifications, Dict[str, Any], Dict[str, Any]]: + ) -> Union[Classification, Classifications, Dict[str, Any]]: pred = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample pred = torch.tensor(pred) diff --git a/flash/core/model.py b/flash/core/model.py index 6f87bcb4c3..85e13aa8de 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -51,7 +51,7 @@ from flash.core.registry import FlashRegistry from flash.core.serve.composition import Composition from flash.core.utilities.apply_func import get_callable_dict -from flash.core.utilities.imports import requires +from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, requires from flash.core.utilities.providers import _HUGGINGFACE from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import ( @@ -399,6 +399,10 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: for name, metric in metrics.items(): if isinstance(metric, torchmetrics.metric.Metric): metric(y_hat, y) + # PL 1.4.0 -> 1.4.9 tries to deepcopy the metric. + # Sometimes _forward_cache is not a leaf, so we convert it to one. + if not metric._forward_cache.is_leaf and not _PL_GREATER_EQUAL_1_5_0: + metric._forward_cache = metric._forward_cache.clone().detach() logs[name] = metric # log the metric itself if it is of type Metric else: logs[name] = metric(y_hat, y) diff --git a/flash/core/regression.py b/flash/core/regression.py new file mode 100644 index 0000000000..edd351d02f --- /dev/null +++ b/flash/core/regression.py @@ -0,0 +1,57 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Mapping, Optional, Sequence, Union + +import torch +import torch.nn.functional as F +import torchmetrics + +from flash.core.data.process import Serializer +from flash.core.model import Task + + +class RegressionMixin: + @staticmethod + def _build( + loss_fn: Optional[Callable] = None, + metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + ): + metrics = metrics or torchmetrics.MeanSquaredError() + loss_fn = loss_fn or F.mse_loss + + return metrics, loss_fn + + def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class RegressionTask(Task, RegressionMixin): + def __init__( + self, + *args, + loss_fn: Optional[Callable] = None, + metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + **kwargs, + ) -> None: + + metrics, loss_fn = RegressionMixin._build(loss_fn, metrics) + + super().__init__( + *args, + loss_fn=loss_fn, + metrics=metrics, + serializer=serializer, + **kwargs, + ) diff --git a/flash/tabular/__init__.py b/flash/tabular/__init__.py index 466ca1fc0f..b62af944c0 100644 --- a/flash/tabular/__init__.py +++ b/flash/tabular/__init__.py @@ -5,4 +5,4 @@ TabularForecastingDataFrameDataSource, TabularForecastingPreprocess, ) -from flash.tabular.regression import TabularRegressionData # noqa: F401 +from flash.tabular.regression import TabularRegressionData, TabularRegressor # noqa: F401 diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index 2ed21da6e0..a0c0cc7114 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -40,7 +40,7 @@ class TabularClassifier(ClassificationTask): package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`. - learning_rate: Learning rate to use for training, defaults to `1e-3` + learning_rate: Learning rate to use for training. multi_label: Whether the targets are multi-label or not. serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. **tabnet_kwargs: Optional additional arguments for the TabNet model, see @@ -90,10 +90,7 @@ def __init__( def forward(self, x_in) -> torch.Tensor: # TabNet takes single input, x_in is composed of (categorical, numerical) - xs = [] - for x in x_in: - if x.numel(): - xs.append(x) + xs = [x for x in x_in if x.numel()] x = torch.cat(xs, dim=1) return self.model(x)[0] diff --git a/flash/tabular/regression/__init__.py b/flash/tabular/regression/__init__.py index a93e599ff0..13f8c8490d 100644 --- a/flash/tabular/regression/__init__.py +++ b/flash/tabular/regression/__init__.py @@ -1 +1,2 @@ from flash.tabular.regression.data import TabularRegressionData # noqa: F401 +from flash.tabular.regression.model import TabularRegressor # noqa: F401 diff --git a/flash/tabular/regression/model.py b/flash/tabular/regression/model.py new file mode 100644 index 0000000000..7710332670 --- /dev/null +++ b/flash/tabular/regression/model.py @@ -0,0 +1,111 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, List, Tuple + +import torch +from torch.nn import functional as F + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.regression import RegressionTask +from flash.core.utilities.imports import _TABULAR_AVAILABLE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE + +if _TABULAR_AVAILABLE: + from pytorch_tabnet.tab_network import TabNet + + +class TabularRegressor(RegressionTask): + """The ``TabularRegressor`` is a :class:`~flash.Task` for regression tabular data. + + Args: + num_features: Number of columns in table (not including target column). + embedding_sizes: List of (num_classes, emb_dim) to form categorical embeddings. + loss_fn: Loss function for training, defaults to cross entropy. + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. + metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics` + package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict + containing a combination of the aforementioned. In all cases, each metric needs to have the signature + `metric(preds,target)` and return a single scalar tensor. + learning_rate: Learning rate to use for training. + multi_label: Whether the targets are multi-label or not. + serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. + **tabnet_kwargs: Optional additional arguments for the TabNet model, see + `pytorch_tabnet `_. + """ + + required_extras: str = "tabular" + + def __init__( + self, + num_features: int, + embedding_sizes: List[Tuple[int, int]] = None, + loss_fn: Callable = F.mse_loss, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, + metrics: METRICS_TYPE = None, + learning_rate: float = 1e-2, + serializer: SERIALIZER_TYPE = None, + **tabnet_kwargs, + ): + self.save_hyperparameters() + + cat_dims, cat_emb_dim = zip(*embedding_sizes) if embedding_sizes else ([], []) + model = TabNet( + input_dim=num_features, + output_dim=1, + cat_idxs=list(range(len(embedding_sizes))), + cat_dims=list(cat_dims), + cat_emb_dim=list(cat_emb_dim), + **tabnet_kwargs, + ) + + super().__init__( + model=model, + loss_fn=loss_fn, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + metrics=metrics, + learning_rate=learning_rate, + serializer=serializer, + ) + + self.save_hyperparameters() + + def forward(self, x_in) -> torch.Tensor: + # TabNet takes single input, x_in is composed of (categorical, numerical) + xs = [x for x in x_in if x.numel()] + x = torch.cat(xs, dim=1) + return self.model(x)[0].flatten() + + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().test_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch = batch[DefaultDataKeys.INPUT] + return self(batch) + + @classmethod + def from_data(cls, datamodule, **kwargs) -> "TabularRegressor": + model = cls(datamodule.num_features, datamodule.embedding_sizes, **kwargs) + return model diff --git a/flash_examples/tabular_regression.py b/flash_examples/tabular_regression.py new file mode 100644 index 0000000000..cf1a4aabf2 --- /dev/null +++ b/flash_examples/tabular_regression.py @@ -0,0 +1,54 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import flash +from flash.core.data.utils import download_data +from flash.tabular import TabularRegressionData, TabularRegressor + +# 1. Create the DataModule +download_data("https://archive.ics.uci.edu/ml/machine-learning-databases/00560/SeoulBikeData.csv", "./data") + +datamodule = TabularRegressionData.from_csv( + categorical_fields=["Seasons", "Holiday", "Functioning Day"], + numerical_fields=[ + "Hour", + "Temperature(�C)", + "Humidity(%)", + "Wind speed (m/s)", + "Wind speed (m/s)", + "Visibility (10m)", + "Dew point temperature(�C)", + "Solar Radiation (MJ/m2)", + "Rainfall(mm)", + "Snowfall (cm)", + ], + target_fields="Rented Bike Count", + train_file="data/SeoulBikeData.csv", + val_split=0.1, +) + +# 2. Build the task +model = TabularRegressor.from_data(datamodule, learning_rate=0.1) + +# 3. Create the trainer and train the model +trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count()) +trainer.fit(model, datamodule=datamodule) + +# 4. Generate predictions from a CSV +predictions = model.predict("data/SeoulBikeData.csv") +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("tabular_regression_model.pt") diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 1060e43eb2..eeeb725ee9 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -76,6 +76,10 @@ "tabular_classification.py", marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed"), ), + pytest.param( + "tabular_regression.py", + marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed"), + ), pytest.param("template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")), pytest.param( "text_classification.py",