Skip to content

Commit

Permalink
Use iv.num_inducing instead of len(iv), for compatibility with future…
Browse files Browse the repository at this point in the history
… GPflow. (#66)
  • Loading branch information
jesnie authored Feb 2, 2022
1 parent 26e5640 commit 7449437
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion gpflux/layers/gp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def __init__(
)

num_inducing, self.num_latent_gps = (
len(inducing_variable),
inducing_variable.num_inducing,
num_latent_gps,
)

Expand Down
2 changes: 1 addition & 1 deletion gpflux/runtime_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ def verify_compatibility(
f"the number of separate independent inducing_variables ({latent_inducing_points})"
)

num_inducing_points = len(inducing_variable) # currently the same for each dim
num_inducing_points = inducing_variable.num_inducing # currently the same for each dim
return num_inducing_points, num_latent_gps
8 changes: 4 additions & 4 deletions tests/gpflux/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_construct_inducing_separate_independent_custom_list(z_init):
assert isinstance(moiv, SeparateIndependentInducingVariables)
assert isinstance(moiv, MultioutputInducingVariables)
for i, iv in enumerate(moiv.inducing_variable_list):
assert len(iv) == num_inducing[i]
assert iv.num_inducing == num_inducing[i]


@pytest.mark.parametrize("z_init", [True, False])
Expand All @@ -116,7 +116,7 @@ def test_construct_inducing_separate_independent_duplicates(z_init):
assert isinstance(moiv, SeparateIndependentInducingVariables)
assert isinstance(moiv, MultioutputInducingVariables)
for iv in moiv.inducing_variable_list:
assert len(iv) == num_inducing
assert iv.num_inducing == num_inducing


@pytest.mark.parametrize("z_init", [True, False])
Expand All @@ -136,7 +136,7 @@ def test_construct_inducing_shared_independent_duplicates(z_init):

assert isinstance(moiv, SharedIndependentInducingVariables)
assert isinstance(moiv, MultioutputInducingVariables)
assert len(moiv.inducing_variable) == num_inducing
assert moiv.inducing_variable.num_inducing == num_inducing


def test_construct_mean_function_Identity():
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_construct_gp_layer():
# inducing variable
assert isinstance(layer.inducing_variable, SharedIndependentInducingVariables)
assert isinstance(layer.inducing_variable.inducing_variable, InducingPoints)
assert len(layer.inducing_variable.inducing_variable) == num_inducing
assert layer.inducing_variable.inducing_variable.num_inducing == num_inducing

# mean function
assert isinstance(layer.mean_function, gpflow.mean_functions.Zero)
Expand Down

0 comments on commit 7449437

Please sign in to comment.