Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add copy_dim to QuantileForecast, change dim method for multivariate data #2352

Merged
merged 5 commits into from
Nov 9, 2022

Conversation

codingWhale13
Copy link
Contributor

Issue #, if available:

Description of changes:
This PR adds copy_dim to the QuantileForecast class, similar to how it is done in SampleForecast

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@codingWhale13 codingWhale13 added the enhancement New feature or request label Oct 10, 2022
@codingWhale13 codingWhale13 mentioned this pull request Oct 10, 2022
@jaheba
Copy link
Contributor

jaheba commented Oct 12, 2022

What is the purpose of copy_dim?

@codingWhale13
Copy link
Contributor Author

What is the purpose of copy_dim?

The purpose is stated in the base class Forecast: "Returns a new Forecast object with only the selected sub-dimension."

@jaheba
Copy link
Contributor

jaheba commented Oct 12, 2022

That's some weird naming.

Should we add a test for this?

@codingWhale13
Copy link
Contributor Author

codingWhale13 commented Oct 18, 2022

While trying to write a test, I found that in the multivariate case, the 3d NumPy arrays for samples/forecast_array (for SampleForecast/QuantileForecast respectively) seem inconsistent:

SampleForecast: (num_samples, prediction_length, target_dim)
QuantileForecast: (num_samples, target_dim, prediction_length).

Before this commit, it was both (num_samples, target_dim, prediction_length) and then it was changed in SampleForecast.

It seems to me like the change for QuantileForecasts slipped through and multivariate QuantileForecasts weren't much used since then.

Please correct me if I'm wrong and if changing the meaning of QuantileForecast.forecast_array in this way breaks anything.

@codingWhale13 codingWhale13 marked this pull request as draft October 18, 2022 12:32
@codingWhale13
Copy link
Contributor Author

codingWhale13 commented Nov 2, 2022

With N = batch axis, T = time axis, C = channel axis (in the context of multivariate forecasts: target_dim), here are some examples of what's used in a few random places of the code:

Even though comparing these different contexts only makes limited sense, it looks like T usually precedes C, to which things have been updated in the SampleForecast a while ago, as mentioned in an above comment.

Long story short, in my view this confirms that changing the dimensions of the multivariate QuantileForecast makes sense.

@lostella
Copy link
Contributor

lostella commented Nov 9, 2022

We changed to the (N, T, C) layout because that's closer to the distribution abstraction: for multivariate distributions, it's over the trailing dimension (C) that data is sampled jointly and density is evaluated. This is the case for example in PyTorch distributions, but also in the MXNet based ones we have in GluonTS.

Indeed, the change you point out just missed the QuantileForecast case. I think it makes sense to change that as you proposed: it will be a breaking change, but hopefully a minor one.

@lostella lostella marked this pull request as ready for review November 9, 2022 10:27
@lostella lostella added the BREAKING This is a breaking change (one of pr required labels) label Nov 9, 2022
@lostella lostella changed the title Add copy_dim to QuantileForecast Add copy_dim to QuantileForecast, change dim method for multivariate data Nov 9, 2022
@lostella lostella merged commit 0e4a488 into awslabs:dev Nov 9, 2022
@codingWhale13 codingWhale13 deleted the improve-forecast branch November 9, 2022 11:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
BREAKING This is a breaking change (one of pr required labels) enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants