Skip to content

Commit

Permalink
add adversarial step
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-santiago committed Dec 4, 2023
1 parent b37a793 commit 2180b18
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 55 deletions.
18 changes: 15 additions & 3 deletions met/conf/model/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,23 @@
# See https://stackoverflow.com/questions/71438040/overwriting-hydra-configuration-groups-from-cli/71439510#71439510
defaults:
- /optimizer@optimizer: adam
- /scheduler@scheduler: cyclic
- /scheduler@scheduler: plateau

name:
name: MET

nn:
_target_:
_target_: met.models.met.MET
num_embeddings: 784
embedding_dim: 64
p_mask: 0.70
n_head: 1
num_encoder_layers: 6
num_decoder_layers: 1
dim_feedforward: 64
dropout: 0.1
adver_steps: 2
lr_perturb: 0.0001
eps: 12
lam: 1.0
loss_func:
_target_: torch.nn.MSELoss
121 changes: 69 additions & 52 deletions met/models/met.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,12 @@
from typing import Optional

import numpy as np
import torch
from torch import nn as nn

from met.models.base import BaseModule


def mask_tensor(x, pct_mask: float = 0.7, dim=1):
n_row, n_col = x.shape
n_masked = int((pct_mask * 100 * n_col) // 100)
idx = torch.stack([torch.randperm(n_col) for _ in range(n_row)])
masked_idx, _ = idx[:, :n_masked].sort(dim=dim)
unmasked_idx, _ = idx[:, n_masked:].sort(dim=dim)
unmasked_x = torch.zeros(n_row, unmasked_idx.shape[dim])
for i in range(n_row):
unmasked_x += x[i][unmasked_idx[i]]
return unmasked_x, unmasked_idx, masked_idx


class MET(BaseModule):
def __init__(
self,
Expand All @@ -29,6 +18,10 @@ def __init__(
num_decoder_layers: int = 1,
dim_feedforward: int = 64,
dropout: float = 0.1,
adver_steps: int = 2,
lr_perturb: float = 1e-4,
eps: int = 12,
lam: float = 1.0,
loss_func: nn.Module = nn.MSELoss(),
optim: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
Expand All @@ -42,6 +35,12 @@ def __init__(
self.num_decoder_layers = num_decoder_layers
self.dim_feedforward = dim_feedforward
self.dropout = dropout
self.adver_steps = adver_steps
self.lr_perturb = lr_perturb
self.eps = eps
self.lam = lam

self.automatic_optimization = False

# Subtract 1 from desired embedding dim to account for token
self.embedding = nn.Embedding(
Expand All @@ -54,61 +53,79 @@ def __init__(
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True,
)
self.transformer_head = nn.Sequential(nn.Linear(embedding_dim, 1), nn.Flatten())
# self.transformer_head = nn.Linear(embedding_dim, 1)

def embed_inputs(self, x, idx):
# Transformer inputs are concat of original tokens and position embed
embd = self.embedding(idx)
return torch.concat([x.unsqueeze(-1), embd], dim=-1)

self.transformer_head = nn.Linear(embedding_dim, 1)
def forward(self, unmasked_x, unmasked_idx, masked_x, masked_idx):
unmasked_inputs = self.embed_inputs(unmasked_x, unmasked_idx)
masked_inputs = self.embed_inputs(masked_x, masked_idx)

# Recovering embeddings for all original features/cols (784 in MNIST example)
all_embd = torch.concat([unmasked_inputs, masked_inputs], dim=1)
outputs = self.transformer(unmasked_inputs, all_embd)
x_hat = self.transformer_head(outputs)
return x_hat

def adversarial_step(self, unmasked_x, unmasked_idx, masked_x, masked_idx, original):
opt = self.optimizers()
unmasked_x.retain_grad()
perturbed_recon = self(unmasked_x, unmasked_idx, masked_x, masked_idx)
perturbed_loss = -self.loss_func(original, perturbed_recon) # grad ascent
opt.zero_grad()
self.manual_backward(perturbed_loss, retain_graph=True)
# !! Interesting-- if this is enabled it causes in-place errors w/grad calcs !!
# I believe it modifies the `recon_loss` grad that's done before this
# opt.step()

# Constrain h
h = unmasked_x + self.lr_perturb * (unmasked_x.grad / torch.norm(unmasked_x.grad))
alpha = (torch.norm(h) * int(torch.norm(h) < self.eps)) + (
self.eps * int(torch.norm(h) >= self.eps)
)
h_adj = alpha * (h / torch.norm(h))
return unmasked_x + h_adj

def training_step(self, batch, idx):
original = batch[0]
opt = self.optimizers()
# inputs: (batch, cols)
# mask: (batch, rows, mask_cols)
# embeds: (mask_cols, embed_dim)
# new inputs: (batch, mask_cols, embed_dim)
x, unmasked_idx, masked_idx = mask_tensor(batch[0], pct_mask=self.p_mask)
unmasked_x, unmasked_idx, masked_x, masked_idx, original, _ = batch

# Transformer inputs are concat of original tokens and position embed
unmasked_embd = self.embedding(unmasked_idx)
unmasked_inputs = torch.concat([x.unsqueeze(-1), unmasked_embd], dim=-1)

# Need a constant tensor of masked inputs to learn embed params
masked_embd = self.embedding(masked_idx)
masked_inputs = torch.concat(
[torch.ones_like(masked_idx).unsqueeze(-1), masked_embd], dim=-1
)

# Recovering embeddings for all original features/cols (784 in MNIST example)
all_embd = torch.concat([unmasked_inputs, masked_inputs], dim=1)

outputs = self.transformer(unmasked_inputs, all_embd)
recon = self.transformer_head(outputs).squeeze()
# standard reconstruction loss
recon = self(unmasked_x, unmasked_idx, masked_x, masked_idx)
recon_loss = self.loss_func(original, recon)
self.log("recon-loss", recon_loss, on_step=True, on_epoch=True, prog_bar=True, logger=False)

# TODO adversarial loss (make for loop)
h = torch.normal(torch.zeros_like(x), torch.ones_like(x)) / torch.sqrt(
torch.tensor(x.shape[1])
)
# for _ in range(adver_steps):
perturbed = x + h
perturbed_inputs = torch.concat([perturbed.unsqueeze(-1), unmasked_embd], dim=-1)
perturbed_outputs = self.transformer(perturbed_inputs, all_embd)
perturbed_recon = self.transformer_head(perturbed_outputs).squeeze()
perturbed_loss = self.loss_func(original, perturbed_recon)
# TODO gradient ascent...but how? Separate optimizer? Can I just add gradient?
# TODO do I need a separate optimzer? (https://discuss.pytorch.org/t/pytorch-equivalant-of-tensorflow-gradienttape/74915)
# TODO https://discuss.pytorch.org/t/gradient-ascent-and-gradient-modification-modifying-optimizer-instead-of-grad-weight/62777/2
# TODO https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html#gradient-accumulation
h += perturbed_loss.grad
adver_loss = self.loss_func(original, perturbed_recon)
self.log("adver-loss", adver_loss, on_step=True, on_epoch=True, prog_bar=True, logger=False)

total_loss = recon_loss + adver_loss
# adversarial loss
h = torch.normal(0, 1, size=unmasked_x.shape, device=self.device, requires_grad=True)
perturbed_x = (unmasked_x.clone() + h) / np.sqrt(original.shape[-1])
for i in range(self.adver_steps):
h = self.adversarial_step(perturbed_x, unmasked_idx, masked_x, masked_idx, original)
perturbed_x = h.clone()
adv_recon = self(perturbed_x, unmasked_idx, masked_x, masked_idx)
adv_loss = self.loss_func(original, adv_recon)
self.log("adv-loss", adv_loss, on_step=True, on_epoch=True, prog_bar=True, logger=False)

total_loss = recon_loss + adv_loss * self.lam
opt.zero_grad()
self.manual_backward(total_loss)
opt.step()
self.log("total-loss", total_loss, on_step=True, on_epoch=True, prog_bar=True, logger=False)

metrics = {
"reconstruction-loss": recon_loss,
"adversarial-loss": adver_loss,
"train-total-loss": total_loss,
"adversarial-loss": adv_loss,
"train-loss": total_loss,
}
self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=False, logger=True)
return total_loss

# return total_loss

0 comments on commit 2180b18

Please sign in to comment.