Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: DeepLearning Estimator. #2702

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions docs/rfc/deep-learning-estimator.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@

# Deep Learning Estimator

This RFC proposes ``DeepLearningEstimator``, a base class for deep learning
based models.


## Status Quo and Problem Statement

Despite the fact that our deep learning based models behave very similar with
respect to loading data, each estimator currently has to configure the input
processing pipeline manually. These steps are described in the following.

* Each model has an implicit schema that defines the input data. Currently, we
use `AsNumpyArray` to ensure fields are present in input data and that they
have the correct type and shape.

* Next, some features are added and the input is extended with information such
as missing value indicators.

* We then sample windows out of the processed data, using ``InstanceSplitter``
and ``InstanceSampler``.

* Windows are then grouped into batches, which are then stacked (i.e. turning
rows into columns). At last, we convert the input from ``numpy`` arrays to
framework specific types (``torch.tensor``, ``mxnet.ndarray``).

There are some complexifiers": We treat training input as an endless stream of
time series by re-iterating the training dataset over and over again. And to
increase performance we offer the option to cache the input data.
Additionally, there are three kinds of data-loaders, which all behave
differently, one each for "training", "validation" and prediction.


### Further Limitations

Both ``GluonEstimator`` and ``PyTorchLightningEstimator`` provide a common
interface to implement new estimators in MXNet and PyTorch respectively.

However, both classes only provide a transformation chain to handle input data,
meaning that these pipelines need to handle schema validation as well as
feature processing.

Since input time series are represented using plain dictionaries, code handling
these need a lot of additional information to be able to work on them. Thus,
each estimator currently defines its own ``InstanceSplitter`` and configures
dataloading.


## Proposal

We can simplify the implementation of estimators using a common way to handle
data loading.

We replace our dictionary based approach with ``zebras.TimeFrame`` and
``zebras.SplitFrame`` to handle time series data. This has two advantages:
We can use ``zebras.Schema`` to specify input data for a given estimator; and
implementing transformation steps become a lot easier.

We introduce ``DeepLearningInterface`` which provides a more granular interface
to load data. It requires each derived estimator to provide information such
as ``past_length`` which is used to construct batches of the correct size.
Using that information it then can provide default implementations for methods
such as ``.training_instance_splitter(...)`` .

To further simplify configuration we move some options to a ``Settings``
object (dependency injection). This has the advantage that we can alter some
training behaviour without altering the estimator code. For example, we would
configure data caching through settings instead of passing these arguments to
the estimator or train methods:

```py
with setting._let(cache_data=True):
estimator.train(...)
```

## Implementation

We introduce a new class ``DeepLearningEstimator``:

```py

class DeepLearningEstimator:
def get_schema(self) -> zb.Schema:
raise NotImplementedError

def training_pipeline(self):
return []

def prediction_pipeline(self):
self.training_pipeline()

def train_model(self, training_data, validation_data):
raise NotImplementedError

# default implementations

def train(self, training_data, validation_data=None):
training_data = self.training_dataloader(training_data)
validation_data = maybe.map(
validation_data, self.validation_dataloader
)

return self.train_model(training_data, validation_data)

def training_instance_splitter(self):
...

def training_dataloader(self, training_data):
...

def validation_dataloader(self, validation_data):
...

```