-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sample_posterior_predictive flattens chains and draws #4004
Comments
I believe that there are already options to do this. Have you tried using the |
In my test case, I had 2 chains and 500 draws per chain. With keep_size=False, the output of PPC was I'll look at |
Here is my minimal reproducing example:
|
Here are the relevant lines in https://github.com/pymc-devs/pymc3/blob/master/pymc3/sampling.py#L1598-L1605 The chain information is not retrieved in the dataset case, only if a trace is passed, thus |
@kyleabeauchamp Thanks for checking this: it looks like I can shove these into the tests, so that we can fix and make sure that |
For the case of an arviz trace, fast_sample_posterior_predictive doesn't run at all due to some strong type checking:
For the case of a pymc3 trace, fast_sample_posterior_predictive does seem to respect the (chain, draw) shape:
|
Probably we didn't get the handling of shape right when we added I would imagine that if we take an |
@michaelosthege while I was debugging this, I was surprised to see that the posterior predictive sampling functions now accept an xarray dataset, but not an |
I also experienced this confusion---in the pymc3 codepath you pass the output of |
@kyleabeauchamp I'm going to address this in my fix for this issue; I agree with you -- it seems odd to require that extra step if using |
Just debugging a solution to this now... |
I have a WIP solution -- waiting to see if it passes CI. |
I noticed that the results of
pm.sample_posterior_predictive
are flattened over (chain, draw) dimensions. AFAIK, this is a lossy transform because the results are numpy objects that don't track the source. I wonder if there should be an option to preserve the (chain, draw) shape and output the results using an arviz InferenceData object. See also arviz-devs/arviz#1282The text was updated successfully, but these errors were encountered: