You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[19], line 15
12 ar_params = pm.Normal("ar_params", sigma=0.25, shape=(3,))
13 sigma_x = pm.Exponential("sigma_x", 1)
---> 15 ar3.build_statespace_graph(data=data, mode="JAX")
File ~/.venv/pymc-ss/lib/python3.12/site-packages/pymc_experimental/statespace/core/statespace.py:860, in PyMCStateSpace.build_statespace_graph(self, data, register_data, mode, missing_fill_value, cov_jitter, save_kalman_filter_outputs_in_idata)
813 """
814 Given a parameter vector `theta`, constructs the full computational graph describing the state space model and
815 the associated log probability of the data. Hidden states and log probabilities are computed via the Kalman
(...)
856 should not be necessary for the majority of users.
857 """
858 pm_mod = modelcontext(None)
--> 860 self._insert_random_variables()
861 self._insert_data_variables()
863 obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None)
File ~/.venv/pymc-ss/lib/python3.12/site-packages/pymc_experimental/statespace/core/statespace.py:680, in PyMCStateSpace._insert_random_variables(self)
677 matrices = list(self._unpack_statespace_with_placeholders())
679 replacement_dict = {var: pymc_model[name] for name, var in self._name_to_variable.items()}
--> 680 self.subbed_ssm = graph_replace(matrices, replace=replacement_dict, strict=True)
File ~/.venv/pymc-ss/lib/python3.12/site-packages/pytensor/graph/replace.py:201, in graph_replace(outputs, replace, strict)
193 raise ValueError(f"{key} is not a part of graph")
195 sorted_replacements = sorted(
196 fg_replace.items(),
197 # sort based on the fg toposort, if a variable has no owner, it goes first
198 key=partial(toposort_key, fg, toposort),
199 reverse=True,
200 )
--> 201 fg.replace_all(sorted_replacements, import_missing=True)
202 if as_list:
203 return list(fg.outputs)
File ~/.venv/pymc-ss/lib/python3.12/site-packages/pytensor/graph/fg.py:519, in FunctionGraph.replace_all(self, pairs, **kwargs)
517 """Replace variables in the `FunctionGraph` according to ``(var, new_var)`` pairs in a list."""
518 for var, new_var in pairs:
--> 519 self.replace(var, new_var, **kwargs)
File ~/.venv/pymc-ss/lib/python3.12/site-packages/pytensor/graph/fg.py:483, in FunctionGraph.replace(self, var, new_var, reason, verbose, import_missing)
478 if verbose:
479 print(
480 f"rewriting: rewrite {reason} replaces {var} of {var.owner} with {new_var} of {new_var.owner}"
481 )
--> 483 new_var = var.type.filter_variable(new_var, allow_convert=True)
485 if var not in self.variables:
486 # TODO: Raise an actual exception here.
487 # Old comment:
(...)
491 # multiple-output ops
492 # raise ValueError()
493 return
File ~/.venv/pymc-ss/lib/python3.12/site-packages/pytensor/tensor/type.py:278, in TensorType.filter_variable(self, other, allow_convert)
275 if other2 is not None:
276 return other2
--> 278 raise TypeError(
279 f"Cannot convert Type {other.type} "
280 f"(of Variable {other}) into Type {self}. "
281 f"You can try to manually convert {other} into a {self}."
282 )
TypeError: Cannot convert Type Scalar(float64, shape=()) (of Variable sigma_x) into Type Vector(float64, shape=(1,)). You can try to manually convert sigma_x into a Vector(float64, shape=(1,)).
This is about an example notebook
Making a Custom Statespace Model.ipynb
. When executing the code cellthe following error occurred:
pymc 5.13.1
pymc-experimental 0.1.0
jax 0.4.26
jaxlib 0.4.26
numpy 1.26.4
numpyro 0.14.0
python 3.12
The text was updated successfully, but these errors were encountered: