diff --git a/src/optim/FishLeg/fishleg.py b/src/optim/FishLeg/fishleg.py index 1f049e7..8220f35 100644 --- a/src/optim/FishLeg/fishleg.py +++ b/src/optim/FishLeg/fishleg.py @@ -155,7 +155,7 @@ def __setstate__(self, state): for s in state_values: s["step"] = torch.tensor(float(s["step"])) - def update_aux(self) -> None: + def update_aux(self, log_results=True, u_sample_overwrite=False) -> None: """ Performs a single auxliarary parameter update using Adam. By minimizing the following objective: @@ -175,7 +175,10 @@ def update_aux(self) -> None: method = group["method"] method_kwargs = group["method_kwargs"] precondition_aux = group["precondition_aux"] - u_sampling = group["u_sampling"] + if u_sample_overwrite: + u_sampling = u_sample_overwrite + else: + u_sampling = group["u_sampling"] pred_y = self.model(data_x) with torch.no_grad(): @@ -185,7 +188,8 @@ def sample_u(g: torch.Tensor) -> torch.Tensor: if u_sampling == "gradient": return g elif u_sampling == "gaussian": - return torch.randn(size=g.shape) + u = torch.randn(size=g.shape) + return u.to(g.device) else: raise NotImplementedError( f"{u_sampling} method of sampling u not implemented yet!" @@ -210,7 +214,7 @@ def sample_u(g: torch.Tensor) -> torch.Tensor: u_model.append(u) u_norm = torch.sqrt(u2) - v_norm = torch.sqrt(v2) * u_norm + v_norm = torch.sqrt(v2) # the different methods differ in how they compute Fv_norm @@ -284,7 +288,7 @@ def _augment_params_by(eps: float): v_adj = list( map( lambda Fv, v, u: ( - (Fv + float(damping) * v.detach() / v_norm) - u / v_norm + (Fv + damping * v.detach() / v_norm) - u / (u_norm * v_norm) ), Fv_norm, v_model, @@ -320,7 +324,7 @@ def _augment_params_by(eps: float): if not precondition_aux: aux_loss = surrogate_loss.item() - if self.writer: + if self.writer and log_results: self.writer.add_scalar( "AuxLoss/train", aux_loss, @@ -340,6 +344,24 @@ def _augment_params_by(eps: float): self.aux_opt.step() return surrogate_loss.item() + def pretrain_fish( + self, iterations: int, pretrain_writer: SummaryWriter or None = None + ) -> List: + pretrain_losses = [] + for pre_step in range(1, iterations + 1): + loss = self.update_aux(log_results=False, u_sample_overwrite="gaussian") + + pretrain_losses.append(loss) + + if pretrain_writer: + pretrain_writer.add_scalar( + "SurrogateLoss/pretrain", + loss, + pre_step, + ) + + return pretrain_losses + def _init_group( self, group, diff --git a/src/optim/FishLeg/layers/fish_batchNorm2d.py b/src/optim/FishLeg/layers/fish_batchNorm2d.py index 18a7863..a626310 100644 --- a/src/optim/FishLeg/layers/fish_batchNorm2d.py +++ b/src/optim/FishLeg/layers/fish_batchNorm2d.py @@ -15,6 +15,7 @@ def __init__( momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, + init_scale: float = 1.0, device=None, dtype=None, ) -> None: @@ -26,21 +27,21 @@ def __init__( self.fishleg_aux = ParameterDict( { "L_w": FishAuxParameter( - torch.ones( - (num_features,), device=device - ) # * np.sqrt(init_scale) # TODO: CHECK + torch.ones((num_features,), device=device).mul_( + np.sqrt(init_scale) + ) # TODO: CHECK ), "L_b": FishAuxParameter( - torch.ones( - (num_features,), device=device - ) # * np.sqrt(init_scale) + torch.ones((num_features,), device=device).mul_( + np.sqrt(init_scale) + ) ), } ) self.order = ["weight", "bias"] - def Qv(self, v: Tuple, full=False): + def Qv(self, v: Tuple): return ( torch.square(self.fishleg_aux["L_w"]) * v[0], torch.square(self.fishleg_aux["L_b"]) * v[1], diff --git a/src/optim/FishLeg/layers/fish_conv2d.py b/src/optim/FishLeg/layers/fish_conv2d.py index 8277b67..2376f67 100644 --- a/src/optim/FishLeg/layers/fish_conv2d.py +++ b/src/optim/FishLeg/layers/fish_conv2d.py @@ -116,7 +116,7 @@ def diagQ(self) -> Tensor: R = self.fishleg_aux["R"] A = self.fishleg_aux["A"] - diagA = torch.square(torch.reshape(A.T, (-1))) + diagA = torch.square(torch.reshape(A.T, (-1,))) diag = diagA * torch.kron( torch.sum(torch.square(L), dim=0), torch.sum(torch.square(R), dim=0) ) diff --git a/src/optim/FishLeg/layers/fish_linear.py b/src/optim/FishLeg/layers/fish_linear.py index 497a4ac..af50e73 100644 --- a/src/optim/FishLeg/layers/fish_linear.py +++ b/src/optim/FishLeg/layers/fish_linear.py @@ -102,4 +102,4 @@ def diagQ(self) -> Tuple: diag = diag * torch.square(self.fishleg_aux["A"].T).reshape(-1) diag = diag.reshape(L.shape[0], R.shape[0]).T - return (diag[:, :-1], diag[:, -1]) if self.bias else (diag,) + return (diag[:, :-1], diag[:, -1]) if self.bias is not None else (diag,) diff --git a/src/optim/FishLeg/likelihoods/bernoulli_likelihood.py b/src/optim/FishLeg/likelihoods/bernoulli_likelihood.py index 2195005..3ecda91 100644 --- a/src/optim/FishLeg/likelihoods/bernoulli_likelihood.py +++ b/src/optim/FishLeg/likelihoods/bernoulli_likelihood.py @@ -22,8 +22,8 @@ class BernoulliLikelihood(FishLikelihoodBase): def __init__(self, device: str = "cpu") -> None: self.device = device - def nll(self, preds: torch.Tensor, observations: torch.Tensor) -> torch.Tensor: - bce = torch.sum(preds * (1.0 - observations) + torch.nn.Softplus()(-preds)) + def nll(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + bce = torch.sum(preds * (1.0 - targets) + torch.nn.Softplus()(-preds)) return bce / preds.shape[0] def draw(self, preds: torch.Tensor) -> torch.Tensor: diff --git a/src/optim/FishLeg/likelihoods/softmax_likelihood.py b/src/optim/FishLeg/likelihoods/softmax_likelihood.py index 7ff5ea2..177f58c 100644 --- a/src/optim/FishLeg/likelihoods/softmax_likelihood.py +++ b/src/optim/FishLeg/likelihoods/softmax_likelihood.py @@ -15,10 +15,10 @@ def __init__(self, device: str = "cpu") -> None: def nll(sef, preds: torch.Tensor, observations: torch.Tensor) -> torch.Tensor: logits = log_softmax(preds, dim=1) - return -torch.mean(torch.sum(logits * observations, dim=1)) + obs_one_hot = one_hot(observations, num_classes=logits.shape[-1]) + return -torch.mean(torch.sum(logits * obs_one_hot, dim=1)) def draw(self, preds: torch.Tensor) -> torch.Tensor: - # logits = torch.log(preds) logits = log_softmax(preds, dim=1) dense = Categorical(logits=logits).sample() - return one_hot(dense, num_classes=logits.shape[-1]) + return dense