From 0e746ebd419cb4a7505e1699003f3f14935b6cc2 Mon Sep 17 00:00:00 2001 From: Akshay Kulkarni Date: Sun, 31 May 2020 12:24:37 +0530 Subject: [PATCH 1/2] added initial atkd files (part 1) --- .../experiments/attention_transfer_kd.py | 76 +++++++++++++++++++ image_classification/experiments/trainer.py | 24 +++--- 2 files changed, 91 insertions(+), 9 deletions(-) create mode 100644 image_classification/experiments/attention_transfer_kd.py diff --git a/image_classification/experiments/attention_transfer_kd.py b/image_classification/experiments/attention_transfer_kd.py new file mode 100644 index 0000000..f6187da --- /dev/null +++ b/image_classification/experiments/attention_transfer_kd.py @@ -0,0 +1,76 @@ +import warnings +warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional") +from comet_ml import Experiment +from fastai.vision import * +import torch +import argparse +import os +from image_classification.arguments import get_args +from image_classification.datasets.dataset import get_dataset +from image_classification.utils.utils import * +from image_classification.models.custom_resnet import * +from trainer import * + + +args = get_args(description='Attention Transfer KD', mode='train') +expt = 'attention-kd' + +torch.manual_seed(args.seed) +if args.gpu != 'cpu': + args.gpu = int(args.gpu) + torch.cuda.set_device(args.gpu) + torch.cuda.manual_seed(args.seed) + +hyper_params = { + "dataset": args.dataset, + "model": args.model, + "stage": 0, + "num_classes": 10, + "batch_size": 64, + "num_epochs": args.epoch, + "learning_rate": 1e-4, + "seed": args.seed, + "percentage":args.percentage, + "gpu": args.gpu, +} + +data = get_dataset(dataset=hyper_params['dataset'], + batch_size=hyper_params['batch_size'], + percentage=args.percentage) + +learn, net = get_model(hyper_params['model'], hyper_params['dataset'], data, teach=True) +learn.model, net = learn.model.to(args.gpu), net.to(args.gpu) + +teacher = learn.model + +sf_student, sf_teacher = get_features(net, teacher, experiment=expt) + +project_name = expt + '-' + hyper_params['model'] + '-' + hyper_params['dataset'] +experiment = Experiment(api_key="1jNZ1sunRoAoI2TyremCNnYLO", project_name = project_name, workspace="akshaykvnit") +experiment.log_parameters(hyper_params) + +optimizer = torch.optim.Adam(net.parameters(), lr = hyper_params["learning_rate"]) +loss_function2 = nn.MSELoss() +loss_function = nn.CrossEntropyLoss() +savename = get_savename(hyper_params, experiment=expt) +best_val_acc = 0 + +for epoch in range(hyper_params['num_epochs']): + student, train_loss, val_loss, val_acc, best_val_acc = train( + net, + teacher, + data, + sf_teacher, + sf_student, + loss_function, + loss_function2, + optimizer, + hyper_params, + epoch, + savename, + best_val_acc, + expt=expt + ) + experiment.log_metric("train_loss", train_loss) + experiment.log_metric("val_loss", val_loss) + experiment.log_metric("val_acc", val_acc * 100) diff --git a/image_classification/experiments/trainer.py b/image_classification/experiments/trainer.py index ffcceb6..1394202 100644 --- a/image_classification/experiments/trainer.py +++ b/image_classification/experiments/trainer.py @@ -5,7 +5,7 @@ from image_classification.utils.utils import * -def train(student, teacher, data, sf_teacher, sf_student, loss_function, loss_function2, optimizer, hyper_params, epoch, savename, best_val_acc): +def train(student, teacher, data, sf_teacher, sf_student, loss_function, loss_function2, optimizer, hyper_params, epoch, savename, best_val_acc, expt=None): loop = tqdm(data.train_dl) max_val_acc = best_val_acc gpu = hyper_params['gpu'] @@ -33,6 +33,12 @@ def train(student, teacher, data, sf_teacher, sf_student, loss_function, loss_fu # stage training (and assuming sf_teacher and sf_student are given) elif loss_function2 is None: loss = loss_function(sf_student[hyper_params['stage']].features, sf_teacher[hyper_params['stage']].features) + # attention transfer KD + elif expt == 'attention-kd': + loss = loss_function(y_pred, labels) + for k in range(4): + loss += loss_function2(at(sf_student[k].features), at(sf_teacher[k].features)) + loss /= 5 # 2 loss functions and student and teacher are given -> simultaneous training else: loss = loss_function(y_pred, labels) @@ -64,41 +70,41 @@ def train(student, teacher, data, sf_teacher, sf_student, loss_function, loss_fu else: images = torch.autograd.Variable(images).float() labels = torch.autograd.Variable(labels) - + y_pred = student(images) if teacher is not None: _ = teacher(images) - + # classifier training if teacher is None: loss = loss_function(y_pred, labels) y_pred = F.log_softmax(y_pred, dim = 1) _, pred_ind = torch.max(y_pred, 1) - + total += labels.size(0) correct += (pred_ind == labels).sum().item() # stage training elif loss_function2 is None: loss = loss_function(sf_student[hyper_params['stage']].features, sf_teacher[hyper_params['stage']].features) - # simultaneous training + # simultaneous training or attention KD else: loss = loss_function(y_pred, labels) y_pred = F.log_softmax(y_pred, dim = 1) _, pred_ind = torch.max(y_pred, 1) - + total += labels.size(0) correct += (pred_ind == labels).sum().item() val.append(loss.item()) - + val_loss = (sum(val) / len(val)) if total > 0: val_acc = correct / total else: val_acc = None - + # classifier training if teacher is None: if (val_acc * 100) > max_val_acc : @@ -111,7 +117,7 @@ def train(student, teacher, data, sf_teacher, sf_student, loss_function, loss_fu print(f'lower valid loss obtained: {val_loss}') max_val_acc = val_loss torch.save(student.state_dict(), savename) - # simultaneous training + # simultaneous training or attention kd else: if (val_acc * 100) > max_val_acc : print(f'higher valid acc obtained: {val_acc * 100}') From b54e8baa1fd321a6444167ed0558fb93c761971d Mon Sep 17 00:00:00 2001 From: Akshay Kulkarni Date: Sun, 31 May 2020 12:25:26 +0530 Subject: [PATCH 2/2] added initial atkd files (part 2) --- image_classification/utils/utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/image_classification/utils/utils.py b/image_classification/utils/utils.py index 53c7706..96740a5 100644 --- a/image_classification/utils/utils.py +++ b/image_classification/utils/utils.py @@ -22,6 +22,9 @@ def get_features(student, teacher, experiment): elif experiment == 'traditional-kd': sf_teacher = [SaveFeatures(m) for m in [teacher[0][5]]] sf_student = [SaveFeatures(m) for m in [student.layer2]] + elif experiment == 'attention-kd': + sf_teacher = [SaveFeatures(m) for m in [teacher[0][4], teacher[0][5], teacher[0][6], teacher[0][7]]] + sf_student = [SaveFeatures(m) for m in [student.layer1, student.layer2, student.layer3, student.layer4]] return sf_student, sf_teacher @@ -58,7 +61,7 @@ def freeze_student(model, hyper_params, experiment): def get_savename(hyper_params, experiment): - assert experiment in ['stagewise-kd', 'traditional-kd', 'simultaneous-kd', 'no-teacher'] + assert experiment in ['stagewise-kd', 'traditional-kd', 'simultaneous-kd', 'attention-kd', 'no-teacher'] dsize = 'full_data' if hyper_params['percentage'] is None else f"less_data{str(hyper_params['percentage'])}" @@ -122,3 +125,7 @@ def get_accuracy(dataloader, net): correct += (pred_ind == labels).sum().item() return (correct / total) + + +def at(x): + return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))