Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for whitening in sampling #26

Merged
merged 1 commit into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/notebooks/efficient_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
kernel_with_features,
inducing_variable,
num_data,
whiten=False,
whiten=True,
num_latent_gps=1,
mean_function=gpflow.mean_functions.Zero(),
)
Expand Down
2 changes: 1 addition & 1 deletion gpflux/experiment_support/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
#
from typing import List, Optional, Sequence

import numpy as np
import matplotlib.pyplot as plt
import numpy as np

from gpflow.base import TensorType

Expand Down
14 changes: 6 additions & 8 deletions gpflux/sampling/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def _efficient_sample_conditional_gaussian(
Most costly implementation for obtaining a consistent GP sample.
However, this method can be used for any kernel.
"""
assert not whiten, "Currently only whiten=False is supported"

class SampleConditional(Sample):
# N_old is 0 at first, we then start keeping track of past evaluation points.
Expand Down Expand Up @@ -155,13 +154,7 @@ def _efficient_sample_matheron_rule(
:param q_mu: A tensor with the shape ``[M, P]``.
:param q_sqrt: A tensor with the shape ``[P, M, M]``.
:param whiten: Determines the parameterisation of the inducing variables.
If True, ``p(u) = N(0, I)``, otherwise ``p(u) = N(0, Kuu)``.
.. note:: Currenly, only *whiten* equals ``False`` is supported.
"""
# TODO(VD): allow for both whiten=True and False, currently only support False.
# Remember u = Luu v, with Kuu = Luu Luu^T and p(v) = N(0, I)
# so that p(u) = N(0, Luu Luu^T) = N(0, Kuu).
assert not whiten, "Currently only whiten=False is supported"
L = tf.shape(kernel.feature_coefficients)[0] # num eigenfunctions # noqa: F841

prior_weights = tf.sqrt(kernel.feature_coefficients) * tf.random.normal(
Expand All @@ -173,9 +166,14 @@ def _efficient_sample_matheron_rule(
q_sqrt,
tf.random.normal((P, M, 1), dtype=default_float()), # [P, M, M] # [P, M, 1]
) # [P, M, 1]
u_sample = q_mu + tf.linalg.matrix_transpose(u_sample_noise[..., 0]) # [M, P]
Kmm = Kuu(inducing_variable, kernel, jitter=default_jitter()) # [M, M]
tf.debugging.assert_equal(tf.shape(Kmm), [M, M])
u_sample = q_mu + tf.linalg.matrix_transpose(u_sample_noise[..., 0]) # [M, P]

if whiten:
Luu = tf.linalg.cholesky(Kmm) # [M, M]
u_sample = tf.matmul(Luu, u_sample) # [M, P]

phi_Z = kernel.feature_functions(inducing_variable.Z) # [M, L]
weight_space_prior_Z = phi_Z @ prior_weights # [M, 1]
diff = u_sample - weight_space_prior_Z # [M, P] -- using implicit broadcasting
Expand Down
13 changes: 9 additions & 4 deletions tests/gpflux/sampling/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def _inducing_variable_fixture():
return gpflow.inducing_variables.InducingPoints(Z)


@pytest.fixture(name="whiten", params=[True, False])
def _whiten_fixture(request):
return request.param


def _get_qmu_qsqrt(kernel, inducing_variable):
"""Returns q_mu and q_sqrt for a kernel and inducing_variable"""
Z = inducing_variable.Z.numpy()
Expand All @@ -45,7 +50,7 @@ def _get_qmu_qsqrt(kernel, inducing_variable):
return q_mu, q_sqrt


def test_conditional_sample(kernel, inducing_variable):
def test_conditional_sample(kernel, inducing_variable, whiten):
"""Smoke and consistency test for efficient sampling using MVN Conditioning"""
q_mu, q_sqrt = _get_qmu_qsqrt(kernel, inducing_variable)

Expand All @@ -54,7 +59,7 @@ def test_conditional_sample(kernel, inducing_variable):
kernel,
q_mu,
q_sqrt=1e-3 * tf.convert_to_tensor(q_sqrt[np.newaxis]),
whiten=False,
whiten=whiten,
)

X = np.linspace(-1, 1, 100).reshape(-1, 1)
Expand All @@ -70,7 +75,7 @@ def test_conditional_sample(kernel, inducing_variable):
)


def test_wilson_efficient_sample(kernel, inducing_variable):
def test_wilson_efficient_sample(kernel, inducing_variable, whiten):
"""Smoke and consistency test for efficient sampling using Wilson"""
eigenfunctions = RandomFourierFeatures(kernel, 100, dtype=default_float())
eigenvalues = np.ones((100, 1), dtype=default_float())
Expand All @@ -83,7 +88,7 @@ def test_wilson_efficient_sample(kernel, inducing_variable):
kernel2,
q_mu,
q_sqrt=1e-3 * tf.convert_to_tensor(q_sqrt[np.newaxis]),
whiten=False,
whiten=whiten,
)

X = np.linspace(-1, 0, 100).reshape(-1, 1)
Expand Down