Skip to content

Implementation of various neural network pruing methods in pytorch.

License

Notifications You must be signed in to change notification settings

Krishnkant-Swarnkar/Pytorch-pruning

Repository files navigation

Pytorch-Pruning

This repository contains the implementations of neural network weight pruning methods.

Available Pruning Methods

Pruning has three stages:

  1. Training
  2. Pruning:
    • OneShotPruning
  3. Retraining
    • No Retraining (none)
    • FineTuning (fine-tuning)
    • Learning Rate Rewinding (lr-rewinding)
    • Weight Rewinding (weight-rewinding)

Usage

Install the following python libraries: torch>=1.6.0 torchvision>=0.7.0 numpy tqdm livelossplot Run : $ mkdir saved_models

To get the pruned models you need:

  1. A Model (torch.nn.module)
    • for which forward() takes input as a single variable "batch"
    • which returns loss function (by default) when the forward() is called, and returns the prediction when get_prediction argument is set to true in the forward.
  2. Torch Data Loaders (train_dataloader, val_dataloader)

See the Jupyter Notebooks (.ipynb files) for a better idea about how to do the pruning. This repository contains implementations of the following Models: lenet, resnet20, resnet32, resnet44, resnet56, resnet110, resnet1202 and functions to load data loaders for MNIST, CIFAR-10

Directory Structure

.
|--- models
|    |--- __init__.py
|    |--- lenet.py
|    |--- resnet.py
|--- pruning
|    |--- unstructured
|    |    |--- __init__.py
|    |    |--- one_shot_pruning.py
|    |    |--- iterative_pruning.py
|    |--- structured
|    |    |--- __init__.py
|--- utils
|    |--- __init__.py
|    |--- data_utils.py
|    |--- train_utils.py
|--- Lisence
|--- README.md
|--- LISENCE

Todo

  1. Write code for weight distribution visualizations for each layer.
  2. Write code for iterative pruning.
  3. Write code for ADMM sparsity regularization.

Releases

No releases published

Packages

No packages published