Skip to content

Commit

Permalink
Expose aggregation method in ensemble NBEATS, fix forecast shape (#2598)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored Jan 30, 2023
1 parent 0554230 commit 0af9010
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions src/gluonts/mx/model/n_beats/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,15 @@ def predict(
# get the forecast start date
if start_date is None:
start_date = prediction.start_date
output = np.stack(output, axis=0)
output = np.concatenate(output, axis=0)

# aggregating output of different models
# default according to paper is median,
# but we can also make use of not aggregating
if self.aggregation_method == "median":
output = np.median(output, axis=0)
output = np.median(output, axis=0, keepdims=True)
elif self.aggregation_method == "mean":
output = np.mean(output, axis=0)
output = np.mean(output, axis=0, keepdims=True)
else: # "none": do not aggregate
pass

Expand Down Expand Up @@ -311,6 +311,10 @@ class NBEATSEnsembleEstimator(Estimator):
(trend). A list of strings of length 1 or 'num_stacks'.
Default and recommended value for generic mode: ["G"]
Recommended value for interpretable mode: ["T","S"]
aggregation_method
The method by which to aggregate the individual predictions of the
models. Either 'median', 'mean' or 'none', in which case no aggregation
happens. Default is 'median'.
**kwargs
Arguments passed down to the individual estimators.
"""
Expand All @@ -335,6 +339,7 @@ def __init__(
expansion_coefficient_lengths: Optional[List[int]] = None,
sharing: Optional[List[bool]] = None,
stack_types: Optional[List[str]] = None,
aggregation_method: str = "median",
**kwargs,
) -> None:
super().__init__()
Expand Down Expand Up @@ -384,6 +389,7 @@ def __init__(
self.expansion_coefficient_lengths = expansion_coefficient_lengths
self.sharing = sharing
self.stack_types = stack_types
self.aggregation_method = aggregation_method

# Actually instantiate the different models
self.estimators = self._estimator_factory(**kwargs)
Expand Down Expand Up @@ -468,7 +474,11 @@ def _train(
)
)

return NBEATSEnsemblePredictor(self.prediction_length, predictors)
return NBEATSEnsemblePredictor(
self.prediction_length,
predictors,
aggregation_method=self.aggregation_method,
)

def train(
self, training_data: Dataset, validation_data: Optional[Dataset] = None
Expand Down

0 comments on commit 0af9010

Please sign in to comment.