diff --git a/docs/notebooks/efficient_sampling.py b/docs/notebooks/efficient_sampling.py index d59b36cd..79420344 100644 --- a/docs/notebooks/efficient_sampling.py +++ b/docs/notebooks/efficient_sampling.py @@ -87,7 +87,7 @@ kernel_with_features, inducing_variable, num_data, - whiten=False, + whiten=True, num_latent_gps=1, mean_function=gpflow.mean_functions.Zero(), ) diff --git a/gpflux/experiment_support/plotting.py b/gpflux/experiment_support/plotting.py index accfc44d..bf656b5e 100644 --- a/gpflux/experiment_support/plotting.py +++ b/gpflux/experiment_support/plotting.py @@ -15,8 +15,8 @@ # from typing import List, Optional, Sequence -import numpy as np import matplotlib.pyplot as plt +import numpy as np from gpflow.base import TensorType diff --git a/gpflux/sampling/sample.py b/gpflux/sampling/sample.py index 6cd07381..26eaa333 100644 --- a/gpflux/sampling/sample.py +++ b/gpflux/sampling/sample.py @@ -95,7 +95,6 @@ def _efficient_sample_conditional_gaussian( Most costly implementation for obtaining a consistent GP sample. However, this method can be used for any kernel. """ - assert not whiten, "Currently only whiten=False is supported" class SampleConditional(Sample): # N_old is 0 at first, we then start keeping track of past evaluation points. @@ -155,13 +154,7 @@ def _efficient_sample_matheron_rule( :param q_mu: A tensor with the shape ``[M, P]``. :param q_sqrt: A tensor with the shape ``[P, M, M]``. :param whiten: Determines the parameterisation of the inducing variables. - If True, ``p(u) = N(0, I)``, otherwise ``p(u) = N(0, Kuu)``. - .. note:: Currenly, only *whiten* equals ``False`` is supported. """ - # TODO(VD): allow for both whiten=True and False, currently only support False. - # Remember u = Luu v, with Kuu = Luu Luu^T and p(v) = N(0, I) - # so that p(u) = N(0, Luu Luu^T) = N(0, Kuu). - assert not whiten, "Currently only whiten=False is supported" L = tf.shape(kernel.feature_coefficients)[0] # num eigenfunctions # noqa: F841 prior_weights = tf.sqrt(kernel.feature_coefficients) * tf.random.normal( @@ -173,9 +166,14 @@ def _efficient_sample_matheron_rule( q_sqrt, tf.random.normal((P, M, 1), dtype=default_float()), # [P, M, M] # [P, M, 1] ) # [P, M, 1] - u_sample = q_mu + tf.linalg.matrix_transpose(u_sample_noise[..., 0]) # [M, P] Kmm = Kuu(inducing_variable, kernel, jitter=default_jitter()) # [M, M] tf.debugging.assert_equal(tf.shape(Kmm), [M, M]) + u_sample = q_mu + tf.linalg.matrix_transpose(u_sample_noise[..., 0]) # [M, P] + + if whiten: + Luu = tf.linalg.cholesky(Kmm) # [M, M] + u_sample = tf.matmul(Luu, u_sample) # [M, P] + phi_Z = kernel.feature_functions(inducing_variable.Z) # [M, L] weight_space_prior_Z = phi_Z @ prior_weights # [M, 1] diff = u_sample - weight_space_prior_Z # [M, P] -- using implicit broadcasting diff --git a/tests/gpflux/sampling/test_sample.py b/tests/gpflux/sampling/test_sample.py index e92379fa..50096d01 100644 --- a/tests/gpflux/sampling/test_sample.py +++ b/tests/gpflux/sampling/test_sample.py @@ -36,6 +36,11 @@ def _inducing_variable_fixture(): return gpflow.inducing_variables.InducingPoints(Z) +@pytest.fixture(name="whiten", params=[True, False]) +def _whiten_fixture(request): + return request.param + + def _get_qmu_qsqrt(kernel, inducing_variable): """Returns q_mu and q_sqrt for a kernel and inducing_variable""" Z = inducing_variable.Z.numpy() @@ -45,7 +50,7 @@ def _get_qmu_qsqrt(kernel, inducing_variable): return q_mu, q_sqrt -def test_conditional_sample(kernel, inducing_variable): +def test_conditional_sample(kernel, inducing_variable, whiten): """Smoke and consistency test for efficient sampling using MVN Conditioning""" q_mu, q_sqrt = _get_qmu_qsqrt(kernel, inducing_variable) @@ -54,7 +59,7 @@ def test_conditional_sample(kernel, inducing_variable): kernel, q_mu, q_sqrt=1e-3 * tf.convert_to_tensor(q_sqrt[np.newaxis]), - whiten=False, + whiten=whiten, ) X = np.linspace(-1, 1, 100).reshape(-1, 1) @@ -70,7 +75,7 @@ def test_conditional_sample(kernel, inducing_variable): ) -def test_wilson_efficient_sample(kernel, inducing_variable): +def test_wilson_efficient_sample(kernel, inducing_variable, whiten): """Smoke and consistency test for efficient sampling using Wilson""" eigenfunctions = RandomFourierFeatures(kernel, 100, dtype=default_float()) eigenvalues = np.ones((100, 1), dtype=default_float()) @@ -83,7 +88,7 @@ def test_wilson_efficient_sample(kernel, inducing_variable): kernel2, q_mu, q_sqrt=1e-3 * tf.convert_to_tensor(q_sqrt[np.newaxis]), - whiten=False, + whiten=whiten, ) X = np.linspace(-1, 0, 100).reshape(-1, 1)