Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose aggregation method in ensemble NBEATS, fix forecast shape #2598

Merged
merged 5 commits into from
Jan 30, 2023

Conversation

lostella
Copy link
Contributor

Issue #, if available: Fixes #2592

Description of changes:

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup

@lostella lostella added bug fix (one of pr required labels) pending v0.11.x backport This contains a fix to be backported to the v0.11.x branch pending v0.12.x backport This contains a fix to be backported to the v0.12.x branch labels Jan 27, 2023
@gorold
Copy link
Contributor

gorold commented Jan 29, 2023

output = np.median(output, axis=0)

output = np.mean(output, axis=0)

I believe the above two lines should be output = np.median(output, axis=0, keepdims=True) instead, since output = np.concatenate(output, axis=0) results in an (N, T) array, and SampleForecast takes in an (N, T) array even if N = 1.

Other than that, looks great! Thanks a lot for the fix!

@lostella lostella added this to the v0.12 milestone Jan 30, 2023
@lostella lostella enabled auto-merge (squash) January 30, 2023 13:31
@lostella lostella added the pending v0.10.x backport This contains a fix to be backported to the v0.10.x branch label Jan 30, 2023
@lostella lostella merged commit 0af9010 into awslabs:dev Jan 30, 2023
@lostella lostella deleted the fix-nbeats-ensembling branch January 30, 2023 14:06
lostella added a commit to lostella/gluonts that referenced this pull request Jan 30, 2023
@lostella lostella mentioned this pull request Jan 30, 2023
lostella added a commit to lostella/gluonts that referenced this pull request Jan 30, 2023
@lostella lostella mentioned this pull request Jan 30, 2023
lostella added a commit that referenced this pull request Jan 30, 2023
* Fix: avoid automatic device detection via serialized tensors when deserializing. (#2576)

* Make itertools Map/Filter dataclasses. (#2579)

* serde: Fix encoding of dtypes. (#2586)

* Add assertion to split function ensuring valid windows (#2587)

* Ensure dtype on feat_time in torch DeepAR. (#2596)

* Expose aggregation method in ensemble NBEATS, fix forecast shape (#2598)

* Fix version in requirements to comply with stricter setuptools. (#2604)

Co-authored-by: Lorenzo Stella <[email protected]>

* Add `gluonts.util.safe_extract` (#2606)

Co-authored-by: Jasper <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>

* fix requirements further

* fix style

* remove undesired change

---------

Co-authored-by: Shubham Kapoor <[email protected]>
Co-authored-by: Jasper <[email protected]>
Co-authored-by: MarcelK1102 <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>
Co-authored-by: Jasper <[email protected]>
@lostella lostella removed the pending v0.11.x backport This contains a fix to be backported to the v0.11.x branch label Jan 30, 2023
jaheba pushed a commit to jaheba/gluon-ts that referenced this pull request Jan 31, 2023
@jaheba jaheba mentioned this pull request Jan 31, 2023
lostella added a commit that referenced this pull request Feb 2, 2023
* Add assertion to split function ensuring valid windows (#2587)

* Ensure dtype on feat_time in torch DeepAR. (#2596)

* Move NPTS back to `gluonts.model` (#2597)

* Expose aggregation method in ensemble NBEATS, fix forecast shape (#2598)

* Fix bug with static cardinalities in `PandasDataset` (#2599)

* Expose `weight_decay` in torch TFT estimator class (#2603)

* Fix version in requirements to comply with stricter setuptools. (#2604)

Co-authored-by: Lorenzo Stella <[email protected]>

* Add `gluonts.util.safe_extract` (#2606)

Co-authored-by: Jasper <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>

* Fix incorrect import in `tsbench`, apply latest black (#2613)

* Allow ReduceLROnPlateau to track val_loss when validation set is available (#2614)

---------

Co-authored-by: MarcelK1102 <[email protected]>
Co-authored-by: Jasper <[email protected]>
Co-authored-by: Gerald Woo <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>
Co-authored-by: Jasper <[email protected]>
jaheba pushed a commit that referenced this pull request Feb 6, 2023
* Fix version in requirements to comply with stricter setuptools. (#2604)

Co-authored-by: Lorenzo Stella <[email protected]>

* Backport: Add gluonts.util.safe_extract (#2606)

* Expose aggregation method in ensemble NBEATS, fix forecast shape (#2598)

* Disable Py36 tests, fix version.

* Fixup.

* Cap numpy compatibility in `mxnet` extra requirements (#2506)

* xfail multivariate grouper test

Co-authored-by: Lorenzo Stella <[email protected]>
Co-authored-by: Jasper <[email protected]>

---------

Co-authored-by: Lorenzo Stella <[email protected]>
Co-authored-by: Lorenzo Stella <[email protected]>
@lostella lostella removed pending v0.10.x backport This contains a fix to be backported to the v0.10.x branch pending v0.12.x backport This contains a fix to be backported to the v0.12.x branch labels Feb 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug fix (one of pr required labels)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Expose NBEATSEnsemble aggregation_method and wrong dimensions
3 participants