diff --git a/pymc_experimental/statespace/core/statespace.py b/pymc_experimental/statespace/core/statespace.py index e71ce51e..52af2491 100644 --- a/pymc_experimental/statespace/core/statespace.py +++ b/pymc_experimental/statespace/core/statespace.py @@ -1855,7 +1855,16 @@ def forecast( ) start = time_index[-1] - scenario = self._validate_scenario_data(scenario, verbose=verbose) + if not isinstance(scenario, dict): + if len(self.data_names) > 1: + raise ValueError( + "Model needs more than one exogenous data to do forecasting. In this case, you must " + "pass a dictionary of scenario data." + ) + [data_name] = self.data_names + scenario = {data_name: scenario} + + scenario: dict = self._validate_scenario_data(scenario, verbose=verbose) self._validate_forecast_args( time_index=time_index, @@ -1917,19 +1926,14 @@ def forecast( for data_name in self.data_names } - subbed_matrices = graph_replace(matrices, replace=sub_dict, strict=True) - [ - setattr(matrix, "name", name) - for name, matrix in zip(MATRIX_NAMES[2:], subbed_matrices) - ] - else: - subbed_matrices = matrices + matrices = graph_replace(matrices, replace=sub_dict, strict=True) + [setattr(matrix, "name", name) for name, matrix in zip(MATRIX_NAMES[2:], matrices)] _ = LinearGaussianStateSpace( "forecast", x0, P0, - *subbed_matrices, + *matrices, steps=len(forecast_index[:-1]), dims=dims, mode=self._fit_mode,