Skip to content

Commit

Permalink
update gluonts.dataset.split code, test, docs
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella committed Aug 19, 2022
1 parent b47a602 commit 8e75055
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 133 deletions.
274 changes: 160 additions & 114 deletions src/gluonts/dataset/split/splitter.py → src/gluonts/dataset/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,66 @@
Train/test splitter
~~~~~~~~~~~~~~~~~~~
.. testsetup:: *
import pandas as pd
import numpy as np
from gluonts.dataset.split import OffsetSplitter, DateSplitter
whole_dataset = [
{"start": pd.Period("2018-01-01", freq="D"), "target": np.arange(50)},
{"start": pd.Period("2018-01-01", freq="D"), "target": 2*np.arange(50)},
]
This module defines strategies to split a whole dataset into train and test
subsets.
For uniform datasets, where all time-series start and end at the same point in
time `OffsetSplitter` can be used::
time :class:`OffsetSplitter` can be used::
.. testcode::
splitter = OffsetSplitter(prediction_length=24, split_offset=24)
train, test = splitter.split(whole_dataset)
splitter = OffsetSplitter(offset=7)
train, test_template = splitter.split(whole_dataset)
For all other datasets, the more flexible `DateSplitter` can be used::
For all other datasets, the more flexible :class:`DateSplitter` can be used::
.. testcode::
splitter = DateSplitter(
prediction_length=24,
split_date=pd.Period('2018-01-31', freq='D')
)
train, test = splitter.split(whole_dataset)
train, test_template = splitter.split(whole_dataset)
The module also supports rolling splits::
In the above examples, the ``train`` output is a regular ``Dataset`` that can be
used for training purposes; ``test_template`` can generate test instances as
follows::
splitter = DateSplitter(
prediction_length=24,
split_date=pd.Period('2018-01-31', freq='D'),
windows=7
.. testcode::
test_dataset = test_template.generate_instances(
prediction_length=7,
windows=2,
)
train, test = splitter.split(whole_dataset)
The ``windows`` argument controls how many test windows to generate from each
entry in the original dataset. Each window will begin after the split point,
and so will not contain any training data. By default, windows are
non-overlapping, but this can be controlled with the ``distance`` optional
argument.
.. testcode::
test_dataset = test_template.generate_instances(
prediction_length=7,
windows=2,
distance=3, # windows are three time steps apart from each other
)
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import cast, Generator, Iterable, List, Optional, Tuple
from typing import cast, Generator, List, Optional, Tuple

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -195,20 +225,20 @@ def _trim_history(
return item

def split(self, dataset: Dataset):
test_data = TestTemplate(dataset=dataset, splitter=self)
train_data = TrainingDataset(dataset=dataset, splitter=self)

return train_data, test_data
return (
TrainingDataset(dataset=dataset, splitter=self),
TestTemplate(dataset=dataset, splitter=self),
)

def _generate_train_slices(self, items: List[DataEntry]):
for item in map(TimeSeriesSlice.from_data_entry, items):
train = self._train_slice(item)
def _generate_train_slices(self, dataset: Dataset):
for entry in map(TimeSeriesSlice.from_data_entry, dataset):
train = self._train_slice(entry)

yield train.to_data_entry()

def _generate_test_slices(
self,
items: Dataset,
dataset: Dataset,
prediction_length: int,
windows: int = 1,
distance: Optional[int] = None,
Expand All @@ -217,13 +247,13 @@ def _generate_test_slices(
if distance is None:
distance = prediction_length

for item in map(TimeSeriesSlice.from_data_entry, items):
train = self._train_slice(item)
for entry in map(TimeSeriesSlice.from_data_entry, dataset):
train = self._train_slice(entry)

for window in range(windows):
offset = window * distance
test = self._test_slice(
item, prediction_length=prediction_length, offset=offset
entry, prediction_length=prediction_length, offset=offset
)

_check_split_length(
Expand All @@ -250,94 +280,6 @@ def _check_split_length(
assert train_end + prediction_length <= test_end, msg


@dataclass
class TestIterable:
"""
An iterable class used for wrapping test data.
Parameters
----------
dataset:
Whole dataset used for testing.
splitter:
A specific splitter that knows how to slices training and
test data.
kwargs:
Parameters used for generating specific test instances.
See `TestTemplate.generate_instances`
"""

dataset: Dataset
splitter: AbstractBaseSplitter
kwargs: dict

def __iter__(self):
yield from self.splitter._generate_test_slices(
self.dataset,
**self.kwargs,
)


@dataclass
class TestTemplate:
"""
A class used for generating test data.
Parameters
----------
dataset:
Whole dataset used for testing.
splitter:
A specific splitter that knows how to slices training and
test data.
"""

dataset: Dataset
splitter: AbstractBaseSplitter

def generate_instances(
self,
prediction_length: int,
windows=1,
distance=None,
max_history=None,
) -> TestIterable:
"""
Generate an iterator of test dataset, which includes input part and
label part.
Parameters
----------
prediction_length
Length of the prediction interval in test data.
windows
Indicates how many test windows to generate for each original
dataset entry.
distance
This is rather the difference between the start of each test
window generated, for each of the original dataset entries.
max_history
If given, all entries in the *test*-set have a max-length of
`max_history`. This can be used to produce smaller file-sizes.
"""
kwargs = {
"prediction_length": prediction_length,
"windows": windows,
"distance": distance,
"max_history": max_history,
}
return TestIterable(self.dataset, self.splitter, kwargs)


@dataclass
class TrainingDataset:
dataset: Iterable[DataEntry]
splitter: AbstractBaseSplitter

def __iter__(self):
return self.splitter._generate_train_slices(self.dataset)


@dataclass
class OffsetSplitter(AbstractBaseSplitter):
"""
Expand Down Expand Up @@ -414,10 +356,114 @@ def _test_slice(


def split(dataset, *, offset=None, date=None):
# You need to provide `offset` or `date`, but not both
assert (offset is None) != (date is None)
assert (offset is None) != (
date is None
), "You need to provide ``offset`` or ``date``, but not both."
if offset is not None:
splitter = OffsetSplitter(offset)
else:
splitter = DateSplitter(date)
return splitter.split(dataset)


@dataclass
class TestDataset:
"""
An iterable type used for wrapping test data.
Elements of a ``TestDataset`` are pairs ``(input, label)``, where
``input`` is input data for models, while ``label`` is the future
ground truth that models are supposed to predict.
Parameters
----------
dataset:
Whole dataset used for testing.
splitter:
A specific splitter that knows how to slices training and
test data.
prediction_length
Length of the prediction interval in test data.
windows
Indicates how many test windows to generate for each original
dataset entry.
distance
This is rather the difference between the start of each test
window generated, for each of the original dataset entries.
max_history
If given, all entries in the *test*-set have a max-length of
`max_history`. This can be used to produce smaller file-sizes.
"""

dataset: Dataset
splitter: AbstractBaseSplitter
prediction_length: int
windows: int = 1
distance: Optional[int] = None
max_history: Optional[int] = None

def __iter__(self):
yield from self.splitter._generate_test_slices(
dataset=self.dataset,
prediction_length=self.prediction_length,
windows=self.windows,
distance=self.distance,
max_history=self.max_history,
)

@property
def input(self):
"""
Iterable over the ``input`` portion of the test data.
"""
for input, _ in self:
yield input

@property
def label(self):
"""
Iterable over the ``label`` portion of the test data.
"""
for _, label in self:
yield label


@dataclass
class TestTemplate:
"""
A class used for generating test data.
Parameters
----------
dataset:
Whole dataset used for testing.
splitter:
A specific splitter that knows how to slices training and
test data.
"""

dataset: Dataset
splitter: AbstractBaseSplitter

def generate_instances(self, **kwargs) -> TestDataset:
"""
Generate an iterator of test dataset, which includes input part and
label part.
Keyword arguments are the same as for :class:`TestDataset`.
"""
return TestDataset(
self.dataset, self.splitter, **kwargs
)


@dataclass
class TrainingDataset:
dataset: Dataset
splitter: AbstractBaseSplitter

def __iter__(self):
return self.splitter._generate_train_slices(self.dataset)

def __len__(self):
return len(self.dataset)
16 changes: 0 additions & 16 deletions src/gluonts/dataset/split/__init__.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@

from gluonts.dataset.common import ListDataset
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.dataset.split import DateSplitter, OffsetSplitter, split
from gluonts.dataset.split.splitter import TimeSeriesSlice
from gluonts.dataset.split import (
DateSplitter,
OffsetSplitter,
split,
TimeSeriesSlice,
)


def make_series(data, start="2020", freq="D"):
Expand Down

0 comments on commit 8e75055

Please sign in to comment.