Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LongDataset. #2377

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/tutorials/data_manipulation/pandasdataframes.md.template
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ df.head()
```

After reading the data into a `pandas.DataFrame` we can easily convert
it to `gluonts.dataset.pandas.PandasDataset` and train an estimator an get forecasts.
it into a `Dataset` and train an estimator an get forecasts.

```python
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.pandas import LongDataset

ds = PandasDataset.from_long_dataframe(df, target="target", item_id="item_id")
ds = LongDataset(df, item_id="item_id")
```

```python
Expand Down
190 changes: 147 additions & 43 deletions src/gluonts/dataset/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@
# permissions and limitations under the License.

from copy import deepcopy
from dataclasses import dataclass, field
from dataclasses import dataclass, field, InitVar
from functools import partial
from operator import methodcaller
from typing import Any, cast, Dict, Iterator, List, Optional, Union

import pandas as pd
from pandas.core.indexes.datetimelike import DatetimeIndexOpsMixin
from toolz import valmap
from toolz import valmap, second

from gluonts.dataset.common import DataEntry, ProcessDataEntry
from gluonts.dataset.field_names import FieldName
from gluonts.itertools import Map
from .common import DataEntry, ProcessDataEntry, _as_period
from .field_names import FieldName
from .schema import Translator


@dataclass
Expand Down Expand Up @@ -152,35 +156,6 @@ def __iter__(self) -> Iterator[DataEntry]:
def __len__(self) -> int:
return len(self._dataframes)

@classmethod
def from_long_dataframe(
cls, dataframe: pd.DataFrame, item_id: str, **kwargs
) -> "PandasDataset":
"""
Construct ``PandasDataset`` out of a long dataframe.
A long dataframe uses the long format for each variable. Target time
series values, for example, are stacked on top of each other rather
than side-by-side. The same is true for other dynamic or categorical
features.

Parameters
----------
dataframe
pandas.DataFrame containing at least ``timestamp``, ``target`` and
``item_id`` columns.
item_id
Name of the column that, when grouped by, gives the different time
series.
**kwargs
Additional arguments. Same as of PandasDataset class.

Returns
-------
PandasDataset
Gluonts dataset based on ``pandas.DataFrame``s.
"""
return cls(dataframes=dict(list(dataframe.groupby(item_id))), **kwargs)


def series_to_dataframe(
series: Union[pd.Series, List[pd.Series], Dict[str, pd.Series]]
Expand Down Expand Up @@ -281,10 +256,12 @@ def prepare_prediction_data(
Remove ``ignore_last_n_targets`` values from ``target`` and
``past_feat_dynamic_real``. Works in univariate and multivariate case.

>>> prepare_prediction_data(
>>> {"target": np.array([1., 2., 3., 4.])}, ignore_last_n_targets=2
>>> )
{'target': array([1., 2.])}
>>> import numpy as np
>>> prepare_prediction_data(
... {"target": np.array([1., 2., 3., 4.])}, ignore_last_n_targets=2
... )
{'target': array([1., 2.])}

"""
entry = deepcopy(dataentry)
for fname in [FieldName.TARGET, FieldName.PAST_FEAT_DYNAMIC_REAL]:
Expand All @@ -298,12 +275,139 @@ def is_uniform(index: pd.PeriodIndex) -> bool:
Check if ``index`` contains monotonically increasing periods, evenly spaced
with frequency ``index.freq``.

>>> ts = ["2021-01-01 00:00", "2021-01-01 02:00", "2021-01-01 04:00"]
>>> is_uniform(pd.DatetimeIndex(ts).to_period("2H"))
True
>>> ts = ["2021-01-01 00:00", "2021-01-01 04:00"]
>>> is_uniform(pd.DatetimeIndex(ts).to_period("2H"))
False
>>> ts = ["2021-01-01 00:00", "2021-01-01 02:00", "2021-01-01 04:00"]
>>> is_uniform(pd.DatetimeIndex(ts).to_period("2H"))
True
>>> ts = ["2021-01-01 00:00", "2021-01-01 04:00"]
>>> is_uniform(pd.DatetimeIndex(ts).to_period("2H"))
False

"""
other = pd.period_range(index[0], periods=len(index), freq=index.freq)
return (other == index).all()


def _column_as_start(dct: dict, column: str, freq, unchecked=False) -> dict:
ts = dct.pop(column)

if unchecked:
dct["start"] = _as_period(ts[0], freq=freq)
else:
idx = pd.PeriodIndex(ts, freq=freq)
assert is_uniform(idx)
dct["start"] = idx[0]

return dct


def _drop_index(df, name):
index = df.index
df = df.reset_index(drop=True)
df[name] = index
return df


@dataclass
class LongDataset:
"""Wrapper for ``pandas.DataFrame`` in long format.

Given a dataframe and an item identifier, this will yield a ``DataEntry``
for each item, by first calling ``dataframe.groupby(item_id)``.

A time-column is optional, but if `timestamp` is provided, `freq` needs to
be provided as well.

``translate`` can be set to rename and stack columns on the fly.

Since the result of the ``groupby`` operation is not guaranteed to be in
right order, each group is sorted by default. If ``assume_sorted`` is set
to ``True``, sorting is skipped which improves performance. Note however,
that if the dataframe is unordered, this will lead to incorrect behaviour.

To improve performance further, ``unchecked`` can be set to ``True`` to
indicate that the timestamp column is uniform (see ``is_uniform``).
"""

df: pd.DataFrame
item_id: Union[str, List[str]]
timestamp: Optional[str] = None
freq: Optional[str] = None

assume_sorted: bool = False
translate: InitVar[Optional[dict]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There’s both translate and translator?

unchecked: bool = False
translator: Optional[Translator] = field(default=None, init=False)
use_index: bool = field(init=False)

def __post_init__(self, translate):
self.use_index = False

if self.timestamp is None:
if isinstance(self.df.index, DatetimeIndexOpsMixin):
if self.freq is None:
assert self.df.index.freq is not None
self.freq = self.df.index.freq

self.use_index = True
self.timestamp = "__timestamp_index__"

elif self.freq is None:
raise ValueError(
"When providing `timestamp`, `freq` needs to be provided too."
)

if translate is not None:
self.translator = Translator.parse(translate, drop=True)
else:
self.translator = None

def _handle_item_id(self, dct):
"""Ensure field "item_id" is a single string."""
if isinstance(self.item_id, list):
dct["item_id"] = ", ".join(
dct.pop(column)[0] for column in self.item_id
)
else:
dct["item_id"] = dct.pop(self.item_id)[0]

return dct

def __iter__(self):
groups = self.df.groupby(self.item_id)
# groups contains tuples of (item_id, df), but we only want df
dataset = Map(second, groups)

if self.use_index:
dataset = Map(
partial(_drop_index, name=self.timestamp),
dataset,
)

if not self.assume_sorted and self.timestamp is not None:
dataset = Map(
methodcaller("sort_values", by=self.timestamp), dataset
)

dataset = Map(methodcaller("to_dict", orient="list"), dataset)
dataset = Map(self._handle_item_id, dataset)

if self.translator is not None:
dataset = Map(self.translator, dataset)
Comment on lines +394 to +395
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering whether this class is taking care of too much now: applying translator here interferes with applying column_as_start below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but the problem is that dataframe only have 1D columns, so we need a way to stack columns. One could of course do it themselves afterwards, but since this is a more lab-focused setup I think there is an argument for having it all in one place.


# handle start field
if self.timestamp is not None:
dataset = Map(
partial(
_column_as_start,
column=self.timestamp,
freq=self.freq,
unchecked=self.unchecked,
),
dataset,
)

yield from dataset
self.unchecked = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's smart


def __len__(self):
return len(self.df.groupby(self.item_id))
99 changes: 99 additions & 0 deletions src/gluonts/dataset/polars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from dataclasses import dataclass, field, InitVar
from functools import partial
from operator import methodcaller
from typing import Union, List, Optional

# import pandas as pd
import polars as pl

from gluonts.itertools import Map
from .pandas import _column_as_start
from .schema import Translator


@dataclass
class LongDataset:
df: pl.DataFrame
item_id: Union[str, List[str]]
timestamp: Optional[str] = None
freq: Optional[str] = None

assume_sorted: bool = False
translate: InitVar[Optional[dict]] = None
translator: Optional[Translator] = field(default=None, init=False)

use_partition: bool = False
unchecked: bool = False

def __post_init__(self, translate):
if (self.timestamp is None) != (self.freq is None):
raise ValueError(
"Either both `timestamp` and `freq` have to be "
"provided or neither."
)

if translate is not None:
self.translator = Translator.parse(translate, drop=True)
else:
self.translator = None

def _pop_item_id(self, dct):
if isinstance(self.item_id, list):
dct["item_id"] = ", ".join(
dct.pop(column)[0] for column in self.item_id
)
else:
dct["item_id"] = dct.pop(self.item_id)[0]

return dct

def __iter__(self):
if self.use_partition:
dataset = self.df.partition_by(self.item_id)
else:
dataset = self.df.groupby(self.item_id)

if not self.assume_sorted:
sort_by = [self.item_id]
if self.timestamp is not None:
sort_by.append(self.timestamp)
dataset = Map(methodcaller("sort", by=sort_by), dataset)

dataset = Map(methodcaller("to_dict", as_series=True), dataset)
dataset = Map(self._pop_item_id, dataset)

if self.translator is not None:
dataset = Map(self.translator, dataset)

if self.timestamp is not None:
dataset = Map(
partial(
_column_as_start,
column=self.timestamp,
freq=self.freq,
unchecked=self.unchecked,
),
dataset,
)

yield from dataset

# we were successful to iterate once over the dataset
# so no more need to check more
self.unchecked = True

def __len__(self):
return len(self.df.groupby(self.item_id).count())
20 changes: 10 additions & 10 deletions test/dataset/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,16 @@ def long_dataframe():

@pytest.fixture
def long_dataset(long_dataframe): # initialized with dict
return pandas.PandasDataset.from_long_dataframe(
dataframe=long_dataframe,
target="target",
return pandas.LongDataset(
long_dataframe,
timestamp="time",
item_id="item",
freq="1H",
feat_dynamic_real=["dyn_real_1"],
feat_static_cat=["stat_cat_1"],
translate=dict(
target="target",
feat_dynamic_real=["dyn_real_1"],
feat_static_cat=["stat_cat_1"],
),
)


Expand Down Expand Up @@ -221,9 +223,8 @@ def test_long_csv_3M():
)

with io.StringIO(data) as fp:
ds = pandas.PandasDataset.from_long_dataframe(
ds = pandas.LongDataset(
pd.read_csv(fp),
target="target",
item_id="item_id",
timestamp="timestamp",
freq="3M",
Expand All @@ -232,9 +233,8 @@ def test_long_csv_3M():
assert entry["start"].freqstr == "3M"

with io.StringIO(data) as fp:
ds = pandas.PandasDataset.from_long_dataframe(
pd.read_csv(fp, index_col="timestamp"),
target="target",
ds = pandas.LongDataset(
pd.read_csv(fp, index_col="timestamp", parse_dates=True),
item_id="item_id",
freq="3M",
)
Expand Down