|
| 1 | +import torch |
| 2 | +import wandb |
| 3 | +import torch.amp as amp |
| 4 | +import torch.optim as optim |
| 5 | +import torch.optim.lr_scheduler as lr_scheduler |
| 6 | +import torch.nn as nn |
| 7 | +import torch.nn.functional as F |
| 8 | +from torch import Tensor |
| 9 | +from typing import Tuple, Dict |
| 10 | +from attrs import define |
| 11 | +from pathlib import Path |
| 12 | +from source import console |
| 13 | +from rich.progress import Progress |
| 14 | +from source.model import workspace |
| 15 | +from source.interface import SAE |
| 16 | +from source.dataset.msMarco import MsMarcoDataset |
| 17 | +from source.embedding.bgeBase import BgeBaseEmbedding as Embedding |
| 18 | + |
| 19 | + |
| 20 | +class Model(nn.Module, SAE): |
| 21 | + |
| 22 | + def __init__(self, features: int, expandBy: int) -> None: |
| 23 | + super().__init__() |
| 24 | + self.encoder = nn.Linear(features, features * expandBy) |
| 25 | + self.decoder = nn.Linear(features * expandBy, features) |
| 26 | + |
| 27 | + def forwardEncoder(self, x: Tensor, activate: int) -> Tensor: |
| 28 | + xbar = x - self.decoder.bias |
| 29 | + a = self.encoder.forward(xbar) |
| 30 | + pack = torch.topk(a, activate) |
| 31 | + f = torch.zeros_like(a) |
| 32 | + f.scatter_(1, pack.indices, F.relu(pack.values)) |
| 33 | + return f |
| 34 | + |
| 35 | + def forwardDecoder(self, f: Tensor) -> Tensor: |
| 36 | + xhat = self.decoder.forward(f) |
| 37 | + return xhat |
| 38 | + |
| 39 | + def forward(self, x: Tensor, activate: int) -> Tuple[Tensor, Tensor]: |
| 40 | + f = self.forwardEncoder(x, activate) |
| 41 | + xhat = self.forwardDecoder(f) |
| 42 | + return f, xhat |
| 43 | + |
| 44 | + |
| 45 | +@define |
| 46 | +class HyperParams: |
| 47 | + features: int |
| 48 | + expandBy: int |
| 49 | + activate: int |
| 50 | + relevant: int |
| 51 | + |
| 52 | + |
| 53 | +@define |
| 54 | +class TrainParams: |
| 55 | + batchSize: int |
| 56 | + numEpochs: int |
| 57 | + learnRate: float |
| 58 | + |
| 59 | + |
| 60 | +class Trainer: |
| 61 | + |
| 62 | + hyperParams = HyperParams(features=768, expandBy=256, activate=32, relevant=8) |
| 63 | + trainParams = TrainParams(batchSize=512, numEpochs=128, learnRate=1e-3) |
| 64 | + |
| 65 | + def __init__(self) -> None: |
| 66 | + self.name = Path(__file__).stem |
| 67 | + self.workDir = Path(workspace, self.name, "snapshot") |
| 68 | + self.workDir.parent.mkdir(mode=0o770, exist_ok=True) |
| 69 | + self.workDir.mkdir(mode=0o770, exist_ok=True) |
| 70 | + self.dataset = MsMarcoDataset() |
| 71 | + expandBy = self.hyperParams.expandBy |
| 72 | + self.model = nn.DataParallel(Model(Embedding.size, expandBy).cuda()) |
| 73 | + learnRate = self.trainParams.learnRate |
| 74 | + self.optimizer = optim.Adam(self.model.parameters(), lr=learnRate) |
| 75 | + self.scaler = amp.GradScaler() |
| 76 | + numEpochs = self.trainParams.numEpochs |
| 77 | + self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=numEpochs) |
| 78 | + wandb.init(project="interpret", entity="haok", name=self.name) |
| 79 | + wandb.save("source/**/*.py", policy="now") |
| 80 | + |
| 81 | + def trainLoss( |
| 82 | + self, qry: Tensor, docs: Tensor, qryHat: Tensor, docsHat: Tensor |
| 83 | + ) -> Dict[str, Tensor]: |
| 84 | + loss = dict() |
| 85 | + loss["Train.MSE"] = torch.tensor(0.0, requires_grad=True) |
| 86 | + loss["Train.MSE"] = loss["Train.MSE"] + F.mse_loss(qryHat, qry) |
| 87 | + loss["Train.MSE"] = loss["Train.MSE"] + F.mse_loss(docsHat, docs) |
| 88 | + loss["Train.KLD"] = torch.tensor(0.0, requires_grad=True) |
| 89 | + buf = torch.exp( |
| 90 | + torch.matmul( |
| 91 | + qry.unsqueeze(1), |
| 92 | + docs.transpose(1, 2), |
| 93 | + ).squeeze(1) |
| 94 | + ) |
| 95 | + bufSum = buf.sum(dim=1) |
| 96 | + bufHat = torch.exp( |
| 97 | + torch.matmul( |
| 98 | + qryHat.unsqueeze(1), |
| 99 | + docsHat.transpose(1, 2), |
| 100 | + ).squeeze(1) |
| 101 | + ) |
| 102 | + bufHatSum = bufHat.sum(dim=1) |
| 103 | + for i in range(qry.size(0)): |
| 104 | + for j in range(docs.size(1)): |
| 105 | + tar = buf[i, j] / (buf[i, j] + bufSum[i]) |
| 106 | + ins = torch.log(bufHat[i, j] / (bufHat[i, j] + bufHatSum[i])) |
| 107 | + loss["Train.KLD"] = loss["Train.KLD"] + F.kl_div( |
| 108 | + ins, tar, reduction="batchmean" |
| 109 | + ) |
| 110 | + return loss |
| 111 | + |
| 112 | + def trainStep(self, qry: Tensor, docs: Tensor) -> Dict[str, Tensor]: |
| 113 | + self.optimizer.zero_grad() |
| 114 | + qry = qry.to(self.model.device_ids[0]) |
| 115 | + docs = docs.to(self.model.device_ids[0]) |
| 116 | + activate = self.hyperParams.activate |
| 117 | + with amp.autocast("cuda"): |
| 118 | + _, qryHat = self.model.forward(qry, activate) |
| 119 | + _, docsHat = self.model.forward(docs.view(-1, docs.size(-1)), activate) |
| 120 | + assert isinstance(qryHat, Tensor) and isinstance(docsHat, Tensor) |
| 121 | + loss = self.trainLoss(qry, docs, qryHat, docsHat.view(docs.size())) |
| 122 | + self.scaler.scale(sum(loss.values())).backward() |
| 123 | + self.scaler.step(self.optimizer) |
| 124 | + self.scaler.update() |
| 125 | + return loss |
| 126 | + |
| 127 | + def trainIter(self, i: int) -> Dict[str, float]: |
| 128 | + iterLoss = dict() |
| 129 | + relevant = self.hyperParams.relevant |
| 130 | + batchSize = self.trainParams.batchSize |
| 131 | + numEpochs = self.trainParams.numEpochs |
| 132 | + self.model.train() |
| 133 | + with Progress(console=console) as progress: |
| 134 | + T = progress.add_task( |
| 135 | + f"[{i:>3}/{numEpochs}]", total=self.dataset.getMixLen("Train") |
| 136 | + ) |
| 137 | + for qry, docs in self.dataset.mixEmbIter( |
| 138 | + Embedding, "Train", relevant, batchSize, 8, True |
| 139 | + ): |
| 140 | + loss = self.trainStep(qry, docs) |
| 141 | + progress.advance(T, qry.size(0)) |
| 142 | + for key, val in loss.items(): |
| 143 | + iterLoss[key] = iterLoss.get(key, 0.0) + val.item() |
| 144 | + progress.stop_task(T) |
| 145 | + numBatches = self.dataset.getMixLen("Train") // batchSize |
| 146 | + for key in iterLoss.keys(): |
| 147 | + iterLoss[key] /= numBatches |
| 148 | + self.scheduler.step() |
| 149 | + return iterLoss |
| 150 | + |
| 151 | + def validateLoss( |
| 152 | + self, qry: Tensor, docs: Tensor, qryHat: Tensor, docsHat: Tensor |
| 153 | + ) -> Dict[str, Tensor]: |
| 154 | + loss = dict() |
| 155 | + loss["Validate.MSE"] = torch.tensor(0.0) |
| 156 | + loss["Validate.MSE"] = loss["Validate.MSE"] + F.mse_loss(qryHat, qry) |
| 157 | + loss["Validate.MSE"] = loss["Validate.MSE"] + F.mse_loss(docsHat, docs) |
| 158 | + loss["Validate.KLD"] = torch.tensor(0.0) |
| 159 | + buf = torch.exp( |
| 160 | + torch.matmul( |
| 161 | + qry.unsqueeze(1), |
| 162 | + docs.transpose(1, 2), |
| 163 | + ).squeeze(1) |
| 164 | + ) |
| 165 | + bufSum = buf.sum(dim=1) |
| 166 | + bufHat = torch.exp( |
| 167 | + torch.matmul( |
| 168 | + qryHat.unsqueeze(1), |
| 169 | + docsHat.transpose(1, 2), |
| 170 | + ).squeeze(1) |
| 171 | + ) |
| 172 | + bufHatSum = bufHat.sum(dim=1) |
| 173 | + # compute KLD in a vectorized manner |
| 174 | + tar = buf / (buf + bufSum) |
| 175 | + ins = torch.log(bufHat / (bufHat + bufHatSum)) |
| 176 | + loss["Validate.KLD"] += F.kl_div(ins, tar, reduction="batchmean") |
| 177 | + return loss |
| 178 | + |
| 179 | + def validateStep(self, qry: Tensor, docs: Tensor) -> Dict[str, Tensor]: |
| 180 | + qry = qry.to(self.model.device_ids[0]) |
| 181 | + docs = docs.to(self.model.device_ids[0]) |
| 182 | + activate = self.hyperParams.activate |
| 183 | + _, qryHat = self.model.forward(qry, activate) |
| 184 | + _, docsHat = self.model.forward(docs.view(-1, docs.size(-1)), activate) |
| 185 | + assert isinstance(qryHat, Tensor) and isinstance(docsHat, Tensor) |
| 186 | + loss = self.validateLoss(qry, docs, qryHat, docsHat.view(docs.size())) |
| 187 | + return loss |
| 188 | + |
| 189 | + @torch.inference_mode() |
| 190 | + def validateIter(self, i: int) -> Dict[str, float]: |
| 191 | + iterLoss = dict() |
| 192 | + relevant = self.hyperParams.relevant |
| 193 | + batchSize = self.trainParams.batchSize |
| 194 | + numEpochs = self.trainParams.numEpochs |
| 195 | + self.model.eval() |
| 196 | + with Progress(console=console) as progress: |
| 197 | + T = progress.add_task( |
| 198 | + f"[{i:>3}/{numEpochs}]", total=self.dataset.getMixLen("Validate") |
| 199 | + ) |
| 200 | + for qry, docs in self.dataset.mixEmbIter( |
| 201 | + Embedding, "Validate", relevant, batchSize, 8, True |
| 202 | + ): |
| 203 | + loss = self.validateStep(qry, docs) |
| 204 | + progress.advance(T, qry.size(0)) |
| 205 | + for key, val in loss.items(): |
| 206 | + iterLoss[key] = iterLoss.get(key, 0.0) + val.item() |
| 207 | + progress.stop_task(T) |
| 208 | + numBatches = self.dataset.getMixLen("Validate") // batchSize |
| 209 | + for key in iterLoss.keys(): |
| 210 | + iterLoss[key] /= numBatches |
| 211 | + self.scheduler.step() |
| 212 | + return iterLoss |
| 213 | + |
| 214 | + def run(self): |
| 215 | + minLoss = float("inf") |
| 216 | + numEpochs = self.trainParams.numEpochs |
| 217 | + for i in range(1, numEpochs + 1): |
| 218 | + trainLoss = self.trainIter(i) |
| 219 | + validateLoss = self.validateIter(i) |
| 220 | + if sum(validateLoss.values()) <= minLoss: |
| 221 | + minLoss, state = sum(trainLoss.values()), dict() |
| 222 | + state["model"] = self.model.module.state_dict() |
| 223 | + state["optimizer"] = self.optimizer.state_dict() |
| 224 | + state["scheduler"] = self.scheduler.state_dict() |
| 225 | + state["scaler"] = self.scaler.state_dict() |
| 226 | + state["epoch"], state["minLoss"] = i, minLoss |
| 227 | + torch.save(state, Path(self.workDir, f"{i:03}.pth")) |
| 228 | + globs = sorted(self.workDir.glob("*.pth"), reverse=True) |
| 229 | + while len(globs) > 3: |
| 230 | + globs.pop().unlink() |
| 231 | + health = dict() |
| 232 | + health.update(trainLoss) |
| 233 | + health.update(validateLoss) |
| 234 | + health["LR"] = self.optimizer.param_groups[0]["lr"] |
| 235 | + wandb.log(health) |
| 236 | + for key, val in health.items(): |
| 237 | + console.log(f"{key:>12}={val:.7f}") |
| 238 | + wandb.finish() |
| 239 | + |
| 240 | + |
| 241 | +if __name__ == "__main__": |
| 242 | + T = Trainer() |
| 243 | + T.run() |
0 commit comments