Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add feature/pretrain fish #29

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions src/optim/FishLeg/fishleg.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __setstate__(self, state):
for s in state_values:
s["step"] = torch.tensor(float(s["step"]))

def update_aux(self) -> None:
def update_aux(self, log_results=True, u_sample_overwrite=False) -> None:
"""
Performs a single auxliarary parameter update
using Adam. By minimizing the following objective:
Expand All @@ -175,7 +175,10 @@ def update_aux(self) -> None:
method = group["method"]
method_kwargs = group["method_kwargs"]
precondition_aux = group["precondition_aux"]
u_sampling = group["u_sampling"]
if u_sample_overwrite:
u_sampling = u_sample_overwrite
else:
u_sampling = group["u_sampling"]

pred_y = self.model(data_x)
with torch.no_grad():
Expand All @@ -185,7 +188,8 @@ def sample_u(g: torch.Tensor) -> torch.Tensor:
if u_sampling == "gradient":
return g
elif u_sampling == "gaussian":
return torch.randn(size=g.shape)
u = torch.randn(size=g.shape)
return u.to(g.device)
else:
raise NotImplementedError(
f"{u_sampling} method of sampling u not implemented yet!"
Expand All @@ -210,7 +214,7 @@ def sample_u(g: torch.Tensor) -> torch.Tensor:
u_model.append(u)

u_norm = torch.sqrt(u2)
v_norm = torch.sqrt(v2) * u_norm
v_norm = torch.sqrt(v2)

# the different methods differ in how they compute Fv_norm

Expand Down Expand Up @@ -284,7 +288,7 @@ def _augment_params_by(eps: float):
v_adj = list(
map(
lambda Fv, v, u: (
(Fv + float(damping) * v.detach() / v_norm) - u / v_norm
(Fv + damping * v.detach() / v_norm) - u / (u_norm * v_norm)
),
Fv_norm,
v_model,
Expand Down Expand Up @@ -320,7 +324,7 @@ def _augment_params_by(eps: float):
if not precondition_aux:
aux_loss = surrogate_loss.item()

if self.writer:
if self.writer and log_results:
self.writer.add_scalar(
"AuxLoss/train",
aux_loss,
Expand All @@ -340,6 +344,24 @@ def _augment_params_by(eps: float):
self.aux_opt.step()
return surrogate_loss.item()

def pretrain_fish(
self, iterations: int, pretrain_writer: SummaryWriter or None = None
) -> List:
pretrain_losses = []
for pre_step in range(1, iterations + 1):
loss = self.update_aux(log_results=False, u_sample_overwrite="gaussian")

pretrain_losses.append(loss)

if pretrain_writer:
pretrain_writer.add_scalar(
"SurrogateLoss/pretrain",
loss,
pre_step,
)

return pretrain_losses

def _init_group(
self,
group,
Expand Down
15 changes: 8 additions & 7 deletions src/optim/FishLeg/layers/fish_batchNorm2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
init_scale: float = 1.0,
device=None,
dtype=None,
) -> None:
Expand All @@ -26,21 +27,21 @@ def __init__(
self.fishleg_aux = ParameterDict(
{
"L_w": FishAuxParameter(
torch.ones(
(num_features,), device=device
) # * np.sqrt(init_scale) # TODO: CHECK
torch.ones((num_features,), device=device).mul_(
np.sqrt(init_scale)
) # TODO: CHECK
),
"L_b": FishAuxParameter(
torch.ones(
(num_features,), device=device
) # * np.sqrt(init_scale)
torch.ones((num_features,), device=device).mul_(
np.sqrt(init_scale)
)
),
}
)

self.order = ["weight", "bias"]

def Qv(self, v: Tuple, full=False):
def Qv(self, v: Tuple):
return (
torch.square(self.fishleg_aux["L_w"]) * v[0],
torch.square(self.fishleg_aux["L_b"]) * v[1],
Expand Down
2 changes: 1 addition & 1 deletion src/optim/FishLeg/layers/fish_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def diagQ(self) -> Tensor:
R = self.fishleg_aux["R"]
A = self.fishleg_aux["A"]

diagA = torch.square(torch.reshape(A.T, (-1)))
diagA = torch.square(torch.reshape(A.T, (-1,)))
diag = diagA * torch.kron(
torch.sum(torch.square(L), dim=0), torch.sum(torch.square(R), dim=0)
)
Expand Down
2 changes: 1 addition & 1 deletion src/optim/FishLeg/layers/fish_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,4 @@ def diagQ(self) -> Tuple:
diag = diag * torch.square(self.fishleg_aux["A"].T).reshape(-1)

diag = diag.reshape(L.shape[0], R.shape[0]).T
return (diag[:, :-1], diag[:, -1]) if self.bias else (diag,)
return (diag[:, :-1], diag[:, -1]) if self.bias is not None else (diag,)
4 changes: 2 additions & 2 deletions src/optim/FishLeg/likelihoods/bernoulli_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class BernoulliLikelihood(FishLikelihoodBase):
def __init__(self, device: str = "cpu") -> None:
self.device = device

def nll(self, preds: torch.Tensor, observations: torch.Tensor) -> torch.Tensor:
bce = torch.sum(preds * (1.0 - observations) + torch.nn.Softplus()(-preds))
def nll(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
bce = torch.sum(preds * (1.0 - targets) + torch.nn.Softplus()(-preds))
return bce / preds.shape[0]

def draw(self, preds: torch.Tensor) -> torch.Tensor:
Expand Down
6 changes: 3 additions & 3 deletions src/optim/FishLeg/likelihoods/softmax_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ def __init__(self, device: str = "cpu") -> None:

def nll(sef, preds: torch.Tensor, observations: torch.Tensor) -> torch.Tensor:
logits = log_softmax(preds, dim=1)
return -torch.mean(torch.sum(logits * observations, dim=1))
obs_one_hot = one_hot(observations, num_classes=logits.shape[-1])
return -torch.mean(torch.sum(logits * obs_one_hot, dim=1))

def draw(self, preds: torch.Tensor) -> torch.Tensor:
# logits = torch.log(preds)
logits = log_softmax(preds, dim=1)
dense = Categorical(logits=logits).sample()
return one_hot(dense, num_classes=logits.shape[-1])
return dense