Skip to content

Commit

Permalink
Test with best checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
oguzhanbsolak committed Oct 11, 2024
1 parent 5f467b7 commit 6bf238c
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,9 @@ def flush(self):

# Finally run results on the test set
if not args.dr:
test(test_loader, model, criterion, [pylogger], args=args)
test(test_loader, model, criterion, [pylogger], args=args, mode="ckpt")
test(test_loader, model, criterion, [pylogger], args=args, mode="best",
ckpt_name=checkpoint_name)

if args.copy_output_folder and local_rank <= 0:
msglogger.info('Copying output folder to: %s', args.copy_output_folder)
Expand Down Expand Up @@ -1080,11 +1082,20 @@ def validate(val_loader, model, criterion, loggers, args, epoch=-1, tflogger=Non
return _validate(val_loader, model, criterion, loggers, args, epoch, tflogger)


def test(test_loader, model, criterion, loggers, args):
def test(test_loader, model, criterion, loggers, args, mode='ckpt', ckpt_name=None):
"""Model Test"""
assert msglogger is not None
msglogger.info('--- test ---------------------')
top1, top5, vloss, mAP = _validate(test_loader, model, criterion, loggers, args)
if mode == 'ckpt':
msglogger.info('--- test (ckpt) ---------------------')
top1, top5, vloss, mAP = _validate(test_loader, model, criterion, loggers, args)
else:
msglogger.info('--- test (best) ---------------------')
if ckpt_name is None:
best_ckpt_path = os.path.join(msglogger.logdir, 'best.pth.tar')
else:
best_ckpt_path = os.path.join(msglogger.logdir, ckpt_name + "_best.pth.tar")
model = apputils.load_lean_checkpoint(model, best_ckpt_path)
top1, top5, vloss, mAP = _validate(test_loader, model, criterion, loggers, args)

return top1, top5, vloss, mAP

Expand Down

0 comments on commit 6bf238c

Please sign in to comment.