Skip to content

Commit

Permalink
Merge pull request #65 from AlexImmer/low-rank
Browse files Browse the repository at this point in the history
Low rank LA
  • Loading branch information
aleximmer authored Dec 10, 2021
2 parents b926f21 + 5ee4009 commit 0915472
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 17 deletions.
4 changes: 2 additions & 2 deletions laplace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
REGRESSION = 'regression'
CLASSIFICATION = 'classification'

from laplace.baselaplace import BaseLaplace, ParametricLaplace, FullLaplace, KronLaplace, DiagLaplace
from laplace.baselaplace import BaseLaplace, ParametricLaplace, FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace
from laplace.lllaplace import LLLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace
from laplace.laplace import Laplace
from laplace.marglik_training import marglik_training

__all__ = ['Laplace', # direct access to all Laplace classes via unified interface
'BaseLaplace', 'ParametricLaplace', # base-class and its (first-level) subclasses
'FullLaplace', 'KronLaplace', 'DiagLaplace', # all-weights
'FullLaplace', 'KronLaplace', 'DiagLaplace', 'LowRankLaplace', # all-weights
'LLLaplace', # base-class last-layer
'FullLLLaplace', 'KronLLLaplace', 'DiagLLLaplace', # last-layer
'marglik_training'] # methods
102 changes: 100 additions & 2 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from math import sqrt, pi, log
from laplace.curvature.asdl import AsdlHessian
import numpy as np
import torch
from torch.nn.utils import parameters_to_vector, vector_to_parameters
Expand Down Expand Up @@ -327,8 +328,6 @@ class ParametricLaplace(BaseLaplace):

def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
prior_mean=0., temperature=1., backend=BackPackGGN, backend_kwargs=None):
assert backend in [BackPackGGN, BackPackEF, AsdlGGN, AsdlEF], \
'GGN or EF backends required in ParametricLaplace.'
super().__init__(model, likelihood, sigma_noise, prior_precision,
prior_mean, temperature, backend, backend_kwargs)
try:
Expand Down Expand Up @@ -842,6 +841,105 @@ def prior_precision(self, prior_precision):
raise ValueError('Prior precision for Kron either scalar or per-layer.')


class LowRankLaplace(ParametricLaplace):
"""Laplace approximation with low-rank log likelihood Hessian (approximation).
The low-rank matrix is represented by an eigendecomposition (vecs, values).
Based on the chosen `backend`, either a true Hessian or, for example, GGN
approximation could be used.
The posterior precision is computed as
\\( P = V diag(l) V^T + P_0.\\)
To sample, compute the functional variance, and log determinant, algebraic tricks
are usedto reduce the costs of inversion to the that of a \\(K \times K\\) matrix
if we have a rank of K.
See `BaseLaplace` for the full interface.
"""
_key = ('all', 'lowrank')
def __init__(self, model, likelihood, sigma_noise=1, prior_precision=1, prior_mean=0,
temperature=1, backend=AsdlHessian, backend_kwargs=None):
super().__init__(model, likelihood, sigma_noise=sigma_noise,
prior_precision=prior_precision, prior_mean=prior_mean,
temperature=temperature, backend=backend, backend_kwargs=backend_kwargs)

def _init_H(self):
pass

@property
def V(self):
(U, l), prior_prec_diag = self.posterior_precision
return U / prior_prec_diag.reshape(-1, 1)

@property
def Kinv(self):
(U, l), _ = self.posterior_precision
return torch.inverse(torch.diag(1 / l) + U.T @ self.V)

def fit(self, train_loader, override=True):
# override fit since output of eighessian not additive across batch
if not override:
# LowRankLA cannot be updated since eigenvalue representation not additive
raise ValueError('LowRank LA does not support updating.')

self.model.eval()
self.mean = parameters_to_vector(self.model.parameters()).detach()

X, _ = next(iter(train_loader))
with torch.no_grad():
try:
out = self.model(X[:1].to(self._device))
except (TypeError, AttributeError):
out = self.model(X.to(self._device))
self.n_outputs = out.shape[-1]
setattr(self.model, 'output_size', self.n_outputs)

eigenvectors, eigenvalues, loss = self.backend.eig_lowrank(train_loader)
self.H = (eigenvectors, eigenvalues)
self.loss = loss

self.n_data = len(train_loader.dataset)

@property
def posterior_precision(self):
"""Return correctly scaled posterior precision that would be constructed
as H[0] @ diag(H[1]) @ H[0].T + self.prior_precision_diag.
Returns
-------
H : tuple(eigenvectors, eigenvalues)
scaled self.H with temperature and loss factors.
prior_precision_diag : torch.Tensor
diagonal prior precision shape `parameters` to be added to H.
"""
return (self.H[0], self._H_factor * self.H[1]), self.prior_precision_diag

def functional_variance(self, Jacs):
prior_var = torch.einsum('ncp,nkp->nck', Jacs / self.prior_precision_diag, Jacs)
Jacs_V = torch.einsum('ncp,pl->ncl', Jacs, self.V)
info_gain = torch.einsum('ncl,nkl->nck', Jacs_V @ self.Kinv, Jacs_V)
return prior_var - info_gain

def sample(self, n_samples):
samples = torch.randn(self.n_params, n_samples)
d = self.prior_precision_diag
Vs = self.V * d.sqrt().reshape(-1, 1)
VtV = Vs.T @ Vs
Ik = torch.eye(len(VtV))
A = torch.linalg.cholesky(VtV)
B = torch.linalg.cholesky(VtV + Ik)
A_inv = torch.inverse(A)
C = torch.inverse(A_inv.T @ (B - Ik) @ A_inv)
Kern_inv = torch.inverse(torch.inverse(C) + Vs.T @ Vs)
dinv_sqrt = (d).sqrt().reshape(-1, 1)
prior_sample = dinv_sqrt * samples
gain_sample = dinv_sqrt * Vs @ Kern_inv @ (Vs.T @ samples)
return self.mean + (prior_sample - gain_sample).T

@property
def log_det_posterior_precision(self):
(U, l), prior_prec_diag = self.posterior_precision
return l.log().sum() + prior_prec_diag.log().sum() - torch.logdet(self.Kinv)


class DiagLaplace(ParametricLaplace):
"""Laplace approximation with diagonal log likelihood Hessian approximation
and hence posterior precision.
Expand Down
44 changes: 39 additions & 5 deletions laplace/curvature/asdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,21 @@
import torch

from asdfghjkl import FISHER_EXACT, FISHER_MC, COV
from asdfghjkl import SHAPE_KRON, SHAPE_DIAG
from asdfghjkl import SHAPE_KRON, SHAPE_DIAG, SHAPE_FULL
from asdfghjkl import fisher_for_cross_entropy
from asdfghjkl.hessian import hessian_eigenvalues, hessian_for_loss
from asdfghjkl.gradient import batch_gradient

from laplace.curvature import CurvatureInterface, GGNInterface, EFInterface
from laplace.matrix import Kron
from laplace.utils import _is_batchnorm

EPS = 1e-6


class AsdlInterface(CurvatureInterface):
"""Interface for asdfghjkl backend.
"""
def __init__(self, model, likelihood, last_layer=False):
if likelihood != 'classification':
raise ValueError('This backend only supports classification currently.')
super().__init__(model, likelihood, last_layer)

@staticmethod
def jacobians(model, x):
Expand Down Expand Up @@ -131,10 +130,41 @@ def kron(self, X, y, N, **wkwargs) -> [torch.Tensor, Kron]:
return self.factor * loss, self.factor * kron


class AsdlHessian(AsdlInterface):

def __init__(self, model, likelihood, last_layer=False, low_rank=10):
super().__init__(model, likelihood, last_layer)
self.low_rank = low_rank

@property
def _ggn_type(self):
raise NotImplementedError()

def full(self, x, y, **kwargs):
hessian_for_loss(self.model, self.lossfunc, SHAPE_FULL, x, y)
H = self._model.hessian.data
loss = self.lossfunc(self.model(x), y).detach()
return self.factor * loss, self.factor * H

def eig_lowrank(self, data_loader):
# compute truncated eigendecomposition of the Hessian, only keep eigvals > EPS
eigvals, eigvecs = hessian_eigenvalues(self.model, self.lossfunc, data_loader,
top_n=self.low_rank, max_iters=self.low_rank*10)
eigvals = torch.from_numpy(np.array(eigvals))
mask = (eigvals > EPS)
eigvecs = torch.stack([torch.cat([p.flatten() for p in params])
for params in eigvecs], dim=1)[:, mask]
eigvals = eigvals[mask].to(eigvecs.dtype).to(eigvecs.device)
loss = sum([self.lossfunc(self.model(x).detach(), y) for x, y in data_loader])
return eigvecs, self.factor * eigvals, self.factor * loss


class AsdlGGN(AsdlInterface, GGNInterface):
"""Implementation of the `GGNInterface` using asdfghjkl.
"""
def __init__(self, model, likelihood, last_layer=False, stochastic=False):
if likelihood != 'classification':
raise ValueError('This backend only supports classification currently.')
super().__init__(model, likelihood, last_layer)
self.stochastic = stochastic

Expand All @@ -146,6 +176,10 @@ def _ggn_type(self):
class AsdlEF(AsdlInterface, EFInterface):
"""Implementation of the `EFInterface` using asdfghjkl.
"""
def __init__(self, model, likelihood, last_layer=False):
if likelihood != 'classification':
raise ValueError('This backend only supports classification currently.')
super().__init__(model, likelihood, last_layer)

@property
def _ggn_type(self):
Expand Down
2 changes: 1 addition & 1 deletion laplace/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def Laplace(model, likelihood, subset_of_weights='last_layer', hessian_structure
likelihood : {'classification', 'regression'}
subset_of_weights : {'last_layer', 'all'}, default='last_layer'
subset of weights to consider for inference
hessian_structure : {'diag', 'kron', 'full'}, default='kron'
hessian_structure : {'diag', 'kron', 'full', 'lowrank'}, default='kron'
structure of the Hessian approximation
Returns
Expand Down
19 changes: 12 additions & 7 deletions tests/test_baselaplace.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from math import sqrt
from _pytest.mark import param
from laplace.matrix import KronDecomposed
import pytest
from itertools import product
import numpy as np
Expand All @@ -12,13 +10,15 @@
from torch.utils.data import DataLoader, TensorDataset
from torch.distributions import Normal, Categorical

from laplace.laplace import Laplace, FullLaplace, KronLaplace, DiagLaplace
from laplace.laplace import FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace
from laplace.matrix import KronDecomposed
from tests.utils import jacobians_naive


torch.manual_seed(240)
torch.set_default_tensor_type(torch.DoubleTensor)
flavors = [FullLaplace, KronLaplace, DiagLaplace]
flavors = [FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace]
online_flavors = [FullLaplace, KronLaplace, DiagLaplace]


@pytest.fixture
Expand Down Expand Up @@ -221,6 +221,9 @@ def test_laplace_functionality(laplace, lh, model, reg_loader, class_loader):
lml = lml - 1/2 * theta @ prior_prec @ theta
if laplace == DiagLaplace:
log_det_post_prec = lap.posterior_precision.log().sum()
elif laplace == LowRankLaplace:
(U, l), p0 = lap.posterior_precision
log_det_post_prec = (U @ torch.diag(l) @ U.T + p0.diag()).logdet()
else:
log_det_post_prec = lap.posterior_precision.logdet()
lml = lml + 1/2 * (prior_prec.logdet() - log_det_post_prec)
Expand All @@ -241,6 +244,9 @@ def test_laplace_functionality(laplace, lh, model, reg_loader, class_loader):
Sigma = lap.posterior_covariance
elif laplace == KronLaplace:
Sigma = lap.posterior_precision.to_matrix(exponent=-1)
elif laplace == LowRankLaplace:
(U, l), p0 = lap.posterior_precision
Sigma = (U @ torch.diag(l) @ U.T + p0.diag()).inverse()
elif laplace == DiagLaplace:
Sigma = torch.diag(lap.posterior_variance)
Js, f = jacobians_naive(model, loader.dataset.tensors[0])
Expand All @@ -249,7 +255,7 @@ def test_laplace_functionality(laplace, lh, model, reg_loader, class_loader):
assert torch.allclose(true_f_var, comp_f_var, rtol=1e-4)


@pytest.mark.parametrize('laplace', flavors)
@pytest.mark.parametrize('laplace', online_flavors)
def test_overriding_fit(laplace, model, reg_loader):
lap = laplace(model, 'regression', sigma_noise=0.3, prior_precision=0.7)
lap.fit(reg_loader)
Expand All @@ -270,7 +276,7 @@ def test_overriding_fit(laplace, model, reg_loader):
assert lap.n_data == len(reg_loader.dataset)


@pytest.mark.parametrize('laplace', flavors)
@pytest.mark.parametrize('laplace', online_flavors)
def test_online_fit(laplace, model, reg_loader):
lap = laplace(model, 'regression', sigma_noise=0.3, prior_precision=0.7)
lap.fit(reg_loader)
Expand Down Expand Up @@ -308,7 +314,6 @@ def test_log_prob_kron(model, class_loader):
posterior = Normal(loc=lap.mean, scale=sqrt(1/0.24))
assert torch.allclose(lap.log_prob(theta), posterior.log_prob(theta).sum())
lap.fit(class_loader)
print(type(lap.H), type(lap.H_facs), lap._H_factor)
posterior = MultivariateNormal(loc=lap.mean, precision_matrix=lap.posterior_precision.to_matrix())
assert torch.allclose(lap.log_prob(theta), posterior.log_prob(theta))

Expand Down

0 comments on commit 0915472

Please sign in to comment.