Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let plot_posterior_predictive_glm work with inferencedata too #4234

Merged

Conversation

MarcoGorelli
Copy link
Contributor

@MarcoGorelli MarcoGorelli commented Nov 19, 2020

xref #4215 (comment)

BTW, extending it to work with InferenceData and time series data would be awesome 🤩 So the best may be to add it to Bambi instead

I haven't (yet) made sense of Bambi (though it does look like an awesome project), so for now here's a little PR to slightly extend plot_posterior_predictive_glm's functionality - then, the glm notebooks can be re-run with return_inferencedata=True


BTW, what's the rule with the copyright header? I see it in some files, but not all

RELEASE-NOTES.md Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Nov 20, 2020

Codecov Report

Merging #4234 (2a38f67) into master (b7b145d) will increase coverage by 0.09%.
The diff coverage is 88.88%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #4234      +/-   ##
==========================================
+ Coverage   87.85%   87.95%   +0.09%     
==========================================
  Files          88       88              
  Lines       14495    14499       +4     
==========================================
+ Hits        12734    12752      +18     
+ Misses       1761     1747      -14     
Impacted Files Coverage Δ
pymc3/plots/posteriorplot.py 95.65% <88.88%> (+74.59%) ⬆️

Copy link
Contributor

@AlexAndorra AlexAndorra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @MarcoGorelli, this is a nice start! I think the plotting functions can actually be refactored into one unique function -- see my comments below and feel free to ask if anything is unclear 😉

Comment on lines 57 to 71
def _plot_multitrace(trace, eval, lm, samples, kwargs):
for rand_loc in np.random.randint(0, len(trace), samples):
rand_sample = trace[rand_loc]
plt.plot(eval, lm(eval, rand_sample), **kwargs)
# Make sure to not plot label multiple times
kwargs.pop("label", None)

plt.title("Posterior predictive")

def _plot_inferencedata(trace, eval, lm, samples, kwargs):
trace_df = trace.posterior.to_dataframe()
for rand_loc in np.random.randint(0, len(trace_df), samples):
rand_sample = trace_df.iloc[rand_loc]
plt.plot(eval, lm(eval, rand_sample), **kwargs)
# Make sure to not plot label multiple times
kwargs.pop("label", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two functions have a lot of duplicated lines; I think they can be merged into one by checking if isinstance(trace, MultiTrace) at the beginning of the function (or just before) and casting the InferenceData to_array (I think this is the name of the function but you can check on ArviZ website) instead of to a dataframe.
After that, the handling should be the same as you're dealing with numpy arrays in both cases

Copy link
Contributor Author

@MarcoGorelli MarcoGorelli Nov 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we cast to array then I think we wouldn't be able to access the different parameters (e.g. 'Intercept' or 'x'), which appear in lm. Each element here is a dict in the multitrace case:

> /home/mgorelli/pymc3-dev/pymc3/plots/posteriorplot.py(61)_plot_multitrace()
-> plt.plot(eval, lm(eval, rand_sample), **kwargs)
(Pdb) type(rand_sample)
<class 'dict'>
(Pdb) rand_sample
{'x': 1.0, 'Intercept': 1.0}

at this point, the only lines they have in common are

        plt.plot(eval, lm(eval, rand_sample), **kwargs)
        # Make sure to not plot label multiple times
        kwargs.pop("label", None)
  • the others are slightly different. My reason for making two separate helper functions is that I thought it'd be more readable than a single function with many if/then statements - I'll go with whatever you think is best though 😇

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right, I forgot the whole trace was given here, and not only trace["y"] for instance. But then, wouldn't trace.posterior.to_dataframe().to_dict() get the format we want? That way we'd need only one plotting function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure - I think this'd be slightly more expensive, but arguably it's worth it for the sake of much simpler code

pymc3/tests/test_plots.py Show resolved Hide resolved
pymc3/tests/test_plots.py Show resolved Hide resolved
@@ -12,18 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.

try:
import matplotlib.pyplot as plt
except ImportError: # mpl is optional
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think mpl is optional anymore - arviz is a required dependency, and mpl is a required dependency of arviz

@michaelosthege
Copy link
Member

Without going into the details myself, can you consider to mark the non-InferenceData based API as deprecated?
GLMs should work just fine with pm.sample(return_inferencedata=True), right?

@michaelosthege michaelosthege added this to the vNext (3.11.0) milestone Dec 15, 2020
@MarcoGorelli
Copy link
Contributor Author

GLMs should work just fine with pm.sample(return_inferencedata=True), right?

Not currently, no

Is pm.sample no longer going to return MultiTrace objects at all, or is that just no longer going to be the default?

@michaelosthege
Copy link
Member

GLMs should work just fine with pm.sample(return_inferencedata=True), right?

Not currently, no

Is pm.sample no longer going to return MultiTrace objects at all, or is that just no longer going to be the default?

I don't think we can our should drop the MultiTrace option as long as it's still used internally.
If at some point we implement an xarray-backend, we should do it.

But I think we should switch the default.

@MarcoGorelli
Copy link
Contributor Author

Sure, but then why should it be marked as deprecated in plot_posterior_predictive_glm if it's not going to be removed in the future?

I may have misunderstood - could you clarify what exactly should throw a deprecation warning?

@michaelosthege
Copy link
Member

Sure, but then why should it be marked as deprecated in plot_posterior_predictive_glm if it's not going to be removed in the future?

I may have misunderstood - could you clarify what exactly should throw a deprecation warning?

Oh, you're right. Unless GLMs work fine with InferenceData, we can't take away MultiTrace support from the plotting.

@MarcoGorelli
Copy link
Contributor Author

Here's the issue: currently, this works

with glm_model:
    trace = pm.sample()
pm.plot_posterior_predictive_glm(trace)

but this doesn't

with glm_model:
    trace = pm.sample(return_inferencedata=True)
pm.plot_posterior_predictive_glm(trace)

This PR would let both of the above work. Unless pm.sample(return_inferencedata=False) will be deprecated, I don't think we need to deprecate plot_posterior_predictive_glm taking a MultiTrace

@AlexAndorra
Copy link
Contributor

Is this ready for review @MarcoGorelli ?

@MarcoGorelli
Copy link
Contributor Author

sure!

Copy link
Contributor

@AlexAndorra AlexAndorra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All good now, thanks for sticking to it @MarcoGorelli !

@AlexAndorra AlexAndorra merged commit 4e5edd5 into pymc-devs:master Dec 17, 2020
@MarcoGorelli MarcoGorelli deleted the extend-plot_posterior_predictive_glm branch December 17, 2020 15:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants