From 31b4b54dec956f574f4ec42d67d36b1f06a1b8c6 Mon Sep 17 00:00:00 2001 From: Tau-J <674106399@qq.com> Date: Thu, 9 Mar 2023 19:00:53 +0800 Subject: [PATCH 1/2] remove ema and message_hub --- tools/misc/publish_model.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tools/misc/publish_model.py b/tools/misc/publish_model.py index 393721ab06..71fd461e33 100644 --- a/tools/misc/publish_model.py +++ b/tools/misc/publish_model.py @@ -17,9 +17,19 @@ def parse_args(): def process_checkpoint(in_file, out_file): checkpoint = torch.load(in_file, map_location='cpu') + # remove optimizer for smaller file size if 'optimizer' in checkpoint: del checkpoint['optimizer'] + if 'message_hub' in checkpoint: + del checkpoint['message_hub'] + if 'ema_state_dict' in checkpoint: + del checkpoint['ema_state_dict'] + + for key in list(checkpoint['state_dict']): + if key.startswith('data_preprocessor'): + checkpoint['state_dict'].pop(key) + # if it is necessary to remove some sensitive data in checkpoint['meta'], # add the code here. torch.save(checkpoint, out_file) From b7b372b55c9e20324a043a7b1fef65a76abc5927 Mon Sep 17 00:00:00 2001 From: Tau-J <674106399@qq.com> Date: Thu, 9 Mar 2023 19:35:58 +0800 Subject: [PATCH 2/2] update publish script --- tools/misc/publish_model.py | 40 ++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/tools/misc/publish_model.py b/tools/misc/publish_model.py index 71fd461e33..4a8338fdbd 100644 --- a/tools/misc/publish_model.py +++ b/tools/misc/publish_model.py @@ -4,6 +4,9 @@ from datetime import date import torch +from mmengine.logging import print_log +from mmengine.utils import digit_version +from mmengine.utils.dl_utils import TORCH_VERSION def parse_args(): @@ -11,28 +14,37 @@ def parse_args(): description='Process a checkpoint to be published') parser.add_argument('in_file', help='input checkpoint filename') parser.add_argument('out_file', help='output checkpoint filename') + parser.add_argument( + '--save-keys', + nargs='+', + type=str, + default=['meta', 'state_dict'], + help='keys to save in published checkpoint (default: meta state_dict)') args = parser.parse_args() return args -def process_checkpoint(in_file, out_file): +def process_checkpoint(in_file, out_file, save_keys=['meta', 'state_dict']): checkpoint = torch.load(in_file, map_location='cpu') - # remove optimizer for smaller file size - if 'optimizer' in checkpoint: - del checkpoint['optimizer'] - if 'message_hub' in checkpoint: - del checkpoint['message_hub'] - if 'ema_state_dict' in checkpoint: - del checkpoint['ema_state_dict'] - - for key in list(checkpoint['state_dict']): - if key.startswith('data_preprocessor'): - checkpoint['state_dict'].pop(key) + # only keep `meta` and `state_dict` for smaller file size + ckpt_keys = list(checkpoint.keys()) + for k in ckpt_keys: + if k not in save_keys: + print_log( + f'Key `{k}` will be removed because it is not in ' + f'save_keys. If you want to keep it, ' + f'please set --save-keys.', + logger='current') + checkpoint.pop(k, None) # if it is necessary to remove some sensitive data in checkpoint['meta'], # add the code here. - torch.save(checkpoint, out_file) + + if digit_version(TORCH_VERSION) >= digit_version('1.6.0'): + torch.save(checkpoint, out_file, _use_new_zipfile_serialization=False) + else: + torch.save(checkpoint, out_file) sha = subprocess.check_output(['sha256sum', out_file]).decode() if out_file.endswith('.pth'): out_file_name = out_file[:-4] @@ -46,7 +58,7 @@ def process_checkpoint(in_file, out_file): def main(): args = parse_args() - process_checkpoint(args.in_file, args.out_file) + process_checkpoint(args.in_file, args.out_file, args.save_keys) if __name__ == '__main__':