Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempting to learn models with multidimensional inputs leads to an error. #27

Closed
clwgg opened this issue Apr 26, 2021 · 6 comments
Closed
Labels
bug Something isn't working

Comments

@clwgg
Copy link

clwgg commented Apr 26, 2021

Thanks a lot for making this exciting project public! I'm not 100% sure if what I'm reporting is a bug of if this isn't supposed to work in GPflux, but here we go:

Describe the bug
Attempting to learn models with multidimensional inputs leads to an error.

To reproduce
First of all, the setup of a toy example and a GPflow SVGP-based version which works as expected:

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import gpflow
import gpflux
from gpflow.utilities import print_summary, set_trainable

tf.keras.backend.set_floatx("float64")
tf.get_logger().setLevel("INFO")

grid = np.meshgrid(np.linspace(0, np.pi*2, 20),
                   np.linspace(0, np.pi*2, 20))
X = np.column_stack(tuple(map(np.ravel, grid)))
Y = (np.sin(X[:, 0]) * np.sin(X[:, 1]))[:, None]

plt.contourf(grid[0], grid[1], Y.reshape(grid[0].shape))
plt.title("DATA")
plt.show()

num_data = len(X)
num_inducing = 10
output_dim = Y.shape[1]

kernel = (gpflow.kernels.SquaredExponential(active_dims=[0]) *
          gpflow.kernels.SquaredExponential(active_dims=[1]))
inducing_variable = gpflow.inducing_variables.InducingPoints(
    X[np.random.choice(X.shape[0], size=num_inducing, replace=False),:].copy()
)

#---------- SVGP
svgp = gpflow.models.SVGP(kernel, gpflow.likelihoods.Gaussian(), inducing_variable,
                          num_latent_gps=output_dim, num_data=num_data)
set_trainable(svgp.q_mu, False)
set_trainable(svgp.q_sqrt, False)
variational_params = [(svgp.q_mu, svgp.q_sqrt)]
natgrad_opt = gpflow.optimizers.NaturalGradient(gamma=0.1)
adam_opt = tf.optimizers.Adam(0.01)
minibatch_size = 10
train_dataset = tf.data.Dataset.from_tensor_slices(
    (X, Y)).repeat().shuffle(num_data)
iter_train = iter(train_dataset.batch(minibatch_size))
objective = svgp.training_loss_closure(iter_train, compile=True)

@tf.function
def optim_step():
    natgrad_opt.minimize(objective, var_list=variational_params)
    adam_opt.minimize(objective, svgp.trainable_variables)

for i in range(100):
    optim_step()
elbo = -objective().numpy()
print(f"it: {i} of dual-optimizer... elbo: {elbo}")


atgrid = np.meshgrid(np.linspace(0, np.pi*2, 40),
                     np.linspace(0, np.pi*2, 40))
atX = np.column_stack(tuple(map(np.ravel, atgrid)))

mean, var = svgp.predict_f(atX)
plt.contourf(atgrid[0], atgrid[1], mean.numpy().reshape(atgrid[0].shape))
plt.title("SVGP")
plt.show()

And here a single-layer DGP with GPflux:

#---------- DEEPGP
gp_layer = gpflux.layers.GPLayer(
    kernel, inducing_variable, num_data=num_data, num_latent_gps=output_dim
)

likelihood_layer = gpflux.layers.LikelihoodLayer(gpflow.likelihoods.Gaussian(0.1))

single_layer_dgp = gpflux.models.DeepGP([gp_layer], likelihood_layer)
model = single_layer_dgp.as_training_model()
model.compile(tf.optimizers.Adam(0.01))

log = model.fit({"inputs": X, "targets": Y}, epochs=int(100), verbose=1)

which throws the following error when reaching the last line of the example:

ValueError: in user code:

    venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:805 train_function  *
        return step_function(self, iterator)
    venv/lib/python3.7/site-packages/gpflux/layers/gp_layer.py:277 call  *
        outputs = super().call(inputs, *args, **kwargs)
    venv/lib/python3.7/site-packages/tensorflow_probability/python/layers/distribution_layer.py:252 call  **
        inputs, *args, **kwargs)
    venv/lib/python3.7/site-packages/tensorflow/python/keras/layers/core.py:917 call
        result = self.function(inputs, **kwargs)
    venv/lib/python3.7/site-packages/tensorflow_probability/python/layers/distribution_layer.py:172 _fn
        d = make_distribution_fn(*fargs, **fkwargs)
    venv/lib/python3.7/site-packages/gpflux/layers/gp_layer.py:328 _make_distribution_fn
        return tfp.distributions.MultivariateNormalDiag(loc=mean, scale_diag=tf.sqrt(cov))
    <decorator-gen-394>:2 __init__
        
    venv/lib/python3.7/site-packages/tensorflow_probability/python/distributions/distribution.py:298 wrapped_init
        default_init(self_, *args, **kwargs)
    venv/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py:538 new_func
        return func(*args, **kwargs)
    venv/lib/python3.7/site-packages/tensorflow_probability/python/distributions/mvn_diag.py:252 __init__
        name=name)
    <decorator-gen-322>:2 __init__
        
    venv/lib/python3.7/site-packages/tensorflow_probability/python/distributions/distribution.py:298 wrapped_init
        default_init(self_, *args, **kwargs)
    venv/lib/python3.7/site-packages/tensorflow_probability/python/distributions/mvn_linear_operator.py:190 __init__
        loc, scale)
    venv/lib/python3.7/site-packages/tensorflow_probability/python/internal/distribution_util.py:136 shapes_from_loc_and_scale
        'of `loc` ({}).'.format(event_size_, loc_event_size_))

    ValueError: Event size of `scale` (1) could not be broadcast up to that of `loc` (2).

Expected behaviour
I expected this to not throw an error, and produce a (at least qualitatively) similar result to the SVGP implementation, but again, I'm not sure if this expectation is justified.

System information

  • OS: Linux, kernel 5.4.112-1
  • Python version: 3.7.5
  • GPflux version: 0.1.0 from pip
  • TensorFlow version: 2.4.1
  • GPflow version: 2.1.5
@clwgg clwgg added the bug Something isn't working label Apr 26, 2021
@st--
Copy link
Member

st-- commented Apr 26, 2021

That's a bit surprising. GPflux should work just fine with multidimensional inputs. GPflux is designed around multi-output kernels. There is a check built into GPLayer but I just noticed that by default is disabled. Could you add an additional verbose=True keyword argument to the GPLayer constructor, and see if that then prints a warning "Could not verify the compatibility ..."? Have a look at the gpflux.helpers.construct_* functions, you should be able to make use of those to construct a set of compatible kernel/inducing variables objects.

@clwgg
Copy link
Author

clwgg commented Apr 26, 2021

It does indeed print the warning when run with verbose! I'll look into the helpers and get back to you.

(also, this is totally tangential but the Slack invite link in the GPflux README is no longer active -- just wanted to let you know in case this isn't intentional).

@st--
Copy link
Member

st-- commented Apr 26, 2021

(also, this is totally tangential but the Slack invite link in the GPflux README is no longer active -- just wanted to let you know in case this isn't intentional).

yeah Slack now expires all invite links after 30 days so we need to regularly update it. You can find the link valid until end of May in this PR: #28

@clwgg
Copy link
Author

clwgg commented Apr 26, 2021

Ok, I have now wrapped the generation of kernel and inducing points in their respective construct helpers like so:

input_dim = X.shape[1]
kernel = gpflux.helpers.construct_basic_kernel(
    (gpflow.kernels.SquaredExponential(active_dims=[0]) *
     gpflow.kernels.SquaredExponential(active_dims=[1])),
    output_dim=output_dim)
inducing_variable = gpflux.helpers.construct_basic_inducing_variables(
    num_inducing, input_dim, output_dim, share_variables=True,
    z_init=X[np.random.choice(X.shape[0],
                              size=num_inducing,
                              replace=False),:].copy())

Now the warning during GPLayer (with verbose) disappears, but the error I report above persists. The same error also occurs if I don't use a kernel product with active_dims like so:

kernel = gpflux.helpers.construct_basic_kernel(
    gpflow.kernels.SquaredExponential(),
    output_dim=output_dim)

@vdutor
Copy link
Member

vdutor commented Apr 27, 2021

Hi @clwgg, thank you very much for raising this issue. I have to agree that the error message returned by GPflux is not clear at all. The problem is, however, that GPflux uses an Identity mean function by default. Given your 2D input, the Identity mean function will turn the GPLayer into a 2D output model, which clashes with your 1D targets. The problem can be solved by simply setting a Zero mean function:

gp_layer = gpflux.layers.GPLayer(
    kernel, inducing_variable, num_data=num_data, num_latent_gps=output_dim,
    mean_function=gpflow.mean_functions.Zero()
)

I'll open a PR to make the error messaging more informative.

@vdutor vdutor closed this as completed Apr 27, 2021
@clwgg
Copy link
Author

clwgg commented Apr 27, 2021

Wonderful, thank you @vdutor !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants