From 4747f099b161069c0a87550ad0d003ee45b69491 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 27 Oct 2021 02:01:34 +0200 Subject: [PATCH 01/16] tab regression --- flash/core/Regression.py | 60 ++++++++++++++ flash/core/classification.py | 9 +- flash/tabular/classification/model.py | 2 +- flash/tabular/regression/model.py | 114 ++++++++++++++++++++++++++ 4 files changed, 180 insertions(+), 5 deletions(-) create mode 100644 flash/core/Regression.py create mode 100644 flash/tabular/regression/model.py diff --git a/flash/core/Regression.py b/flash/core/Regression.py new file mode 100644 index 0000000000..7d103e75a3 --- /dev/null +++ b/flash/core/Regression.py @@ -0,0 +1,60 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Mapping, Optional, Sequence, Union + +import torch +import torch.nn.functional as F +import torchmetrics + +from flash.core.data.process import Serializer +from flash.core.model import Task + + +class RegressionMixin: + + @staticmethod + def _build( + loss_fn: Optional[Callable] = None, + metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + ): + metrics = metrics or torchmetrics.MeanSquaredError() + loss_fn = loss_fn or F.mse_loss + + return metrics, loss_fn + + def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: + if getattr(self.hparams, "multi_label", False): + return torch.sigmoid(x) + return torch.softmax(x, dim=1) + + +class RegressionTask(Task, RegressionMixin): + def __init__( + self, + *args, + loss_fn: Optional[Callable] = None, + metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + **kwargs, + ) -> None: + + metrics, loss_fn = RegressionMixin._build(loss_fn, metrics) + + super().__init__( + *args, + loss_fn=loss_fn, + metrics=metrics, + serializer=serializer, + **kwargs, + ) \ No newline at end of file diff --git a/flash/core/classification.py b/flash/core/classification.py index 5dacef2bb8..0d0c647fd6 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -39,8 +39,9 @@ def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch. class ClassificationMixin: + + @staticmethod def _build( - self, num_classes: Optional[int] = None, loss_fn: Optional[Callable] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, @@ -72,7 +73,7 @@ def __init__( **kwargs, ) -> None: - metrics, loss_fn = ClassificationMixin._build(self, num_classes, loss_fn, metrics, multi_label) + metrics, loss_fn = ClassificationMixin._build(num_classes, loss_fn, metrics, multi_label) super().__init__( *args, @@ -95,7 +96,7 @@ def __init__( **kwargs, ) -> None: - metrics, loss_fn = ClassificationMixin._build(self, num_classes, loss_fn, metrics, multi_label) + metrics, loss_fn = ClassificationMixin._build(num_classes, loss_fn, metrics, multi_label) super().__init__( *args, @@ -257,7 +258,7 @@ def __init__( def serialize( self, sample: Any, - ) -> Union[Classification, Classifications, Dict[str, Any], Dict[str, Any]]: + ) -> Union[Classification, Classifications, Dict[str, Any]]: pred = sample[DefaultDataKeys.PREDS] if isinstance(sample, Dict) else sample pred = torch.tensor(pred) diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index 2ed21da6e0..0da74a969f 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -40,7 +40,7 @@ class TabularClassifier(ClassificationTask): package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`. - learning_rate: Learning rate to use for training, defaults to `1e-3` + learning_rate: Learning rate to use for training. multi_label: Whether the targets are multi-label or not. serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. **tabnet_kwargs: Optional additional arguments for the TabNet model, see diff --git a/flash/tabular/regression/model.py b/flash/tabular/regression/model.py new file mode 100644 index 0000000000..66ba9e1dd3 --- /dev/null +++ b/flash/tabular/regression/model.py @@ -0,0 +1,114 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, List, Tuple + +import torch +from torch.nn import functional as F + +from flash.core.Regression import RegressionTask +from flash.core.data.data_source import DefaultDataKeys +from flash.core.utilities.imports import _TABULAR_AVAILABLE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE + +if _TABULAR_AVAILABLE: + from pytorch_tabnet.tab_network import TabNet + + +class TabularRegressor(RegressionTask): + """The ``TabularRegressor`` is a :class:`~flash.Task` for regression tabular data. + + Args: + num_features: Number of columns in table (not including target column). + embedding_sizes: List of (num_classes, emb_dim) to form categorical embeddings. + loss_fn: Loss function for training, defaults to cross entropy. + optimizer: Optimizer to use for training. + lr_scheduler: The LR scheduler to use during training. + metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics` + package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict + containing a combination of the aforementioned. In all cases, each metric needs to have the signature + `metric(preds,target)` and return a single scalar tensor. + learning_rate: Learning rate to use for training. + multi_label: Whether the targets are multi-label or not. + serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. + **tabnet_kwargs: Optional additional arguments for the TabNet model, see + `pytorch_tabnet `_. + """ + + required_extras: str = "tabular" + + def __init__( + self, + num_features: int, + embedding_sizes: List[Tuple[int, int]] = None, + loss_fn: Callable = F.cross_entropy, + optimizer: OPTIMIZER_TYPE = "Adam", + lr_scheduler: LR_SCHEDULER_TYPE = None, + metrics: METRICS_TYPE = None, + learning_rate: float = 1e-2, + serializer: SERIALIZER_TYPE = None, + **tabnet_kwargs, + ): + self.save_hyperparameters() + + cat_dims, cat_emb_dim = zip(*embedding_sizes) if embedding_sizes else ([], []) + model = TabNet( + input_dim=num_features, + output_dim=1, + cat_idxs=list(range(len(embedding_sizes))), + cat_dims=list(cat_dims), + cat_emb_dim=list(cat_emb_dim), + **tabnet_kwargs, + ) + + super().__init__( + model=model, + loss_fn=loss_fn, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + metrics=metrics, + learning_rate=learning_rate, + serializer=serializer, + ) + + self.save_hyperparameters() + + def forward(self, x_in) -> torch.Tensor: + # TabNet takes single input, x_in is composed of (categorical, numerical) + xs = [] + for x in x_in: + if x.numel(): + xs.append(x) + x = torch.cat(xs, dim=1) + return self.model(x)[0] + + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().test_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch = batch[DefaultDataKeys.INPUT] + return self(batch) + + @classmethod + def from_data(cls, datamodule, **kwargs) -> "TabularRegressor": + model = cls(datamodule.num_features, datamodule.num_classes, datamodule.embedding_sizes, **kwargs) + return model From c775078449dd9e48167d03750c79160a648d4d51 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Oct 2021 00:03:37 +0000 Subject: [PATCH 02/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash/core/Regression.py | 3 +-- flash/core/classification.py | 1 - flash/tabular/regression/model.py | 2 +- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/flash/core/Regression.py b/flash/core/Regression.py index 7d103e75a3..8cafcb5472 100644 --- a/flash/core/Regression.py +++ b/flash/core/Regression.py @@ -22,7 +22,6 @@ class RegressionMixin: - @staticmethod def _build( loss_fn: Optional[Callable] = None, @@ -57,4 +56,4 @@ def __init__( metrics=metrics, serializer=serializer, **kwargs, - ) \ No newline at end of file + ) diff --git a/flash/core/classification.py b/flash/core/classification.py index 0d0c647fd6..884dffbdfb 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -39,7 +39,6 @@ def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch. class ClassificationMixin: - @staticmethod def _build( num_classes: Optional[int] = None, diff --git a/flash/tabular/regression/model.py b/flash/tabular/regression/model.py index 66ba9e1dd3..95f1ecf296 100644 --- a/flash/tabular/regression/model.py +++ b/flash/tabular/regression/model.py @@ -16,8 +16,8 @@ import torch from torch.nn import functional as F -from flash.core.Regression import RegressionTask from flash.core.data.data_source import DefaultDataKeys +from flash.core.Regression import RegressionTask from flash.core.utilities.imports import _TABULAR_AVAILABLE from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE From c1d83037f713a71e5e421db6e809621f1559e4b4 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 28 Oct 2021 18:17:47 +0200 Subject: [PATCH 03/16] rename --- flash/core/{Regression.py => regression.py} | 0 flash/tabular/regression/model.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename flash/core/{Regression.py => regression.py} (100%) diff --git a/flash/core/Regression.py b/flash/core/regression.py similarity index 100% rename from flash/core/Regression.py rename to flash/core/regression.py diff --git a/flash/tabular/regression/model.py b/flash/tabular/regression/model.py index 95f1ecf296..7fa6147f0e 100644 --- a/flash/tabular/regression/model.py +++ b/flash/tabular/regression/model.py @@ -17,7 +17,7 @@ from torch.nn import functional as F from flash.core.data.data_source import DefaultDataKeys -from flash.core.Regression import RegressionTask +from flash.core.regression import RegressionTask from flash.core.utilities.imports import _TABULAR_AVAILABLE from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, SERIALIZER_TYPE From a59e48ad74712802db617560a78ad343bb62c609 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 3 Nov 2021 23:19:15 +0100 Subject: [PATCH 04/16] example --- flash/tabular/__init__.py | 2 +- flash/tabular/regression/__init__.py | 1 + flash_examples/tabular_regression.py | 43 ++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 flash_examples/tabular_regression.py diff --git a/flash/tabular/__init__.py b/flash/tabular/__init__.py index 466ca1fc0f..b62af944c0 100644 --- a/flash/tabular/__init__.py +++ b/flash/tabular/__init__.py @@ -5,4 +5,4 @@ TabularForecastingDataFrameDataSource, TabularForecastingPreprocess, ) -from flash.tabular.regression import TabularRegressionData # noqa: F401 +from flash.tabular.regression import TabularRegressionData, TabularRegressor # noqa: F401 diff --git a/flash/tabular/regression/__init__.py b/flash/tabular/regression/__init__.py index a93e599ff0..13f8c8490d 100644 --- a/flash/tabular/regression/__init__.py +++ b/flash/tabular/regression/__init__.py @@ -1 +1,2 @@ from flash.tabular.regression.data import TabularRegressionData # noqa: F401 +from flash.tabular.regression.model import TabularRegressor # noqa: F401 diff --git a/flash_examples/tabular_regression.py b/flash_examples/tabular_regression.py new file mode 100644 index 0000000000..06577c601b --- /dev/null +++ b/flash_examples/tabular_regression.py @@ -0,0 +1,43 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import flash +from flash.core.data.utils import download_data +from flash.tabular import TabularRegressionData, TabularRegressor + +# 1. Create the DataModule +download_data("https://archive.ics.uci.edu/ml/machine-learning-databases/00560/SeoulBikeData.csv", "./data") + +datamodule = TabularRegressionData.from_csv( + categorical_fields=["Hour", "Temperature(�C)", "Humidity(%)", "Wind speed (m/s)", "Wind speed (m/s)", "Visibility (10m)", "Dew point temperature(�C)", "Solar Radiation (MJ/m2)", "Rainfall(mm)", "Snowfall (cm)"], + numerical_fields=["Seasons", "Holiday", "Functioning Day"], + target_fields="Rented Bike Count", + train_file="data/SeoulBikeData.csv", + val_split=0.1, +) + +# 2. Build the task +model = TabularRegressor.from_data(datamodule) + +# 3. Create the trainer and train the model +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) +trainer.fit(model, datamodule=datamodule) + +# 4. Generate predictions from a CSV +predictions = model.predict("data/SeoulBikeData.csv") +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("tabular_regression_model.pt") From d44c51f0a5e817064a21714fd5c90ba28e075513 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 3 Nov 2021 22:20:02 +0000 Subject: [PATCH 05/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- flash_examples/tabular_regression.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/flash_examples/tabular_regression.py b/flash_examples/tabular_regression.py index 06577c601b..2e1ff97169 100644 --- a/flash_examples/tabular_regression.py +++ b/flash_examples/tabular_regression.py @@ -21,7 +21,18 @@ download_data("https://archive.ics.uci.edu/ml/machine-learning-databases/00560/SeoulBikeData.csv", "./data") datamodule = TabularRegressionData.from_csv( - categorical_fields=["Hour", "Temperature(�C)", "Humidity(%)", "Wind speed (m/s)", "Wind speed (m/s)", "Visibility (10m)", "Dew point temperature(�C)", "Solar Radiation (MJ/m2)", "Rainfall(mm)", "Snowfall (cm)"], + categorical_fields=[ + "Hour", + "Temperature(�C)", + "Humidity(%)", + "Wind speed (m/s)", + "Wind speed (m/s)", + "Visibility (10m)", + "Dew point temperature(�C)", + "Solar Radiation (MJ/m2)", + "Rainfall(mm)", + "Snowfall (cm)", + ], numerical_fields=["Seasons", "Holiday", "Functioning Day"], target_fields="Rented Bike Count", train_file="data/SeoulBikeData.csv", From e4d40be478de120750b6e60cabca6f666aa2d88b Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 3 Nov 2021 23:40:02 +0100 Subject: [PATCH 06/16] example --- flash/tabular/regression/model.py | 2 +- flash_examples/tabular_regression.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flash/tabular/regression/model.py b/flash/tabular/regression/model.py index 7fa6147f0e..dacf16928d 100644 --- a/flash/tabular/regression/model.py +++ b/flash/tabular/regression/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, List, Tuple import torch from torch.nn import functional as F diff --git a/flash_examples/tabular_regression.py b/flash_examples/tabular_regression.py index 2e1ff97169..7daba3f22f 100644 --- a/flash_examples/tabular_regression.py +++ b/flash_examples/tabular_regression.py @@ -21,7 +21,8 @@ download_data("https://archive.ics.uci.edu/ml/machine-learning-databases/00560/SeoulBikeData.csv", "./data") datamodule = TabularRegressionData.from_csv( - categorical_fields=[ + categorical_fields=["Seasons", "Holiday", "Functioning Day"], + numerical_fields=[ "Hour", "Temperature(�C)", "Humidity(%)", @@ -33,7 +34,6 @@ "Rainfall(mm)", "Snowfall (cm)", ], - numerical_fields=["Seasons", "Holiday", "Functioning Day"], target_fields="Rented Bike Count", train_file="data/SeoulBikeData.csv", val_split=0.1, From feae7d29f21a9a110277a660c8b8568485354f6e Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 4 Nov 2021 13:34:38 +0100 Subject: [PATCH 07/16] docs --- docs/source/api/tabular.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/api/tabular.rst b/docs/source/api/tabular.rst index 60b4c18f89..45fbca2fb3 100644 --- a/docs/source/api/tabular.rst +++ b/docs/source/api/tabular.rst @@ -28,6 +28,7 @@ __________ :nosignatures: :template: classtemplate.rst + ~regression.model.TabularRegressor ~regression.data.TabularRegressionData Forecasting From 1c7ac872fc821995a4bb44e16ed91a7a96bbfda7 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 4 Nov 2021 14:24:27 +0100 Subject: [PATCH 08/16] fixing --- flash/core/regression.py | 4 +--- flash/tabular/classification/model.py | 5 +---- flash/tabular/regression/model.py | 5 +---- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/flash/core/regression.py b/flash/core/regression.py index 8cafcb5472..edd351d02f 100644 --- a/flash/core/regression.py +++ b/flash/core/regression.py @@ -33,9 +33,7 @@ def _build( return metrics, loss_fn def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: - if getattr(self.hparams, "multi_label", False): - return torch.sigmoid(x) - return torch.softmax(x, dim=1) + return x class RegressionTask(Task, RegressionMixin): diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index 0da74a969f..a0c0cc7114 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -90,10 +90,7 @@ def __init__( def forward(self, x_in) -> torch.Tensor: # TabNet takes single input, x_in is composed of (categorical, numerical) - xs = [] - for x in x_in: - if x.numel(): - xs.append(x) + xs = [x for x in x_in if x.numel()] x = torch.cat(xs, dim=1) return self.model(x)[0] diff --git a/flash/tabular/regression/model.py b/flash/tabular/regression/model.py index dacf16928d..f5ee149882 100644 --- a/flash/tabular/regression/model.py +++ b/flash/tabular/regression/model.py @@ -85,10 +85,7 @@ def __init__( def forward(self, x_in) -> torch.Tensor: # TabNet takes single input, x_in is composed of (categorical, numerical) - xs = [] - for x in x_in: - if x.numel(): - xs.append(x) + xs = [x for x in x_in if x.numel()] x = torch.cat(xs, dim=1) return self.model(x)[0] From 832bf10155746291e1ccdbe7efa848608a35bf96 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 4 Nov 2021 13:44:00 +0000 Subject: [PATCH 09/16] Add example to CI --- tests/examples/test_scripts.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 1060e43eb2..eeeb725ee9 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -76,6 +76,10 @@ "tabular_classification.py", marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed"), ), + pytest.param( + "tabular_regression.py", + marks=pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed"), + ), pytest.param("template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")), pytest.param( "text_classification.py", From 3c039d2fd4004f3332d2b83b1d90fb091cd5ea70 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 4 Nov 2021 13:51:02 +0000 Subject: [PATCH 10/16] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index feb390c8c9..9b89d957b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `TabularForecaster` task based on PyTorch Forecasting ([#647](https://github.com/PyTorchLightning/lightning-flash/pull/647)) +- Added a `TabularRegressor` task ([#892](https://github.com/PyTorchLightning/lightning-flash/pull/892)) + ### Changed ### Fixed From bac3f26b9e73417a16894410d2ff542334f7de5e Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 4 Nov 2021 14:54:40 +0100 Subject: [PATCH 11/16] args --- flash/tabular/regression/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/tabular/regression/model.py b/flash/tabular/regression/model.py index f5ee149882..2606f02f05 100644 --- a/flash/tabular/regression/model.py +++ b/flash/tabular/regression/model.py @@ -107,5 +107,5 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A @classmethod def from_data(cls, datamodule, **kwargs) -> "TabularRegressor": - model = cls(datamodule.num_features, datamodule.num_classes, datamodule.embedding_sizes, **kwargs) + model = cls(datamodule.num_features, datamodule.embedding_sizes, **kwargs) return model From 4726284643e01ed9ae191b86df2731893fa05bf5 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 4 Nov 2021 15:10:28 +0100 Subject: [PATCH 12/16] defaults --- flash/tabular/regression/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash/tabular/regression/model.py b/flash/tabular/regression/model.py index 2606f02f05..7710332670 100644 --- a/flash/tabular/regression/model.py +++ b/flash/tabular/regression/model.py @@ -51,7 +51,7 @@ def __init__( self, num_features: int, embedding_sizes: List[Tuple[int, int]] = None, - loss_fn: Callable = F.cross_entropy, + loss_fn: Callable = F.mse_loss, optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: METRICS_TYPE = None, @@ -87,7 +87,7 @@ def forward(self, x_in) -> torch.Tensor: # TabNet takes single input, x_in is composed of (categorical, numerical) xs = [x for x in x_in if x.numel()] x = torch.cat(xs, dim=1) - return self.model(x)[0] + return self.model(x)[0].flatten() def training_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) From 3b2a8729ac57fb6c096d4dfb3818101a9856ffe6 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 5 Nov 2021 12:31:17 +0000 Subject: [PATCH 13/16] Fix deepcopy bug --- flash/core/model.py | 3 +++ flash_examples/tabular_regression.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 6f87bcb4c3..01b5a1cc1a 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -399,6 +399,9 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: for name, metric in metrics.items(): if isinstance(metric, torchmetrics.metric.Metric): metric(y_hat, y) + # Sometimes _forward_cache is not a leaf and PL attempts to deepcopy it + if not metric._forward_cache.is_leaf: + metric._forward_cache = metric._forward_cache.clone().detach() logs[name] = metric # log the metric itself if it is of type Metric else: logs[name] = metric(y_hat, y) diff --git a/flash_examples/tabular_regression.py b/flash_examples/tabular_regression.py index 7daba3f22f..cf1a4aabf2 100644 --- a/flash_examples/tabular_regression.py +++ b/flash_examples/tabular_regression.py @@ -40,10 +40,10 @@ ) # 2. Build the task -model = TabularRegressor.from_data(datamodule) +model = TabularRegressor.from_data(datamodule, learning_rate=0.1) # 3. Create the trainer and train the model -trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) +trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count()) trainer.fit(model, datamodule=datamodule) # 4. Generate predictions from a CSV From 5a5b6a0822a4113cf29d59068859adfa42aa6b31 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 5 Nov 2021 12:32:24 +0000 Subject: [PATCH 14/16] Update comment --- flash/core/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/core/model.py b/flash/core/model.py index 01b5a1cc1a..d71a37300e 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -399,7 +399,7 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: for name, metric in metrics.items(): if isinstance(metric, torchmetrics.metric.Metric): metric(y_hat, y) - # Sometimes _forward_cache is not a leaf and PL attempts to deepcopy it + # Sometimes _forward_cache is not a leaf and PL tries to deepcopy it, so we convert it to a leaf if not metric._forward_cache.is_leaf: metric._forward_cache = metric._forward_cache.clone().detach() logs[name] = metric # log the metric itself if it is of type Metric From aeb4a3c84e8c1769fdad033d54b2da8331402153 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 5 Nov 2021 12:40:46 +0000 Subject: [PATCH 15/16] Update comment --- flash/core/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flash/core/model.py b/flash/core/model.py index d71a37300e..d5dab29073 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -399,7 +399,8 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: for name, metric in metrics.items(): if isinstance(metric, torchmetrics.metric.Metric): metric(y_hat, y) - # Sometimes _forward_cache is not a leaf and PL tries to deepcopy it, so we convert it to a leaf + # PL 1.4.0 -> 1.4.9 tries to deepcopy the metric. + # Sometimes _forward_cache is not a leaf, so we convert it to one. if not metric._forward_cache.is_leaf: metric._forward_cache = metric._forward_cache.clone().detach() logs[name] = metric # log the metric itself if it is of type Metric From c58ef2557dfeec0d44b2c54111b6ec2f808bb23f Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 5 Nov 2021 13:02:13 +0000 Subject: [PATCH 16/16] Explicitly check for PL 1.5.0 --- flash/core/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index d5dab29073..85e13aa8de 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -51,7 +51,7 @@ from flash.core.registry import FlashRegistry from flash.core.serve.composition import Composition from flash.core.utilities.apply_func import get_callable_dict -from flash.core.utilities.imports import requires +from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, requires from flash.core.utilities.providers import _HUGGINGFACE from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import ( @@ -401,7 +401,7 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: metric(y_hat, y) # PL 1.4.0 -> 1.4.9 tries to deepcopy the metric. # Sometimes _forward_cache is not a leaf, so we convert it to one. - if not metric._forward_cache.is_leaf: + if not metric._forward_cache.is_leaf and not _PL_GREATER_EQUAL_1_5_0: metric._forward_cache = metric._forward_cache.clone().detach() logs[name] = metric # log the metric itself if it is of type Metric else: