You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
I am getting these weird graphs in my tensorboard, it worked fine when I was doing model.cuda() manually , but when I shifted to the automated stuff using gpus = 1 and distributed backend = None.
I have posted this graph below:
My code of trainer and lightning module is as follows:
Trainer:
"""This file runs the main training/val loop, etc... using Lightning Trainer """frompytorch_lightningimportTrainerfromargparseimportArgumentParserfromresearch_seed.baselines.kd_baseline.kd_baselineimportKD_Cifarfrompytorch_lightning.loggingimportTestTubeLoggerdefmain(hparams):
# init modulemodel=KD_Cifar(hparams)
logger=TestTubeLogger(
save_dir=hparams.save_dir,
version=hparams.version# An existing version with a saved checkpoint
)
# most basic trainer, uses good defaultsifhparams.gpus>1:
dist='ddp'else:
dist=None# most basic trainer, uses good defaultstrainer=Trainer(
max_nb_epochs=hparams.epochs,
gpus=hparams.gpus,
nb_gpu_nodes=hparams.nodes,
early_stop_callback=None,
logger=logger,
default_save_path=hparams.save_dir,
distributed_backend=dist,
)
trainer.fit(model)
if__name__=='__main__':
parser=ArgumentParser(add_help=False)
parser.add_argument('--epochs', default=100, type=int, help='number of total epochs to run')
parser.add_argument('--gpus', type=int, default=1)
parser.add_argument('--nodes', type=int, default=1)
parser.add_argument('--save-dir', type=str, default='./lightning_logs')
parser.add_argument('--version', type=int, required=True, help="version number for experiment")
# give the module a chance to add own params# good practice to define LightningModule speficic params in the moduleparser=KD_Cifar.add_model_specific_args(parser)
# parse paramshparams=parser.parse_args()
main(hparams)
Lightning Module
"""This file defines the core research contribution """importosimporttorchfromtorch.nnimportfunctionalasFimporttorch.nnasnnimporttorchvisionfromtorch.utils.dataimportDataLoaderimporttorchvision.transformsastransformsfromargparseimportArgumentParserfromresearch_seed.baselines.model.model_factoryimportcreate_cnn_model, is_resnetimporttorch.optimasoptimimportpytorch_lightningasplimportnumpyasnpfromcollectionsimportOrderedDictdefstr2bool(v):
ifv.lower() in ('yes', 'true', 't', 'y', '1'):
returnTrueelse:
returnFalsedefload_model_chk(model, path):
chkp=torch.load(path)
new_state_dict=OrderedDict()
fork, vinchkp['state_dict'].items():
name=k[6:] # remove `model.`new_state_dict[name] =vmodel.load_state_dict(new_state_dict)
returnmodelclassKD_Cifar(pl.LightningModule):
def__init__(self, hparams):
super(KD_Cifar, self).__init__()
# not the best model...self.hparams=hparamsself.student=create_cnn_model(hparams.student_model, dataset=hparams.dataset)
self.teacher=create_cnn_model(hparams.teacher_model, dataset=hparams.dataset)
# Loading from checkpointself.teacher=load_model_chk(self.teacher, hparams.path_to_teacher)
self.teacher.eval()
self.student.train()
self.criterion=nn.CrossEntropyLoss()
self.train_step=0self.train_num_correct=0self.val_step=0self.val_num_correct=0defloss_fn_kd(self, outputs, labels, teacher_outputs):
""" Credits: https://github.com/peterliht/knowledge-distillation-pytorch/blob/e4c40132fed5a45e39a6ef7a77b15e5d389186f8/model/net.py#L100 Compute the knowledge-distillation (KD) loss given outputs, labels. "Hyperparameters": temperature and alpha NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher and student expects the input tensor to be log probabilities! See Issue #2 """alpha=self.hparams.alphaT=self.hparams.temperatureloss=nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
F.softmax(teacher_outputs/T, dim=1)) * (alpha*T*T) + \
F.cross_entropy(outputs, labels) * (1.-alpha)
returnlossdefforward(self, x, mode):
ifmode=='student':
returnself.student(x)
elifmode=='teacher':
returnself.teacher(x)
else:
raiseValueError("mode should be teacher or student")
deftraining_step(self, batch, batch_idx):
x, y=batchy_teacher=self.forward(x, 'teacher')
y_student=self.forward(x, 'student')
loss=self.loss_fn_kd(y_student, y, y_teacher)
pred=y_student.data.max(1, keepdim=True)[1]
self.train_step+=x.size(0)
self.train_num_correct+=pred.eq(y.data.view_as(pred)).cpu().sum()
return {
'loss': loss,
'log' : {
'train_loss' : loss.item(),
'train_accuracy': float(self.train_num_correct*100/self.train_step),
}
}
defvalidation_step(self, batch, batch_idx):
self.student.eval()
x, y=batchy_hat=self.forward(x, 'student')
val_loss=self.criterion(y_hat, y)
pred=y_hat.data.max(1, keepdim=True)[1]
self.val_step+=x.size(0)
self.val_num_correct+=pred.eq(y.data.view_as(pred)).cpu().sum()
return {
'val_loss': val_loss
}
defvalidation_end(self, outputs):
# OPTIONALavg_loss=torch.stack([x['val_loss'] forxinoutputs]).mean()
log_metrics= {
'val_avg_loss': avg_loss.item(),
'val_accuracy': float(self.val_num_correct*100/self.val_step)
}
self.scheduler.step(np.around(avg_loss.item(),2))
# reset logging stuffself.train_step=0self.train_num_correct=0self.val_step=0self.val_num_correct=0# back to trainingself.student.train()
return {'val_loss': avg_loss, 'log': log_metrics}
defconfigure_optimizers(self):
# REQUIRED# can return multiple optimizers and learning_rate schedulersifself.hparams.optim=='adam':
optimizer=torch.optim.Adam(self.student.parameters(), lr=self.hparams.learning_rate)
elifself.hparams.optim=='sgd':
optimizer=torch.optim.SGD(self.student.parameters(), nesterov=True, momentum=self.hparams.momentum,
weight_decay=self.hparams.weight_decay, lr=self.hparams.learning_rate)
else:
raiseValueError('No such optimizer, please use adam or sgd')
self.scheduler=optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min',patience=5,factor=0.5,verbose=True)
returnoptimizer@pl.data_loaderdeftrain_dataloader(self):
ifself.hparams.dataset=='cifar10'orself.hparams.dataset=='cifar100':
transform_train=transforms.Compose([
transforms.Pad(4, padding_mode="reflect"),
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
else:
raiseValueError('Dataset not supported !')
trainset=torchvision.datasets.CIFAR10(root=self.hparams.dataset_dir, train=True,
download=True, transform=transform_train)
ifself.hparams.gpus>1:
dist_sampler=torch.utils.data.distributed.DistributedSampler(trainset)
else:
dist_sampler=NonereturnDataLoader(trainset, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, sampler=dist_sampler)
@pl.data_loaderdefval_dataloader(self):
ifself.hparams.dataset=='cifar10'orself.hparams.dataset=='cifar100':
transform_test=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
else:
raiseValueError('Dataset not supported !')
valset=torchvision.datasets.CIFAR10(root=self.hparams.dataset_dir, train=False,
download=True, transform=transform_test)
ifself.hparams.gpus>1:
dist_sampler=torch.utils.data.distributed.DistributedSampler(valset)
else:
dist_sampler=NonereturnDataLoader(valset, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, sampler=dist_sampler)
@pl.data_loaderdeftest_dataloader(self):
ifself.hparams.dataset=='cifar10'orself.hparams.dataset=='cifar100':
transform_test=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
else:
raiseValueError('Dataset not supported !')
testset=torchvision.datasets.CIFAR10(root=self.hparams.dataset_dir, train=False,
download=True, transform=transform_test)
ifself.hparams.gpus>1:
dist_sampler=torch.utils.data.distributed.DistributedSampler(testset)
else:
dist_sampler=NonereturnDataLoader(testset, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers, sampler=dist_sampler)
@staticmethoddefadd_model_specific_args(parent_parser):
""" Specify the hyperparams for this LightningModule """# MODEL specificparser=ArgumentParser(parents=[parent_parser])
parser.add_argument('--dataset', default='cifar10', type=str, help='dataset. can be either cifar10 or cifar100')
parser.add_argument('--batch-size', default=128, type=int, help='batch_size')
parser.add_argument('--learning-rate', default=0.001, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, help='SGD momentum')
parser.add_argument('--weight-decay', default=1e-4, type=float, help='SGD weight decay (default: 1e-4)')
parser.add_argument('--dataset-dir', default='./data', type=str, help='dataset directory')
parser.add_argument('--optim', default='adam', type=str, help='Optimizer')
parser.add_argument('--num-workers', default=4, type=float, help='Num workers for data loader')
parser.add_argument('--student-model', default='resnet8', type=str, help='teacher student name')
parser.add_argument('--teacher-model', default='resnet110', type=str, help='teacher student name')
parser.add_argument('--path-to-teacher', default='', type=str, help='teacher chkp path')
parser.add_argument('--temperature', default=10, type=float, help='Temperature for knowledge distillation')
parser.add_argument('--alpha', default=0.7, type=float, help='Alpha for knowledge distillation')
returnparserIwouldbeverygratefulifsomeonecantellmewhatImdoingwrong.
The text was updated successfully, but these errors were encountered:
Cool setup - we also do complicated model loading/weights transfering setups.
The first thing that comes to mind is that you may be resetting the steps when loading the student, so when the student starts training it does that thing where it goes back in 'time'.
For your exact scenario, are you sure you're loading from the very latest checkpoint?
Are you doing any sort of step global step counting agnostic of epochs? if you can pass that into the logger you would always have consistency no matter what checkpoint you load from.
fwiw - try restarting tensorboard server. i have found that it frequently has a 'hangover' from the prior run that gives graph results that are reminiscent of yours. restarting and clearing log files always clears it up for me.
Describe the bug
I am getting these weird graphs in my tensorboard, it worked fine when I was doing model.cuda() manually , but when I shifted to the automated stuff using gpus = 1 and distributed backend = None.
I have posted this graph below:
My code of trainer and lightning module is as follows:
Trainer:
Lightning Module
The text was updated successfully, but these errors were encountered: