From d4cf16ac1aa19394d900ff37c108b03a99d50adc Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Sat, 17 Dec 2022 09:24:10 +0100 Subject: [PATCH] Fix `DateSplitter` for multiples of base frequencies (#2500) --- src/gluonts/dataset/split.py | 55 ++++--- test/dataset/test_split.py | 284 +++++++++++++++++++++++++++-------- 2 files changed, 256 insertions(+), 83 deletions(-) diff --git a/src/gluonts/dataset/split.py b/src/gluonts/dataset/split.py index dc23f51078..78e2fcd6fb 100644 --- a/src/gluonts/dataset/split.py +++ b/src/gluonts/dataset/split.py @@ -76,11 +76,25 @@ from typing import Generator, Optional, Tuple import pandas as pd +from pandas.tseries.frequencies import to_offset from gluonts.dataset import Dataset, DataEntry from gluonts.dataset.field_names import FieldName +def periods_between( + start: pd.Period, + end: pd.Period, +) -> int: + """ + Counts how many periods fit between ``start`` and ``end`` + (inclusive). + + The frequency is taken from ``start``. + """ + return ((end - start).n // to_offset(start.freq).n) + 1 + + def to_positive_slice(slice_: slice, length: int) -> slice: """ Return an equivalent slice with positive bounds, given the @@ -277,7 +291,7 @@ class OffsetSplitter(AbstractBaseSplitter): offset: int def training_entry(self, entry: DataEntry) -> DataEntry: - return TimeSeriesSlice(entry)[: self.offset] + return slice_data_entry(entry, slice(None, self.offset)) def test_pair( self, entry: DataEntry, prediction_length: int, offset: int = 0 @@ -286,17 +300,19 @@ def test_pair( if self.offset < 0 and offset_ >= 0: offset_ += len(entry) if offset_ + prediction_length: - return ( - TimeSeriesSlice(entry, prediction_length)[:offset_], - TimeSeriesSlice(entry, prediction_length)[ - offset_ : offset_ + prediction_length - ], - ) + input_slice = slice(None, offset_) + label_slice = slice(offset_, offset_ + prediction_length) else: - return ( - TimeSeriesSlice(entry, prediction_length)[:offset_], - TimeSeriesSlice(entry, prediction_length)[offset_:], - ) + input_slice = slice(None, offset_) + label_slice = slice(offset_, None) + return ( + slice_data_entry( + entry, input_slice, prediction_length=prediction_length + ), + slice_data_entry( + entry, label_slice, prediction_length=prediction_length + ), + ) @dataclass @@ -317,17 +333,22 @@ class DateSplitter(AbstractBaseSplitter): date: pd.Period def training_entry(self, entry: DataEntry) -> DataEntry: - return TimeSeriesSlice(entry)[: self.date] + length = periods_between(entry["start"], self.date) + return slice_data_entry(entry, slice(None, length)) def test_pair( self, entry: DataEntry, prediction_length: int, offset: int = 0 ) -> Tuple[DataEntry, DataEntry]: - date = self.date.asfreq(entry[FieldName.START].freq) + base = periods_between(entry["start"], self.date) + input_slice = slice(None, base + offset) + label_slice = slice(base + offset, base + offset + prediction_length) return ( - TimeSeriesSlice(entry, prediction_length)[: date + offset], - TimeSeriesSlice(entry, prediction_length)[ - date + (offset + 1) : date + (prediction_length + offset) - ], + slice_data_entry( + entry, input_slice, prediction_length=prediction_length + ), + slice_data_entry( + entry, label_slice, prediction_length=prediction_length + ), ) diff --git a/test/dataset/test_split.py b/test/dataset/test_split.py index 6f80674908..d35c6a68d7 100644 --- a/test/dataset/test_split.py +++ b/test/dataset/test_split.py @@ -15,11 +15,10 @@ import pandas as pd import pytest -from gluonts.dataset.common import ListDataset from gluonts.dataset.field_names import FieldName from gluonts.dataset.split import ( - DateSplitter, OffsetSplitter, + periods_between, split, TimeSeriesSlice, ) @@ -58,47 +57,78 @@ def test_time_series_slice(): ).all() -def test_split_mult_freq(): - splitter = DateSplitter( - date=pd.Period("2021-01-01", "2h"), - ) - - splitter.split( - [ - { - "item_id": "1", - "target": pd.Series([0, 1, 2]), - "start": pd.Period("2021-01-01", freq="2H"), - } - ] - ) +@pytest.mark.parametrize( + "start, end, count", + [ + ( + pd.Period("2021-03-04", freq="2D"), + pd.Period("2021-03-05", freq="2D"), + 1, + ), + ( + pd.Period("2021-03-04", freq="2D"), + pd.Period("2021-03-08", freq="2D"), + 3, + ), + ( + pd.Period("2021-03-03 23:00", freq="30T"), + pd.Period("2021-03-04 03:29", freq="30T"), + 9, + ), + ( + pd.Period("2015-04-07 00:00", freq="30T"), + pd.Period("2015-04-07 09:31", "30T"), + 20, + ), + ( + pd.Period("2015-04-07 00:00", freq="30T"), + pd.Period("2015-04-08 16:10", freq="30T"), + 81, + ), + ( + pd.Period("2021-01-01 00", freq="2H"), + pd.Period("2021-01-01 08", "2H"), + 5, + ), + ( + pd.Period("2021-01-01 00", freq="2H"), + pd.Period("2021-01-01 11", "2H"), + 6, + ), + ], +) +def test_periods_between(start, end, count): + assert count == periods_between(start, end) def test_negative_offset_splitter(): - dataset = ListDataset( - [ - {"item_id": 0, "start": "2021-03-04", "target": [1.0] * 100}, - {"item_id": 1, "start": "2021-03-04", "target": [2.0] * 50}, - ], - freq="D", - ) + dataset = [ + { + "item_id": 0, + "start": pd.Period("2021-03-04", freq="D"), + "target": np.ones(shape=(100,)), + }, + { + "item_id": 1, + "start": pd.Period("2021-03-04", freq="D"), + "target": 2 * np.ones(shape=(50,)), + }, + ] - splitter = OffsetSplitter(offset=-7).split(dataset) + train, test_gen = OffsetSplitter(offset=-7).split(dataset) - assert [len(t["target"]) for t in splitter[0]] == [93, 43] + assert [len(t["target"]) for t in train] == [93, 43] assert [ len(t["target"]) + len(s["target"]) - for t, s in splitter[1].generate_instances(prediction_length=7) + for t, s in test_gen.generate_instances(prediction_length=7) ] == [100, 50] - rolling_splitter = OffsetSplitter(offset=-21).split(dataset) + train, test_gen = OffsetSplitter(offset=-21).split(dataset) - assert [len(t["target"]) for t in rolling_splitter[0]] == [79, 29] + assert [len(t["target"]) for t in train] == [79, 29] assert [ len(t["target"]) + len(s["target"]) - for t, s in rolling_splitter[1].generate_instances( - prediction_length=7, windows=3 - ) + for t, s in test_gen.generate_instances(prediction_length=7, windows=3) ] == [ 86, 93, @@ -155,41 +185,38 @@ def check_training_validation( @pytest.mark.parametrize( "dataset", [ - ListDataset( - [ - { - "item_id": 0, - "start": "2021-03-04", - "target": [1.0] * 365, - "feat_dynamic_real": [[2.0] * 365], - }, - { - "item_id": 1, - "start": "2021-03-04", - "target": [2.0] * 265, - "feat_dynamic_real": [[3.0] * 265], - }, - ], - freq="D", - ), - ListDataset( - [ - { - "item_id": 0, - "start": "2021-03-04", - "target": [[1.0] * 365, [10.0] * 365], - "feat_dynamic_real": [[2.0] * 365], - }, - { - "item_id": 1, - "start": "2021-03-04", - "target": [[2.0] * 265, [20.0] * 265], - "feat_dynamic_real": [[3.0] * 265], - }, - ], - one_dim_target=False, - freq="D", - ), + [ + { + "item_id": 0, + "start": pd.Period("2021-03-04", freq="D"), + "target": np.ones(shape=(365,)), + "feat_dynamic_real": 2 * np.ones(shape=(1, 365)), + }, + { + "item_id": 1, + "start": pd.Period("2021-03-04", freq="D"), + "target": 2 * np.ones(shape=(265,)), + "feat_dynamic_real": 3 * np.ones(shape=(1, 265)), + }, + ], + [ + { + "item_id": 0, + "start": pd.Period("2021-03-04", freq="D"), + "target": np.stack( + [np.ones(shape=(365,)), 10 * np.ones(shape=(365,))] + ), + "feat_dynamic_real": 2 * np.ones(shape=(1, 365)), + }, + { + "item_id": 1, + "start": pd.Period("2021-03-04", freq="D"), + "target": np.stack( + [2 * np.ones(shape=(265,)), 20 * np.ones(shape=(265,))] + ), + "feat_dynamic_real": 3 * np.ones(shape=(1, 265)), + }, + ], ], ) @pytest.mark.parametrize( @@ -238,3 +265,128 @@ def test_split(dataset, date, offset, windows, distance, max_history): offset=offset, ) k += 1 + + +@pytest.mark.parametrize( + "entry, offset, prediction_length, test_label_start", + [ + ( + { + "start": pd.Period("2015-04-07 00:00:00", freq="30T"), + "target": np.random.randn(100), + }, + 20, + 6, + pd.Period("2015-04-07 10:00:00", freq="30T"), + ), + ( + { + "start": pd.Period("2015-04-07 00:00:00", freq="30T"), + "target": np.random.randn(100), + }, + -20, + 6, + pd.Period("2015-04-08 16:00:00", freq="30T"), + ), + ], +) +def test_split_offset( + entry, + offset, + prediction_length, + test_label_start, +): + training_dataset, test_template = split([entry], offset=offset) + + training_entry = next(iter(training_dataset)) + test_input, test_label = next( + iter( + test_template.generate_instances( + prediction_length=prediction_length + ) + ) + ) + + if offset < 0: + training_size = (len(entry["target"]) + offset,) + else: + training_size = (offset,) + + assert training_entry["start"] == entry["start"] + assert training_entry["target"].shape == training_size + + assert test_input["start"] == entry["start"] + assert test_input["target"].shape == training_size + + assert test_label["start"] == test_label_start + assert test_label["target"].shape == (prediction_length,) + + +@pytest.mark.parametrize( + "entry, date, prediction_length, training_size", + [ + ( + { + "start": pd.Period("2015-04-07 00:00:00", freq="30T"), + "target": np.random.randn(100), + }, + pd.Period("2015-04-07 09:30", "30T"), + 6, + (20,), + ), + ( + { + "start": pd.Period("2015-04-07 00:00:00", freq="30T"), + "target": np.random.randn(100), + }, + pd.Period("2015-04-08 16:00:00", freq="30T"), + 6, + (81,), + ), + ( + { + "start": pd.Period("2021-01-01 00", freq="2H"), + "target": np.arange(10), + }, + pd.Period("2021-01-01 08", "2h"), + 2, + (5,), + ), + ( + { + "start": pd.Period("2021-01-01 00", freq="2H"), + "target": np.arange(10), + }, + pd.Period("2021-01-01 11", "2h"), + 2, + (6,), + ), + ], +) +def test_split_date( + entry, + date, + prediction_length, + training_size, +): + training_dataset, test_template = split([entry], date=date) + + training_entry = next(iter(training_dataset)) + test_input, test_label = next( + iter( + test_template.generate_instances( + prediction_length=prediction_length + ) + ) + ) + + assert training_entry["start"] == entry["start"] + assert training_entry["target"].shape == training_size + + assert test_input["start"] == entry["start"] + assert test_input["target"].shape == training_size + + assert test_label["start"] == test_input["start"] + len( + test_input["target"] + ) + assert test_label["target"].shape == (prediction_length,)