-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
100 lines (85 loc) · 3.17 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
import torch
from tqdm import tqdm
from dataset import get_datasets
from BLIP_models.blip import blip_decoder
from config import Config
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
train_dataloader, val_dataloader, test_dataloader = get_datasets(
train_files=Config.train_datafiles,
val_files=Config.val_datafiles,
test_files=Config.test_datafiles,
val_limit=250,
)
model = blip_decoder(
pretrained=Config.pretrained,
image_size=Config.image_size,
vit="base",
vit_grad_ckpt=False,
vit_ckpt_layer=0,
prompt="",
max_tokenizer_length=Config.max_tokenizer_length,
).to(device)
# Freeze some of the vision encoder layers
for name, param in model.visual_encoder.named_parameters():
if "blocks.11" in name:
break
param.requires_grad = False
optimizer = torch.optim.AdamW(
params=model.parameters(),
lr=1e-05,
weight_decay=0.05,
)
scaler = torch.cuda.amp.GradScaler()
for epoch in range(50):
losses = []
model.train()
for image, prompt_raw in tqdm(train_dataloader, ncols=60):
optimizer.zero_grad()
with torch.autocast(device_type='cuda', dtype=torch.float16):
image = image.to(device)
loss = model(image, prompt_raw)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
losses.append(loss.item())
train_loss = np.nanmean(losses)
model.eval()
with torch.no_grad(), open(f"predictions/epoch_{epoch}.txt", "w") as f:
losses = []
for image, prompt_raw in tqdm(val_dataloader, ncols=60):
image = image.to(device)
# human readable prompt
pred_prompt = model.generate(
image,
sample=True,
min_length=1,
max_length=Config.max_tokenizer_length,
)
f.write(
f"(val) TRUE: {prompt_raw[0]}\n(val) PRED: {pred_prompt[0]}\n\n"
)
loss = model(image, prompt_raw)
losses.append(loss.item())
val_loss = np.nanmean(losses)
losses = []
for image, prompt_raw in tqdm(test_dataloader, ncols=60):
image = image.to(device)
# human readable prompt
pred_prompt = model.generate(
image,
sample=True,
min_length=1,
max_length=Config.max_tokenizer_length,
)
f.write(
f"(test) TRUE: {prompt_raw[0]}\n(test) PRED: {pred_prompt[0]}\n\n"
)
loss = model(image, prompt_raw)
losses.append(loss.item())
test_loss = np.nanmean(losses)
print(
f"train: {round(train_loss, 4)}\nval: {round(val_loss, 4)}\ntest: {round(test_loss, 4)}"
)
torch.save(model.state_dict(), f"checkpoints/epoch_{epoch}.pth")