Skip to content

Commit

Permalink
expose aggregation method, fix forecast shape
Browse files Browse the repository at this point in the history
  • Loading branch information
Lorenzo Stella committed Jan 27, 2023
1 parent 85b3dde commit 17c4018
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/gluonts/mx/model/n_beats/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def predict(
# get the forecast start date
if start_date is None:
start_date = prediction.start_date
output = np.stack(output, axis=0)
output = np.concatenate(output, axis=0)

# aggregating output of different models
# default according to paper is median,
Expand Down Expand Up @@ -311,6 +311,10 @@ class NBEATSEnsembleEstimator(Estimator):
(trend). A list of strings of length 1 or 'num_stacks'.
Default and recommended value for generic mode: ["G"]
Recommended value for interpretable mode: ["T","S"]
aggregation_method
The method by which to aggregate the individual predictions of the
models. Either 'median', 'mean' or 'none', in which case no aggregation
happens. Default is 'median'.
**kwargs
Arguments passed down to the individual estimators.
"""
Expand All @@ -335,6 +339,7 @@ def __init__(
expansion_coefficient_lengths: Optional[List[int]] = None,
sharing: Optional[List[bool]] = None,
stack_types: Optional[List[str]] = None,
aggregation_method: str = "median",
**kwargs,
) -> None:
super().__init__()
Expand Down Expand Up @@ -384,6 +389,7 @@ def __init__(
self.expansion_coefficient_lengths = expansion_coefficient_lengths
self.sharing = sharing
self.stack_types = stack_types
self.aggregation_method = aggregation_method

# Actually instantiate the different models
self.estimators = self._estimator_factory(**kwargs)
Expand Down Expand Up @@ -468,7 +474,11 @@ def _train(
)
)

return NBEATSEnsemblePredictor(self.prediction_length, predictors)
return NBEATSEnsemblePredictor(
self.prediction_length,
predictors,
aggregation_method=self.aggregation_method,
)

def train(
self, training_data: Dataset, validation_data: Optional[Dataset] = None
Expand Down

0 comments on commit 17c4018

Please sign in to comment.