|
| 1 | +"""Algorithms to adapt the mass matrix used by algorithms in the Hamiltonian |
| 2 | +Monte Carlo family to the current geometry. |
| 3 | +
|
| 4 | +The Stan Manual [1]_ is a very good reference on automatic tuning of |
| 5 | +parameters used in Hamiltonian Monte Carlo. |
| 6 | +
|
| 7 | +.. [1]: "HMC Algorithm Parameters", Stan Manual |
| 8 | + https://mc-stan.org/docs/2_20/reference-manual/hmc-algorithm-parameters.html |
| 9 | +""" |
| 10 | +from typing import Callable, NamedTuple, Tuple |
| 11 | + |
| 12 | +import jax |
| 13 | +import jax.numpy as jnp |
| 14 | + |
| 15 | +__all__ = ["mass_matrix_adaptation", "welford_algorithm"] |
| 16 | + |
| 17 | + |
| 18 | +class WelfordAlgorithmState(NamedTuple): |
| 19 | + """State carried through the Welford algorithm. |
| 20 | +
|
| 21 | + mean |
| 22 | + The running sample mean. |
| 23 | + m2 |
| 24 | + The running value of the sum of difference of squares. See documentation |
| 25 | + of the `welford_algorithm` function for an explanation. |
| 26 | + sample_size |
| 27 | + The number of successive states the previous values have been computed on; |
| 28 | + also the current number of iterations of the algorithm. |
| 29 | + """ |
| 30 | + |
| 31 | + mean: float |
| 32 | + m2: float |
| 33 | + sample_size: int |
| 34 | + |
| 35 | + |
| 36 | +class MassMatrixAdaptationState(NamedTuple): |
| 37 | + """State carried through the mass matrix adaptation. |
| 38 | +
|
| 39 | + inverse_mass_matrix |
| 40 | + The curent value of the inverse mass matrix. |
| 41 | + wc_state |
| 42 | + The current state of the Welford Algorithm. |
| 43 | + """ |
| 44 | + |
| 45 | + inverse_mass_matrix: jnp.DeviceArray |
| 46 | + wc_state: WelfordAlgorithmState |
| 47 | + |
| 48 | + |
| 49 | +def mass_matrix_adaptation( |
| 50 | + is_diagonal_matrix: bool = True, |
| 51 | +) -> Tuple[Callable, Callable, Callable]: |
| 52 | + """Adapts the values in the mass matrix by computing the covariance |
| 53 | + between parameters. |
| 54 | +
|
| 55 | + Parameters |
| 56 | + ---------- |
| 57 | + is_diagonal_matrix |
| 58 | + When True the algorithm adapts and returns a diagonal mass matrix |
| 59 | + (default), otherwise adaps and returns a dense mass matrix. |
| 60 | +
|
| 61 | + Returns |
| 62 | + ------- |
| 63 | + init |
| 64 | + A function that initializes the step of the mass matrix adaptation. |
| 65 | + update |
| 66 | + A function that updates the state of the mass matrix. |
| 67 | + final |
| 68 | + A function that computes the inverse mass matrix based on the current |
| 69 | + state. |
| 70 | + """ |
| 71 | + wc_init, wc_update, wc_final = welford_algorithm(is_diagonal_matrix) |
| 72 | + |
| 73 | + def init(n_dims: int) -> MassMatrixAdaptationState: |
| 74 | + """Initialize the matrix adaptation. |
| 75 | +
|
| 76 | + Parameters |
| 77 | + ---------- |
| 78 | + ndims |
| 79 | + The number of dimensions of the mass matrix, which corresponds to |
| 80 | + the number of dimensions of the chain position. |
| 81 | + """ |
| 82 | + if is_diagonal_matrix: |
| 83 | + inverse_mass_matrix = jnp.ones(n_dims) |
| 84 | + else: |
| 85 | + inverse_mass_matrix = jnp.identity(n_dims) |
| 86 | + |
| 87 | + wc_state = wc_init(n_dims) |
| 88 | + |
| 89 | + return MassMatrixAdaptationState(inverse_mass_matrix, wc_state) |
| 90 | + |
| 91 | + def update( |
| 92 | + mm_state: MassMatrixAdaptationState, position: jnp.DeviceArray |
| 93 | + ) -> MassMatrixAdaptationState: |
| 94 | + """Update the algorithm's state. |
| 95 | +
|
| 96 | + Parameters |
| 97 | + ---------- |
| 98 | + state: |
| 99 | + The current state of the mass matrix adapation. |
| 100 | + position: |
| 101 | + The current position of the chain. |
| 102 | + """ |
| 103 | + inverse_mass_matrix, wc_state = mm_state |
| 104 | + position, _ = jax.flatten_util.ravel_pytree(position) |
| 105 | + wc_state = wc_update(wc_state, position) |
| 106 | + return MassMatrixAdaptationState(inverse_mass_matrix, wc_state) |
| 107 | + |
| 108 | + def final(mm_state: MassMatrixAdaptationState) -> MassMatrixAdaptationState: |
| 109 | + """Final iteration of the mass matrix adaptation. |
| 110 | +
|
| 111 | + In this step we compute the mass matrix from the covariance matrix computed |
| 112 | + by the Welford algorithm, and re-initialize the later. |
| 113 | + """ |
| 114 | + _, wc_state = mm_state |
| 115 | + covariance, count, mean = wc_final(wc_state) |
| 116 | + |
| 117 | + # Regularize the covariance matrix, see Stan |
| 118 | + scaled_covariance = (count / (count + 5)) * covariance |
| 119 | + shrinkage = 1e-3 * (5 / (count + 5)) |
| 120 | + if is_diagonal_matrix: |
| 121 | + inverse_mass_matrix = scaled_covariance + shrinkage |
| 122 | + else: |
| 123 | + inverse_mass_matrix = scaled_covariance + shrinkage * jnp.identity( |
| 124 | + mean.shape[0] |
| 125 | + ) |
| 126 | + |
| 127 | + ndims = jnp.shape(inverse_mass_matrix)[-1] |
| 128 | + new_mm_state = MassMatrixAdaptationState(inverse_mass_matrix, wc_init(ndims)) |
| 129 | + |
| 130 | + return new_mm_state |
| 131 | + |
| 132 | + return init, update, final |
| 133 | + |
| 134 | + |
| 135 | +def welford_algorithm(is_diagonal_matrix: bool) -> Tuple[Callable, Callable, Callable]: |
| 136 | + """Welford's online estimator of covariance. |
| 137 | +
|
| 138 | + It is possible to compute the variance of a population of values in an |
| 139 | + on-line fashion to avoid storing intermediate results. The naive recurrence |
| 140 | + relations between the sample mean and variance at a step and the next are |
| 141 | + however not numerically stable. |
| 142 | +
|
| 143 | + Welford's algorithm uses the sum of square of differences |
| 144 | + :math:`M_{2,n} = \\sum_{i=1}^n \\left(x_i-\\overline{x_n}\right)^2` |
| 145 | + for updating where :math:`x_n` is the current mean and the following |
| 146 | + recurrence relationships |
| 147 | +
|
| 148 | + Parameters |
| 149 | + ---------- |
| 150 | + is_diagonal_matrix |
| 151 | + When True the algorithm adapts and returns a diagonal mass matrix |
| 152 | + (default), otherwise adaps and returns a dense mass matrix. |
| 153 | +
|
| 154 | + Note |
| 155 | + ---- |
| 156 | + It might seem pedantic to separate the Welford algorithm from mass adaptation, |
| 157 | + but this covariance estimator is used in other parts of the library. |
| 158 | +
|
| 159 | + .. math: |
| 160 | + M_{2,n} = M_{2, n-1} + (x_n-\\overline{x}_{n-1})(x_n-\\overline{x}_n) |
| 161 | + \\sigma_n^2 = \\frac{M_{2,n}}{n} |
| 162 | + """ |
| 163 | + |
| 164 | + def init(n_dims: int) -> WelfordAlgorithmState: |
| 165 | + """Initialize the covariance estimation. |
| 166 | +
|
| 167 | + When the matrix is diagonal it is sufficient to work with an array that contains |
| 168 | + the diagonal value. Otherwise we need to work with the matrix in full. |
| 169 | +
|
| 170 | + Parameters |
| 171 | + ---------- |
| 172 | + n_dims: int |
| 173 | + The number of dimensions of the problem, which corresponds to the size |
| 174 | + of the corresponding square mass matrix. |
| 175 | + """ |
| 176 | + sample_size = 0 |
| 177 | + mean = jnp.zeros(n_dims) |
| 178 | + if is_diagonal_matrix: |
| 179 | + m2 = jnp.zeros(n_dims) |
| 180 | + else: |
| 181 | + m2 = jnp.zeros((n_dims, n_dims)) |
| 182 | + return WelfordAlgorithmState(mean, m2, sample_size) |
| 183 | + |
| 184 | + def update( |
| 185 | + wa_state: WelfordAlgorithmState, value: jnp.DeviceArray |
| 186 | + ) -> WelfordAlgorithmState: |
| 187 | + """Update the M2 matrix using the new value. |
| 188 | +
|
| 189 | + Parameters |
| 190 | + ---------- |
| 191 | + state: |
| 192 | + The current state of the Welford Algorithm |
| 193 | + position: jax.numpy.DeviceArray, shape (1,) |
| 194 | + The new sample (typically position of the chain) used to update m2 |
| 195 | + """ |
| 196 | + mean, m2, sample_size = wa_state |
| 197 | + sample_size = sample_size + 1 |
| 198 | + |
| 199 | + delta = value - mean |
| 200 | + mean = mean + delta / sample_size |
| 201 | + updated_delta = value - mean |
| 202 | + if is_diagonal_matrix: |
| 203 | + new_m2 = m2 + delta * updated_delta |
| 204 | + else: |
| 205 | + new_m2 = m2 + jnp.outer(updated_delta, delta) |
| 206 | + |
| 207 | + return WelfordAlgorithmState(mean, new_m2, sample_size) |
| 208 | + |
| 209 | + def final( |
| 210 | + wa_state: WelfordAlgorithmState, |
| 211 | + ) -> Tuple[jnp.DeviceArray, int, jnp.DeviceArray]: |
| 212 | + mean, m2, sample_size = wa_state |
| 213 | + covariance = m2 / (sample_size - 1) |
| 214 | + return covariance, sample_size, mean |
| 215 | + |
| 216 | + return init, update, final |
0 commit comments