-
Notifications
You must be signed in to change notification settings - Fork 416
/
train_single_gpu_flash_atten.py
83 lines (69 loc) · 2.48 KB
/
train_single_gpu_flash_atten.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
from torchvision import transforms
from torchvision.datasets import CIFAR100
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
from flagai.trainer import Trainer
from flagai.auto_model.auto_loader import AutoLoader
lr = 2e-5
n_epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env_type = "pytorch"
trainer = Trainer(
env_type=env_type,
experiment_name="vit-cifar100-single_gpu",
batch_size=150,
num_gpus=1,
gradient_accumulation_steps=1,
lr=lr,
weight_decay=1e-5,
epochs=n_epochs,
log_interval=10,
eval_interval=1000,
load_dir=None,
pytorch_device=device,
save_dir="checkpoints_vit_cifar100_single_gpu",
save_interval=1000,
num_checkpoints=1,
fp16 = True
)
def build_cifar():
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.Resize(224),
transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = CIFAR100(root="/home/ldwang/Downloads/data", train=True, download=True, transform=transform_train)
test_dataset = CIFAR100(root="/home/ldwang/Downloads/data", train=False, download=True, transform=transform_test)
return train_dataset, test_dataset
def collate_fn(batch):
images = torch.stack([b[0] for b in batch])
if trainer.fp16:
images = images.half()
labels = [b[1] for b in batch]
labels = torch.tensor(labels).long()
return {"images": images, "labels": labels}
def validate(logits, labels, meta=None):
_, predicted = logits.max(1)
total = labels.size(0)
correct = predicted.eq(labels).sum().item()
return correct / total
if __name__ == '__main__':
loader = AutoLoader(task_name="classification",
model_name="vit-base-p16-224",
num_classes=100)
model = loader.get_model()
train_dataset, val_dataset = build_cifar()
trainer.train(model,
optimizer=None,
train_dataset=train_dataset,
valid_dataset=val_dataset,
metric_methods=[["accuracy", validate]],
collate_fn=collate_fn)