-
Notifications
You must be signed in to change notification settings - Fork 752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LongDataset. #2377
Add LongDataset. #2377
Changes from all commits
87e1eb2
5093afd
6323293
a805a50
47d4a7c
8842e1c
2fce16e
8913dba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,15 +12,19 @@ | |
# permissions and limitations under the License. | ||
|
||
from copy import deepcopy | ||
from dataclasses import dataclass, field | ||
from dataclasses import dataclass, field, InitVar | ||
from functools import partial | ||
from operator import methodcaller | ||
from typing import Any, cast, Dict, Iterator, List, Optional, Union | ||
|
||
import pandas as pd | ||
from pandas.core.indexes.datetimelike import DatetimeIndexOpsMixin | ||
from toolz import valmap | ||
from toolz import valmap, second | ||
|
||
from gluonts.dataset.common import DataEntry, ProcessDataEntry | ||
from gluonts.dataset.field_names import FieldName | ||
from gluonts.itertools import Map | ||
from .common import DataEntry, ProcessDataEntry, _as_period | ||
from .field_names import FieldName | ||
from .schema import Translator | ||
|
||
|
||
@dataclass | ||
|
@@ -152,35 +156,6 @@ def __iter__(self) -> Iterator[DataEntry]: | |
def __len__(self) -> int: | ||
return len(self._dataframes) | ||
|
||
@classmethod | ||
def from_long_dataframe( | ||
cls, dataframe: pd.DataFrame, item_id: str, **kwargs | ||
) -> "PandasDataset": | ||
""" | ||
Construct ``PandasDataset`` out of a long dataframe. | ||
A long dataframe uses the long format for each variable. Target time | ||
series values, for example, are stacked on top of each other rather | ||
than side-by-side. The same is true for other dynamic or categorical | ||
features. | ||
|
||
Parameters | ||
---------- | ||
dataframe | ||
pandas.DataFrame containing at least ``timestamp``, ``target`` and | ||
``item_id`` columns. | ||
item_id | ||
Name of the column that, when grouped by, gives the different time | ||
series. | ||
**kwargs | ||
Additional arguments. Same as of PandasDataset class. | ||
|
||
Returns | ||
------- | ||
PandasDataset | ||
Gluonts dataset based on ``pandas.DataFrame``s. | ||
""" | ||
return cls(dataframes=dict(list(dataframe.groupby(item_id))), **kwargs) | ||
|
||
|
||
def series_to_dataframe( | ||
series: Union[pd.Series, List[pd.Series], Dict[str, pd.Series]] | ||
|
@@ -281,10 +256,12 @@ def prepare_prediction_data( | |
Remove ``ignore_last_n_targets`` values from ``target`` and | ||
``past_feat_dynamic_real``. Works in univariate and multivariate case. | ||
|
||
>>> prepare_prediction_data( | ||
>>> {"target": np.array([1., 2., 3., 4.])}, ignore_last_n_targets=2 | ||
>>> ) | ||
{'target': array([1., 2.])} | ||
>>> import numpy as np | ||
>>> prepare_prediction_data( | ||
... {"target": np.array([1., 2., 3., 4.])}, ignore_last_n_targets=2 | ||
... ) | ||
{'target': array([1., 2.])} | ||
|
||
""" | ||
entry = deepcopy(dataentry) | ||
for fname in [FieldName.TARGET, FieldName.PAST_FEAT_DYNAMIC_REAL]: | ||
|
@@ -298,12 +275,139 @@ def is_uniform(index: pd.PeriodIndex) -> bool: | |
Check if ``index`` contains monotonically increasing periods, evenly spaced | ||
with frequency ``index.freq``. | ||
|
||
>>> ts = ["2021-01-01 00:00", "2021-01-01 02:00", "2021-01-01 04:00"] | ||
>>> is_uniform(pd.DatetimeIndex(ts).to_period("2H")) | ||
True | ||
>>> ts = ["2021-01-01 00:00", "2021-01-01 04:00"] | ||
>>> is_uniform(pd.DatetimeIndex(ts).to_period("2H")) | ||
False | ||
>>> ts = ["2021-01-01 00:00", "2021-01-01 02:00", "2021-01-01 04:00"] | ||
>>> is_uniform(pd.DatetimeIndex(ts).to_period("2H")) | ||
True | ||
>>> ts = ["2021-01-01 00:00", "2021-01-01 04:00"] | ||
>>> is_uniform(pd.DatetimeIndex(ts).to_period("2H")) | ||
False | ||
|
||
""" | ||
other = pd.period_range(index[0], periods=len(index), freq=index.freq) | ||
return (other == index).all() | ||
|
||
|
||
def _column_as_start(dct: dict, column: str, freq, unchecked=False) -> dict: | ||
ts = dct.pop(column) | ||
|
||
if unchecked: | ||
dct["start"] = _as_period(ts[0], freq=freq) | ||
else: | ||
idx = pd.PeriodIndex(ts, freq=freq) | ||
assert is_uniform(idx) | ||
dct["start"] = idx[0] | ||
|
||
return dct | ||
|
||
|
||
def _drop_index(df, name): | ||
index = df.index | ||
df = df.reset_index(drop=True) | ||
df[name] = index | ||
return df | ||
|
||
|
||
@dataclass | ||
class LongDataset: | ||
"""Wrapper for ``pandas.DataFrame`` in long format. | ||
|
||
Given a dataframe and an item identifier, this will yield a ``DataEntry`` | ||
for each item, by first calling ``dataframe.groupby(item_id)``. | ||
|
||
A time-column is optional, but if `timestamp` is provided, `freq` needs to | ||
be provided as well. | ||
|
||
``translate`` can be set to rename and stack columns on the fly. | ||
|
||
Since the result of the ``groupby`` operation is not guaranteed to be in | ||
right order, each group is sorted by default. If ``assume_sorted`` is set | ||
to ``True``, sorting is skipped which improves performance. Note however, | ||
that if the dataframe is unordered, this will lead to incorrect behaviour. | ||
|
||
To improve performance further, ``unchecked`` can be set to ``True`` to | ||
indicate that the timestamp column is uniform (see ``is_uniform``). | ||
""" | ||
|
||
df: pd.DataFrame | ||
item_id: Union[str, List[str]] | ||
timestamp: Optional[str] = None | ||
freq: Optional[str] = None | ||
|
||
assume_sorted: bool = False | ||
translate: InitVar[Optional[dict]] = None | ||
unchecked: bool = False | ||
translator: Optional[Translator] = field(default=None, init=False) | ||
use_index: bool = field(init=False) | ||
|
||
def __post_init__(self, translate): | ||
self.use_index = False | ||
|
||
if self.timestamp is None: | ||
if isinstance(self.df.index, DatetimeIndexOpsMixin): | ||
if self.freq is None: | ||
assert self.df.index.freq is not None | ||
self.freq = self.df.index.freq | ||
|
||
self.use_index = True | ||
self.timestamp = "__timestamp_index__" | ||
|
||
elif self.freq is None: | ||
raise ValueError( | ||
"When providing `timestamp`, `freq` needs to be provided too." | ||
) | ||
|
||
if translate is not None: | ||
self.translator = Translator.parse(translate, drop=True) | ||
else: | ||
self.translator = None | ||
|
||
def _handle_item_id(self, dct): | ||
"""Ensure field "item_id" is a single string.""" | ||
if isinstance(self.item_id, list): | ||
dct["item_id"] = ", ".join( | ||
dct.pop(column)[0] for column in self.item_id | ||
) | ||
else: | ||
dct["item_id"] = dct.pop(self.item_id)[0] | ||
|
||
return dct | ||
|
||
def __iter__(self): | ||
groups = self.df.groupby(self.item_id) | ||
# groups contains tuples of (item_id, df), but we only want df | ||
dataset = Map(second, groups) | ||
|
||
if self.use_index: | ||
dataset = Map( | ||
partial(_drop_index, name=self.timestamp), | ||
dataset, | ||
) | ||
|
||
if not self.assume_sorted and self.timestamp is not None: | ||
dataset = Map( | ||
methodcaller("sort_values", by=self.timestamp), dataset | ||
) | ||
|
||
dataset = Map(methodcaller("to_dict", orient="list"), dataset) | ||
dataset = Map(self._handle_item_id, dataset) | ||
|
||
if self.translator is not None: | ||
dataset = Map(self.translator, dataset) | ||
Comment on lines
+394
to
+395
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering whether this class is taking care of too much now: applying There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, but the problem is that dataframe only have 1D columns, so we need a way to stack columns. One could of course do it themselves afterwards, but since this is a more lab-focused setup I think there is an argument for having it all in one place. |
||
|
||
# handle start field | ||
if self.timestamp is not None: | ||
dataset = Map( | ||
partial( | ||
_column_as_start, | ||
column=self.timestamp, | ||
freq=self.freq, | ||
unchecked=self.unchecked, | ||
), | ||
dataset, | ||
) | ||
|
||
yield from dataset | ||
self.unchecked = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's smart |
||
|
||
def __len__(self): | ||
return len(self.df.groupby(self.item_id)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). | ||
# You may not use this file except in compliance with the License. | ||
# A copy of the License is located at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# or in the "license" file accompanying this file. This file is distributed | ||
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
# express or implied. See the License for the specific language governing | ||
# permissions and limitations under the License. | ||
|
||
from dataclasses import dataclass, field, InitVar | ||
from functools import partial | ||
from operator import methodcaller | ||
from typing import Union, List, Optional | ||
|
||
# import pandas as pd | ||
import polars as pl | ||
|
||
from gluonts.itertools import Map | ||
from .pandas import _column_as_start | ||
from .schema import Translator | ||
|
||
|
||
@dataclass | ||
class LongDataset: | ||
df: pl.DataFrame | ||
item_id: Union[str, List[str]] | ||
timestamp: Optional[str] = None | ||
freq: Optional[str] = None | ||
|
||
assume_sorted: bool = False | ||
translate: InitVar[Optional[dict]] = None | ||
translator: Optional[Translator] = field(default=None, init=False) | ||
|
||
use_partition: bool = False | ||
unchecked: bool = False | ||
|
||
def __post_init__(self, translate): | ||
if (self.timestamp is None) != (self.freq is None): | ||
raise ValueError( | ||
"Either both `timestamp` and `freq` have to be " | ||
"provided or neither." | ||
) | ||
|
||
if translate is not None: | ||
self.translator = Translator.parse(translate, drop=True) | ||
else: | ||
self.translator = None | ||
|
||
def _pop_item_id(self, dct): | ||
if isinstance(self.item_id, list): | ||
dct["item_id"] = ", ".join( | ||
dct.pop(column)[0] for column in self.item_id | ||
) | ||
else: | ||
dct["item_id"] = dct.pop(self.item_id)[0] | ||
|
||
return dct | ||
|
||
def __iter__(self): | ||
if self.use_partition: | ||
dataset = self.df.partition_by(self.item_id) | ||
else: | ||
dataset = self.df.groupby(self.item_id) | ||
|
||
if not self.assume_sorted: | ||
sort_by = [self.item_id] | ||
if self.timestamp is not None: | ||
sort_by.append(self.timestamp) | ||
dataset = Map(methodcaller("sort", by=sort_by), dataset) | ||
|
||
dataset = Map(methodcaller("to_dict", as_series=True), dataset) | ||
dataset = Map(self._pop_item_id, dataset) | ||
|
||
if self.translator is not None: | ||
dataset = Map(self.translator, dataset) | ||
|
||
if self.timestamp is not None: | ||
dataset = Map( | ||
partial( | ||
_column_as_start, | ||
column=self.timestamp, | ||
freq=self.freq, | ||
unchecked=self.unchecked, | ||
), | ||
dataset, | ||
) | ||
|
||
yield from dataset | ||
|
||
# we were successful to iterate once over the dataset | ||
# so no more need to check more | ||
self.unchecked = True | ||
|
||
def __len__(self): | ||
return len(self.df.groupby(self.item_id).count()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There’s both
translate
andtranslator
?