Skip to content

Commit

Permalink
Expose aggregation method in ensemble NBEATS, fix forecast shape (aws…
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored and Jasper Zschiegner committed Jan 31, 2023
1 parent a8a1e5a commit 94b2609
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions src/gluonts/model/n_beats/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,15 @@ def predict(
# get the forecast start date
if start_date is None:
start_date = prediction.start_date
output = np.stack(output, axis=0)
output = np.concatenate(output, axis=0)

# aggregating output of different models
# default according to paper is median,
# but we can also make use of not aggregating
if self.aggregation_method == "median":
output = np.median(output, axis=0)
output = np.median(output, axis=0, keepdims=True)
elif self.aggregation_method == "mean":
output = np.mean(output, axis=0)
output = np.mean(output, axis=0, keepdims=True)
else: # "none": do not aggregate
pass

Expand Down Expand Up @@ -312,6 +312,10 @@ class NBEATSEnsembleEstimator(Estimator):
(trend). A list of strings of length 1 or 'num_stacks'.
Default and recommended value for generic mode: ["G"]
Recommended value for interpretable mode: ["T","S"]
aggregation_method
The method by which to aggregate the individual predictions of the
models. Either 'median', 'mean' or 'none', in which case no aggregation
happens. Default is 'median'.
**kwargs
Arguments passed down to the individual estimators.
"""
Expand All @@ -336,6 +340,7 @@ def __init__(
expansion_coefficient_lengths: Optional[List[int]] = None,
sharing: Optional[List[bool]] = None,
stack_types: Optional[List[str]] = None,
aggregation_method: str = "median",
**kwargs,
) -> None:
super().__init__()
Expand Down Expand Up @@ -385,6 +390,7 @@ def __init__(
self.expansion_coefficient_lengths = expansion_coefficient_lengths
self.sharing = sharing
self.stack_types = stack_types
self.aggregation_method = aggregation_method

# Actually instantiate the different models
self.estimators = self._estimator_factory(**kwargs)
Expand Down Expand Up @@ -449,4 +455,8 @@ def train(
)
predictors.append(estimator.train(training_data, validation_data))

return NBEATSEnsemblePredictor(self.prediction_length, predictors)
return NBEATSEnsemblePredictor(
self.prediction_length,
predictors,
aggregation_method=self.aggregation_method,
)

0 comments on commit 94b2609

Please sign in to comment.