diff --git a/src/gluonts/ext/rotbaum/_model.py b/src/gluonts/ext/rotbaum/_model.py index eb6b345ffa..94c8fb2300 100644 --- a/src/gluonts/ext/rotbaum/_model.py +++ b/src/gluonts/ext/rotbaum/_model.py @@ -20,7 +20,7 @@ import gc from collections import defaultdict -from gluonts.core.component import validated +from gluonts.core.component import equals, validated class QRF: @@ -125,6 +125,13 @@ def _create_xgboost_model(model_params: Optional[dict] = None): } return xgboost.sklearn.XGBModel(**model_params) + def __eq__(self, that): + """ + Two QRX instances are considered equal if they have the same + constructor arguments. + """ + return equals(self, that) + def fit( self, x_train: Union[pd.DataFrame, List], diff --git a/src/gluonts/ext/rotbaum/_predictor.py b/src/gluonts/ext/rotbaum/_predictor.py index 482f4adbd1..b713a5e8d5 100644 --- a/src/gluonts/ext/rotbaum/_predictor.py +++ b/src/gluonts/ext/rotbaum/_predictor.py @@ -13,12 +13,14 @@ import concurrent.futures import logging +import pickle from itertools import chain from typing import Iterator, List, Optional from toolz import first import numpy as np import pandas as pd +from pathlib import Path from itertools import compress from gluonts.core.component import validated @@ -337,6 +339,31 @@ def predict( item_id=ts.get("item_id"), ) + def serialize(self, path: Path) -> None: + """ + This function calls parent class serialize() in order to serialize + the class name, version information and constuctor arguments. It + persists the tree predictor by pickling the model list that is + generated when pickling the TreePredictor. + """ + super().serialize(path) + with (path / "predictor.pkl").open("wb") as f: + pickle.dump(self.model_list, f) + + @classmethod + def deserialize(cls, path: Path, **kwargs) -> "TreePredictor": + """ + This function loads and returns the serialized model. It loads + the predictor class with the serialized arguments. It then loads + the trained model list by reading the pickle file. + """ + + predictor = super().deserialize(path) + assert isinstance(predictor, cls) + with (path / "predictor.pkl").open("rb") as f: + predictor.model_list = pickle.load(f) + return predictor + def explain( self, importance_type: str = "gain", percentage: bool = True ) -> ExplanationResult: diff --git a/src/gluonts/ext/rotbaum/_preprocess.py b/src/gluonts/ext/rotbaum/_preprocess.py index b8e59e500b..1718ebf3bb 100644 --- a/src/gluonts/ext/rotbaum/_preprocess.py +++ b/src/gluonts/ext/rotbaum/_preprocess.py @@ -448,9 +448,12 @@ def make_features(self, time_series: Dict, starting_index: int) -> List: end_index = starting_index + self.context_window_size if starting_index < 0: prefix = [None] * abs(starting_index) + time_series_window = time_series["target"] else: prefix = [] - time_series_window = time_series["target"][starting_index:end_index] + time_series_window = time_series["target"][ + starting_index:end_index + ] only_lag_features, transform_dict = self._pre_transform( time_series_window, self.subtract_mean, self.count_nans ) @@ -460,7 +463,10 @@ def make_features(self, time_series: Dict, starting_index: int) -> List: if self.use_feat_static_real else [] ) - if self.cardinality: + if ( + self.cardinality + and time_series.get("feat_static_cat", None) is not None + ): feat_static_cat = ( self.encode_one_hot_all(time_series["feat_static_cat"]) if self.one_hot_encode @@ -473,10 +479,10 @@ def make_features(self, time_series: Dict, starting_index: int) -> List: list( chain( *[ - list(ent[0]) + list(ent[1].values()) + prefix + list(ent[0]) + list(ent[1].values()) for ent in [ self._pre_transform( - ts[starting_index:end_index], + ts if prefix else ts[starting_index:end_index], self.subtract_mean, self.count_nans, ) diff --git a/test/ext/rotbaum/test_model.py b/test/ext/rotbaum/test_model.py index f4feaad2d9..51869034c7 100644 --- a/test/ext/rotbaum/test_model.py +++ b/test/ext/rotbaum/test_model.py @@ -11,10 +11,11 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. - +from pathlib import Path import pytest +import tempfile -from gluonts.ext.rotbaum import TreeEstimator +from gluonts.ext.rotbaum import TreeEstimator, TreePredictor @pytest.fixture() @@ -33,5 +34,20 @@ def test_accuracy(accuracy_test, hyperparameters, quantiles): accuracy_test(TreeEstimator, hyperparameters, accuracy=0.20) -def test_serialize(serialize_test, hyperparameters): - serialize_test(TreeEstimator, hyperparameters) +def test_serialize(serialize_test, hyperparameters, dsinfo): + forecaster = TreeEstimator.from_hyperparameters( + freq=dsinfo.freq, + **{ + "prediction_length": dsinfo.prediction_length, + "num_parallel_samples": dsinfo.num_parallel_samples, + }, + **hyperparameters, + ) + + predictor_act = forecaster.train(dsinfo.train_ds) + + with tempfile.TemporaryDirectory() as temp_dir: + predictor_act.serialize(Path(temp_dir)) + predictor_exp = TreePredictor.deserialize(Path(temp_dir)) + assert predictor_act == predictor_exp + assert predictor_act.model_list == predictor_exp.model_list diff --git a/test/ext/rotbaum/test_rotbaum_smoke.py b/test/ext/rotbaum/test_rotbaum_smoke.py index 2634644660..93d1e96dd5 100644 --- a/test/ext/rotbaum/test_rotbaum_smoke.py +++ b/test/ext/rotbaum/test_rotbaum_smoke.py @@ -12,10 +12,12 @@ # permissions and limitations under the License. import pytest +import numpy as np -from gluonts.ext.rotbaum import TreeEstimator +from gluonts.ext.rotbaum import TreeEstimator, TreePredictor from gluonts.testutil.dummy_datasets import make_dummy_datasets_with_features +from gluonts.dataset.common import ListDataset # TODO: Add support for categorical and dynamic features. @@ -59,3 +61,68 @@ def test_rotbaum_smoke(datasets): predictor = estimator.train(dataset_train) forecasts = list(predictor.predict(dataset_test)) assert len(forecasts) == len(dataset_test) + + +def test_short_history_item_pred(): + prediction_length = 7 + freq = "D" + + dataset = ListDataset( + data_iter=[ + { + "start": "2017-10-11", + "item_id": "item_1", + "target": np.array( + [ + 1.0, + 9.0, + 2.0, + 0.0, + 0.0, + 1.0, + 5.0, + 3.0, + 4.0, + 2.0, + 0.0, + 0.0, + 1.0, + 6.0, + ] + ), + "feat_static_cat": np.array([0.0, 0.0], dtype=float), + "past_feat_dynamic_real": np.array( + [ + [1.0222e06 for i in range(14)], + [750.0 for i in range(14)], + ] + ), + }, + { + "start": "2017-10-11", + "item_id": "item_2", + "target": np.array([7.0, 0.0, 0.0, 23.0, 13.0]), + "feat_static_cat": np.array([0.0, 1.0], dtype=float), + "past_feat_dynamic_real": np.array( + [[0 for i in range(5)], [750.0 for i in range(5)]] + ), + }, + ], + freq=freq, + ) + + predictor = TreePredictor( + freq=freq, + prediction_length=prediction_length, + quantiles=[0.1, 0.5, 0.9], + max_n_datapts=50000, + method="QuantileRegression", + use_past_feat_dynamic_real=True, + use_feat_dynamic_real=False, + use_feat_dynamic_cat=False, + use_feat_static_real=False, + cardinality="auto", + ) + predictor = predictor.train(dataset) + forecasts = list(predictor.predict(dataset)) + assert forecasts[1].quantile(0.5).shape[0] == prediction_length