Skip to content

Commit

Permalink
Enable regression with AsdlHessian
Browse files Browse the repository at this point in the history
  • Loading branch information
aleximmer committed Dec 10, 2021
1 parent 9e2fed6 commit d37f26b
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions laplace/curvature/asdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@
class AsdlInterface(CurvatureInterface):
"""Interface for asdfghjkl backend.
"""
def __init__(self, model, likelihood, last_layer=False, low_rank=10):
if likelihood != 'classification':
raise ValueError('This backend only supports classification currently.')
self.low_rank = low_rank
super().__init__(model, likelihood, last_layer)

@staticmethod
def jacobians(model, x):
Expand Down Expand Up @@ -137,6 +132,10 @@ def kron(self, X, y, N, **wkwargs) -> [torch.Tensor, Kron]:

class AsdlHessian(AsdlInterface):

def __init__(self, model, likelihood, last_layer=False, low_rank=10):
super().__init__(model, likelihood, last_layer)
self.low_rank = low_rank

@property
def _ggn_type(self):
raise NotImplementedError()
Expand Down Expand Up @@ -164,6 +163,8 @@ class AsdlGGN(AsdlInterface, GGNInterface):
"""Implementation of the `GGNInterface` using asdfghjkl.
"""
def __init__(self, model, likelihood, last_layer=False, stochastic=False):
if likelihood != 'classification':
raise ValueError('This backend only supports classification currently.')
super().__init__(model, likelihood, last_layer)
self.stochastic = stochastic

Expand All @@ -175,6 +176,10 @@ def _ggn_type(self):
class AsdlEF(AsdlInterface, EFInterface):
"""Implementation of the `EFInterface` using asdfghjkl.
"""
def __init__(self, model, likelihood, last_layer=False):
if likelihood != 'classification':
raise ValueError('This backend only supports classification currently.')
super().__init__(model, likelihood, last_layer)

@property
def _ggn_type(self):
Expand Down

0 comments on commit d37f26b

Please sign in to comment.