Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add several speedup examples #3880

Merged
merged 7 commits into from
Jul 15, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions examples/model_compress/pruning/speedup/speedup_mobilnetv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from torchvision.models import mobilenet_v2
from nni.compression.pytorch import ModelSpeedup
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner


model = mobilenet_v2(pretrained=True)
dummy_input = torch.rand(8, 3, 416, 416)

cfg_list = [{'op_types':['Conv2d'], 'sparsity':0.5}]
pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
pruner.export_model('./model', './mask')
# need call _unwrap_model if you want run the speedup on the same model
pruner._unwrap_model()

# Speedup the nanodet
ms = ModelSpeedup(model, dummy_input, './mask')
ms.speedup_model()

model(dummy_input)
39 changes: 39 additions & 0 deletions examples/model_compress/pruning/speedup/speedup_nanodet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from nanodet.model.arch import build_model
from nanodet.util import cfg, load_config

from nni.compression.pytorch import ModelSpeedup
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner

"""
NanoDet model can be installed from https://github.com/RangiLyu/nanodet.git
"""

cfg_path = r"nanodet/config/nanodet-RepVGG-A0_416.yml"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a little strange that user has to install model from other repo for running example. If this model is not very complicated, can we add it into our model compression model files so that user can run it directly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to put their code into our repo. but we can provide concrete commands about how to prepare the code in the comment

load_config(cfg, cfg_path)

model = build_model(cfg.model)
dummy_input = torch.rand(8, 3, 416, 416)

op_names = []
# these three conv layers are followed by reshape-like functions
# that cannot be replaced, so we skip these three conv layers,
# you can also get such layers by `not_safe_to_prune` function
excludes = ['head.gfl_cls.0', 'head.gfl_cls.1', 'head.gfl_cls.2']
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
if name not in excludes:
op_names.append(name)

cfg_list = [{'op_types':['Conv2d'], 'sparsity':0.5, 'op_names':op_names}]
pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
pruner.export_model('./model', './mask')
# need call _unwrap_model if you want run the speedup on the same model
pruner._unwrap_model()

# Speedup the nanodet
ms = ModelSpeedup(model, dummy_input, './mask')
ms.speedup_model()

model(dummy_input)
36 changes: 36 additions & 0 deletions examples/model_compress/pruning/speedup/speedup_yolov3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from pytorchyolo import models

from nni.compression.pytorch import ModelSpeedup
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner, LevelPruner
from nni.compression.pytorch.utils import not_safe_to_prune
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved

# The Yolo can be downloaded at https://github.com/eriklindernoren/PyTorch-YOLOv3.git
prefix = '/home/user/PyTorch-YOLOv3' # replace this path with yours
# Load the YOLO model
model = models.load_model(
"%s/config/yolov3.cfg" % prefix,
"%s/yolov3.weights" % prefix)
model.eval()
dummy_input = torch.rand(8, 3, 320, 320)
model(dummy_input)
# Generate the config list for pruner
# Filter the layers that may not be able to prune
not_safe = not_safe_to_prune(model, dummy_input)
cfg_list = []
for name, module in model.named_modules():
if name in not_safe:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can leverage "exclude" here in the config, @J-shang

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exclude is much more elegant.

continue
if isinstance(module, torch.nn.Conv2d):
cfg_list.append({'op_types':['Conv2d'], 'sparsity':0.6, 'op_names':[name]})
# Prune the model
pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
pruner.export_model('./model', './mask')
pruner._unwrap_model()
# Speedup the model
ms = ModelSpeedup(model, dummy_input, './mask')

ms.speedup_model()
model(dummy_input)

9 changes: 8 additions & 1 deletion nni/compression/pytorch/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,14 @@ def randomize_tensor(tensor, start=1, end=100):

def not_safe_to_prune(model, dummy_input):
"""
Get the layers that are safe to prune(will not bring the shape conflict).
Get the layers that are not safe to prune(may bring the shape conflict).
For example, if the output tensor of a conv layer is directly followed by
a shape-dependent function(such as reshape/view), then this conv layer
may be not safe to be pruned. Pruning may change the output shape of
this conv layer and result in shape problems. This function find all the
layers that directly followed by the shape-dependent functions(view, reshape, etc).
If you run the inference after the speedup and run into a shape related error,
please exclude the layers returned by this function and try again.

Parameters
----------
Expand Down