-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Let plot_posterior_predictive_glm work with inferencedata too #4234
Let plot_posterior_predictive_glm work with inferencedata too #4234
Conversation
Codecov Report
@@ Coverage Diff @@
## master #4234 +/- ##
==========================================
+ Coverage 87.85% 87.95% +0.09%
==========================================
Files 88 88
Lines 14495 14499 +4
==========================================
+ Hits 12734 12752 +18
+ Misses 1761 1747 -14
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @MarcoGorelli, this is a nice start! I think the plotting functions can actually be refactored into one unique function -- see my comments below and feel free to ask if anything is unclear 😉
pymc3/plots/posteriorplot.py
Outdated
def _plot_multitrace(trace, eval, lm, samples, kwargs): | ||
for rand_loc in np.random.randint(0, len(trace), samples): | ||
rand_sample = trace[rand_loc] | ||
plt.plot(eval, lm(eval, rand_sample), **kwargs) | ||
# Make sure to not plot label multiple times | ||
kwargs.pop("label", None) | ||
|
||
plt.title("Posterior predictive") | ||
|
||
def _plot_inferencedata(trace, eval, lm, samples, kwargs): | ||
trace_df = trace.posterior.to_dataframe() | ||
for rand_loc in np.random.randint(0, len(trace_df), samples): | ||
rand_sample = trace_df.iloc[rand_loc] | ||
plt.plot(eval, lm(eval, rand_sample), **kwargs) | ||
# Make sure to not plot label multiple times | ||
kwargs.pop("label", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two functions have a lot of duplicated lines; I think they can be merged into one by checking if isinstance(trace, MultiTrace)
at the beginning of the function (or just before) and casting the InferenceData
to_array
(I think this is the name of the function but you can check on ArviZ website) instead of to a dataframe.
After that, the handling should be the same as you're dealing with numpy arrays in both cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we cast to array
then I think we wouldn't be able to access the different parameters (e.g. 'Intercept'
or 'x'
), which appear in lm
. Each element here is a dict in the multitrace case:
> /home/mgorelli/pymc3-dev/pymc3/plots/posteriorplot.py(61)_plot_multitrace()
-> plt.plot(eval, lm(eval, rand_sample), **kwargs)
(Pdb) type(rand_sample)
<class 'dict'>
(Pdb) rand_sample
{'x': 1.0, 'Intercept': 1.0}
at this point, the only lines they have in common are
plt.plot(eval, lm(eval, rand_sample), **kwargs)
# Make sure to not plot label multiple times
kwargs.pop("label", None)
- the others are slightly different. My reason for making two separate helper functions is that I thought it'd be more readable than a single function with many if/then statements - I'll go with whatever you think is best though 😇
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah right, I forgot the whole trace was given here, and not only trace["y"]
for instance. But then, wouldn't trace.posterior.to_dataframe().to_dict()
get the format we want? That way we'd need only one plotting function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure - I think this'd be slightly more expensive, but arguably it's worth it for the sake of much simpler code
…cogorelli/pymc3 into extend-plot_posterior_predictive_glm
…rior_predictive_glm
@@ -12,18 +12,28 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
|
|||
try: | |||
import matplotlib.pyplot as plt | |||
except ImportError: # mpl is optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think mpl is optional anymore - arviz is a required dependency, and mpl is a required dependency of arviz
Without going into the details myself, can you consider to mark the non- |
Not currently, no Is |
I don't think we can our should drop the But I think we should switch the default. |
Sure, but then why should it be marked as deprecated in I may have misunderstood - could you clarify what exactly should throw a deprecation warning? |
Oh, you're right. Unless GLMs work fine with |
Here's the issue: currently, this works
but this doesn't
This PR would let both of the above work. Unless |
…rior_predictive_glm
…coGorelli/pymc3 into extend-plot_posterior_predictive_glm
Is this ready for review @MarcoGorelli ? |
sure! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All good now, thanks for sticking to it @MarcoGorelli !
xref #4215 (comment)
I haven't (yet) made sense of Bambi (though it does look like an awesome project), so for now here's a little PR to slightly extend
plot_posterior_predictive_glm
's functionality - then, the glm notebooks can be re-run withreturn_inferencedata=True
BTW, what's the rule with the copyright header? I see it in some files, but not all