diff --git a/tools/misc/publish_model.py b/tools/misc/publish_model.py index 393721ab06..71fd461e33 100644 --- a/tools/misc/publish_model.py +++ b/tools/misc/publish_model.py @@ -17,9 +17,19 @@ def parse_args(): def process_checkpoint(in_file, out_file): checkpoint = torch.load(in_file, map_location='cpu') + # remove optimizer for smaller file size if 'optimizer' in checkpoint: del checkpoint['optimizer'] + if 'message_hub' in checkpoint: + del checkpoint['message_hub'] + if 'ema_state_dict' in checkpoint: + del checkpoint['ema_state_dict'] + + for key in list(checkpoint['state_dict']): + if key.startswith('data_preprocessor'): + checkpoint['state_dict'].pop(key) + # if it is necessary to remove some sensitive data in checkpoint['meta'], # add the code here. torch.save(checkpoint, out_file)