Skip to content

Latest commit

 

History

History
445 lines (361 loc) · 26.8 KB

README_CN.md

File metadata and controls

445 lines (361 loc) · 26.8 KB

Towards Any Structural Pruning

Test Status Tested PyTorch Versions License Downloads Latest Version Open In Colab arXiv

[Documentation & Tutorials] [FAQ]

Torch-Pruning (TP) 是一个用于结构化剪枝的库,具有以下特点:

有关更多技术细节,请参阅我们在CVPR'23的论文:

DepGraph: Towards Any Structural Pruning
Gongfan FangXinyin MaMingli SongMichael Bi MiXinchao Wang
Learning and Vision Lab、National University of Singapore

更新:

主要功能:

联系我们:

如果在使用库或论文时遇到任何问题,请随时提交 Issue
或者加入我们的Discord或微信交流群进行交流:

Table of Contents

Installation

Torch-Pruning兼容PyTorch 1.x和2.x版本。强烈推荐使用PyTorch 2.0。

pip install torch-pruning 

or

git clone https://github.com/VainF/Torch-Pruning.git

Quickstart

这里我们提供一个Torch-Pruning的快速入门。更多详细的解释可以在 Tutorals 中找到。

How It Works

在结构化剪枝中,“组”被定义为深度网络中最小的可移除单元。大多数组由相互依赖的多层组成,需要一起剪枝以保持结果结构的完整性。然而,深度网络的层之间通常存在复杂的依赖关系,使得结构化剪枝成为一项挑战。这项工作通过引入一种自动化机制“DepGraph”来解决这个问题。DepGraph实现了自动化的参数分组,支持各种类型深度网络的剪枝。

A Minimal Example of DepGraph

请确保你的模型已设置为启用AutoGrad,且没有使用 torch.no_grad 或 .requires_grad=False。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()

# 1. 为resnet18构建依赖图
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 2. 为model.conv1分组耦合层
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )

# 3. 执行剪枝
if DG.check_pruning_group(group): # 避免完全剪枝,即channels=0
    group.prune()

# 4. 保存与加载
model.zero_grad() # 清除梯度,避免较大的checkpoint
torch.save(model, 'model.pth') # 我们不能使用.state_dict进行存储,这是因为剪枝导致模型结构发生变化。
model = torch.load('model.pth') # 加载剪枝后的模型

上面的示例展示了使用DepGraph的基本剪枝流程。目标层resnet.conv1与多个层耦合,因此在结构化剪枝期间需要同时移除这些层。为了观察剪枝操作的级联效应,我们可以打印这些组并观察一个剪枝操作如何“触发”其他操作。在后续输出中,“A => B”表示剪枝操作“A”触发了剪枝操作“B”。组[0]指的是在DG.get_pruning_group中的剪枝起始位置。有关分组的更多详细信息,请参阅 Wiki - DepGraph & Group.

--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs=[2, 6, 9] (Pruning Root)
[1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), idxs=[2, 6, 9]
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), idxs=[2, 6, 9]
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), idxs=[2, 6, 9]
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), idxs=[2, 6, 9]
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), idxs=[2, 6, 9]
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), idxs=[2, 6, 9]
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs=[2, 6, 9]
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
--------------------------------

How to scan all groups (Advanced):

我们可以使用DG.get_all_groups(ignored_layers, root_module_types)来顺序扫描和剪枝所有组。每个组将从一个类型匹配root_module_types参数的层开始。请注意,DG.get_all_groups仅负责分组。如果需要剪枝,我们应使用额外的group.prune(idxs=idxs)来指定剪枝索引。当您想要实现自己的剪枝算法时,可以使用该函数扫描所有组,并移除不需要的通道/维度/参数。

for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):
    # handle groups in sequential order
    idxs = [2,4,6] # your pruning indices
    group.prune(idxs=idxs)
    print(group)

High-level Pruners

使用DepGraph,我们在该项目中开发了几种高级剪枝器,以简化剪枝过程。通过指定所需的通道剪枝比例,剪枝器将扫描所有可剪枝的组,估算重要性,剪枝整个模型,并使用您自己的训练代码进行微调。有关此过程的详细信息,请参阅 this tutorial,该教程展示了如何从头实现一个 slimming 剪枝器。此外,一个更实际的示例可在 benchmarks/main.py 中找到。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True)
example_inputs = torch.randn(1, 3, 224, 224)

# 1. Importance criterion
imp = tp.importance.GroupTaylorImportance() # or GroupNormImportance(p=2), GroupHessianImportance(), etc.

# 2. Initialize a pruner with the model and the importance criterion
ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

pruner = tp.pruner.MetaPruner( # We can always choose MetaPruner if sparse training is not required.
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    # pruning_ratio_dict = {model.conv1: 0.2, model.layer2: 0.8}, # customized pruning ratios for layers or blocks
    ignored_layers=ignored_layers,
)

# 3. Prune & finetune the model
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
if isinstance(imp, tp.importance.GroupTaylorImportance):
    # Taylor expansion requires gradients for importance estimation
    loss = model(example_inputs).sum() # A dummy loss, please replace this line with your loss function and data!
    loss.backward() # before pruner.step()

pruner.step()
macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
# finetune the pruned model here
# finetune(model)
# ...

Global Pruning

通过选择全局剪枝选项(global_pruning=True),剪枝器将根据参数的重要性全局排名,自适应地分配不同层的稀疏度。虽然这种策略可以提供性能优势,但它也可能过度剪枝某些层,导致整体性能显著下降。对于不熟悉的模型/算法,建议从global_pruning=False开始。

Sparse Training

一些剪枝器如 BNScalePrunerGroupNormPruner 支持稀疏训练。这可以通过在训练前插入pruner.update_regularizer(),并在loss.backward()optimizer.step()之间插入pruner.regularize(model)来轻松实现。剪枝器会将正则化梯度累积到.grad中。

for epoch in range(epochs):
    model.train()
    pruner.update_regularizer() # <== initialize regularizer
    for i, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = F.cross_entropy(out, target)
        loss.backward() # after loss.backward()
        pruner.regularize(model) # <== for sparse training
        optimizer.step() # before optimizer.step()

Interactive Pruning

所有高级剪枝器都支持交互式剪枝。你可以使用pruner.step(interactive=True)方法获取所有待剪枝的组,并通过手动调用group.prune()进行交互式剪枝。该功能可以用于控制或监控剪枝流程,例如手动修改剪枝位置、或者打印被剪枝层的信息。

for i in range(iterative_steps):
    for group in pruner.step(interactive=True): # Warning: groups must be handled sequentially. Do not keep them as a list.
        print(group) 
        # do whatever you like with the group 
        dep, idxs = group[0] # get the idxs
        target_module = dep.target.module # get the root module
        pruning_fn = dep.handler # get the pruning function
        group.prune()
        # group.prune(idxs=[0, 2, 6]) # It is even possible to change the pruning behaviour with the idxs parameter
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    # finetune your model here
    # finetune(model)
    # ...

Soft Pruning

利用interactive=True可以轻松实现软剪枝,它将参数置零而不移除它们。一个示例可以在 tests/test_soft_pruning.py 中找到。

Group-level Pruning

借助DepGraph,可以设计一些“组级”标准来估计整个组的重要性,而不是单个层的重要性。这一特性还可以用于稀疏耦合层,使所有待剪枝的参数一致稀疏。在Torch-Pruning中,所有剪枝器都在组级工作。查看以下结果,了解分组如何提高剪枝性能。

  • Pruning a ResNet50 pre-trained on ImageNet-1K without fine-tuning.
  • Pruning a Vision Transformer pre-trained on ImageNet-1K without fine-tuning.

Modify module attributes or forward function

在某些实现中,模型前向可能依赖于一些静态属性。例如在timm的convformer_s18中,我们有:

class Scale(nn.Module):
    """
    Scale vector by element multiplications.
    """

    def __init__(self, dim, init_value=1.0, trainable=True, use_nchw=True):
        super().__init__()
        self.shape = (dim, 1, 1) if use_nchw else (dim,) # static shape, which should be updated after pruning
        self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable)

    def forward(self, x):
        return x * self.scale.view(self.shape) # => x * self.scale.view(-1, 1, 1), this works for pruning

其中forward函数在前向时依赖于self.shape。但在剪枝后,真正的self.shape已经改变,需要手动更新。

Save and Load

Method 1:

以下脚本将整个模型对象(结构+权重)保存为 model.pth

model.zero_grad() # Remove gradients
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the pruned model

Method 2 (Experimental Features):

另一种可选方式是使用tp.state_dicttp.load_state_dict从未剪枝的模型重新创建剪枝后的模型。

# save the pruned state_dict, which includes both pruned parameters and modified attributes
state_dict = tp.state_dict(pruned_model) # the pruned model, e.g., a resnet-18-half
torch.save(state_dict, 'pruned.pth')

# create a new model, e.g. resnet18
new_model = resnet18().eval()

# load the pruned state_dict into the unpruned model.
loaded_state_dict = torch.load('pruned.pth', map_location='cpu')
tp.load_state_dict(new_model, state_dict=loaded_state_dict)

参阅 tests/test_serialization.py 了解ViT示例。在此示例中,我们将剪枝模型并修改一些属性如model.hidden_dims

Low-level Pruning Functions

尽管可以使用低级函数手动剪枝你的模型,但由于需要精细管理依赖关系,这种方法可能繁琐且耗时。因此,我们强烈推荐使用前述高级剪枝器来简化剪枝过程。这些剪枝器提供了更便捷、高效的模型剪枝方式。要手动剪枝ResNet-18的model.conv1,剪枝流程应如下所示:

tp.prune_conv_out_channels( model.conv1, idxs=[2,6,9] )

# fix the broken dependencies manually
tp.prune_batchnorm_out_channels( model.bn1, idxs=[2,6,9] )
tp.prune_conv_in_channels( model.layer2[0].conv1, idxs=[2,6,9] )
...

The following pruning functions are available:

'prune_conv_out_channels',
'prune_conv_in_channels',
'prune_depthwise_conv_out_channels',
'prune_depthwise_conv_in_channels',
'prune_batchnorm_out_channels',
'prune_batchnorm_in_channels',
'prune_linear_out_channels',
'prune_linear_in_channels',
'prune_prelu_out_channels',
'prune_prelu_in_channels',
'prune_layernorm_out_channels',
'prune_layernorm_in_channels',
'prune_embedding_out_channels',
'prune_embedding_in_channels',
'prune_parameter_out_channels',
'prune_parameter_in_channels',
'prune_multihead_attention_out_channels',
'prune_multihead_attention_in_channels',
'prune_groupnorm_out_channels',
'prune_groupnorm_in_channels',
'prune_instancenorm_out_channels',
'prune_instancenorm_in_channels',

Customized Layers

请参阅 examples/transformers/prune_hf_swin.py,该示例为自定义模块SwinPatchMerging实现了一个新的剪枝器。一个更简单的示例可以在 tests/test_customized_layer.py 中找到。

Benchmarks

Our results on {ResNet-56 / CIFAR-10 / 2.00x}

Method Base (%) Pruned (%) $\Delta$ Acc (%) Speed Up
NIPS [1] - - -0.03 1.76x
Geometric [2] 93.59 93.26 -0.33 1.70x
Polar [3] 93.80 93.83 +0.03 1.88x
CP [4] 92.80 91.80 -1.00 2.00x
AMC [5] 92.80 91.90 -0.90 2.00x
HRank [6] 93.26 92.17 -0.09 2.00x
SFP [7] 93.59 93.36 +0.23 2.11x
ResRep [8] 93.71 93.71 +0.00 2.12x
Ours-L1 93.53 92.93 -0.60 2.12x
Ours-BN 93.53 93.29 -0.24 2.12x
Ours-Group 93.53 93.77 +0.38 2.13x

Latency

Latency test on ResNet-50, Batch Size=64.

[Iter 0]        Pruning ratio: 0.00,         MACs: 4.12 G,   Params: 25.56 M,        Latency: 45.22 ms +- 0.03 ms
[Iter 1]        Pruning ratio: 0.05,         MACs: 3.68 G,   Params: 22.97 M,        Latency: 46.53 ms +- 0.06 ms
[Iter 2]        Pruning ratio: 0.10,         MACs: 3.31 G,   Params: 20.63 M,        Latency: 43.85 ms +- 0.08 ms
[Iter 3]        Pruning ratio: 0.15,         MACs: 2.97 G,   Params: 18.36 M,        Latency: 41.22 ms +- 0.10 ms
[Iter 4]        Pruning ratio: 0.20,         MACs: 2.63 G,   Params: 16.27 M,        Latency: 39.28 ms +- 0.20 ms
[Iter 5]        Pruning ratio: 0.25,         MACs: 2.35 G,   Params: 14.39 M,        Latency: 34.60 ms +- 0.19 ms
[Iter 6]        Pruning ratio: 0.30,         MACs: 2.02 G,   Params: 12.46 M,        Latency: 33.38 ms +- 0.27 ms
[Iter 7]        Pruning ratio: 0.35,         MACs: 1.74 G,   Params: 10.75 M,        Latency: 31.46 ms +- 0.20 ms
[Iter 8]        Pruning ratio: 0.40,         MACs: 1.50 G,   Params: 9.14 M,         Latency: 29.04 ms +- 0.19 ms
[Iter 9]        Pruning ratio: 0.45,         MACs: 1.26 G,   Params: 7.68 M,         Latency: 27.47 ms +- 0.28 ms
[Iter 10]       Pruning ratio: 0.50,         MACs: 1.07 G,   Params: 6.41 M,         Latency: 20.68 ms +- 0.13 ms
[Iter 11]       Pruning ratio: 0.55,         MACs: 0.85 G,   Params: 5.14 M,         Latency: 20.48 ms +- 0.21 ms
[Iter 12]       Pruning ratio: 0.60,         MACs: 0.67 G,   Params: 4.07 M,         Latency: 18.12 ms +- 0.15 ms
[Iter 13]       Pruning ratio: 0.65,         MACs: 0.53 G,   Params: 3.10 M,         Latency: 15.19 ms +- 0.01 ms
[Iter 14]       Pruning ratio: 0.70,         MACs: 0.39 G,   Params: 2.28 M,         Latency: 13.47 ms +- 0.01 ms
[Iter 15]       Pruning ratio: 0.75,         MACs: 0.29 G,   Params: 1.61 M,         Latency: 10.07 ms +- 0.01 ms
[Iter 16]       Pruning ratio: 0.80,         MACs: 0.18 G,   Params: 1.01 M,         Latency: 8.96 ms +- 0.02 ms
[Iter 17]       Pruning ratio: 0.85,         MACs: 0.10 G,   Params: 0.57 M,         Latency: 7.03 ms +- 0.04 ms
[Iter 18]       Pruning ratio: 0.90,         MACs: 0.05 G,   Params: 0.25 M,         Latency: 5.81 ms +- 0.03 ms
[Iter 19]       Pruning ratio: 0.95,         MACs: 0.01 G,   Params: 0.06 M,         Latency: 5.70 ms +- 0.03 ms
[Iter 20]       Pruning ratio: 1.00,         MACs: 0.01 G,   Params: 0.06 M,         Latency: 5.71 ms +- 0.03 ms

Please refer to benchmarks for more details.

Series of Works

DepGraph: Towards Any Structural Pruning [Project] [Paper]
Gongfan Fang, Xinyin Ma, Mingli Song, Michael Bi Mi, Xinchao Wang
CVPR 2023

LLM-Pruner: On the Structural Pruning of Large Language Models [Project] [arXiv]
Xinyin Ma, Gongfan Fang, Xinchao Wang
NeurIPS 2023

Structural Pruning for Diffusion Models [Project] [arxiv]
Gongfan Fang, Xinyin Ma, Xinchao Wang
NeurIPS 2023

DeepCache: Accelerating Diffusion Models for Free [Project] [Arxiv]
Xinyin Ma, Gongfan Fang, and Xinchao Wang
CVPR 2024

0.1% Data Makes Segment Anything Slim [Project] [Arxiv]
Zigeng Chen, Gongfan Fang, Xinyin Ma, Xinchao Wang
Preprint 2023

Citation

@inproceedings{fang2023depgraph,
  title={Depgraph: Towards any structural pruning},
  author={Fang, Gongfan and Ma, Xinyin and Song, Mingli and Mi, Michael Bi and Wang, Xinchao},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={16091--16101},
  year={2023}
}