Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix checkpointing bug in tabular forecasting (#995)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Nov 23, 2021
1 parent 19cf911 commit 5f11d8f
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where additional `DataModule` keyword arguments could not be configured with Flash Zero for some tasks ([#994](https://github.com/PyTorchLightning/lightning-flash/pull/994))

- Fixed a bug where the TabularForecaster would not work with some versions of pandas ([#995](https://github.com/PyTorchLightning/lightning-flash/pull/995))

### Removed

- Removed `OutputMapping` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939))
Expand Down
2 changes: 1 addition & 1 deletion flash/core/integrations/pytorch_forecasting/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def from_task(
) -> Adapter:
parameters = copy(parameters)
# Remove the single row of data from the parameters to reconstruct the `time_series_dataset`
data = parameters.pop("data_sample")
data = DataFrame.from_dict(parameters.pop("data_sample"))
time_series_dataset = PatchTimeSeriesDataSet.from_parameters(parameters, data)

backbone_kwargs["loss"] = loss_fn
Expand Down
2 changes: 1 addition & 1 deletion flash/tabular/forecasting/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def load_data(
parameters = time_series_dataset.get_parameters()

# Add some sample data so that we can recreate the `TimeSeriesDataSet` later on
parameters["data_sample"] = data.iloc[[0]]
parameters["data_sample"] = data.iloc[[0]].to_dict()

self.set_state(TimeSeriesDataSetParametersState(parameters))
self.parameters = parameters
Expand Down
2 changes: 1 addition & 1 deletion flash/tabular/forecasting/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
metrics: Union[torchmetrics.Metric, List[torchmetrics.Metric]] = None,
learning_rate: float = 4e-3,
):
self.save_hyperparameters(ignore="parameters")
self.save_hyperparameters()

if backbone_kwargs is None:
backbone_kwargs = {}
Expand Down

0 comments on commit 5f11d8f

Please sign in to comment.