Skip to content

Commit

Permalink
simplify and fix date splitting, update and consolidate tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Lorenzo Stella committed Dec 16, 2022
1 parent 3231271 commit 0f926f1
Show file tree
Hide file tree
Showing 2 changed files with 256 additions and 83 deletions.
55 changes: 38 additions & 17 deletions src/gluonts/dataset/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,25 @@
from typing import Generator, Optional, Tuple

import pandas as pd
from pandas.tseries.frequencies import to_offset

from gluonts.dataset import Dataset, DataEntry
from gluonts.dataset.field_names import FieldName


def periods_between(
start: pd.Period,
end: pd.Period,
) -> int:
"""
Counts how many periods fit between ``start`` and ``end``
(inclusive).
The frequency is taken from ``start``.
"""
return ((end - start).n // to_offset(start.freq).n) + 1


def to_positive_slice(slice_: slice, length: int) -> slice:
"""
Return an equivalent slice with positive bounds, given the
Expand Down Expand Up @@ -277,7 +291,7 @@ class OffsetSplitter(AbstractBaseSplitter):
offset: int

def training_entry(self, entry: DataEntry) -> DataEntry:
return TimeSeriesSlice(entry)[: self.offset]
return slice_data_entry(entry, slice(None, self.offset))

def test_pair(
self, entry: DataEntry, prediction_length: int, offset: int = 0
Expand All @@ -286,17 +300,19 @@ def test_pair(
if self.offset < 0 and offset_ >= 0:
offset_ += len(entry)
if offset_ + prediction_length:
return (
TimeSeriesSlice(entry, prediction_length)[:offset_],
TimeSeriesSlice(entry, prediction_length)[
offset_ : offset_ + prediction_length
],
)
input_slice = slice(None, offset_)
label_slice = slice(offset_, offset_ + prediction_length)
else:
return (
TimeSeriesSlice(entry, prediction_length)[:offset_],
TimeSeriesSlice(entry, prediction_length)[offset_:],
)
input_slice = slice(None, offset_)
label_slice = slice(offset_, None)
return (
slice_data_entry(
entry, input_slice, prediction_length=prediction_length
),
slice_data_entry(
entry, label_slice, prediction_length=prediction_length
),
)


@dataclass
Expand All @@ -317,17 +333,22 @@ class DateSplitter(AbstractBaseSplitter):
date: pd.Period

def training_entry(self, entry: DataEntry) -> DataEntry:
return TimeSeriesSlice(entry)[: self.date]
length = periods_between(entry["start"], self.date)
return slice_data_entry(entry, slice(None, length))

def test_pair(
self, entry: DataEntry, prediction_length: int, offset: int = 0
) -> Tuple[DataEntry, DataEntry]:
date = self.date.asfreq(entry[FieldName.START].freq)
base = periods_between(entry["start"], self.date)
input_slice = slice(None, base + offset)
label_slice = slice(base + offset, base + offset + prediction_length)
return (
TimeSeriesSlice(entry, prediction_length)[: date + offset],
TimeSeriesSlice(entry, prediction_length)[
date + (offset + 1) : date + (prediction_length + offset)
],
slice_data_entry(
entry, input_slice, prediction_length=prediction_length
),
slice_data_entry(
entry, label_slice, prediction_length=prediction_length
),
)


Expand Down
Loading

0 comments on commit 0f926f1

Please sign in to comment.