Skip to content

Commit

Permalink
Rebase from main and run new pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Jul 28, 2024
1 parent 2394448 commit 1fb5536
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
8 changes: 4 additions & 4 deletions pymc_experimental/statespace/models/ETS.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Sequence
from collections.abc import Sequence
from typing import Any

import numpy as np
import pytensor.tensor as pt
Expand Down Expand Up @@ -159,7 +160,6 @@ def __init__(
filter_type: str = "standard",
verbose: bool = True,
):

if order is not None:
if len(order) != 3 or any(not isinstance(o, str) for o in order):
raise ValueError("Order must be a tuple of three strings.")
Expand Down Expand Up @@ -405,14 +405,14 @@ def make_symbolic_graph(self) -> None:
self.ssm["design"] = Z

# Set up the state covariance matrix
state_cov_idx = ("state_cov",) + np.diag_indices(self.k_posdef)
state_cov_idx = ("state_cov", *np.diag_indices(self.k_posdef))
state_cov = self.make_and_register_variable(
"sigma_state", shape=() if self.k_posdef == 1 else (self.k_posdef,), dtype=floatX
)
self.ssm[state_cov_idx] = state_cov**2

if self.measurement_error:
obs_cov_idx = ("obs_cov",) + np.diag_indices(self.k_endog)
obs_cov_idx = ("obs_cov", *np.diag_indices(self.k_endog))
obs_cov = self.make_and_register_variable(
"sigma_obs", shape=() if self.k_endog == 1 else (self.k_endog,), dtype=floatX
)
Expand Down
5 changes: 3 additions & 2 deletions tests/statespace/test_ETS.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytensor
import pytest
import statsmodels.api as sm

from numpy.testing import assert_allclose
from pytensor.graph.basic import explicit_graph_inputs
from scipy import linalg
Expand All @@ -20,11 +21,11 @@ def data():
def tests_invalid_order_raises():
# Order must be length 3
with pytest.raises(ValueError, match="Order must be a tuple of three strings"):
BayesianETS(order=("A", "N")) # noqa
BayesianETS(order=("A", "N"))

# Order must be strings
with pytest.raises(ValueError, match="Order must be a tuple of three strings"):
BayesianETS(order=(2, 1, 1)) # noqa
BayesianETS(order=(2, 1, 1))

# Only additive errors allowed
with pytest.raises(ValueError, match="Only additive errors are supported"):
Expand Down

0 comments on commit 1fb5536

Please sign in to comment.