Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 9, 2024
1 parent 607e0d5 commit 17d8d19
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 48 deletions.
17 changes: 8 additions & 9 deletions notebooks/Making a Custom Statespace Model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
"\n",
"numpyro.set_host_device_count(4)\n",
"\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import arviz as az\n",
"\n",
"from pymc_experimental.statespace.core.statespace import PyMCStateSpace\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pymc as pm\n",
"import pytensor.tensor as pt\n",
"import pymc as pm"
"\n",
"from pymc_experimental.statespace.core.statespace import PyMCStateSpace"
]
},
{
Expand Down Expand Up @@ -1092,7 +1092,7 @@
],
"source": [
"az.plot_posterior(\n",
" idata, var_names=[\"ar_params\", \"sigma_x\"], ref_val=true_ar.tolist() + [true_sigma_x]\n",
" idata, var_names=[\"ar_params\", \"sigma_x\"], ref_val=[*true_ar.tolist(), true_sigma_x]\n",
");"
]
},
Expand Down Expand Up @@ -1169,13 +1169,12 @@
"metadata": {},
"outputs": [],
"source": [
"from pymc_experimental.statespace.models.utilities import make_default_coords\n",
"from pymc_experimental.statespace.utils.constants import (\n",
" ALL_STATE_DIM,\n",
" ALL_STATE_AUX_DIM,\n",
" OBS_STATE_DIM,\n",
" ALL_STATE_DIM,\n",
" SHOCK_DIM,\n",
")\n",
"from pymc_experimental.statespace.models.utilities import make_default_coords\n",
"\n",
"\n",
"class AutoRegressiveThree(PyMCStateSpace):\n",
Expand Down
17 changes: 8 additions & 9 deletions notebooks/SARMA Example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,19 @@
"\n",
"numpyro.set_host_device_count(8)\n",
"\n",
"import pymc as pm\n",
"from pytensor import tensor as pt\n",
"\n",
"import arviz as az\n",
"import statsmodels.api as sm\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"from scipy import stats\n",
"import pymc as pm\n",
"import statsmodels.api as sm\n",
"\n",
"import pymc_experimental.statespace as pmss\n",
"from pymc.model.transform.optimization import freeze_dims_and_data\n",
"from pytensor import tensor as pt\n",
"from pytensor.link.jax.dispatch import jax_funcify\n",
"from pytensor.tensor.nlinalg import KroneckerProduct\n",
"from pymc.model.transform.optimization import freeze_dims_and_data\n",
"\n",
"import pymc_experimental.statespace as pmss\n",
"\n",
"\n",
"@jax_funcify.register(KroneckerProduct)\n",
Expand Down Expand Up @@ -2582,8 +2581,8 @@
"source": [
"fig, ax = plt.subplots()\n",
"post = az.extract(post_pred).map(np.exp)\n",
"hdi = az.hdi(post_pred.map(np.exp))[f\"predicted_posterior_observed\"]\n",
"post[f\"predicted_posterior_observed\"].isel(observed_state=0).mean(dim=\"sample\").plot.line(\n",
"hdi = az.hdi(post_pred.map(np.exp))[\"predicted_posterior_observed\"]\n",
"post[\"predicted_posterior_observed\"].isel(observed_state=0).mean(dim=\"sample\").plot.line(\n",
" x=\"time\", ax=ax, add_legend=False, label=\"Posterior Mean, Predicted\"\n",
")\n",
"ax.fill_between(\n",
Expand Down
20 changes: 10 additions & 10 deletions notebooks/Structural Timeseries Modeling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,18 @@
"import sys\n",
"\n",
"sys.path.append(\"..\")\n",
"from pymc_experimental.statespace import structural as st\n",
"from pymc_experimental.statespace.utils.constants import SHORT_NAME_TO_LONG, MATRIX_NAMES\n",
"import matplotlib.pyplot as plt\n",
"import pymc as pm\n",
"import arviz as az\n",
"import pytensor\n",
"import pytensor.tensor as pt\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import pymc as pm\n",
"import pytensor.tensor as pt\n",
"\n",
"from patsy import dmatrix\n",
"\n",
"from pymc_experimental.statespace import structural as st\n",
"from pymc_experimental.statespace.utils.constants import SHORT_NAME_TO_LONG\n",
"\n",
"plt.rcParams.update(\n",
" {\n",
" \"figure.figsize\": (14, 4),\n",
Expand Down Expand Up @@ -61,15 +62,14 @@
},
"outputs": [],
"source": [
"from pymc_experimental.statespace.filters.kalman_filter import StandardFilter\n",
"from pymc_experimental.statespace.filters.kalman_smoother import KalmanSmoother\n",
"from pymc.pytensorf import compile_pymc, inputvars\n",
"\n",
"from pymc_experimental.statespace.filters.distributions import LinearGaussianStateSpace\n",
"from pymc.pytensorf import inputvars, compile_pymc\n",
"\n",
"\n",
"def make_numpy_function(mod):\n",
" mod = mod.build(verbose=False)\n",
" data = pt.matrix(\"data\", shape=(None, 1))\n",
" pt.matrix(\"data\", shape=(None, 1))\n",
" steps = pt.iscalar(\"steps\")\n",
" x0, _, c, d, T, Z, R, H, Q = mod._unpack_statespace_with_placeholders()\n",
" sequence_names = [x.name for x in [c, d] if x.ndim == 2]\n",
Expand Down
16 changes: 8 additions & 8 deletions notebooks/VARMAX Example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@
"\n",
"numpyro.set_host_device_count(8)\n",
"\n",
"import sys\n",
"\n",
"import arviz as az\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import statsmodels.api as sm\n",
"import pandas as pd\n",
"\n",
"import pymc as pm\n",
"import pytensor.tensor as pt\n",
"import arviz as az\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import sys\n",
"import statsmodels.api as sm\n",
"\n",
"sys.path.append(\"..\")\n",
"import pymc_experimental.statespace as pmss\n",
"import re\n",
"\n",
"import pymc_experimental.statespace as pmss\n",
"\n",
"config = {\n",
" \"figure.figsize\": [12.0, 4.0],\n",
" \"figure.dpi\": 72.0 * 2,\n",
Expand Down Expand Up @@ -679,7 +679,7 @@
" new_labels = []\n",
" for label in axis.yaxis.get_majorticklabels():\n",
" old_text = \"[\" + label.get_text().split(\"[\")[-1]\n",
" labels = eval(re.sub(\"([\\d\\w]+)\", '\"\\g<1>\"', old_text))\n",
" labels = eval(re.sub(r\"([\\d\\w]+)\", r'\"\\g<1>\"', old_text))\n",
" lag, other_var = labels\n",
" new_text = f\"L{lag}.{other_var}\"\n",
" new_labels.append(new_text)\n",
Expand Down
Loading

0 comments on commit 17d8d19

Please sign in to comment.