diff --git a/pymc_extras/__init__.py b/pymc_extras/__init__.py index c675be24..7df570f5 100644 --- a/pymc_extras/__init__.py +++ b/pymc_extras/__init__.py @@ -27,5 +27,3 @@ if len(_log.handlers) == 0: handler = logging.StreamHandler() _log.addHandler(handler) - -__all__ = ["fit", "MarginalModel", "marginalize", "as_model"] diff --git a/pymc_extras/model/marginal/graph_analysis.py b/pymc_extras/model/marginal/graph_analysis.py index 57a74a4f..d4ad98a9 100644 --- a/pymc_extras/model/marginal/graph_analysis.py +++ b/pymc_extras/model/marginal/graph_analysis.py @@ -5,7 +5,6 @@ from pymc import SymbolicRandomVariable from pytensor.compile import SharedVariable -from pytensor.compile.builders import OpFromGraph from pytensor.graph import Constant, Variable, ancestors from pytensor.graph.basic import io_toposort from pytensor.tensor import TensorType, TensorVariable @@ -17,6 +16,8 @@ from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor, get_idx_list from pytensor.tensor.type_other import NoneTypeT +from pymc_extras.model.marginal.distributions import MarginalRV + def static_shape_ancestors(vars): """Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph).""" @@ -62,7 +63,7 @@ def find_conditional_dependent_rvs(dependable_rv, all_rvs): def get_support_axes(op) -> tuple[tuple[int, ...], ...]: - if hasattr(op, "support_axes"): + if isinstance(op, MarginalRV): return op.support_axes else: # For vanilla RVs, the support axes are the last ndim_supp @@ -145,7 +146,7 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) output_dims = tuple(None if i == "x" else input_dims[i] for i in node.op.new_order) var_dims[node.outputs[0]] = output_dims - elif (isinstance(node.op, OpFromGraph) and hasattr(node.op, "support_axes")) or ( + elif isinstance(node.op, MarginalRV) or ( isinstance(node.op, SymbolicRandomVariable) and node.op.extended_signature is None ): # MarginalRV and SymbolicRandomVariables without signature are a wild-card, @@ -159,7 +160,7 @@ def _subgraph_batch_dim_connection(var_dims: VAR_DIMS, input_vars, output_vars) ) support_axes = iter(get_support_axes(op)) - if hasattr(op, "support_axes"): + if isinstance(op, MarginalRV): # The first output is the marginalized variable for which we don't compute support axes support_axes = itertools.chain(((),), support_axes) for i, (out, inner_out) in enumerate(zip(node.outputs, inner_outputs)): diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 825ab327..8ad86c10 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -19,6 +19,7 @@ from pymc_extras.statespace.core.representation import PytensorRepresentation from pymc_extras.statespace.filters import ( KalmanSmoother, + SquareRootFilter, StandardFilter, UnivariateFilter, ) @@ -50,6 +51,7 @@ FILTER_FACTORY = { "standard": StandardFilter, "univariate": UnivariateFilter, + "cholesky": SquareRootFilter, } diff --git a/pymc_extras/statespace/filters/__init__.py b/pymc_extras/statespace/filters/__init__.py index b418a45e..f76dea8d 100644 --- a/pymc_extras/statespace/filters/__init__.py +++ b/pymc_extras/statespace/filters/__init__.py @@ -1,5 +1,6 @@ from pymc_extras.statespace.filters.distributions import LinearGaussianStateSpace from pymc_extras.statespace.filters.kalman_filter import ( + SquareRootFilter, StandardFilter, UnivariateFilter, ) @@ -9,5 +10,6 @@ "StandardFilter", "UnivariateFilter", "KalmanSmoother", + "SquareRootFilter", "LinearGaussianStateSpace", ] diff --git a/pymc_extras/statespace/models/utilities.py b/pymc_extras/statespace/models/utilities.py index c5e9428c..6bc22370 100644 --- a/pymc_extras/statespace/models/utilities.py +++ b/pymc_extras/statespace/models/utilities.py @@ -233,8 +233,8 @@ def make_SARIMA_transition_matrix( 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \end{bmatrix} When ARIMA differences and seasonal differences are mixed, the seasonal differences will be written in terms of the - highest ARIMA difference order, and recovery of the level state will require the use of all the ARIMA - differences, as well as the seasonal differences. In addition, the seasonal differences are needed to back out the ARIMA + highest ARIMA difference order, and recovery of the level state will require the use of all the ARIMA differences, + as well as the seasonal differences. In addition, the seasonal differences are needed to back out the ARIMA differences from :math:`x_t^\star`. Here is the differencing block for a SARIMA(0,2,0)x(0,2,0,4) -- the identites of the states is left an exercise for the motivated reader: diff --git a/tests/statespace/test_coord_assignment.py b/tests/statespace/test_coord_assignment.py index dafc4bd9..8e2fea58 100644 --- a/tests/statespace/test_coord_assignment.py +++ b/tests/statespace/test_coord_assignment.py @@ -11,17 +11,12 @@ from pymc_extras.statespace.utils.constants import ( FILTER_OUTPUT_DIMS, FILTER_OUTPUT_NAMES, - JITTER_DEFAULT, - LONG_MATRIX_NAMES, - MISSING_FILL, - SHORT_NAME_TO_LONG, SMOOTHER_OUTPUT_NAMES, TIME_DIM, ) from pymc_extras.statespace.utils.data_tools import ( NO_FREQ_INFO_WARNING, NO_TIME_INDEX_WARNING, - register_data_with_pymc, ) from tests.statespace.utilities.test_helpers import load_nile_test_data diff --git a/tests/statespace/test_distributions.py b/tests/statespace/test_distributions.py index d26bd5e1..e7819428 100644 --- a/tests/statespace/test_distributions.py +++ b/tests/statespace/test_distributions.py @@ -16,11 +16,7 @@ ) from pymc_extras.statespace.utils.constants import ( ALL_STATE_DIM, - JITTER_DEFAULT, - LONG_MATRIX_NAMES, - MISSING_FILL, OBS_STATE_DIM, - SHORT_NAME_TO_LONG, TIME_DIM, ) from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import @@ -41,7 +37,7 @@ filter_names = [ "standard", - # "cholesky", + "cholesky", "univariate", ] diff --git a/tests/statespace/test_kalman_filter.py b/tests/statespace/test_kalman_filter.py index 67bddacc..b55cbcea 100644 --- a/tests/statespace/test_kalman_filter.py +++ b/tests/statespace/test_kalman_filter.py @@ -10,7 +10,7 @@ StandardFilter, UnivariateFilter, ) -from pymc_extras.statespace.filters.kalman_filter import BaseFilter +from pymc_extras.statespace.filters.kalman_filter import BaseFilter, SquareRootFilter from tests.statespace.utilities.shared_fixtures import ( # pylint: disable=unused-import rng, ) @@ -30,17 +30,18 @@ RTOL = 1e-6 if floatX.endswith("64") else 1e-3 standard_inout = initialize_filter(StandardFilter()) -# cholesky_inout = initialize_filter(CholeskyFilter()) +cholesky_inout = initialize_filter(SquareRootFilter()) univariate_inout = initialize_filter(UnivariateFilter()) f_standard = pytensor.function(*standard_inout, on_unused_input="ignore") -# f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore") +f_cholesky = pytensor.function(*cholesky_inout, on_unused_input="ignore") f_univariate = pytensor.function(*univariate_inout, on_unused_input="ignore") -filter_funcs = [f_standard, f_univariate] +filter_funcs = [f_standard, f_cholesky, f_univariate] filter_names = [ "StandardFilter", + "CholeskyFilter", "UnivariateFilter", ] @@ -229,8 +230,8 @@ def test_last_smoother_is_last_filtered(filter_func, output_idx, rng): @pytest.mark.skipif(floatX == "float32", reason="Tests are too sensitive for float32") def test_filters_match_statsmodel_output(filter_func, filter_name, n_missing, rng): fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing) - # if filter_name == "CholeskyFilter": - # P0 = np.linalg.cholesky(P0) + if filter_name == "CholeskyFilter": + P0 = np.linalg.cholesky(P0) inputs = [data, a0, P0, c, d, T, Z, R, H, Q] outputs = filter_func(*inputs) @@ -278,8 +279,8 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob pytest.skip("Univariate filter not stable at half precision without measurement error") fit_sm_mod, [data, a0, P0, c, d, T, Z, R, H, Q] = nile_test_test_helper(rng, n_missing) - # if filter_name == "CholeskyFilter": - # P0 = np.linalg.cholesky(P0) + if filter_name == "CholeskyFilter": + P0 = np.linalg.cholesky(P0) H *= int(obs_noise) inputs = [data, a0, P0, c, d, T, Z, R, H, Q] @@ -301,7 +302,7 @@ def test_all_covariance_matrices_are_PSD(filter_func, filter_name, n_missing, ob @pytest.mark.parametrize( "filter", - [StandardFilter], + [StandardFilter, SquareRootFilter], ids=["standard"], ) def test_kalman_filter_jax(filter): diff --git a/tests/statespace/test_statespace.py b/tests/statespace/test_statespace.py index a85875b0..b9b78dff 100644 --- a/tests/statespace/test_statespace.py +++ b/tests/statespace/test_statespace.py @@ -14,14 +14,8 @@ from pymc_extras.statespace.models.utilities import make_default_coords from pymc_extras.statespace.utils.constants import ( FILTER_OUTPUT_NAMES, - JITTER_DEFAULT, - LONG_MATRIX_NAMES, MATRIX_NAMES, - MISSING_FILL, - NEVER_TIME_VARYING, - SHORT_NAME_TO_LONG, SMOOTHER_OUTPUT_NAMES, - VECTOR_VALUED, ) from tests.statespace.utilities.shared_fixtures import ( rng, diff --git a/tests/statespace/test_statespace_JAX.py b/tests/statespace/test_statespace_JAX.py index 599efd15..9e8d9975 100644 --- a/tests/statespace/test_statespace_JAX.py +++ b/tests/statespace/test_statespace_JAX.py @@ -10,11 +10,7 @@ from pymc_extras.statespace.utils.constants import ( FILTER_OUTPUT_NAMES, - JITTER_DEFAULT, - LONG_MATRIX_NAMES, MATRIX_NAMES, - MISSING_FILL, - SHORT_NAME_TO_LONG, SMOOTHER_OUTPUT_NAMES, ) from tests.statespace.test_statespace import ( # pylint: disable=unused-import diff --git a/tests/statespace/test_structural.py b/tests/statespace/test_structural.py index 0fead4a4..858efadf 100644 --- a/tests/statespace/test_structural.py +++ b/tests/statespace/test_structural.py @@ -20,9 +20,6 @@ ALL_STATE_AUX_DIM, ALL_STATE_DIM, AR_PARAM_DIM, - JITTER_DEFAULT, - LONG_MATRIX_NAMES, - MISSING_FILL, OBS_STATE_AUX_DIM, OBS_STATE_DIM, SHOCK_AUX_DIM, diff --git a/tests/statespace/utilities/test_helpers.py b/tests/statespace/utilities/test_helpers.py index 6a1cae31..c6170f88 100644 --- a/tests/statespace/utilities/test_helpers.py +++ b/tests/statespace/utilities/test_helpers.py @@ -9,9 +9,7 @@ from pymc_extras.statespace.filters.kalman_smoother import KalmanSmoother from pymc_extras.statespace.utils.constants import ( - JITTER_DEFAULT, MATRIX_NAMES, - MISSING_FILL, SHORT_NAME_TO_LONG, ) from tests.statespace.utilities.statsmodel_local_level import LocalLinearTrend diff --git a/tests/test_blackjax_smc.py b/tests/test_blackjax_smc.py index 5735784a..06245d0b 100644 --- a/tests/test_blackjax_smc.py +++ b/tests/test_blackjax_smc.py @@ -20,6 +20,9 @@ from numpy import dtype from xarray.core.utils import Frozen +jax = pytest.importorskip("jax") +pytest.importorskip("blackjax") + from pymc_extras.inference.smc.sampling import ( arviz_from_particles, blackjax_particles_from_pymc_population, @@ -28,9 +31,6 @@ sample_smc_blackjax, ) -jax = pytest.importorskip("jax") -pytest.importorskip("blackjax") - def two_gaussians_model(): n = 4 diff --git a/tests/test_find_map.py b/tests/test_find_map.py index 67d22b98..13a39a4b 100644 --- a/tests/test_find_map.py +++ b/tests/test_find_map.py @@ -1,11 +1,9 @@ -from typing import Literal - import numpy as np import pymc as pm import pytensor.tensor as pt import pytest -from pymc_extras.inference.find_map import find_MAP, scipy_optimize_funcs_from_loss +from pymc_extras.inference.find_map import GradientBackend, find_MAP, scipy_optimize_funcs_from_loss pytest.importorskip("jax") @@ -16,10 +14,6 @@ def rng(): return np.random.default_rng(seed) -# Define GradientBackend type alias -GradientBackend = Literal["jax", "pytensor"] - - @pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str) def test_jax_functions_from_graph(gradient_backend: GradientBackend): x = pt.tensor("x", shape=(2,))