Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quadrature Fourier features #56

Merged
merged 23 commits into from
Oct 22, 2021
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/notebooks/efficient_posterior_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@

that combines both feature approximations and exact kernel evaluations from the Matheron function and weight space approximation formulas above.

The subsequent experiments demonstrate the qualitative efficiency of the hybrid rule when compared to the vanilla Matheron weight space approximation, in terms of the Wasserstein distance to the exact posterior GP. To conduct these experiments, the required classes in `gpflux` are `RandomFourierFeatures`, to approximate a stationary kernel with finitely many random Fourier features $\phi_d(\cdot)$ according to Bochner's theorem and following Rahimi and Recht "Random features for large-scale kernel machines" (NeurIPS, 2007), and `KernelWithFeatureDecomposition`, to approximate a kernel with a specified set of feature functions.
The subsequent experiments demonstrate the qualitative efficiency of the hybrid rule when compared to the vanilla Matheron weight space approximation, in terms of the Wasserstein distance to the exact posterior GP. To conduct these experiments, the required classes in `gpflux` are `RandomFourierFeaturesCosine`, to approximate a stationary kernel with finitely many random Fourier features $\phi_d(\cdot)$ according to Bochner's theorem and following Rahimi and Recht "Random features for large-scale kernel machines" (NeurIPS, 2007), and `KernelWithFeatureDecomposition`, to approximate a kernel with a specified set of feature functions.
"""

# %%
Expand All @@ -79,7 +79,7 @@
from gpflow.kernels import RBF, Matern52
from gpflow.models import GPR

from gpflux.layers.basis_functions.random_fourier_features import RandomFourierFeatures
from gpflux.layers.basis_functions.fourier_features import RandomFourierFeaturesCosine
from gpflux.sampling.kernel_with_feature_decomposition import KernelWithFeatureDecomposition

# %% [markdown]
Expand Down Expand Up @@ -253,9 +253,9 @@ def conduct_experiment(num_input_dimensions, num_train_samples, num_features):
exact_kernel = kernel_class(lengthscales=lengthscale)

# weight space approximated kernel
feature_functions = RandomFourierFeatures(
feature_functions = RandomFourierFeaturesCosine(
kernel=kernel_class(lengthscales=lengthscale),
output_dim=num_features,
n_components=num_features,
dtype=default_float(),
)
feature_coefficients = np.ones((num_features, 1), dtype=default_float())
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/efficient_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from gpflow.config import default_float

from gpflux.layers.basis_functions.random_fourier_features import RandomFourierFeatures
from gpflux.layers.basis_functions.fourier_features import RandomFourierFeaturesCosine
from gpflux.sampling import KernelWithFeatureDecomposition
from gpflux.models.deep_gp import sample_dgp

Expand Down Expand Up @@ -71,7 +71,7 @@
gpflow.utilities.set_trainable(inducing_variable, False)

num_rff = 1000
eigenfunctions = RandomFourierFeatures(kernel, num_rff, dtype=default_float())
eigenfunctions = RandomFourierFeaturesCosine(kernel, num_rff, dtype=default_float())
eigenvalues = np.ones((num_rff, 1), dtype=default_float())
kernel_with_features = KernelWithFeatureDecomposition(kernel, eigenfunctions, eigenvalues)

Expand Down
12 changes: 6 additions & 6 deletions docs/notebooks/weight_space_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

The advantage of expressing a Gaussian process in weight space is that functions are represented as weight vectors $\textbf{w}$ (rather than actual functions $f(\cdot)$) from which samples can be obtained a priori without knowing where the function should be evaluated. When expressing a Gaussian process in function space view the latter is not possible, i.e. a function $f(\cdot)$ cannot be sampled without knowing where to evaluate the function, namely at $\{X_{n^\star}^\star\}_{n^\star=1,...,N^\star}$. Weight space approximated Gaussian processes therefore hold the potential to sample efficiently from Gaussian process posteriors, which is desirable in vanilla supervised learning but also in domains such as Bayesian optimisation or model-based reinforcement learning.

In the following example, we compare a weight space approximated GPR model (WSA model) with both a proper GPR model and a sparse variational Gaussian Process model (SVGP). GPR models and SVGP models are implemented in `gpflow`, but the two necessary ingredients for building the WSA model are part of `gpflux`: these are random Fourier feature functions via the `RandomFourierFeatures` class, and approximate kernels based on Bochner's theorem (or any other theorem that approximates a kernel with a finite number of feature functions, e.g. Mercer) via the `KernelWithFeatureDecomposition` class.
In the following example, we compare a weight space approximated GPR model (WSA model) with both a proper GPR model and a sparse variational Gaussian Process model (SVGP). GPR models and SVGP models are implemented in `gpflow`, but the two necessary ingredients for building the WSA model are part of `gpflux`: these are random Fourier feature functions via the `RandomFourierFeaturesCosine` class, and approximate kernels based on Bochner's theorem (or any other theorem that approximates a kernel with a finite number of feature functions, e.g. Mercer) via the `KernelWithFeatureDecomposition` class.
"""

# %%
Expand All @@ -61,7 +61,7 @@
from gpflow.likelihoods import Gaussian
from gpflow.inducing_variables import InducingPoints

from gpflux.layers.basis_functions.random_fourier_features import RandomFourierFeatures
from gpflux.layers.basis_functions.fourier_features import RandomFourierFeaturesCosine
from gpflux.sampling.kernel_with_feature_decomposition import KernelWithFeatureDecomposition

# %% [markdown]
Expand All @@ -77,7 +77,7 @@
# experiment parameters that are the same for both sets of experiments
X_interval = [0.14, 0.5] # interval where training points live
lengthscale = 0.1 # lengthscale for the kernel (which is not learned in all experiments, the kernel variance is 1)
number_of_basis_functions = 2000 # number of basis functions for weight-space approximated kernels
number_of_features = 2000 # number of basis functions for weight-space approximated kernels
noise_variance = 1e-3 # noise variance of the likelihood (which is not learned in all experiments)
number_of_test_samples = 1024 # number of evaluation points for prediction
number_of_function_samples = (
Expand Down Expand Up @@ -264,12 +264,12 @@ def optimize_model_with_scipy(model):
)

# create exact GPR model with weight-space approximated kernel (WSA model)
feature_functions = RandomFourierFeatures(
feature_functions = RandomFourierFeaturesCosine(
kernel=kernel_class(lengthscales=lengthscale),
output_dim=number_of_basis_functions,
n_components=number_of_features,
dtype=default_float(),
)
feature_coefficients = np.ones((number_of_basis_functions, 1), dtype=default_float())
feature_coefficients = np.ones((number_of_features, 1), dtype=default_float())
kernel = KernelWithFeatureDecomposition(
kernel=None, feature_functions=feature_functions, feature_coefficients=feature_coefficients
)
Expand Down
5 changes: 1 addition & 4 deletions gpflux/layers/basis_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,5 @@
# limitations under the License.
#
"""
A kernel's features for efficient sampling, used by
:class:`gpflux.sampling.KernelWithFeatureDecomposition`
Basis functions.
"""

from gpflux.layers.basis_functions.random_fourier_features import RandomFourierFeatures
27 changes: 27 additions & 0 deletions gpflux/layers/basis_functions/fourier_features/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#
# Copyright (c) 2021 The GPflux Contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
A kernel's features for efficient sampling, used by
:class:`gpflux.sampling.KernelWithFeatureDecomposition`
"""

from gpflux.layers.basis_functions.fourier_features.quadrature import QuadratureFourierFeatures
from gpflux.layers.basis_functions.fourier_features.random import (
RandomFourierFeatures,
RandomFourierFeaturesCosine,
)

__all__ = ["QuadratureFourierFeatures", "RandomFourierFeatures", "RandomFourierFeaturesCosine"]
101 changes: 101 additions & 0 deletions gpflux/layers/basis_functions/fourier_features/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#
# Copyright (c) 2021 The GPflux Contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
""" Shared functionality for stationary kernel basis functions. """

from abc import ABC, abstractmethod
ltiao marked this conversation as resolved.
Show resolved Hide resolved
from typing import Mapping

import tensorflow as tf

import gpflow
from gpflow.base import TensorType

from gpflux.types import ShapeType


class FourierFeaturesBase(ABC, tf.keras.layers.Layer):
def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: Mapping):
"""
:param kernel: kernel to approximate using a set of random features.
:param output_dim: total number of basis functions used to approximate
ltiao marked this conversation as resolved.
Show resolved Hide resolved
the kernel.
"""
super(FourierFeaturesBase, self).__init__(**kwargs)
self.kernel = kernel
self.n_components = n_components # M: number of Monte Carlo samples
if kwargs.get("input_dim", None):
self._input_dim = kwargs["input_dim"]
self.build(tf.TensorShape([self._input_dim]))
else:
self._input_dim = None

def call(self, inputs: TensorType) -> tf.Tensor:
"""
Evaluate the basis functions at ``inputs``.

:param inputs: The evaluation points, a tensor with the shape ``[N, D]``.

:return: A tensor with the shape ``[N, M]``.
"""
X = tf.divide(inputs, self.kernel.lengthscales) # [N, D]
const = self._compute_constant()
bases = self._compute_bases(X)
output = const * bases
tf.ensure_shape(output, self.compute_output_shape(inputs.shape))
return output

def compute_output_shape(self, input_shape: ShapeType) -> tf.TensorShape:
"""
Computes the output shape of the layer.
See `tf.keras.layers.Layer.compute_output_shape()
<https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#compute_output_shape>`_.
"""
# TODO: Keras docs say "If the layer has not been built, this method
# will call `build` on the layer." -- do we need to do so?
tensor_shape = tf.TensorShape(input_shape).with_rank(2)
output_dim = self._compute_output_dim(input_shape)
return tensor_shape[:-1].concatenate(output_dim)

def get_config(self) -> Mapping:
"""
Returns the config of the layer.
See `tf.keras.layers.Layer.get_config()
<https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#get_config>`_.
"""
config = super(FourierFeaturesBase, self).get_config()
config.update(
{"kernel": self.kernel, "n_components": self.n_components, "input_dim": self._input_dim}
)

return config

@abstractmethod
def _compute_output_dim(self, input_shape: ShapeType) -> int:
pass

@abstractmethod
def _compute_constant(self) -> tf.Tensor:
"""
Compute normalizing constant for basis functions.
ltiao marked this conversation as resolved.
Show resolved Hide resolved
"""
pass

@abstractmethod
def _compute_bases(self, inputs: TensorType) -> tf.Tensor:
ltiao marked this conversation as resolved.
Show resolved Hide resolved
"""
Compute basis functions.
"""
pass
71 changes: 71 additions & 0 deletions gpflux/layers/basis_functions/fourier_features/quadrature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#
# Copyright (c) 2021 The GPflux Contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
""" A kernel's features and coefficients using quadrature Fourier features (QFF). """

import warnings
from typing import Mapping

import tensorflow as tf

import gpflow
from gpflow.base import TensorType
from gpflow.quadrature.gauss_hermite import ndgh_points_and_weights

from gpflux.layers.basis_functions.fourier_features.base import FourierFeaturesBase
from gpflux.layers.basis_functions.fourier_features.utils import (
QFF_SUPPORTED_KERNELS,
_bases_concat,
)
from gpflux.types import ShapeType


class QuadratureFourierFeatures(FourierFeaturesBase):
def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: Mapping):
assert isinstance(kernel, QFF_SUPPORTED_KERNELS), "Unsupported Kernel"
if tf.reduce_any(tf.less(kernel.lengthscales, 1e-1)):
warnings.warn(
"Quadrature Fourier feature approximation of kernels "
"with small lengthscale lead to unexpected behaviors!"
)
super(QuadratureFourierFeatures, self).__init__(kernel, n_components, **kwargs)

def build(self, input_shape: ShapeType) -> None:
"""
Creates the variables of the layer.
See `tf.keras.layers.Layer.build()
<https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#build>`_.
"""
input_dim = input_shape[-1]
abscissa_value, omegas_value = ndgh_points_and_weights(
dim=input_dim, n_gh=self.n_components
)
omegas_value = tf.squeeze(omegas_value, axis=-1)

# Quadrature node points
self.abscissa = tf.Variable(initial_value=abscissa_value, trainable=False) # (M^D, D)
# Gauss-Hermite weights
self.factors = tf.Variable(initial_value=omegas_value, trainable=False) # (M^D,)
super(QuadratureFourierFeatures, self).build(input_shape)

def _compute_output_dim(self, input_shape: ShapeType) -> int:
input_dim = input_shape[-1]
return 2 * self.n_components ** input_dim

def _compute_constant(self) -> tf.Tensor:
return tf.tile(tf.sqrt(self.kernel.variance * self.factors), multiples=[2])

def _compute_bases(self, inputs: TensorType) -> tf.Tensor:
return _bases_concat(inputs, self.abscissa)
Loading