Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
fix flops counter bug in auto_pruners_torch.py (#3265)
Browse files Browse the repository at this point in the history
  • Loading branch information
linbinskn authored Jan 6, 2021
1 parent 99bc459 commit decb78e
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions examples/model_compress/pruning/auto_pruners_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def get_trained_model_optimizer(args, device, train_loader, val_loader, criterio

if args.save_model:
torch.save(state_dict, os.path.join(args.experiment_data_dir, 'model_trained.pth'))
print('Model trained saved to %s', args.experiment_data_dir)
print('Model trained saved to %s' % args.experiment_data_dir)

return model, optimizer

Expand Down Expand Up @@ -312,7 +312,7 @@ def evaluator(model):
if args.save_model:
pruner.export_model(
os.path.join(args.experiment_data_dir, 'model_masked.pth'), os.path.join(args.experiment_data_dir, 'mask.pth'))
print('Masked model saved to %s', args.experiment_data_dir)
print('Masked model saved to %s' % args.experiment_data_dir)

# model speed up
if args.speed_up:
Expand All @@ -336,7 +336,7 @@ def evaluator(model):
result['performance']['speedup'] = evaluation_result

torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_speed_up.pth'))
print('Speed up model saved to %s', args.experiment_data_dir)
print('Speed up model saved to %s' % args.experiment_data_dir)
flops, params, _ = count_flops_params(model, get_input_size(args.dataset))
result['flops']['speedup'] = flops
result['params']['speedup'] = params
Expand Down Expand Up @@ -367,7 +367,7 @@ def evaluator(model):
torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_fine_tuned.pth'))

print('Evaluation result (fine tuned): %s' % best_acc)
print('Fined tuned model saved to %s', args.experiment_data_dir)
print('Fined tuned model saved to %s' % args.experiment_data_dir)
result['performance']['finetuned'] = best_acc

with open(os.path.join(args.experiment_data_dir, 'result.json'), 'w+') as f:
Expand Down

0 comments on commit decb78e

Please sign in to comment.