Skip to content

Commit 7be2822

Browse files
authored
Stan warmup (blackjax-devs#16)
1 parent c6f75e9 commit 7be2822

21 files changed

+1567
-347
lines changed

.github/workflows/test.yml

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ jobs:
2525
run: |
2626
python -m pip install --upgrade pip
2727
pip install .
28+
pip install -r requirements-jax.txt
2829
less requirements-dev.txt | grep pytest | xargs -i -t pip install {}
2930
3031
- name: Run the tests with pytest

Makefile

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
PKG_VERSION = $(shell python setup.py --version)
22

33
lint:
4+
isort --profile black blackjax tests
5+
black blackjax tests
6+
7+
lint_check:
48
isort --profile black --check blackjax tests
59
flake8 blackjax tests --count --ignore=E501,E203,E731,W503 --show-source --statistics
610
black --check blackjax tests

README.md

+12-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
BlackJAX is a library of samplers for [JAX](https://github.com/google/jax) that
99
works on CPU as well as GPU.
1010

11-
It is *not* a probabilistic programming library. However it integrates really
11+
It is *not* a probabilistic programming library. However it integrates really
1212
well with PPLs as long as they can provide a (potentially unnormalized)
1313
log-probability density function compatible with JAX.
1414

@@ -17,6 +17,7 @@ log-probability density function compatible with JAX.
1717
BlackJAX should appeal to those who:
1818
- Have a logpdf and just need a sampler;
1919
- Need more than a general-purpose sampler;
20+
- Want to sample on GPU;
2021
- Want to build upon robust elementary blocks for their research;
2122
- Are building a PPL;
2223
- Want to learn how sampling algorithms work.
@@ -25,9 +26,16 @@ BlackJAX should appeal to those who:
2526

2627
### Installation
2728

28-
BlackJAX is written in pure Python but depends on XLA via JAX. JAX installation
29-
is different depending on whether you want GPU support and your CUDA version,
30-
follow [these instructions](https://github.com/google/jax#installation) to install JAX with the relevant hardware acceleration support.
29+
BlackJAX is written in pure Python but depends on XLA via JAX. Since the JAX
30+
installation depends on your CUDA version BlackJAX does not list JAX as a
31+
dependency. If you simply want to use JAX on CPU, install it with:
32+
33+
```python
34+
pip install jax jaxlib
35+
```
36+
37+
Follow [these instructions](https://github.com/google/jax#installation) to
38+
install JAX with the relevant hardware acceleration support.
3139

3240
Then install BlackJAX
3341

blackjax/adaptation/__init__.py

Whitespace-only changes.

blackjax/adaptation/mass_matrix.py

+216
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
"""Algorithms to adapt the mass matrix used by algorithms in the Hamiltonian
2+
Monte Carlo family to the current geometry.
3+
4+
The Stan Manual [1]_ is a very good reference on automatic tuning of
5+
parameters used in Hamiltonian Monte Carlo.
6+
7+
.. [1]: "HMC Algorithm Parameters", Stan Manual
8+
https://mc-stan.org/docs/2_20/reference-manual/hmc-algorithm-parameters.html
9+
"""
10+
from typing import Callable, NamedTuple, Tuple
11+
12+
import jax
13+
import jax.numpy as jnp
14+
15+
__all__ = ["mass_matrix_adaptation", "welford_algorithm"]
16+
17+
18+
class WelfordAlgorithmState(NamedTuple):
19+
"""State carried through the Welford algorithm.
20+
21+
mean
22+
The running sample mean.
23+
m2
24+
The running value of the sum of difference of squares. See documentation
25+
of the `welford_algorithm` function for an explanation.
26+
sample_size
27+
The number of successive states the previous values have been computed on;
28+
also the current number of iterations of the algorithm.
29+
"""
30+
31+
mean: float
32+
m2: float
33+
sample_size: int
34+
35+
36+
class MassMatrixAdaptationState(NamedTuple):
37+
"""State carried through the mass matrix adaptation.
38+
39+
inverse_mass_matrix
40+
The curent value of the inverse mass matrix.
41+
wc_state
42+
The current state of the Welford Algorithm.
43+
"""
44+
45+
inverse_mass_matrix: jnp.DeviceArray
46+
wc_state: WelfordAlgorithmState
47+
48+
49+
def mass_matrix_adaptation(
50+
is_diagonal_matrix: bool = True,
51+
) -> Tuple[Callable, Callable, Callable]:
52+
"""Adapts the values in the mass matrix by computing the covariance
53+
between parameters.
54+
55+
Parameters
56+
----------
57+
is_diagonal_matrix
58+
When True the algorithm adapts and returns a diagonal mass matrix
59+
(default), otherwise adaps and returns a dense mass matrix.
60+
61+
Returns
62+
-------
63+
init
64+
A function that initializes the step of the mass matrix adaptation.
65+
update
66+
A function that updates the state of the mass matrix.
67+
final
68+
A function that computes the inverse mass matrix based on the current
69+
state.
70+
"""
71+
wc_init, wc_update, wc_final = welford_algorithm(is_diagonal_matrix)
72+
73+
def init(n_dims: int) -> MassMatrixAdaptationState:
74+
"""Initialize the matrix adaptation.
75+
76+
Parameters
77+
----------
78+
ndims
79+
The number of dimensions of the mass matrix, which corresponds to
80+
the number of dimensions of the chain position.
81+
"""
82+
if is_diagonal_matrix:
83+
inverse_mass_matrix = jnp.ones(n_dims)
84+
else:
85+
inverse_mass_matrix = jnp.identity(n_dims)
86+
87+
wc_state = wc_init(n_dims)
88+
89+
return MassMatrixAdaptationState(inverse_mass_matrix, wc_state)
90+
91+
def update(
92+
mm_state: MassMatrixAdaptationState, position: jnp.DeviceArray
93+
) -> MassMatrixAdaptationState:
94+
"""Update the algorithm's state.
95+
96+
Parameters
97+
----------
98+
state:
99+
The current state of the mass matrix adapation.
100+
position:
101+
The current position of the chain.
102+
"""
103+
inverse_mass_matrix, wc_state = mm_state
104+
position, _ = jax.flatten_util.ravel_pytree(position)
105+
wc_state = wc_update(wc_state, position)
106+
return MassMatrixAdaptationState(inverse_mass_matrix, wc_state)
107+
108+
def final(mm_state: MassMatrixAdaptationState) -> MassMatrixAdaptationState:
109+
"""Final iteration of the mass matrix adaptation.
110+
111+
In this step we compute the mass matrix from the covariance matrix computed
112+
by the Welford algorithm, and re-initialize the later.
113+
"""
114+
_, wc_state = mm_state
115+
covariance, count, mean = wc_final(wc_state)
116+
117+
# Regularize the covariance matrix, see Stan
118+
scaled_covariance = (count / (count + 5)) * covariance
119+
shrinkage = 1e-3 * (5 / (count + 5))
120+
if is_diagonal_matrix:
121+
inverse_mass_matrix = scaled_covariance + shrinkage
122+
else:
123+
inverse_mass_matrix = scaled_covariance + shrinkage * jnp.identity(
124+
mean.shape[0]
125+
)
126+
127+
ndims = jnp.shape(inverse_mass_matrix)[-1]
128+
new_mm_state = MassMatrixAdaptationState(inverse_mass_matrix, wc_init(ndims))
129+
130+
return new_mm_state
131+
132+
return init, update, final
133+
134+
135+
def welford_algorithm(is_diagonal_matrix: bool) -> Tuple[Callable, Callable, Callable]:
136+
"""Welford's online estimator of covariance.
137+
138+
It is possible to compute the variance of a population of values in an
139+
on-line fashion to avoid storing intermediate results. The naive recurrence
140+
relations between the sample mean and variance at a step and the next are
141+
however not numerically stable.
142+
143+
Welford's algorithm uses the sum of square of differences
144+
:math:`M_{2,n} = \\sum_{i=1}^n \\left(x_i-\\overline{x_n}\right)^2`
145+
for updating where :math:`x_n` is the current mean and the following
146+
recurrence relationships
147+
148+
Parameters
149+
----------
150+
is_diagonal_matrix
151+
When True the algorithm adapts and returns a diagonal mass matrix
152+
(default), otherwise adaps and returns a dense mass matrix.
153+
154+
Note
155+
----
156+
It might seem pedantic to separate the Welford algorithm from mass adaptation,
157+
but this covariance estimator is used in other parts of the library.
158+
159+
.. math:
160+
M_{2,n} = M_{2, n-1} + (x_n-\\overline{x}_{n-1})(x_n-\\overline{x}_n)
161+
\\sigma_n^2 = \\frac{M_{2,n}}{n}
162+
"""
163+
164+
def init(n_dims: int) -> WelfordAlgorithmState:
165+
"""Initialize the covariance estimation.
166+
167+
When the matrix is diagonal it is sufficient to work with an array that contains
168+
the diagonal value. Otherwise we need to work with the matrix in full.
169+
170+
Parameters
171+
----------
172+
n_dims: int
173+
The number of dimensions of the problem, which corresponds to the size
174+
of the corresponding square mass matrix.
175+
"""
176+
sample_size = 0
177+
mean = jnp.zeros(n_dims)
178+
if is_diagonal_matrix:
179+
m2 = jnp.zeros(n_dims)
180+
else:
181+
m2 = jnp.zeros((n_dims, n_dims))
182+
return WelfordAlgorithmState(mean, m2, sample_size)
183+
184+
def update(
185+
wa_state: WelfordAlgorithmState, value: jnp.DeviceArray
186+
) -> WelfordAlgorithmState:
187+
"""Update the M2 matrix using the new value.
188+
189+
Parameters
190+
----------
191+
state:
192+
The current state of the Welford Algorithm
193+
position: jax.numpy.DeviceArray, shape (1,)
194+
The new sample (typically position of the chain) used to update m2
195+
"""
196+
mean, m2, sample_size = wa_state
197+
sample_size = sample_size + 1
198+
199+
delta = value - mean
200+
mean = mean + delta / sample_size
201+
updated_delta = value - mean
202+
if is_diagonal_matrix:
203+
new_m2 = m2 + delta * updated_delta
204+
else:
205+
new_m2 = m2 + jnp.outer(updated_delta, delta)
206+
207+
return WelfordAlgorithmState(mean, new_m2, sample_size)
208+
209+
def final(
210+
wa_state: WelfordAlgorithmState,
211+
) -> Tuple[jnp.DeviceArray, int, jnp.DeviceArray]:
212+
mean, m2, sample_size = wa_state
213+
covariance = m2 / (sample_size - 1)
214+
return covariance, sample_size, mean
215+
216+
return init, update, final

0 commit comments

Comments
 (0)