Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in Making a Custom Statespace Model.ipynb #335

Open
spring-haru opened this issue Apr 17, 2024 · 0 comments
Open

Error in Making a Custom Statespace Model.ipynb #335

spring-haru opened this issue Apr 17, 2024 · 0 comments

Comments

@spring-haru
Copy link

This is about an example notebook Making a Custom Statespace Model.ipynb. When executing the code cell

ar3 = AutoRegressiveThree()
data = np.full((100, 1), np.nan)
with pm.Model() as pymc_mod:
    x0 = pm.Deterministic(
        "x0",
        pt.zeros(
            3,
        ),
    )
    P0 = pm.Deterministic("P0", pt.eye(3) * 10)

    ar_params = pm.Normal("ar_params", sigma=0.25, shape=(3,))
    sigma_x = pm.Exponential("sigma_x", 1)

    ar3.build_statespace_graph(data=data, mode="JAX")

the following error occurred:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[19], line 15
     12 ar_params = pm.Normal("ar_params", sigma=0.25, shape=(3,))
     13 sigma_x = pm.Exponential("sigma_x", 1)
---> 15 ar3.build_statespace_graph(data=data, mode="JAX")

File ~/.venv/pymc-ss/lib/python3.12/site-packages/pymc_experimental/statespace/core/statespace.py:860, in PyMCStateSpace.build_statespace_graph(self, data, register_data, mode, missing_fill_value, cov_jitter, save_kalman_filter_outputs_in_idata)
    813 """
    814 Given a parameter vector `theta`, constructs the full computational graph describing the state space model and
    815 the associated log probability of the data. Hidden states and log probabilities are computed via the Kalman
   (...)
    856     should not be necessary for the majority of users.
    857 """
    858 pm_mod = modelcontext(None)
--> 860 self._insert_random_variables()
    861 self._insert_data_variables()
    863 obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None)

File ~/.venv/pymc-ss/lib/python3.12/site-packages/pymc_experimental/statespace/core/statespace.py:680, in PyMCStateSpace._insert_random_variables(self)
    677 matrices = list(self._unpack_statespace_with_placeholders())
    679 replacement_dict = {var: pymc_model[name] for name, var in self._name_to_variable.items()}
--> 680 self.subbed_ssm = graph_replace(matrices, replace=replacement_dict, strict=True)

File ~/.venv/pymc-ss/lib/python3.12/site-packages/pytensor/graph/replace.py:201, in graph_replace(outputs, replace, strict)
    193             raise ValueError(f"{key} is not a part of graph")
    195 sorted_replacements = sorted(
    196     fg_replace.items(),
    197     # sort based on the fg toposort, if a variable has no owner, it goes first
    198     key=partial(toposort_key, fg, toposort),
    199     reverse=True,
    200 )
--> 201 fg.replace_all(sorted_replacements, import_missing=True)
    202 if as_list:
    203     return list(fg.outputs)

File ~/.venv/pymc-ss/lib/python3.12/site-packages/pytensor/graph/fg.py:519, in FunctionGraph.replace_all(self, pairs, **kwargs)
    517 """Replace variables in the `FunctionGraph` according to ``(var, new_var)`` pairs in a list."""
    518 for var, new_var in pairs:
--> 519     self.replace(var, new_var, **kwargs)

File ~/.venv/pymc-ss/lib/python3.12/site-packages/pytensor/graph/fg.py:483, in FunctionGraph.replace(self, var, new_var, reason, verbose, import_missing)
    478 if verbose:
    479     print(
    480         f"rewriting: rewrite {reason} replaces {var} of {var.owner} with {new_var} of {new_var.owner}"
    481     )
--> 483 new_var = var.type.filter_variable(new_var, allow_convert=True)
    485 if var not in self.variables:
    486     # TODO: Raise an actual exception here.
    487     # Old comment:
   (...)
    491     # multiple-output ops
    492     # raise ValueError()
    493     return

File ~/.venv/pymc-ss/lib/python3.12/site-packages/pytensor/tensor/type.py:278, in TensorType.filter_variable(self, other, allow_convert)
    275     if other2 is not None:
    276         return other2
--> 278 raise TypeError(
    279     f"Cannot convert Type {other.type} "
    280     f"(of Variable {other}) into Type {self}. "
    281     f"You can try to manually convert {other} into a {self}."
    282 )

TypeError: Cannot convert Type Scalar(float64, shape=()) (of Variable sigma_x) into Type Vector(float64, shape=(1,)). You can try to manually convert sigma_x into a Vector(float64, shape=(1,)).

pymc 5.13.1
pymc-experimental 0.1.0
jax 0.4.26
jaxlib 0.4.26
numpy 1.26.4
numpyro 0.14.0
python 3.12

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant