|
| 1 | +"""Public API for Metropolis Adjusted Langevin kernels.""" |
| 2 | +from typing import Callable, NamedTuple, Tuple |
| 3 | + |
| 4 | +import jax |
| 5 | +import jax.numpy as jnp |
| 6 | + |
| 7 | +from blackjax.mcmc.diffusion import overdamped_langevin |
| 8 | +from blackjax.types import PRNGKey, PyTree |
| 9 | + |
| 10 | +__all__ = ["MALAState", "MALAInfo", "init", "kernel"] |
| 11 | + |
| 12 | + |
| 13 | +class MALAState(NamedTuple): |
| 14 | + """State of the MALA algorithm. |
| 15 | +
|
| 16 | + The MALA algorithm takes one position of the chain and returns another |
| 17 | + position. In order to make computations more efficient, we also store |
| 18 | + the current log-probability density as well as the current gradient of the |
| 19 | + log-probability density. |
| 20 | +
|
| 21 | + """ |
| 22 | + |
| 23 | + position: PyTree |
| 24 | + logprob_grad: PyTree |
| 25 | + |
| 26 | + |
| 27 | +class MALAInfo(NamedTuple): |
| 28 | + """Additional information on the MALA transition. |
| 29 | +
|
| 30 | + This additional information can be used for debugging or computing |
| 31 | + diagnostics. |
| 32 | +
|
| 33 | + acceptance_probability |
| 34 | + The acceptance probability of the transition. |
| 35 | + is_accepted |
| 36 | + Whether the proposed position was accepted or the original position |
| 37 | + was returned. |
| 38 | +
|
| 39 | + """ |
| 40 | + |
| 41 | + acceptance_probability: float |
| 42 | + is_accepted: bool |
| 43 | + |
| 44 | + |
| 45 | +def init(position: PyTree, logprob_fn: Callable) -> MALAState: |
| 46 | + grad_fn = jax.grad(logprob_fn) |
| 47 | + logprob_grad = grad_fn(position) |
| 48 | + return MALAState(position, logprob_grad) |
| 49 | + |
| 50 | + |
| 51 | +def kernel(): |
| 52 | + """Build a MALA kernel. |
| 53 | +
|
| 54 | + Returns |
| 55 | + ------- |
| 56 | + A kernel that takes a rng_key and a Pytree that contains the current state |
| 57 | + of the chain and that returns a new state of the chain along with |
| 58 | + information about the transition. |
| 59 | +
|
| 60 | + """ |
| 61 | + |
| 62 | + def transition_probability(state, new_state, step_size): |
| 63 | + """Transition probability to go from `state` to `new_state`""" |
| 64 | + theta = new_state.position - state.position - step_size * state.logprob_grad |
| 65 | + return -0.25 * (1.0 / step_size) * jnp.dot(theta, theta) |
| 66 | + |
| 67 | + def one_step( |
| 68 | + rng_key: PRNGKey, state: MALAState, logprob_fn: Callable, step_size: float |
| 69 | + ) -> Tuple[MALAState, MALAInfo]: |
| 70 | + """Generate a new sample with the MALA kernel. |
| 71 | +
|
| 72 | + TODO expand the docstring. |
| 73 | +
|
| 74 | + """ |
| 75 | + grad_fn = jax.grad(logprob_fn) |
| 76 | + integrator = overdamped_langevin(grad_fn) |
| 77 | + |
| 78 | + key_integrator, key_rmh = jax.random.split(rng_key) |
| 79 | + |
| 80 | + new_state = integrator(key_integrator, state, step_size) |
| 81 | + |
| 82 | + delta = ( |
| 83 | + logprob_fn(new_state.position) |
| 84 | + - logprob_fn(state.position) |
| 85 | + + transition_probability(new_state, state, step_size) |
| 86 | + - transition_probability(state, new_state, step_size) |
| 87 | + ) |
| 88 | + delta = jnp.where(jnp.isnan(delta), -jnp.inf, delta) |
| 89 | + p_accept = jnp.clip(jnp.exp(delta), a_max=1) |
| 90 | + |
| 91 | + do_accept = jax.random.bernoulli(key_rmh, p_accept) |
| 92 | + |
| 93 | + new_state = MALAState(*new_state) |
| 94 | + info = MALAInfo(p_accept, do_accept) |
| 95 | + |
| 96 | + return jax.lax.cond( |
| 97 | + do_accept, |
| 98 | + lambda _: (new_state, info), |
| 99 | + lambda _: (state, info), |
| 100 | + operand=None, |
| 101 | + ) |
| 102 | + |
| 103 | + return one_step |
0 commit comments