Skip to content

Commit

Permalink
address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
Lorenzo Stella committed Jan 29, 2023
1 parent 701328f commit e8bfc9a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/gluonts/mx/model/n_beats/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,9 @@ def predict(
# default according to paper is median,
# but we can also make use of not aggregating
if self.aggregation_method == "median":
output = np.median(output, axis=0)
output = np.median(output, axis=0, keepdims=True)
elif self.aggregation_method == "mean":
output = np.mean(output, axis=0)
output = np.mean(output, axis=0, keepdims=True)
else: # "none": do not aggregate
pass

Expand Down

0 comments on commit e8bfc9a

Please sign in to comment.