From 2934ebfb3440c5a6f40d53bea73b8b613d50a8ad Mon Sep 17 00:00:00 2001 From: Sebastian Ober Date: Mon, 2 May 2022 11:13:54 +0100 Subject: [PATCH] Different samples for each output --- gpflux/sampling/sample.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/gpflux/sampling/sample.py b/gpflux/sampling/sample.py index 26eaa333..5b7b8c36 100644 --- a/gpflux/sampling/sample.py +++ b/gpflux/sampling/sample.py @@ -156,12 +156,12 @@ def _efficient_sample_matheron_rule( :param whiten: Determines the parameterisation of the inducing variables. """ L = tf.shape(kernel.feature_coefficients)[0] # num eigenfunctions # noqa: F841 + M, P = tf.shape(q_mu)[0], tf.shape(q_mu)[1] # num inducing, num output heads prior_weights = tf.sqrt(kernel.feature_coefficients) * tf.random.normal( - tf.shape(kernel.feature_coefficients), dtype=default_float() - ) # [L, 1] + (L, P), dtype=default_float() + ) # [L, P] - M, P = tf.shape(q_mu)[0], tf.shape(q_mu)[1] # num inducing, num output heads u_sample_noise = tf.matmul( q_sqrt, tf.random.normal((P, M, 1), dtype=default_float()), # [P, M, M] # [P, M, 1] @@ -175,8 +175,8 @@ def _efficient_sample_matheron_rule( u_sample = tf.matmul(Luu, u_sample) # [M, P] phi_Z = kernel.feature_functions(inducing_variable.Z) # [M, L] - weight_space_prior_Z = phi_Z @ prior_weights # [M, 1] - diff = u_sample - weight_space_prior_Z # [M, P] -- using implicit broadcasting + weight_space_prior_Z = phi_Z @ prior_weights # [M, P] + diff = u_sample - weight_space_prior_Z # [M, P] v = compute_A_inv_b(Kmm, diff) # [M, P] tf.debugging.assert_equal(tf.shape(v), [M, P]) @@ -188,11 +188,11 @@ def __call__(self, X: TensorType) -> tf.Tensor: """ N = tf.shape(X)[0] phi_X = kernel.feature_functions(X) # [N, L] - weight_space_prior_X = phi_X @ prior_weights # [N, 1] + weight_space_prior_X = phi_X @ prior_weights # [N, P] Knm = tf.linalg.matrix_transpose(Kuf(inducing_variable, kernel, X)) # [N, M] function_space_update_X = Knm @ v # [N, P] - tf.debugging.assert_equal(tf.shape(weight_space_prior_X), [N, 1]) + tf.debugging.assert_equal(tf.shape(weight_space_prior_X), [N, P]) tf.debugging.assert_equal(tf.shape(function_space_update_X), [N, P]) return weight_space_prior_X + function_space_update_X # [N, P]