Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure basis function modules #63

Merged
merged 4 commits into from
Dec 15, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#
# Copyright (c) 2021 The GPflux Contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
""" A kernel's features and coefficients using quadrature Fourier features (QFF). """
from .gaussian import QuadratureFourierFeatures
ltiao marked this conversation as resolved.
Show resolved Hide resolved

__all__ = ["QuadratureFourierFeatures"]
ltiao marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
""" A kernel's features and coefficients using quadrature Fourier features (QFF). """
"""
Kernel decompositon into features and coefficients based on Gauss-Christoffel
quadrature aka Gaussian quadrature.
"""

import warnings
from typing import Mapping
from typing import Mapping, Tuple, Type

import tensorflow as tf

Expand All @@ -25,12 +28,20 @@
from gpflow.quadrature.gauss_hermite import ndgh_points_and_weights

from gpflux.layers.basis_functions.fourier_features.base import FourierFeaturesBase
from gpflux.layers.basis_functions.fourier_features.utils import (
QFF_SUPPORTED_KERNELS,
_bases_concat,
)
from gpflux.layers.basis_functions.fourier_features.utils import _bases_concat
from gpflux.types import ShapeType

"""
Kernels supported by :class:`QuadratureFourierFeatures`.

Currently we only support the :class:`gpflow.kernels.SquaredExponential` kernel.
For Matern kernels please use :class:`RandomFourierFeatures`
or :class:`RandomFourierFeaturesCosine`.
"""
QFF_SUPPORTED_KERNELS: Tuple[Type[gpflow.kernels.Stationary], ...] = (
gpflow.kernels.SquaredExponential,
)


class QuadratureFourierFeatures(FourierFeaturesBase):
def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: Mapping):
Expand Down
25 changes: 25 additions & 0 deletions gpflux/layers/basis_functions/fourier_features/random/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#
# Copyright (c) 2021 The GPflux Contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
""" A kernel's features and coefficients using Random Fourier Features (RFF). """

from .base import RandomFourierFeatures, RandomFourierFeaturesCosine
from .orthogonal import OrthogonalRandomFeatures
ltiao marked this conversation as resolved.
Show resolved Hide resolved

__all__ = [
"OrthogonalRandomFeatures",
"RandomFourierFeatures",
"RandomFourierFeaturesCosine",
]
ltiao marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
""" A kernel's features and coefficients using Random Fourier Features (RFF). """

from typing import Mapping, Optional
from typing import Mapping, Optional, Tuple, Type

import numpy as np
import tensorflow as tf
Expand All @@ -25,17 +23,54 @@

from gpflux.layers.basis_functions.fourier_features.base import FourierFeaturesBase
from gpflux.layers.basis_functions.fourier_features.utils import (
ORF_SUPPORTED_KERNELS,
RFF_SUPPORTED_KERNELS,
_bases_concat,
_bases_cosine,
_ceil_divide,
_matern_number,
_sample_chi,
_sample_students_t,
)
from gpflux.types import ShapeType

"""
Kernels supported by :class:`RandomFourierFeatures`.

You can build RFF for shift-invariant stationary kernels from which you can
sample frequencies from their power spectrum, following Bochner's theorem.
"""
RFF_SUPPORTED_KERNELS: Tuple[Type[gpflow.kernels.Stationary], ...] = (
gpflow.kernels.SquaredExponential,
gpflow.kernels.Matern12,
gpflow.kernels.Matern32,
gpflow.kernels.Matern52,
)


def _sample_students_t(nu: float, shape: ShapeType, dtype: DType) -> TensorType:
"""
Draw samples from a (central) Student's t-distribution using the following:
BETA ~ Gamma(nu/2, nu/2) (shape-rate parameterization)
X ~ Normal(0, 1/BETA)
then:
X ~ StudentsT(nu)

Note this is equivalent to the more commonly used parameterization
Z ~ Chi2(nu) = Gamma(nu/2, 1/2)
EPSILON ~ Normal(0, 1)
X = EPSILON * sqrt(nu/Z)

To see this, note
Z/nu ~ Gamma(nu/2, nu/2)
and
X ~ Normal(0, nu/Z)
The equivalence becomes obvious when we set BETA = Z/nu
ltiao marked this conversation as resolved.
Show resolved Hide resolved
"""
# Normal(0, 1)
normal_rvs = tf.random.normal(shape=shape, dtype=dtype)
shape = tf.concat([shape[:-1], [1]], axis=0)
# Gamma(nu/2, nu/2)
gamma_rvs = tf.random.gamma(shape, alpha=0.5 * nu, beta=0.5 * nu, dtype=dtype)
# StudentsT(nu)
students_t_rvs = tf.math.rsqrt(gamma_rvs) * normal_rvs
return students_t_rvs


class RandomFourierFeaturesBase(FourierFeaturesBase):
def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: Mapping):
Expand Down Expand Up @@ -202,27 +237,3 @@ def _compute_constant(self) -> tf.Tensor:
:return: A tensor with the shape ``[]`` (i.e. a scalar).
"""
return self.rff_constant(self.kernel.variance, output_dim=self.n_components)


class OrthogonalRandomFeatures(RandomFourierFeatures):
r"""
Orthogonal random Fourier features (ORF) :cite:p:`yu2016orthogonal` for more
efficient and accurate kernel approximations than :class:`RandomFourierFeatures`.
"""

def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: Mapping):
assert isinstance(kernel, ORF_SUPPORTED_KERNELS), "Unsupported Kernel"
super(OrthogonalRandomFeatures, self).__init__(kernel, n_components, **kwargs)

def _weights_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType:
n_components, input_dim = shape # M, D
n_reps = _ceil_divide(n_components, input_dim) # K, smallest integer s.t. K*D >= M

W = tf.random.normal(shape=(n_reps, input_dim, input_dim), dtype=dtype)
Q, _ = tf.linalg.qr(W) # throw away R; shape [K, D, D]

s = _sample_chi(nu=input_dim, shape=(n_reps, input_dim), dtype=dtype) # shape [K, D]
U = tf.expand_dims(s, axis=-1) * Q # equiv: S @ Q where S = diag(s); shape [K, D, D]
V = tf.reshape(U, shape=(-1, input_dim)) # shape [K*D, D]

return V[: self.n_components] # shape [M, D] (throw away K*D - M rows)
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#
# Copyright (c) 2021 The GPflux Contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from typing import Mapping, Optional, Tuple, Type

import numpy as np
import tensorflow as tf

import gpflow
from gpflow.base import DType, TensorType

from gpflux.layers.basis_functions.fourier_features.random.base import RandomFourierFeatures
from gpflux.types import ShapeType

"""
Kernels supported by :class:`OrthogonalRandomFeatures`.

This random matrix sampling scheme only applies to the :class:`gpflow.kernels.SquaredExponential`
kernel.
For Matern kernels please use :class:`RandomFourierFeatures`
or :class:`RandomFourierFeaturesCosine`.
"""
ORF_SUPPORTED_KERNELS: Tuple[Type[gpflow.kernels.Stationary], ...] = (
gpflow.kernels.SquaredExponential,
)


def _sample_chi_squared(nu: float, shape: ShapeType, dtype: DType) -> TensorType:
"""
Draw samples from Chi-squared distribution with `nu` degrees of freedom.

See https://mathworld.wolfram.com/Chi-SquaredDistribution.html for further
details regarding relationship to Gamma distribution.
"""
return tf.random.gamma(shape=shape, alpha=0.5 * nu, beta=0.5, dtype=dtype)


def _sample_chi(nu: float, shape: ShapeType, dtype: DType) -> TensorType:
"""
Draw samples from Chi-distribution with `nu` degrees of freedom.
"""
s = _sample_chi_squared(nu, shape, dtype)
return tf.sqrt(s)


def _ceil_divide(a: float, b: float) -> int:
"""
Ceiling division. Returns the smallest integer `m` s.t. `m*b >= a`.
"""
return -np.floor_divide(-a, b)


class OrthogonalRandomFeatures(RandomFourierFeatures):
r"""
Orthogonal random Fourier features (ORF) :cite:p:`yu2016orthogonal` for more
efficient and accurate kernel approximations than :class:`RandomFourierFeatures`.
"""

def __init__(self, kernel: gpflow.kernels.Kernel, n_components: int, **kwargs: Mapping):
assert isinstance(kernel, ORF_SUPPORTED_KERNELS), "Unsupported Kernel"
super(OrthogonalRandomFeatures, self).__init__(kernel, n_components, **kwargs)

def _weights_init(self, shape: TensorType, dtype: Optional[DType] = None) -> TensorType:
n_components, input_dim = shape # M, D
n_reps = _ceil_divide(n_components, input_dim) # K, smallest integer s.t. K*D >= M

W = tf.random.normal(shape=(n_reps, input_dim, input_dim), dtype=dtype)
Q, _ = tf.linalg.qr(W) # throw away R; shape [K, D, D]

s = _sample_chi(nu=input_dim, shape=(n_reps, input_dim), dtype=dtype) # shape [K, D]
U = tf.expand_dims(s, axis=-1) * Q # equiv: S @ Q where S = diag(s); shape [K, D, D]
V = tf.reshape(U, shape=(-1, input_dim)) # shape [K*D, D]

return V[: self.n_components] # shape [M, D] (throw away K*D - M rows)
97 changes: 1 addition & 96 deletions gpflux/layers/basis_functions/fourier_features/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,51 +16,10 @@
"""
This module provides a set of common utilities for kernel feature decompositions.
"""
from typing import Tuple, Type

import numpy as np
import tensorflow as tf

import gpflow
from gpflow.base import DType, TensorType

from gpflux.types import ShapeType

"""
Kernels supported by :class:`QuadratureFourierFeatures`.

Currently we only support the :class:`gpflow.kernels.SquaredExponential` kernel.
For Matern kernels please use :class:`RandomFourierFeatures`
or :class:`RandomFourierFeaturesCosine`.
"""
QFF_SUPPORTED_KERNELS: Tuple[Type[gpflow.kernels.Stationary], ...] = (
gpflow.kernels.SquaredExponential,
)

"""
Kernels supported by :class:`OrthogonalRandomFeatures`.

This random matrix sampling scheme only applies to the :class:`gpflow.kernels.SquaredExponential`
kernel.
For Matern kernels please use :class:`RandomFourierFeatures`
or :class:`RandomFourierFeaturesCosine`.
"""
ORF_SUPPORTED_KERNELS: Tuple[Type[gpflow.kernels.Stationary], ...] = (
gpflow.kernels.SquaredExponential,
)

"""
Kernels supported by :class:`RandomFourierFeatures`.

You can build RFF for shift-invariant stationary kernels from which you can
sample frequencies from their power spectrum, following Bochner's theorem.
"""
RFF_SUPPORTED_KERNELS: Tuple[Type[gpflow.kernels.Stationary], ...] = (
gpflow.kernels.SquaredExponential,
gpflow.kernels.Matern12,
gpflow.kernels.Matern32,
gpflow.kernels.Matern52,
)
from gpflow.base import TensorType


def _matern_number(kernel: gpflow.kernels.Kernel) -> int:
Expand All @@ -75,53 +34,6 @@ def _matern_number(kernel: gpflow.kernels.Kernel) -> int:
return p


def _sample_chi_squared(nu: float, shape: ShapeType, dtype: DType) -> TensorType:
"""
Draw samples from Chi-squared distribution with `nu` degrees of freedom.

See https://mathworld.wolfram.com/Chi-SquaredDistribution.html for further
details regarding relationship to Gamma distribution.
"""
return tf.random.gamma(shape=shape, alpha=0.5 * nu, beta=0.5, dtype=dtype)


def _sample_chi(nu: float, shape: ShapeType, dtype: DType) -> TensorType:
"""
Draw samples from Chi-distribution with `nu` degrees of freedom.
"""
s = _sample_chi_squared(nu, shape, dtype)
return tf.sqrt(s)


def _sample_students_t(nu: float, shape: ShapeType, dtype: DType) -> TensorType:
"""
Draw samples from a (central) Student's t-distribution using the following:
BETA ~ Gamma(nu/2, nu/2) (shape-rate parameterization)
X ~ Normal(0, 1/BETA)
then:
X ~ StudentsT(nu)

Note this is equivalent to the more commonly used parameterization
Z ~ Chi2(nu) = Gamma(nu/2, 1/2)
EPSILON ~ Normal(0, 1)
X = EPSILON * sqrt(nu/Z)

To see this, note
Z/nu ~ Gamma(nu/2, nu/2)
and
X ~ Normal(0, nu/Z)
The equivalence becomes obvious when we set BETA = Z/nu
"""
# Normal(0, 1)
normal_rvs = tf.random.normal(shape=shape, dtype=dtype)
shape = tf.concat([shape[:-1], [1]], axis=0)
# Gamma(nu/2, nu/2)
gamma_rvs = tf.random.gamma(shape, alpha=0.5 * nu, beta=0.5 * nu, dtype=dtype)
# StudentsT(nu)
students_t_rvs = tf.math.rsqrt(gamma_rvs) * normal_rvs
return students_t_rvs


def _bases_cosine(X: TensorType, W: TensorType, b: TensorType) -> TensorType:
"""
Feature map for random Fourier features (RFF) as originally prescribed
Expand All @@ -140,10 +52,3 @@ def _bases_concat(X: TensorType, W: TensorType) -> TensorType:
"""
proj = tf.matmul(X, W, transpose_b=True) # [N, M]
return tf.concat([tf.sin(proj), tf.cos(proj)], axis=-1) # [N, 2M]


def _ceil_divide(a: float, b: float) -> int:
"""
Ceiling division. Returns the smallest integer `m` s.t. `m*b >= a`.
"""
return -np.floor_divide(-a, b)
Loading