Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix checkpointing bug in tabular forecasting #995

Merged
merged 3 commits into from
Nov 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where additional `DataModule` keyword arguments could not be configured with Flash Zero for some tasks ([#994](https://github.com/PyTorchLightning/lightning-flash/pull/994))

- Fixed a bug where the TabularForecaster would not work with some versions of pandas ([#995](https://github.com/PyTorchLightning/lightning-flash/pull/995))

### Removed

- Removed `OutputMapping` ([#939](https://github.com/PyTorchLightning/lightning-flash/pull/939))
Expand Down
2 changes: 1 addition & 1 deletion flash/core/integrations/pytorch_forecasting/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def from_task(
) -> Adapter:
parameters = copy(parameters)
# Remove the single row of data from the parameters to reconstruct the `time_series_dataset`
data = parameters.pop("data_sample")
data = DataFrame.from_dict(parameters.pop("data_sample"))
time_series_dataset = PatchTimeSeriesDataSet.from_parameters(parameters, data)

backbone_kwargs["loss"] = loss_fn
Expand Down
2 changes: 1 addition & 1 deletion flash/tabular/forecasting/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def load_data(
parameters = time_series_dataset.get_parameters()

# Add some sample data so that we can recreate the `TimeSeriesDataSet` later on
parameters["data_sample"] = data.iloc[[0]]
parameters["data_sample"] = data.iloc[[0]].to_dict()

self.set_state(TimeSeriesDataSetParametersState(parameters))
self.parameters = parameters
Expand Down
2 changes: 1 addition & 1 deletion flash/tabular/forecasting/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
metrics: Union[torchmetrics.Metric, List[torchmetrics.Metric]] = None,
learning_rate: float = 4e-3,
):
self.save_hyperparameters(ignore="parameters")
self.save_hyperparameters()

if backbone_kwargs is None:
backbone_kwargs = {}
Expand Down