From 3cbd13e089bc3fd005cf6a76099bd1f9bc5426dd Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 29 Oct 2021 15:58:03 +0100 Subject: [PATCH] Update on comments --- flash/core/integrations/pytorch_forecasting/adapter.py | 1 + .../pytorch_forecasting/tabular_forecasting_interpretable.py | 1 - flash_examples/tabular_forecasting.py | 1 - tests/tabular/forecasting/test_model.py | 1 - 4 files changed, 1 insertion(+), 3 deletions(-) diff --git a/flash/core/integrations/pytorch_forecasting/adapter.py b/flash/core/integrations/pytorch_forecasting/adapter.py index f77d0f8e56..473ecc38bf 100644 --- a/flash/core/integrations/pytorch_forecasting/adapter.py +++ b/flash/core/integrations/pytorch_forecasting/adapter.py @@ -75,6 +75,7 @@ def from_task( **backbone_kwargs, ) -> Adapter: parameters = copy(parameters) + # Remove the single row of data from the parameters to reconstruct the `time_series_dataset` data = parameters.pop("data_sample") time_series_dataset = PatchTimeSeriesDataSet.from_parameters(parameters, data) diff --git a/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py b/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py index 2c5c9ea60a..ec62cb2643 100644 --- a/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py +++ b/flash_examples/integrations/pytorch_forecasting/tabular_forecasting_interpretable.py @@ -28,7 +28,6 @@ # Example based on this tutorial: https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/ar.html # 1. Create the DataModule data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100, seed=42) -data["static"] = 2 data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") max_prediction_length = 20 diff --git a/flash_examples/tabular_forecasting.py b/flash_examples/tabular_forecasting.py index 718cc0aa72..836f01fe64 100644 --- a/flash_examples/tabular_forecasting.py +++ b/flash_examples/tabular_forecasting.py @@ -26,7 +26,6 @@ # Example based on this tutorial: https://pytorch-forecasting.readthedocs.io/en/latest/tutorials/ar.html # 1. Create the DataModule data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100, seed=42) -data["static"] = 2 data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") max_prediction_length = 20 diff --git a/tests/tabular/forecasting/test_model.py b/tests/tabular/forecasting/test_model.py index 1e0df945e2..bebf8477bc 100644 --- a/tests/tabular/forecasting/test_model.py +++ b/tests/tabular/forecasting/test_model.py @@ -29,7 +29,6 @@ @pytest.fixture def sample_data(): data = generate_ar_data(seasonality=10.0, timesteps=100, n_series=2, seed=42) - data["static"] = 2 data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D") max_prediction_length = 20 training_cutoff = data["time_idx"].max() - max_prediction_length