Skip to content

Commit

Permalink
Fix coords in constant_data issue #5046 (#5062)
Browse files Browse the repository at this point in the history
* Add test case for constant_data coords issue #5046
* Don't replace local coords variable inside iterator

Closes #5046
  • Loading branch information
michaelosthege authored Oct 9, 2021
1 parent a3cc81c commit c06c2f4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ def dict_to_dataset(
for name, vals in data.items():
vals = np.atleast_1d(vals)
val_dims = dims.get(name)
val_dims, coords = generate_dims_coords(vals.shape, name, dims=val_dims, coords=coords)
coords = {key: xr.IndexVariable((key,), data=coords[key]) for key in val_dims}
out_data[name] = xr.DataArray(vals, dims=val_dims, coords=coords)
val_dims, crds = generate_dims_coords(vals.shape, name, dims=val_dims, coords=coords)
crds = {key: xr.IndexVariable((key,), data=crds[key]) for key in val_dims}
out_data[name] = xr.DataArray(vals, dims=val_dims, coords=crds)
return xr.Dataset(data_vars=out_data, attrs=make_attrs(attrs=attrs, library=library))


Expand Down
30 changes: 30 additions & 0 deletions pymc/tests/test_idata_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,36 @@ def test_multivariate_observations(self):
assert "direction" not in idata.log_likelihood.dims
assert "direction" in idata.observed_data.dims

def test_constant_data_coords_issue_5046(self):
"""This is a regression test against a bug where a local coords variable was overwritten."""
dims = {
"alpha": ["backwards"],
"bravo": ["letters", "yesno"],
}
coords = {
"backwards": np.arange(17)[::-1],
"letters": list("ABCDEFGHIJK"),
"yesno": ["yes", "no"],
}
data = {
name: np.random.uniform(size=[len(coords[dn]) for dn in dnames])
for name, dnames in dims.items()
}

for k in data:
assert len(data[k].shape) == len(dims[k])

ds = pm.backends.arviz.dict_to_dataset(
data=data,
library=pm,
coords=coords,
dims=dims,
default_dims=[],
index_origin=0,
)
for dname, cvals in coords.items():
np.testing.assert_array_equal(ds[dname].values, cvals)


class TestPyMCWarmupHandling:
@pytest.mark.parametrize("save_warmup", [False, True])
Expand Down

0 comments on commit c06c2f4

Please sign in to comment.