From 6bf238c84500bc80968fcf9056a733bc36b9d31f Mon Sep 17 00:00:00 2001 From: Oguzhan Buyuksolak Date: Fri, 11 Oct 2024 20:39:24 +0300 Subject: [PATCH] Test with best checkpoint --- train.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index d48a8ba41..04fb6435b 100755 --- a/train.py +++ b/train.py @@ -738,7 +738,9 @@ def flush(self): # Finally run results on the test set if not args.dr: - test(test_loader, model, criterion, [pylogger], args=args) + test(test_loader, model, criterion, [pylogger], args=args, mode="ckpt") + test(test_loader, model, criterion, [pylogger], args=args, mode="best", + ckpt_name=checkpoint_name) if args.copy_output_folder and local_rank <= 0: msglogger.info('Copying output folder to: %s', args.copy_output_folder) @@ -1080,11 +1082,20 @@ def validate(val_loader, model, criterion, loggers, args, epoch=-1, tflogger=Non return _validate(val_loader, model, criterion, loggers, args, epoch, tflogger) -def test(test_loader, model, criterion, loggers, args): +def test(test_loader, model, criterion, loggers, args, mode='ckpt', ckpt_name=None): """Model Test""" assert msglogger is not None - msglogger.info('--- test ---------------------') - top1, top5, vloss, mAP = _validate(test_loader, model, criterion, loggers, args) + if mode == 'ckpt': + msglogger.info('--- test (ckpt) ---------------------') + top1, top5, vloss, mAP = _validate(test_loader, model, criterion, loggers, args) + else: + msglogger.info('--- test (best) ---------------------') + if ckpt_name is None: + best_ckpt_path = os.path.join(msglogger.logdir, 'best.pth.tar') + else: + best_ckpt_path = os.path.join(msglogger.logdir, ckpt_name + "_best.pth.tar") + model = apputils.load_lean_checkpoint(model, best_ckpt_path) + top1, top5, vloss, mAP = _validate(test_loader, model, criterion, loggers, args) return top1, top5, vloss, mAP