From 2180b18abcdf6c5c097c00b0a0c1bb2082b660c0 Mon Sep 17 00:00:00 2001 From: chris-santiago Date: Sun, 3 Dec 2023 22:11:55 -0500 Subject: [PATCH] add adversarial step --- met/conf/model/base.yaml | 18 +++++- met/models/met.py | 121 ++++++++++++++++++++++----------------- 2 files changed, 84 insertions(+), 55 deletions(-) diff --git a/met/conf/model/base.yaml b/met/conf/model/base.yaml index 688b1c0..e930f25 100644 --- a/met/conf/model/base.yaml +++ b/met/conf/model/base.yaml @@ -4,11 +4,23 @@ # See https://stackoverflow.com/questions/71438040/overwriting-hydra-configuration-groups-from-cli/71439510#71439510 defaults: - /optimizer@optimizer: adam - - /scheduler@scheduler: cyclic + - /scheduler@scheduler: plateau -name: +name: MET nn: - _target_: + _target_: met.models.met.MET + num_embeddings: 784 + embedding_dim: 64 + p_mask: 0.70 + n_head: 1 + num_encoder_layers: 6 + num_decoder_layers: 1 + dim_feedforward: 64 + dropout: 0.1 + adver_steps: 2 + lr_perturb: 0.0001 + eps: 12 + lam: 1.0 loss_func: _target_: torch.nn.MSELoss \ No newline at end of file diff --git a/met/models/met.py b/met/models/met.py index 884f285..e63894c 100644 --- a/met/models/met.py +++ b/met/models/met.py @@ -1,23 +1,12 @@ from typing import Optional +import numpy as np import torch from torch import nn as nn from met.models.base import BaseModule -def mask_tensor(x, pct_mask: float = 0.7, dim=1): - n_row, n_col = x.shape - n_masked = int((pct_mask * 100 * n_col) // 100) - idx = torch.stack([torch.randperm(n_col) for _ in range(n_row)]) - masked_idx, _ = idx[:, :n_masked].sort(dim=dim) - unmasked_idx, _ = idx[:, n_masked:].sort(dim=dim) - unmasked_x = torch.zeros(n_row, unmasked_idx.shape[dim]) - for i in range(n_row): - unmasked_x += x[i][unmasked_idx[i]] - return unmasked_x, unmasked_idx, masked_idx - - class MET(BaseModule): def __init__( self, @@ -29,6 +18,10 @@ def __init__( num_decoder_layers: int = 1, dim_feedforward: int = 64, dropout: float = 0.1, + adver_steps: int = 2, + lr_perturb: float = 1e-4, + eps: int = 12, + lam: float = 1.0, loss_func: nn.Module = nn.MSELoss(), optim: Optional[torch.optim.Optimizer] = None, scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, @@ -42,6 +35,12 @@ def __init__( self.num_decoder_layers = num_decoder_layers self.dim_feedforward = dim_feedforward self.dropout = dropout + self.adver_steps = adver_steps + self.lr_perturb = lr_perturb + self.eps = eps + self.lam = lam + + self.automatic_optimization = False # Subtract 1 from desired embedding dim to account for token self.embedding = nn.Embedding( @@ -54,61 +53,79 @@ def __init__( num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward, dropout=dropout, + batch_first=True, ) + self.transformer_head = nn.Sequential(nn.Linear(embedding_dim, 1), nn.Flatten()) + # self.transformer_head = nn.Linear(embedding_dim, 1) + + def embed_inputs(self, x, idx): + # Transformer inputs are concat of original tokens and position embed + embd = self.embedding(idx) + return torch.concat([x.unsqueeze(-1), embd], dim=-1) - self.transformer_head = nn.Linear(embedding_dim, 1) + def forward(self, unmasked_x, unmasked_idx, masked_x, masked_idx): + unmasked_inputs = self.embed_inputs(unmasked_x, unmasked_idx) + masked_inputs = self.embed_inputs(masked_x, masked_idx) + + # Recovering embeddings for all original features/cols (784 in MNIST example) + all_embd = torch.concat([unmasked_inputs, masked_inputs], dim=1) + outputs = self.transformer(unmasked_inputs, all_embd) + x_hat = self.transformer_head(outputs) + return x_hat + + def adversarial_step(self, unmasked_x, unmasked_idx, masked_x, masked_idx, original): + opt = self.optimizers() + unmasked_x.retain_grad() + perturbed_recon = self(unmasked_x, unmasked_idx, masked_x, masked_idx) + perturbed_loss = -self.loss_func(original, perturbed_recon) # grad ascent + opt.zero_grad() + self.manual_backward(perturbed_loss, retain_graph=True) + # !! Interesting-- if this is enabled it causes in-place errors w/grad calcs !! + # I believe it modifies the `recon_loss` grad that's done before this + # opt.step() + + # Constrain h + h = unmasked_x + self.lr_perturb * (unmasked_x.grad / torch.norm(unmasked_x.grad)) + alpha = (torch.norm(h) * int(torch.norm(h) < self.eps)) + ( + self.eps * int(torch.norm(h) >= self.eps) + ) + h_adj = alpha * (h / torch.norm(h)) + return unmasked_x + h_adj def training_step(self, batch, idx): - original = batch[0] + opt = self.optimizers() # inputs: (batch, cols) # mask: (batch, rows, mask_cols) # embeds: (mask_cols, embed_dim) # new inputs: (batch, mask_cols, embed_dim) - x, unmasked_idx, masked_idx = mask_tensor(batch[0], pct_mask=self.p_mask) + unmasked_x, unmasked_idx, masked_x, masked_idx, original, _ = batch - # Transformer inputs are concat of original tokens and position embed - unmasked_embd = self.embedding(unmasked_idx) - unmasked_inputs = torch.concat([x.unsqueeze(-1), unmasked_embd], dim=-1) - - # Need a constant tensor of masked inputs to learn embed params - masked_embd = self.embedding(masked_idx) - masked_inputs = torch.concat( - [torch.ones_like(masked_idx).unsqueeze(-1), masked_embd], dim=-1 - ) - - # Recovering embeddings for all original features/cols (784 in MNIST example) - all_embd = torch.concat([unmasked_inputs, masked_inputs], dim=1) - - outputs = self.transformer(unmasked_inputs, all_embd) - recon = self.transformer_head(outputs).squeeze() + # standard reconstruction loss + recon = self(unmasked_x, unmasked_idx, masked_x, masked_idx) recon_loss = self.loss_func(original, recon) self.log("recon-loss", recon_loss, on_step=True, on_epoch=True, prog_bar=True, logger=False) - # TODO adversarial loss (make for loop) - h = torch.normal(torch.zeros_like(x), torch.ones_like(x)) / torch.sqrt( - torch.tensor(x.shape[1]) - ) - # for _ in range(adver_steps): - perturbed = x + h - perturbed_inputs = torch.concat([perturbed.unsqueeze(-1), unmasked_embd], dim=-1) - perturbed_outputs = self.transformer(perturbed_inputs, all_embd) - perturbed_recon = self.transformer_head(perturbed_outputs).squeeze() - perturbed_loss = self.loss_func(original, perturbed_recon) - # TODO gradient ascent...but how? Separate optimizer? Can I just add gradient? - # TODO do I need a separate optimzer? (https://discuss.pytorch.org/t/pytorch-equivalant-of-tensorflow-gradienttape/74915) - # TODO https://discuss.pytorch.org/t/gradient-ascent-and-gradient-modification-modifying-optimizer-instead-of-grad-weight/62777/2 - # TODO https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html#gradient-accumulation - h += perturbed_loss.grad - adver_loss = self.loss_func(original, perturbed_recon) - self.log("adver-loss", adver_loss, on_step=True, on_epoch=True, prog_bar=True, logger=False) - - total_loss = recon_loss + adver_loss + # adversarial loss + h = torch.normal(0, 1, size=unmasked_x.shape, device=self.device, requires_grad=True) + perturbed_x = (unmasked_x.clone() + h) / np.sqrt(original.shape[-1]) + for i in range(self.adver_steps): + h = self.adversarial_step(perturbed_x, unmasked_idx, masked_x, masked_idx, original) + perturbed_x = h.clone() + adv_recon = self(perturbed_x, unmasked_idx, masked_x, masked_idx) + adv_loss = self.loss_func(original, adv_recon) + self.log("adv-loss", adv_loss, on_step=True, on_epoch=True, prog_bar=True, logger=False) + + total_loss = recon_loss + adv_loss * self.lam + opt.zero_grad() + self.manual_backward(total_loss) + opt.step() self.log("total-loss", total_loss, on_step=True, on_epoch=True, prog_bar=True, logger=False) metrics = { "reconstruction-loss": recon_loss, - "adversarial-loss": adver_loss, - "train-total-loss": total_loss, + "adversarial-loss": adv_loss, + "train-loss": total_loss, } self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=False, logger=True) - return total_loss + + # return total_loss