From 2433156c9a1404167aebb2f4ee2d18625b3c38fb Mon Sep 17 00:00:00 2001 From: Louis Tiao Date: Fri, 25 Apr 2025 11:06:37 -0700 Subject: [PATCH] Fix batch computation in Pivoted Cholesky (#2823) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2823 ## Context Resolves issue https://github.com/pytorch/botorch/issues/2819 where `PivotedCholesky.update_` break when there is more than a single batch dimension. ## Changes Updates a line to extend boolean indexing logic to cases where `len(batch_shape) > 1` Reviewed By: saitcakmak Differential Revision: D72906531 --- botorch/utils/probability/linalg.py | 2 +- test/utils/probability/test_mvnxpb.py | 25 ++++++++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/botorch/utils/probability/linalg.py b/botorch/utils/probability/linalg.py index 331b59913d..88434ef8c4 100644 --- a/botorch/utils/probability/linalg.py +++ b/botorch/utils/probability/linalg.py @@ -125,7 +125,7 @@ def update_(self, eps: float = 1e-10) -> None: rank1 = L[..., i + 1 :, i : i + 1].clone() rank1 = (rank1 * rank1.transpose(-1, -2)).tril() L[..., i + 1 :, i + 1 :] = L[..., i + 1 :, i + 1 :].clone() - rank1 - L[Lii <= i * eps, i:, i] = 0 # numerical stability clause + L[..., i:, i][Lii <= i * eps] = 0 # numerical stability clause self.step += 1 def pivot_(self, pivot: LongTensor) -> None: diff --git a/test/utils/probability/test_mvnxpb.py b/test/utils/probability/test_mvnxpb.py index c478cb74ce..c96e946fc6 100644 --- a/test/utils/probability/test_mvnxpb.py +++ b/test/utils/probability/test_mvnxpb.py @@ -11,7 +11,7 @@ from copy import deepcopy from functools import partial -from itertools import count +from itertools import count, product from typing import Any from unittest.mock import patch @@ -179,6 +179,29 @@ def _estimator(samples, bounds): self.assertAllClose(est, prob, rtol=0, atol=atol) + def test_solve_batch(self): + ndim = 3 + batch_shape = (3, 4) + with torch.random.fork_rng(): + torch.random.manual_seed(next(self.seed_generator)) + bounds = self.gen_bounds(ndim, batch_shape, bound_range=(-5.0, +5.0)) + sqrt_cov = self.gen_covariances(ndim, batch_shape, as_sqrt=True) + + cov = sqrt_cov @ sqrt_cov.mT + + batched_solver = MVNXPB(cov, bounds) + batched_solver.solve() + + # solution for each individual batch element is the same as + # that of the entire batch + for idx in product(*map(range, batch_shape)): + solver = MVNXPB(cov[tuple(idx)], bounds[tuple(idx)]) + solver.solve() + self.assertAlmostEqual( + batched_solver.log_prob[tuple(idx)].item(), + solver.log_prob.item(), + ) + def test_augment(self): r"""Test `augment`.""" with torch.random.fork_rng():