diff --git a/botorch/utils/probability/linalg.py b/botorch/utils/probability/linalg.py index 331b59913d..88434ef8c4 100644 --- a/botorch/utils/probability/linalg.py +++ b/botorch/utils/probability/linalg.py @@ -125,7 +125,7 @@ def update_(self, eps: float = 1e-10) -> None: rank1 = L[..., i + 1 :, i : i + 1].clone() rank1 = (rank1 * rank1.transpose(-1, -2)).tril() L[..., i + 1 :, i + 1 :] = L[..., i + 1 :, i + 1 :].clone() - rank1 - L[Lii <= i * eps, i:, i] = 0 # numerical stability clause + L[..., i:, i][Lii <= i * eps] = 0 # numerical stability clause self.step += 1 def pivot_(self, pivot: LongTensor) -> None: diff --git a/test/utils/probability/test_mvnxpb.py b/test/utils/probability/test_mvnxpb.py index c478cb74ce..c96e946fc6 100644 --- a/test/utils/probability/test_mvnxpb.py +++ b/test/utils/probability/test_mvnxpb.py @@ -11,7 +11,7 @@ from copy import deepcopy from functools import partial -from itertools import count +from itertools import count, product from typing import Any from unittest.mock import patch @@ -179,6 +179,29 @@ def _estimator(samples, bounds): self.assertAllClose(est, prob, rtol=0, atol=atol) + def test_solve_batch(self): + ndim = 3 + batch_shape = (3, 4) + with torch.random.fork_rng(): + torch.random.manual_seed(next(self.seed_generator)) + bounds = self.gen_bounds(ndim, batch_shape, bound_range=(-5.0, +5.0)) + sqrt_cov = self.gen_covariances(ndim, batch_shape, as_sqrt=True) + + cov = sqrt_cov @ sqrt_cov.mT + + batched_solver = MVNXPB(cov, bounds) + batched_solver.solve() + + # solution for each individual batch element is the same as + # that of the entire batch + for idx in product(*map(range, batch_shape)): + solver = MVNXPB(cov[tuple(idx)], bounds[tuple(idx)]) + solver.solve() + self.assertAlmostEqual( + batched_solver.log_prob[tuple(idx)].item(), + solver.log_prob.item(), + ) + def test_augment(self): r"""Test `augment`.""" with torch.random.fork_rng():