Skip to content
40 changes: 21 additions & 19 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,26 @@
Var = Any # pylint: disable=invalid-name


def find_observations(model: Optional["Model"]) -> Optional[Dict[str, Var]]:
"""If there are observations available, return them as a dictionary."""
if model is None:
return None

observations = {}
for obs in model.observed_RVs:
aux_obs = getattr(obs.tag, "observations", None)
if aux_obs is not None:
try:
obs_data = extract_obs_data(aux_obs)
observations[obs.name] = obs_data
except TypeError:
warnings.warn(f"Could not extract data from symbolic observation {obs}")
else:
warnings.warn(f"No data for observation {obs}")

return observations


class _DefaultTrace:
"""
Utility for collecting samples into a dictionary.
Expand Down Expand Up @@ -196,25 +216,7 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
self.dims = {**model_dims, **self.dims}

self.density_dist_obs = density_dist_obs
self.observations = self.find_observations()

def find_observations(self) -> Optional[Dict[str, Var]]:
"""If there are observations available, return them as a dictionary."""
if self.model is None:
return None
observations = {}
for obs in self.model.observed_RVs:
aux_obs = getattr(obs.tag, "observations", None)
if aux_obs is not None:
try:
obs_data = extract_obs_data(aux_obs)
observations[obs.name] = obs_data
except TypeError:
warnings.warn(f"Could not extract data from symbolic observation {obs}")
else:
warnings.warn(f"No data for observation {obs}")

return observations
self.observations = find_observations(self.model)

def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
"""Split MultiTrace object into posterior and warmup.
Expand Down
62 changes: 58 additions & 4 deletions pymc/sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from aesara.link.jax.dispatch import jax_funcify

from pymc import Model, modelcontext
from pymc.aesaraf import compile_rv_inplace, inputvars
from pymc.aesaraf import compile_rv_inplace
from pymc.backends.arviz import find_observations
from pymc.distributions import logpt
from pymc.util import get_default_varnames

warnings.warn("This module is experimental.")
Expand Down Expand Up @@ -95,6 +97,39 @@ def logp_fn_wrap(x):
return logp_fn_wrap


# Adopted from arviz numpyro extractor
def _sample_stats_to_xarray(posterior):
"""Extract sample_stats from NumPyro posterior."""
rename_key = {
"potential_energy": "lp",
"adapt_state.step_size": "step_size",
"num_steps": "n_steps",
"accept_prob": "acceptance_rate",
}
data = {}
for stat, value in posterior.get_extra_fields(group_by_chain=True).items():
if isinstance(value, (dict, tuple)):
continue
name = rename_key.get(stat, stat)
value = value.copy()
data[name] = value
if stat == "num_steps":
data["tree_depth"] = np.log2(value).astype(int) + 1
return data


def _get_log_likelihood(model, samples):
"Compute log-likelihood for all observations"
data = {}
for v in model.observed_RVs:
logp_v = replace_shared_variables([logpt(v)])
fgraph = FunctionGraph(model.value_vars, logp_v, clone=False)
jax_fn = jax_funcify(fgraph)
result = jax.vmap(jax.vmap(jax_fn))(*samples)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, would we expect any benefits to jit_compiling this outer vmap?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to use a similar approach with Aesara directly?

Here we only loop over observed variables in order to get the pointwise log likelihood. We had some discussion about this in #4489 but ended up keeping the 3 nested loops over variables, chains and draws.

Copy link
Member

@ricardoV94 ricardoV94 Nov 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible, but requires a Aesara Scan, and at least for small models this was not faster than python looping when I checked it. Here is a Notebook that documents some things I tried: https://gist.github.com/ricardoV94/6089a8c46a0e19665f01c79ea04e1cb2

It might be faster if using shared variables...

Copy link
Contributor Author

@zaxtax zaxtax Nov 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea. I think the easiest thing to do is just benchmark it. I don't even call optimize_graph on either the graph in this function or the main sample routine.

When I run the model in the unit test with the change

result = jax.vmap(jax.vmap(jax_fn))(*samples)[0] to
result = jax.jit(jax.vmap(jax.vmap(jax_fn)))(*samples)[0]

I don't really get a speed-up until there are millions of samples.

Copy link
Member

@ricardoV94 ricardoV94 Nov 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't even call optimize_graph on either the graph in this function or the main sample routine

We should definitely call optimize_graph, otherwise the computed logps may not correspond to the ones used during sampling. For instance we have many optimizations that improve numerically stability, so you might get underflows to -inf for some of the posterior samples (which would never have been accepted by NUTS) which could screw up things downstream.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible, but requires a Aesara Scan, and at least for small models this was not faster than python looping when I checked it.

Then it's probably not worth it. I was under the impression it would be possible to vectorize/broadcast the operation from the conversations in #4489 and in slack.

Copy link
Member

@ricardoV94 ricardoV94 Nov 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It must be possible, since the vmap above works just fine. I just have no idea how they do it xD, or how/if you could do it in Aesara. I also wonder whether the vmap works for more complicated models with multivariate distributions and the like

Copy link
Contributor Author

@zaxtax zaxtax Nov 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright. I'm going to make a separate PR for some of this other stuff.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, feel free to tag me if you want me to review, I am not watching PRs. I can already say I won't be able to help with the vectorized log_likelihood thing, I tried and I lost much more time with that than what would have been healthy. I should be able to help with coords and dims though

data[v.name] = result
return data


def sample_numpyro_nuts(
draws=1000,
tune=1000,
Expand Down Expand Up @@ -151,9 +186,23 @@ def sample_numpyro_nuts(
map_seed = jax.random.split(seed, chains)

if chains == 1:
pmap_numpyro.run(seed, init_params=init_state, extra_fields=("num_steps",))
init_params = init_state
map_seed = seed
else:
pmap_numpyro.run(map_seed, init_params=init_state_batched, extra_fields=("num_steps",))
init_params = init_state_batched

pmap_numpyro.run(
map_seed,
init_params=init_params,
extra_fields=(
"num_steps",
"potential_energy",
"energy",
"adapt_state.step_size",
"accept_prob",
"diverging",
),
)

raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)

Expand All @@ -172,6 +221,11 @@ def sample_numpyro_nuts(
print("Transformation time = ", tic4 - tic3, file=sys.stdout)

posterior = mcmc_samples
az_trace = az.from_dict(posterior=posterior)
az_posterior = az.from_dict(posterior=posterior)

az_obs = az.from_dict(observed_data=find_observations(model))
az_stats = az.from_dict(sample_stats=_sample_stats_to_xarray(pmap_numpyro))
az_ll = az.from_dict(log_likelihood=_get_log_likelihood(model, raw_mcmc_samples))
az_trace = az.concat(az_posterior, az_ll, az_obs, az_stats)

return az_trace
19 changes: 19 additions & 0 deletions pymc/tests/test_sampling_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pymc as pm

from pymc.sampling_jax import (
_get_log_likelihood,
get_jaxified_logp,
replace_shared_variables,
sample_numpyro_nuts,
Expand Down Expand Up @@ -61,6 +62,24 @@ def test_deterministic_samples():
assert np.allclose(trace.posterior["b"].values, trace.posterior["a"].values / 2)


def test_get_log_log_likelihood():
obs = np.random.normal(10, 2, size=100)
obs_at = aesara.shared(obs, borrow=True, name="obs")
with pm.Model() as model:
a = pm.Normal("a", 0, 2)
sigma = pm.HalfNormal("sigma")
b = pm.Normal("b", a, sigma=sigma, observed=obs_at)

trace = pm.sample(tune=10, draws=10, chains=2, random_seed=1322)

b_true = trace.log_likelihood.b.values
a = np.array(trace.posterior.a)
sigma_log_ = np.log(np.array(trace.posterior.sigma))
b_jax = _get_log_likelihood(model, [a, sigma_log_])["b"]

assert np.allclose(b_jax.reshape(-1), b_true.reshape(-1))


def test_replace_shared_variables():
x = aesara.shared(5, name="shared_x")

Expand Down