From 8e02a6329c824454745f91899b5e7e1d549a2ac9 Mon Sep 17 00:00:00 2001 From: Tau <674106399@qq.com> Date: Fri, 10 Mar 2023 17:09:32 +0800 Subject: [PATCH] [Feature] Remove ema and message_hub when publishing models (#2036) --- tools/misc/publish_model.py | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/tools/misc/publish_model.py b/tools/misc/publish_model.py index 393721ab06..4a8338fdbd 100644 --- a/tools/misc/publish_model.py +++ b/tools/misc/publish_model.py @@ -4,6 +4,9 @@ from datetime import date import torch +from mmengine.logging import print_log +from mmengine.utils import digit_version +from mmengine.utils.dl_utils import TORCH_VERSION def parse_args(): @@ -11,18 +14,37 @@ def parse_args(): description='Process a checkpoint to be published') parser.add_argument('in_file', help='input checkpoint filename') parser.add_argument('out_file', help='output checkpoint filename') + parser.add_argument( + '--save-keys', + nargs='+', + type=str, + default=['meta', 'state_dict'], + help='keys to save in published checkpoint (default: meta state_dict)') args = parser.parse_args() return args -def process_checkpoint(in_file, out_file): +def process_checkpoint(in_file, out_file, save_keys=['meta', 'state_dict']): checkpoint = torch.load(in_file, map_location='cpu') - # remove optimizer for smaller file size - if 'optimizer' in checkpoint: - del checkpoint['optimizer'] + + # only keep `meta` and `state_dict` for smaller file size + ckpt_keys = list(checkpoint.keys()) + for k in ckpt_keys: + if k not in save_keys: + print_log( + f'Key `{k}` will be removed because it is not in ' + f'save_keys. If you want to keep it, ' + f'please set --save-keys.', + logger='current') + checkpoint.pop(k, None) + # if it is necessary to remove some sensitive data in checkpoint['meta'], # add the code here. - torch.save(checkpoint, out_file) + + if digit_version(TORCH_VERSION) >= digit_version('1.6.0'): + torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False) + else: + torch.save(checkpoint, out_file) sha = subprocess.check_output(['sha256sum', out_file]).decode() if out_file.endswith('.pth'): out_file_name = out_file[:-4] @@ -36,7 +58,7 @@ def process_checkpoint(in_file, out_file): def main(): args = parse_args() - process_checkpoint(args.in_file, args.out_file) + process_checkpoint(args.in_file, args.out_file, args.save_keys) if __name__ == '__main__':