Skip to content

Commit 46f5915

Browse files
committed
perform a grid search to more choices of activation
1 parent 890c0a6 commit 46f5915

File tree

8 files changed

+1110
-0
lines changed

8 files changed

+1110
-0
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
wandb
12
**/__pycache__

source/model/kld_x256_k32.py

+243
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
import torch
2+
import wandb
3+
import torch.amp as amp
4+
import torch.optim as optim
5+
import torch.optim.lr_scheduler as lr_scheduler
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
from torch import Tensor
9+
from typing import Tuple, Dict
10+
from attrs import define
11+
from pathlib import Path
12+
from source import console
13+
from rich.progress import Progress
14+
from source.model import workspace
15+
from source.interface import SAE
16+
from source.dataset.msMarco import MsMarcoDataset
17+
from source.embedding.bgeBase import BgeBaseEmbedding as Embedding
18+
19+
20+
class Model(nn.Module, SAE):
21+
22+
def __init__(self, features: int, expandBy: int) -> None:
23+
super().__init__()
24+
self.encoder = nn.Linear(features, features * expandBy)
25+
self.decoder = nn.Linear(features * expandBy, features)
26+
27+
def forwardEncoder(self, x: Tensor, activate: int) -> Tensor:
28+
xbar = x - self.decoder.bias
29+
a = self.encoder.forward(xbar)
30+
pack = torch.topk(a, activate)
31+
f = torch.zeros_like(a)
32+
f.scatter_(1, pack.indices, F.relu(pack.values))
33+
return f
34+
35+
def forwardDecoder(self, f: Tensor) -> Tensor:
36+
xhat = self.decoder.forward(f)
37+
return xhat
38+
39+
def forward(self, x: Tensor, activate: int) -> Tuple[Tensor, Tensor]:
40+
f = self.forwardEncoder(x, activate)
41+
xhat = self.forwardDecoder(f)
42+
return f, xhat
43+
44+
45+
@define
46+
class HyperParams:
47+
features: int
48+
expandBy: int
49+
activate: int
50+
relevant: int
51+
52+
53+
@define
54+
class TrainParams:
55+
batchSize: int
56+
numEpochs: int
57+
learnRate: float
58+
59+
60+
class Trainer:
61+
62+
hyperParams = HyperParams(features=768, expandBy=256, activate=32, relevant=8)
63+
trainParams = TrainParams(batchSize=512, numEpochs=128, learnRate=1e-3)
64+
65+
def __init__(self) -> None:
66+
self.name = Path(__file__).stem
67+
self.workDir = Path(workspace, self.name, "snapshot")
68+
self.workDir.parent.mkdir(mode=0o770, exist_ok=True)
69+
self.workDir.mkdir(mode=0o770, exist_ok=True)
70+
self.dataset = MsMarcoDataset()
71+
expandBy = self.hyperParams.expandBy
72+
self.model = nn.DataParallel(Model(Embedding.size, expandBy).cuda())
73+
learnRate = self.trainParams.learnRate
74+
self.optimizer = optim.Adam(self.model.parameters(), lr=learnRate)
75+
self.scaler = amp.GradScaler()
76+
numEpochs = self.trainParams.numEpochs
77+
self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=numEpochs)
78+
wandb.init(project="interpret", entity="haok", name=self.name)
79+
wandb.save("source/**/*.py", policy="now")
80+
81+
def trainLoss(
82+
self, qry: Tensor, docs: Tensor, qryHat: Tensor, docsHat: Tensor
83+
) -> Dict[str, Tensor]:
84+
loss = dict()
85+
loss["Train.MSE"] = torch.tensor(0.0, requires_grad=True)
86+
loss["Train.MSE"] = loss["Train.MSE"] + F.mse_loss(qryHat, qry)
87+
loss["Train.MSE"] = loss["Train.MSE"] + F.mse_loss(docsHat, docs)
88+
loss["Train.KLD"] = torch.tensor(0.0, requires_grad=True)
89+
buf = torch.exp(
90+
torch.matmul(
91+
qry.unsqueeze(1),
92+
docs.transpose(1, 2),
93+
).squeeze(1)
94+
)
95+
bufSum = buf.sum(dim=1)
96+
bufHat = torch.exp(
97+
torch.matmul(
98+
qryHat.unsqueeze(1),
99+
docsHat.transpose(1, 2),
100+
).squeeze(1)
101+
)
102+
bufHatSum = bufHat.sum(dim=1)
103+
for i in range(qry.size(0)):
104+
for j in range(docs.size(1)):
105+
tar = buf[i, j] / (buf[i, j] + bufSum[i])
106+
ins = torch.log(bufHat[i, j] / (bufHat[i, j] + bufHatSum[i]))
107+
loss["Train.KLD"] = loss["Train.KLD"] + F.kl_div(
108+
ins, tar, reduction="batchmean"
109+
)
110+
return loss
111+
112+
def trainStep(self, qry: Tensor, docs: Tensor) -> Dict[str, Tensor]:
113+
self.optimizer.zero_grad()
114+
qry = qry.to(self.model.device_ids[0])
115+
docs = docs.to(self.model.device_ids[0])
116+
activate = self.hyperParams.activate
117+
with amp.autocast("cuda"):
118+
_, qryHat = self.model.forward(qry, activate)
119+
_, docsHat = self.model.forward(docs.view(-1, docs.size(-1)), activate)
120+
assert isinstance(qryHat, Tensor) and isinstance(docsHat, Tensor)
121+
loss = self.trainLoss(qry, docs, qryHat, docsHat.view(docs.size()))
122+
self.scaler.scale(sum(loss.values())).backward()
123+
self.scaler.step(self.optimizer)
124+
self.scaler.update()
125+
return loss
126+
127+
def trainIter(self, i: int) -> Dict[str, float]:
128+
iterLoss = dict()
129+
relevant = self.hyperParams.relevant
130+
batchSize = self.trainParams.batchSize
131+
numEpochs = self.trainParams.numEpochs
132+
self.model.train()
133+
with Progress(console=console) as progress:
134+
T = progress.add_task(
135+
f"[{i:>3}/{numEpochs}]", total=self.dataset.getMixLen("Train")
136+
)
137+
for qry, docs in self.dataset.mixEmbIter(
138+
Embedding, "Train", relevant, batchSize, 8, True
139+
):
140+
loss = self.trainStep(qry, docs)
141+
progress.advance(T, qry.size(0))
142+
for key, val in loss.items():
143+
iterLoss[key] = iterLoss.get(key, 0.0) + val.item()
144+
progress.stop_task(T)
145+
numBatches = self.dataset.getMixLen("Train") // batchSize
146+
for key in iterLoss.keys():
147+
iterLoss[key] /= numBatches
148+
self.scheduler.step()
149+
return iterLoss
150+
151+
def validateLoss(
152+
self, qry: Tensor, docs: Tensor, qryHat: Tensor, docsHat: Tensor
153+
) -> Dict[str, Tensor]:
154+
loss = dict()
155+
loss["Validate.MSE"] = torch.tensor(0.0)
156+
loss["Validate.MSE"] = loss["Validate.MSE"] + F.mse_loss(qryHat, qry)
157+
loss["Validate.MSE"] = loss["Validate.MSE"] + F.mse_loss(docsHat, docs)
158+
loss["Validate.KLD"] = torch.tensor(0.0)
159+
buf = torch.exp(
160+
torch.matmul(
161+
qry.unsqueeze(1),
162+
docs.transpose(1, 2),
163+
).squeeze(1)
164+
)
165+
bufSum = buf.sum(dim=1)
166+
bufHat = torch.exp(
167+
torch.matmul(
168+
qryHat.unsqueeze(1),
169+
docsHat.transpose(1, 2),
170+
).squeeze(1)
171+
)
172+
bufHatSum = bufHat.sum(dim=1)
173+
# compute KLD in a vectorized manner
174+
tar = buf / (buf + bufSum)
175+
ins = torch.log(bufHat / (bufHat + bufHatSum))
176+
loss["Validate.KLD"] += F.kl_div(ins, tar, reduction="batchmean")
177+
return loss
178+
179+
def validateStep(self, qry: Tensor, docs: Tensor) -> Dict[str, Tensor]:
180+
qry = qry.to(self.model.device_ids[0])
181+
docs = docs.to(self.model.device_ids[0])
182+
activate = self.hyperParams.activate
183+
_, qryHat = self.model.forward(qry, activate)
184+
_, docsHat = self.model.forward(docs.view(-1, docs.size(-1)), activate)
185+
assert isinstance(qryHat, Tensor) and isinstance(docsHat, Tensor)
186+
loss = self.validateLoss(qry, docs, qryHat, docsHat.view(docs.size()))
187+
return loss
188+
189+
@torch.inference_mode()
190+
def validateIter(self, i: int) -> Dict[str, float]:
191+
iterLoss = dict()
192+
relevant = self.hyperParams.relevant
193+
batchSize = self.trainParams.batchSize
194+
numEpochs = self.trainParams.numEpochs
195+
self.model.eval()
196+
with Progress(console=console) as progress:
197+
T = progress.add_task(
198+
f"[{i:>3}/{numEpochs}]", total=self.dataset.getMixLen("Validate")
199+
)
200+
for qry, docs in self.dataset.mixEmbIter(
201+
Embedding, "Validate", relevant, batchSize, 8, True
202+
):
203+
loss = self.validateStep(qry, docs)
204+
progress.advance(T, qry.size(0))
205+
for key, val in loss.items():
206+
iterLoss[key] = iterLoss.get(key, 0.0) + val.item()
207+
progress.stop_task(T)
208+
numBatches = self.dataset.getMixLen("Validate") // batchSize
209+
for key in iterLoss.keys():
210+
iterLoss[key] /= numBatches
211+
self.scheduler.step()
212+
return iterLoss
213+
214+
def run(self):
215+
minLoss = float("inf")
216+
numEpochs = self.trainParams.numEpochs
217+
for i in range(1, numEpochs + 1):
218+
trainLoss = self.trainIter(i)
219+
validateLoss = self.validateIter(i)
220+
if sum(validateLoss.values()) <= minLoss:
221+
minLoss, state = sum(trainLoss.values()), dict()
222+
state["model"] = self.model.module.state_dict()
223+
state["optimizer"] = self.optimizer.state_dict()
224+
state["scheduler"] = self.scheduler.state_dict()
225+
state["scaler"] = self.scaler.state_dict()
226+
state["epoch"], state["minLoss"] = i, minLoss
227+
torch.save(state, Path(self.workDir, f"{i:03}.pth"))
228+
globs = sorted(self.workDir.glob("*.pth"), reverse=True)
229+
while len(globs) > 3:
230+
globs.pop().unlink()
231+
health = dict()
232+
health.update(trainLoss)
233+
health.update(validateLoss)
234+
health["LR"] = self.optimizer.param_groups[0]["lr"]
235+
wandb.log(health)
236+
for key, val in health.items():
237+
console.log(f"{key:>12}={val:.7f}")
238+
wandb.finish()
239+
240+
241+
if __name__ == "__main__":
242+
T = Trainer()
243+
T.run()

0 commit comments

Comments
 (0)