Skip to content

Commit

Permalink
Add copy_dim to QuantileForecast, change dim method for multiva…
Browse files Browse the repository at this point in the history
…riate data (#2352)
  • Loading branch information
codingWhale13 authored Nov 9, 2022
1 parent 2f0a284 commit 0e4a488
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 20 deletions.
56 changes: 36 additions & 20 deletions src/gluonts/model/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,10 +565,9 @@ def mean(self) -> np.ndarray:
"""
Forecast mean.
"""
if self._mean is not None:
return self._mean
else:
return np.mean(self.samples, axis=0)
if self._mean is None:
self._mean = np.mean(self.samples, axis=0)
return self._mean

@property
def mean_ts(self) -> pd.Series:
Expand Down Expand Up @@ -614,17 +613,16 @@ def copy_aggregate(self, agg_fun: Callable) -> "SampleForecast":
)

def dim(self) -> int:
if self._dim is not None:
return self._dim
else:
if self._dim is None:
if len(self.samples.shape) == 2:
# univariate target
# shape: (num_samples, prediction_length)
return 1
self._dim = 1
else:
# multivariate target
# shape: (num_samples, prediction_length, target_dim)
return self.samples.shape[2]
self._dim = self.samples.shape[2]
return self._dim

def as_json_dict(self, config: "Config") -> dict:
result = super().as_json_dict(config)
Expand Down Expand Up @@ -706,7 +704,7 @@ def __init__(
f"The forecast_array (shape={shape} should have the same "
f"length as the forecast_keys (len={len(self.forecast_keys)})."
)
self.prediction_length = shape[-1]
self.prediction_length = shape[1]
self._forecast_dict = {
k: self.forecast_array[i] for i, k in enumerate(self.forecast_keys)
}
Expand Down Expand Up @@ -744,6 +742,25 @@ def quantile(self, inference_quantile: Union[float, str]) -> np.ndarray:
else:
return linear_interpolation(inference_quantile)

def copy_dim(self, dim: int) -> "QuantileForecast":
if len(self.forecast_array.shape) == 2:
forecast_array = self.forecast_array
else:
target_dim = self.forecast_array.shape[2]
assert dim < target_dim, (
f"must set 0 <= dim < target_dim, but got dim={dim},"
f" target_dim={target_dim}"
)
forecast_array = self.forecast_array[:, :, dim]

return QuantileForecast(
forecast_arrays=forecast_array,
start_date=self.start_date,
forecast_keys=self.forecast_keys,
item_id=self.item_id,
info=self.info,
)

@property
def mean(self) -> np.ndarray:
"""
Expand All @@ -755,17 +772,16 @@ def mean(self) -> np.ndarray:
return self.quantile("p50")

def dim(self) -> int:
if self._dim is not None:
return self._dim
else:
if (
len(self.forecast_array.shape) == 2
): # 1D target. shape: (num_samples, prediction_length)
return 1
if self._dim is None:
if len(self.forecast_array.shape) == 2:
# univariate target
# shape: (num_samples, prediction_length)
self._dim = 1
else:
# 2D target. shape: (num_samples, target_dim,
# prediction_length)
return self.forecast_array.shape[1]
# multivariate target
# shape: (num_samples, prediction_length, target_dim)
self._dim = self.forecast_array.shape[2]
return self._dim

def __repr__(self):
return ", ".join(
Expand Down
35 changes: 35 additions & 0 deletions test/model/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@
),
}

MULTIVARIATE_FORECASTS = {
"SampleForecast": SampleForecast(
samples=np.arange(160).reshape(8, 5, 4) / 100,
start_date=START_DATE,
),
"QuantileForecast": QuantileForecast(
forecast_arrays=np.random.normal(size=(np.size(QUANTILES), 7, 3)),
forecast_keys=np.array(QUANTILES, str),
start_date=START_DATE,
),
}


@pytest.mark.parametrize("name", FORECASTS.keys())
def test_Forecast(name):
Expand Down Expand Up @@ -84,6 +96,29 @@ def test_forecast_multivariate(forecast, exp_index):
assert np.all(forecast.index == exp_index)


@pytest.mark.parametrize("name", MULTIVARIATE_FORECASTS.keys())
def test_copy_dim(name):
forecast = MULTIVARIATE_FORECASTS[name]
for dim in range(forecast.dim()):
univariate_forecast = forecast.copy_dim(dim)

assert univariate_forecast.dim() == 1
assert univariate_forecast.start_date == forecast.start_date
assert univariate_forecast.item_id == forecast.item_id
assert univariate_forecast.info == forecast.info

if name == "SampleForecast":
assert np.array_equal(
univariate_forecast.samples,
MULTIVARIATE_FORECASTS[name].samples[:, :, dim],
)
else:
assert np.array_equal(
univariate_forecast.forecast_array,
MULTIVARIATE_FORECASTS[name].forecast_array[:, :, dim],
)


def test_linear_interpolation() -> None:
tol = 1e-7
x_coord = [0.1, 0.5, 0.9]
Expand Down

0 comments on commit 0e4a488

Please sign in to comment.