Skip to content

Commit bd7d149

Browse files
jeremiecoullonrlouf
authored andcommitted
Add the MALA sampling algorithm
1 parent f814c65 commit bd7d149

File tree

7 files changed

+231
-2
lines changed

7 files changed

+231
-2
lines changed

blackjax/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .kernels import (
44
adaptive_tempered_smc,
55
hmc,
6+
mala,
67
nuts,
78
rmh,
89
tempered_smc,
@@ -13,6 +14,7 @@
1314

1415
__all__ = [
1516
"hmc", # mcmc
17+
"mala",
1618
"nuts",
1719
"rmh",
1820
"window_adaptation", # mcmc adaptation

blackjax/kernels.py

+72
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
__all__ = [
1212
"adaptive_tempered_smc",
1313
"hmc",
14+
"mala",
1415
"nuts",
1516
"rmh",
1617
"tempered_smc",
@@ -222,6 +223,77 @@ def step_fn(rng_key: PRNGKey, state):
222223
return SamplingAlgorithm(init_fn, step_fn)
223224

224225

226+
class mala:
227+
"""Implements the (basic) user interface for the MALA kernel.
228+
229+
The general mala kernel (:meth:`blackjax.mcmc.mala.kernel`, alias `blackjax.mala.kernel`) can be
230+
cumbersome to manipulate. Since most users only need to specify the kernel
231+
parameters at initialization time, we provide a helper function that
232+
specializes the general kernel.
233+
234+
We also add the general kernel and state generator as an attribute to this class so
235+
users only need to pass `blackjax.mala` to SMC, adaptation, etc. algorithms.
236+
237+
Examples
238+
--------
239+
240+
A new MALA kernel can be initialized and used with the following code:
241+
242+
.. code::
243+
244+
mala = blackjax.mala(logprob_fn, step_size)
245+
state = mala.init(position)
246+
new_state, info = mala.step(rng_key, state)
247+
248+
Kernels are not jit-compiled by default so you will need to do it manually:
249+
250+
.. code::
251+
252+
step = jax.jit(mala.step)
253+
new_state, info = step(rng_key, state)
254+
255+
Should you need to you can always use the base kernel directly:
256+
257+
.. code::
258+
259+
kernel = blackjax.mala.kernel(logprob_fn)
260+
state = blackjax.mala.init(position, logprob_fn)
261+
state, info = kernel(rng_key, state, logprob_fn, step_size)
262+
263+
Parameters
264+
----------
265+
logprob_fn
266+
The logprobability density function we wish to draw samples from. This
267+
is minus the potential function.
268+
step_size
269+
The value to use for the step size in the symplectic integrator.
270+
271+
Returns
272+
-------
273+
A ``SamplingAlgorithm``.
274+
275+
"""
276+
277+
init = staticmethod(mcmc.mala.init)
278+
kernel = staticmethod(mcmc.mala.kernel)
279+
280+
def __new__( # type: ignore[misc]
281+
cls,
282+
logprob_fn: Callable,
283+
step_size: float,
284+
) -> SamplingAlgorithm:
285+
286+
step = cls.kernel()
287+
288+
def init_fn(position: PyTree):
289+
return cls.init(position, logprob_fn)
290+
291+
def step_fn(rng_key: PRNGKey, state):
292+
return step(rng_key, state, logprob_fn, step_size)
293+
294+
return SamplingAlgorithm(init_fn, step_fn)
295+
296+
225297
class nuts:
226298
"""Implements the (basic) user interface for the nuts kernel.
227299

blackjax/mcmc/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from . import hmc, nuts, rmh
1+
from . import hmc, mala, nuts, rmh
22

3-
__all__ = ["hmc", "nuts", "rmh"]
3+
__all__ = ["hmc", "mala", "nuts", "rmh"]

blackjax/mcmc/diffusion.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Solvers for Langevin diffusions."""
2+
from typing import NamedTuple
3+
4+
import jax
5+
import jax.numpy as jnp
6+
7+
from blackjax.types import PRNGKey, PyTree
8+
9+
__all__ = ["overdamped_langevin"]
10+
11+
12+
class DiffusionState(NamedTuple):
13+
position: PyTree
14+
logprob_grad: PyTree
15+
16+
17+
def generate_gaussian_noise(rng_key: PRNGKey, position):
18+
position_flat, unravel_fn = jax.flatten_util.ravel_pytree(position)
19+
noise_flat = jax.random.normal(rng_key, shape=jnp.shape(position_flat))
20+
return unravel_fn(noise_flat)
21+
22+
23+
def overdamped_langevin(logprob_grad_fn):
24+
"""Euler solver for overdamped Langevin diffusion."""
25+
26+
def one_step(rng_key, state: DiffusionState, step_size: float, batch: tuple = ()):
27+
position, logprob_grad = state
28+
noise = generate_gaussian_noise(rng_key, position)
29+
position = jax.tree_util.tree_multimap(
30+
lambda p, g, n: p + step_size * g + jnp.sqrt(2 * step_size) * n,
31+
position,
32+
logprob_grad,
33+
noise,
34+
)
35+
36+
logprob_grad = logprob_grad_fn(position, *batch)
37+
return DiffusionState(position, logprob_grad)
38+
39+
return one_step

blackjax/mcmc/mala.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""Public API for Metropolis Adjusted Langevin kernels."""
2+
from typing import Callable, NamedTuple, Tuple
3+
4+
import jax
5+
import jax.numpy as jnp
6+
7+
from blackjax.mcmc.diffusion import overdamped_langevin
8+
from blackjax.types import PRNGKey, PyTree
9+
10+
__all__ = ["MALAState", "MALAInfo", "init", "kernel"]
11+
12+
13+
class MALAState(NamedTuple):
14+
"""State of the MALA algorithm.
15+
16+
The MALA algorithm takes one position of the chain and returns another
17+
position. In order to make computations more efficient, we also store
18+
the current log-probability density as well as the current gradient of the
19+
log-probability density.
20+
21+
"""
22+
23+
position: PyTree
24+
logprob_grad: PyTree
25+
26+
27+
class MALAInfo(NamedTuple):
28+
"""Additional information on the MALA transition.
29+
30+
This additional information can be used for debugging or computing
31+
diagnostics.
32+
33+
acceptance_probability
34+
The acceptance probability of the transition.
35+
is_accepted
36+
Whether the proposed position was accepted or the original position
37+
was returned.
38+
39+
"""
40+
41+
acceptance_probability: float
42+
is_accepted: bool
43+
44+
45+
def init(position: PyTree, logprob_fn: Callable) -> MALAState:
46+
grad_fn = jax.grad(logprob_fn)
47+
logprob_grad = grad_fn(position)
48+
return MALAState(position, logprob_grad)
49+
50+
51+
def kernel():
52+
"""Build a MALA kernel.
53+
54+
Returns
55+
-------
56+
A kernel that takes a rng_key and a Pytree that contains the current state
57+
of the chain and that returns a new state of the chain along with
58+
information about the transition.
59+
60+
"""
61+
62+
def transition_probability(state, new_state, step_size):
63+
"""Transition probability to go from `state` to `new_state`"""
64+
theta = new_state.position - state.position - step_size * state.logprob_grad
65+
return -0.25 * (1.0 / step_size) * jnp.dot(theta, theta)
66+
67+
def one_step(
68+
rng_key: PRNGKey, state: MALAState, logprob_fn: Callable, step_size: float
69+
) -> Tuple[MALAState, MALAInfo]:
70+
"""Generate a new sample with the MALA kernel.
71+
72+
TODO expand the docstring.
73+
74+
"""
75+
grad_fn = jax.grad(logprob_fn)
76+
integrator = overdamped_langevin(grad_fn)
77+
78+
key_integrator, key_rmh = jax.random.split(rng_key)
79+
80+
new_state = integrator(key_integrator, state, step_size)
81+
82+
delta = (
83+
logprob_fn(new_state.position)
84+
- logprob_fn(state.position)
85+
+ transition_probability(new_state, state, step_size)
86+
- transition_probability(state, new_state, step_size)
87+
)
88+
delta = jnp.where(jnp.isnan(delta), -jnp.inf, delta)
89+
p_accept = jnp.clip(jnp.exp(delta), a_max=1)
90+
91+
do_accept = jax.random.bernoulli(key_rmh, p_accept)
92+
93+
new_state = MALAState(*new_state)
94+
info = MALAInfo(p_accept, do_accept)
95+
96+
return jax.lax.cond(
97+
do_accept,
98+
lambda _: (new_state, info),
99+
lambda _: (state, info),
100+
operand=None,
101+
)
102+
103+
return one_step

docs/sampling.rst

+6
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Sampling
88

99
hmc
1010
nuts
11+
mala
1112
rmh
1213
tempered_smc
1314
adaptive_tempered_smc
@@ -50,6 +51,11 @@ HMC
5051

5152
.. autoclass:: blackjax.hmc
5253

54+
MALA
55+
~~~~
56+
57+
.. autoclass:: blackjax.mala
58+
5359
NUTS
5460
~~~~
5561

tests/test_sampling.py

+7
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ def test_linear_regression(self, case, is_mass_matrix_diagonal):
121121
"num_sampling_steps": 20_000,
122122
"burnin": 5_000,
123123
},
124+
{
125+
"algorithm": blackjax.mala,
126+
"initial_position": 1.0,
127+
"parameters": {"step_size": 1e-1},
128+
"num_sampling_steps": 20_000,
129+
"burnin": 2_000,
130+
},
124131
]
125132

126133

0 commit comments

Comments
 (0)