-
Notifications
You must be signed in to change notification settings - Fork 301
Update some examples to use new arviz #822
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
aloctavodia
wants to merge
10
commits into
main
Choose a base branch
from
arviz
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
0d2a836
update to use new arviz
aloctavodia c70895c
add more examples
aloctavodia 78aaac0
update BEST example
aloctavodia 12d1990
update per comments
aloctavodia e5abefa
update BF example
aloctavodia 576575e
update model averaging
aloctavodia 8936813
update model averaging
aloctavodia 27a09cb
update and modify example
aloctavodia 3fccc74
update data container example
aloctavodia 4b3eb8c
update data container example
aloctavodia File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
aloctavodia marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,166 @@ | ||
| --- | ||
| jupyter: | ||
| jupytext: | ||
| formats: ipynb,md | ||
| text_representation: | ||
| extension: .md | ||
| format_name: markdown | ||
| format_version: '1.3' | ||
| jupytext_version: 1.16.1 | ||
| kernelspec: | ||
| display_name: Python 3 (ipykernel) | ||
| language: python | ||
| name: python3 | ||
| --- | ||
|
|
||
| (bart_heteroscedasticity)= | ||
| # Modeling Heteroscedasticity with BART | ||
|
|
||
| :::{post} January, 2023 | ||
| :tags: BART, regression | ||
| :category: beginner, reference | ||
| :author: Juan Orduz | ||
| ::: | ||
|
|
||
|
|
||
| In this notebook we show how to use BART to model heteroscedasticity as described in Section 4.1 of [`pymc-bart`](https://github.com/pymc-devs/pymc-bart)'s paper {cite:p}`quiroga2022bart`. We use the `marketing` data set provided by the R package `datarium` {cite:p}`kassambara2019datarium`. The idea is to model a marketing channel contribution to sales as a function of budget. | ||
|
|
||
| ```python | ||
| import os | ||
|
|
||
| import arviz as az | ||
| import matplotlib.pyplot as plt | ||
| import numpy as np | ||
| import pandas as pd | ||
| import pymc as pm | ||
| import pymc_bart as pmb | ||
| ``` | ||
|
|
||
| ```python | ||
| %config InlineBackend.figure_format = "retina" | ||
| az.style.use("arviz-darkgrid") | ||
| plt.rcParams["figure.figsize"] = [10, 6] | ||
| rng = np.random.default_rng(42) | ||
| ``` | ||
|
|
||
| ## Read Data | ||
|
|
||
| ```python | ||
| try: | ||
| df = pd.read_csv(os.path.join("..", "data", "marketing.csv"), sep=";", decimal=",") | ||
| except FileNotFoundError: | ||
| df = pd.read_csv(pm.get_data("marketing.csv"), sep=";", decimal=",") | ||
|
|
||
| n_obs = df.shape[0] | ||
|
|
||
| df.head() | ||
| ``` | ||
|
|
||
| ## EDA | ||
|
|
||
| We start by looking into the data. We are going to focus on *Youtube*. | ||
|
|
||
| ```python | ||
| fig, ax = plt.subplots() | ||
| ax.plot(df["youtube"], df["sales"], "o", c="C0") | ||
| ax.set(title="Sales as a function of Youtube budget", xlabel="budget", ylabel="sales"); | ||
| ``` | ||
|
|
||
| We clearly see that both the mean and variance are increasing as a function of budget. One possibility is to manually select an explicit parametrization of these functions, e.g. square root or logarithm. However, in this example we want to learn these functions from the data using a BART model. | ||
|
|
||
|
|
||
| ## Model Specification | ||
|
|
||
| We proceed to prepare the data for modeling. We are going to use the `budget` as the predictor and `sales` as the response. | ||
|
|
||
| ```python | ||
| X = df["youtube"].to_numpy().reshape(-1, 1) | ||
| Y = df["sales"].to_numpy() | ||
| ``` | ||
|
|
||
| Next, we specify the model. Note that we just need one BART distribution which can be vectorized to model both the mean and variance. We use a Gamma distribution as likelihood as we expect the sales to be positive. | ||
|
|
||
| ```python | ||
| with pm.Model() as model_marketing_full: | ||
| w = pmb.BART("w", X=X, Y=np.log(Y), m=100, shape=(2, n_obs)) | ||
| y = pm.Gamma("y", mu=pm.math.exp(w[0]), sigma=pm.math.exp(w[1]), observed=Y) | ||
|
|
||
| pm.model_to_graphviz(model=model_marketing_full) | ||
| ``` | ||
|
|
||
| We now fit the model. | ||
|
|
||
| ```python | ||
| with model_marketing_full: | ||
| idata_marketing_full = pm.sample(2000, random_seed=rng, compute_convergence_checks=False) | ||
| posterior_predictive_marketing_full = pm.sample_posterior_predictive( | ||
| trace=idata_marketing_full, random_seed=rng | ||
| ) | ||
| ``` | ||
|
|
||
| ## Results | ||
|
|
||
| We can now visualize the posterior predictive distribution of the mean and the likelihood. | ||
|
|
||
| ```python | ||
| posterior_mean = idata_marketing_full.posterior["w"].mean(dim=("chain", "draw"))[0] | ||
|
|
||
| w_hdi = az.hdi(ary=idata_marketing_full, group="posterior", var_names=["w"], hdi_prob=0.5) | ||
|
|
||
| pps = az.extract( | ||
| posterior_predictive_marketing_full, group="posterior_predictive", var_names=["y"] | ||
| ).T | ||
| ``` | ||
|
|
||
| ```python | ||
| idx = np.argsort(X[:, 0]) | ||
|
|
||
|
|
||
| fig, ax = plt.subplots() | ||
| az.plot_hdi( | ||
| x=X[:, 0], | ||
| y=pps, | ||
| ax=ax, | ||
| hdi_prob=0.90, | ||
| fill_kwargs={"alpha": 0.3, "label": r"Observations $90\%$ HDI"}, | ||
| ) | ||
| az.plot_hdi( | ||
| x=X[:, 0], | ||
| hdi_data=np.exp(w_hdi["w"].sel(w_dim_0=0)), | ||
| ax=ax, | ||
| fill_kwargs={"alpha": 0.6, "label": r"Mean $50\%$ HDI"}, | ||
| ) | ||
| ax.plot(df["youtube"], df["sales"], "o", c="C0", label="Raw Data") | ||
| ax.legend(loc="upper left") | ||
| ax.set( | ||
| title="Sales as a function of Youtube budget - Posterior Predictive", | ||
| xlabel="budget", | ||
| ylabel="sales", | ||
| ); | ||
| ``` | ||
|
|
||
| The fit looks good! In fact, we see that the mean and variance increase as a function of the budget. | ||
|
|
||
|
|
||
| ## Authors | ||
| - Authored by [Juan Orduz](https://juanitorduz.github.io/) in Feb, 2023 | ||
| - Rerun by Osvaldo Martin in Mar, 2023 | ||
| - Rerun by Osvaldo Martin in Nov, 2023 | ||
| - Rerun by Osvaldo Martin in Dec, 2024 | ||
|
|
||
|
|
||
| ## References | ||
| :::{bibliography} | ||
| :filter: docname in docnames | ||
| ::: | ||
|
|
||
|
|
||
| ## Watermark | ||
|
|
||
| ```python | ||
| %load_ext watermark | ||
| %watermark -n -u -v -iv -w -p pytensor | ||
| ``` | ||
|
|
||
| :::{include} ../page_footer.md | ||
| ::: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My main comment is probably not actionable but might be worth to keep in mind for arviz development. I am not sure I can see this as it is a comparison between this figure and one a few sections higher. Does
plot_ppc_pavasupport dict input to show both models, if not, do you think that is useful to add?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the only ones supporting more than one model are plot_forest and plot_dist. Not sure how useful, it could be ok for
plot_ppc_pavaand maybeplot_ppc_pit. Maybe something more general is to have something similar tocombine_plotsbut that works for different data. So we can create a single figure with one model per row or column, so comparisons are easier to see and present.