From 1751fc55847a779836df11ae751227c2e1c20219 Mon Sep 17 00:00:00 2001 From: Ningxin Zheng Date: Sat, 26 Jun 2021 04:56:35 +0000 Subject: [PATCH 1/6] examples --- .../pruning/{ => speedup}/model_speedup.py | 0 .../pruning/speedup/speedup_mobilnetv2.py | 21 ++++++++++ .../pruning/speedup/speedup_nanodet.py | 39 +++++++++++++++++++ .../pruning/speedup/speedup_yolov3.py | 36 +++++++++++++++++ 4 files changed, 96 insertions(+) rename examples/model_compress/pruning/{ => speedup}/model_speedup.py (100%) create mode 100644 examples/model_compress/pruning/speedup/speedup_mobilnetv2.py create mode 100644 examples/model_compress/pruning/speedup/speedup_nanodet.py create mode 100644 examples/model_compress/pruning/speedup/speedup_yolov3.py diff --git a/examples/model_compress/pruning/model_speedup.py b/examples/model_compress/pruning/speedup/model_speedup.py similarity index 100% rename from examples/model_compress/pruning/model_speedup.py rename to examples/model_compress/pruning/speedup/model_speedup.py diff --git a/examples/model_compress/pruning/speedup/speedup_mobilnetv2.py b/examples/model_compress/pruning/speedup/speedup_mobilnetv2.py new file mode 100644 index 0000000000..db819298e9 --- /dev/null +++ b/examples/model_compress/pruning/speedup/speedup_mobilnetv2.py @@ -0,0 +1,21 @@ +import torch +from torchvision.models import mobilenet_v2 +from nni.compression.pytorch import ModelSpeedup +from nni.algorithms.compression.pytorch.pruning import L1FilterPruner + + +model = mobilenet_v2(pretrained=True) +dummy_input = torch.rand(8, 3, 416, 416) + +cfg_list = [{'op_types':['Conv2d'], 'sparsity':0.5}] +pruner = L1FilterPruner(model, cfg_list) +pruner.compress() +pruner.export_model('./model', './mask') +# need call _unwrap_model if you want run the speedup on the same model +pruner._unwrap_model() + +# Speedup the nanodet +ms = ModelSpeedup(model, dummy_input, './mask') +ms.speedup_model() + +model(dummy_input) \ No newline at end of file diff --git a/examples/model_compress/pruning/speedup/speedup_nanodet.py b/examples/model_compress/pruning/speedup/speedup_nanodet.py new file mode 100644 index 0000000000..2036141150 --- /dev/null +++ b/examples/model_compress/pruning/speedup/speedup_nanodet.py @@ -0,0 +1,39 @@ +import torch +from nanodet.model.arch import build_model +from nanodet.util import cfg, load_config + +from nni.compression.pytorch import ModelSpeedup +from nni.algorithms.compression.pytorch.pruning import L1FilterPruner + +""" +NanoDet model can be installed from https://github.com/RangiLyu/nanodet.git +""" + +cfg_path = r"nanodet-RepVGG-A0_416.yml" +load_config(cfg, cfg_path) + +model = build_model(cfg.model) +dummy_input = torch.rand(8, 3, 416, 416) + +op_names = [] +# these three conv layers are followed by reshape-like functions +# that cannot be replaced, so we skip these three conv layers, +# you can also get such layers by `not_safe_to_prune` function +excludes = ['head.gfl_cls.0', 'head.gfl_cls.1', 'head.gfl_cls.2'] +for name, module in model.named_modules(): + if isinstance(module, torch.nn.Conv2d): + if name not in excludes: + op_names.append(name) + +cfg_list = [{'op_types':['Conv2d'], 'sparsity':0.5, 'op_names':op_names}] +pruner = L1FilterPruner(model, cfg_list) +pruner.compress() +pruner.export_model('./model', './mask') +# need call _unwrap_model if you want run the speedup on the same model +pruner._unwrap_model() + +# Speedup the nanodet +ms = ModelSpeedup(model, dummy_input, './mask') +ms.speedup_model() + +model(dummy_input) \ No newline at end of file diff --git a/examples/model_compress/pruning/speedup/speedup_yolov3.py b/examples/model_compress/pruning/speedup/speedup_yolov3.py new file mode 100644 index 0000000000..802074be5f --- /dev/null +++ b/examples/model_compress/pruning/speedup/speedup_yolov3.py @@ -0,0 +1,36 @@ +import torch +from pytorchyolo import models + +from nni.compression.pytorch import ModelSpeedup +from nni.algorithms.compression.pytorch.pruning import L1FilterPruner, LevelPruner +from nni.compression.pytorch.utils import not_safe_to_prune + +# The Yolo can be downloaded at https://github.com/eriklindernoren/PyTorch-YOLOv3.git +prefix = '/home/user/PyTorch-YOLOv3' # replace this path with yours +# Load the YOLO model +model = models.load_model( + "%s/config/yolov3.cfg" % prefix, + "%s/yolov3.weights" % prefix) +model.eval() +dummy_input = torch.rand(8, 3, 320, 320) +model(dummy_input) +# Generate the config list for pruner +# Filter the layers that may not be able to prune +not_safe = not_safe_to_prune(model, dummy_input) +cfg_list = [] +for name, module in model.named_modules(): + if name in not_safe: + continue + if isinstance(module, torch.nn.Conv2d): + cfg_list.append({'op_types':['Conv2d'], 'sparsity':0.6, 'op_names':[name]}) +# Prune the model +pruner = L1FilterPruner(model, cfg_list) +pruner.compress() +pruner.export_model('./model', './mask') +pruner._unwrap_model() +# Speedup the model +ms = ModelSpeedup(model, dummy_input, './mask') + +ms.speedup_model() +model(dummy_input) + From 9a18c7df27cc2f3510923bea69f526bd49998593 Mon Sep 17 00:00:00 2001 From: Ningxin Zheng Date: Sat, 26 Jun 2021 05:30:57 +0000 Subject: [PATCH 2/6] update --- examples/model_compress/pruning/speedup/speedup_nanodet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/model_compress/pruning/speedup/speedup_nanodet.py b/examples/model_compress/pruning/speedup/speedup_nanodet.py index 2036141150..ea9c39f5d7 100644 --- a/examples/model_compress/pruning/speedup/speedup_nanodet.py +++ b/examples/model_compress/pruning/speedup/speedup_nanodet.py @@ -9,7 +9,7 @@ NanoDet model can be installed from https://github.com/RangiLyu/nanodet.git """ -cfg_path = r"nanodet-RepVGG-A0_416.yml" +cfg_path = r"nanodet/config/nanodet-RepVGG-A0_416.yml" load_config(cfg, cfg_path) model = build_model(cfg.model) From 7eefd885a5fd2ba676d749d894163e539ac4dca2 Mon Sep 17 00:00:00 2001 From: Ningxin Zheng Date: Wed, 14 Jul 2021 09:42:36 +0000 Subject: [PATCH 3/6] update docstring --- nni/compression/pytorch/utils/utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/nni/compression/pytorch/utils/utils.py b/nni/compression/pytorch/utils/utils.py index 6def03bc91..5e65bda1c9 100644 --- a/nni/compression/pytorch/utils/utils.py +++ b/nni/compression/pytorch/utils/utils.py @@ -70,7 +70,14 @@ def randomize_tensor(tensor, start=1, end=100): def not_safe_to_prune(model, dummy_input): """ - Get the layers that are safe to prune(will not bring the shape conflict). + Get the layers that are not safe to prune(may bring the shape conflict). + For example, if the output tensor of a conv layer is directly followed by + a shape-dependent function(such as reshape/view), then this conv layer + may be not safe to be pruned. Pruning may change the output shape of + this conv layer and result in shape problems. This function find all the + layers that directly followed by the shape-dependent functions(view, reshape, etc). + If you run the inference after the speedup and run into a shape related error, + please exclude the layers returned by this function and try again. Parameters ---------- From e082aefc99ee3e0b4249d4945499fc7f7501f5cb Mon Sep 17 00:00:00 2001 From: Ningxin Zheng Date: Wed, 14 Jul 2021 10:49:07 +0000 Subject: [PATCH 4/6] skip the speedup integration test on windows, too slow --- test/ut/sdk/test_model_speedup.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/ut/sdk/test_model_speedup.py b/test/ut/sdk/test_model_speedup.py index d564c3274a..a106df2752 100644 --- a/test/ut/sdk/test_model_speedup.py +++ b/test/ut/sdk/test_model_speedup.py @@ -361,6 +361,9 @@ def test_speedup_integration_big(self): self.speedup_integration(model_list) def speedup_integration(self, model_list, speedup_cfg=None): + # Note: hack trick, may be updated in the future + if 'win' in sys.platform or 'Win'in sys.platform: + print('Skip test_speedup_integration on windows due to memory limit!') Gen_cfg_funcs = [generate_random_sparsity, generate_random_sparsity_v2] # for model_name in ['vgg16', 'resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121', From 80fb98ecbb2b91471381f6e641fe0a2f23846b9e Mon Sep 17 00:00:00 2001 From: Ningxin Zheng Date: Wed, 14 Jul 2021 11:02:21 +0000 Subject: [PATCH 5/6] fix sys.path --- examples/model_compress/pruning/speedup/model_speedup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/model_compress/pruning/speedup/model_speedup.py b/examples/model_compress/pruning/speedup/model_speedup.py index bec053542a..0ac3758344 100644 --- a/examples/model_compress/pruning/speedup/model_speedup.py +++ b/examples/model_compress/pruning/speedup/model_speedup.py @@ -7,7 +7,7 @@ from torchvision import datasets, transforms import sys -sys.path.append('../models') +sys.path.append('../../models') from cifar10.vgg import VGG from mnist.lenet import LeNet From 664aecab88adf0246776cfe63441ceb92df5dbfc Mon Sep 17 00:00:00 2001 From: Ningxin Zheng Date: Wed, 14 Jul 2021 13:28:10 +0000 Subject: [PATCH 6/6] update --- test/ut/sdk/test_model_speedup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/ut/sdk/test_model_speedup.py b/test/ut/sdk/test_model_speedup.py index a106df2752..f61bc93064 100644 --- a/test/ut/sdk/test_model_speedup.py +++ b/test/ut/sdk/test_model_speedup.py @@ -364,6 +364,7 @@ def speedup_integration(self, model_list, speedup_cfg=None): # Note: hack trick, may be updated in the future if 'win' in sys.platform or 'Win'in sys.platform: print('Skip test_speedup_integration on windows due to memory limit!') + return Gen_cfg_funcs = [generate_random_sparsity, generate_random_sparsity_v2] # for model_name in ['vgg16', 'resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121',