From 5ee4009db522e5994ade85ac149e2273fd2ffd78 Mon Sep 17 00:00:00 2001 From: Alex Immer Date: Fri, 10 Dec 2021 14:50:55 +0100 Subject: [PATCH] Integrate comment of passive aggressive reviewer --- laplace/baselaplace.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index 46d6d612..47449350 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -885,7 +885,11 @@ def fit(self, train_loader, override=True): X, _ = next(iter(train_loader)) with torch.no_grad(): - self.n_outputs = self.model(X[:1].to(self._device)).shape[-1] + try: + out = self.model(X[:1].to(self._device)) + except (TypeError, AttributeError): + out = self.model(X.to(self._device)) + self.n_outputs = out.shape[-1] setattr(self.model, 'output_size', self.n_outputs) eigenvectors, eigenvalues, loss = self.backend.eig_lowrank(train_loader)