Skip to content

Commit

Permalink
Wrap user scenarios in dictionary
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Aug 23, 2024
1 parent f92782c commit eaf285c
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions pymc_experimental/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,7 +1855,16 @@ def forecast(
)
start = time_index[-1]

scenario = self._validate_scenario_data(scenario, verbose=verbose)
if not isinstance(scenario, dict):
if len(self.data_names) > 1:
raise ValueError(
"Model needs more than one exogenous data to do forecasting. In this case, you must "
"pass a dictionary of scenario data."
)
[data_name] = self.data_names
scenario = {data_name: scenario}

scenario: dict = self._validate_scenario_data(scenario, verbose=verbose)

self._validate_forecast_args(
time_index=time_index,
Expand Down Expand Up @@ -1917,19 +1926,14 @@ def forecast(
for data_name in self.data_names
}

subbed_matrices = graph_replace(matrices, replace=sub_dict, strict=True)
[
setattr(matrix, "name", name)
for name, matrix in zip(MATRIX_NAMES[2:], subbed_matrices)
]
else:
subbed_matrices = matrices
matrices = graph_replace(matrices, replace=sub_dict, strict=True)
[setattr(matrix, "name", name) for name, matrix in zip(MATRIX_NAMES[2:], matrices)]

_ = LinearGaussianStateSpace(
"forecast",
x0,
P0,
*subbed_matrices,
*matrices,
steps=len(forecast_index[:-1]),
dims=dims,
mode=self._fit_mode,
Expand Down

0 comments on commit eaf285c

Please sign in to comment.