diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 4c710f0787..3e2129702d 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -222,10 +222,12 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray: aelem = arbitrary_element(get_from) self.ndraws = aelem.shape[0] - self.coords = {} if coords is None else coords - if hasattr(self.model, "coords"): - self.coords = {**self.model.coords, **self.coords} - self.coords = {key: value for key, value in self.coords.items() if value is not None} + self.coords = {**self.model.coords, **(coords or {})} + self.coords = { + cname: np.array(cvals) if isinstance(cvals, tuple) else cvals + for cname, cvals in self.coords.items() + if cvals is not None + } self.dims = {} if dims is None else dims if hasattr(self.model, "RV_dims"): diff --git a/pymc/model.py b/pymc/model.py index 295cb6c92c..240921e99e 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -871,7 +871,7 @@ def RV_dims(self) -> Dict[str, Tuple[Union[str, None], ...]]: return self._RV_dims @property - def coords(self) -> Dict[str, Union[Sequence, None]]: + def coords(self) -> Dict[str, Union[Tuple, None]]: """Coordinate values for model dimensions.""" return self._coords @@ -1096,8 +1096,12 @@ def add_coord( raise ValueError( f"The `length` passed for the '{name}' coord must be an Aesara Variable or None." ) + if values is not None: + # Conversion to a tuple ensures that the coordinate values are immutable. + # Also unlike numpy arrays the's tuple.index(...) which is handy to work with. + values = tuple(values) if name in self.coords: - if not values.equals(self.coords[name]): + if not np.array_equal(values, self.coords[name]): raise ValueError(f"Duplicate and incompatible coordinate: {name}.") else: self._coords[name] = values diff --git a/pymc/tests/test_data_container.py b/pymc/tests/test_data_container.py index 4e79b5a48b..e6506a1c9d 100644 --- a/pymc/tests/test_data_container.py +++ b/pymc/tests/test_data_container.py @@ -287,12 +287,12 @@ def test_explicit_coords(self): pm.Data("observations", data, dims=("rows", "columns")) assert "rows" in pmodel.coords - assert pmodel.coords["rows"] == ["R1", "R2", "R3", "R4", "R5"] + assert pmodel.coords["rows"] == ("R1", "R2", "R3", "R4", "R5") assert "rows" in pmodel.dim_lengths assert isinstance(pmodel.dim_lengths["rows"], ScalarSharedVariable) assert pmodel.dim_lengths["rows"].eval() == 5 assert "columns" in pmodel.coords - assert pmodel.coords["columns"] == ["C1", "C2", "C3", "C4", "C5", "C6", "C7"] + assert pmodel.coords["columns"] == ("C1", "C2", "C3", "C4", "C5", "C6", "C7") assert pmodel.RV_dims == {"observations": ("rows", "columns")} assert "columns" in pmodel.dim_lengths assert isinstance(pmodel.dim_lengths["columns"], ScalarSharedVariable) diff --git a/pymc/tests/test_idata_conversion.py b/pymc/tests/test_idata_conversion.py index 06c520ec19..465169ae13 100644 --- a/pymc/tests/test_idata_conversion.py +++ b/pymc/tests/test_idata_conversion.py @@ -12,7 +12,11 @@ import pymc as pm -from pymc.backends.arviz import predictions_to_inference_data, to_inference_data +from pymc.backends.arviz import ( + InferenceDataConverter, + predictions_to_inference_data, + to_inference_data, +) @pytest.fixture(scope="module") @@ -598,6 +602,45 @@ def test_constant_data_coords_issue_5046(self): for dname, cvals in coords.items(): np.testing.assert_array_equal(ds[dname].values, cvals) + def test_issue_5043_autoconvert_coord_values(self): + coords = { + "city": pd.Series(["Bonn", "Berlin"]), + } + with pm.Model(coords=coords) as pmodel: + # The model tracks coord values as (immutable) tuples + assert isinstance(pmodel.coords["city"], tuple) + pm.Normal("x", dims="city") + mtrace = pm.sample( + return_inferencedata=False, + compute_convergence_checks=False, + step=pm.Metropolis(), + cores=1, + tune=7, + draws=15, + ) + # The converter must convert coord values them to numpy arrays + # because tuples as coordinate values causes problems with xarray. + converter = InferenceDataConverter(trace=mtrace) + assert isinstance(converter.coords["city"], np.ndarray) + converter.to_inference_data() + + # We're not automatically converting things other than tuple, + # so advanced use cases remain supported at the InferenceData level. + # They just can't be used in the model construction already. + converter = InferenceDataConverter( + trace=mtrace, + coords={ + "city": pd.MultiIndex.from_tuples( + [ + ("Bonn", 53111), + ("Berlin", 10178), + ], + names=["name", "zipcode"], + ) + }, + ) + assert isinstance(converter.coords["city"], pd.MultiIndex) + class TestPyMCWarmupHandling: @pytest.mark.parametrize("save_warmup", [False, True])