Skip to content

2SSP: A Two-Stage Framework for Structured Pruning of LLMs

License

Notifications You must be signed in to change notification settings

FabrizioSandri/2SSP

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

2SSP: A Two-Stage Framework for Structured Pruning of LLMs

Fabrizio Sandri, Elia Cunegatti, Giovanni Iacca

arXiv

Abstract. We propose a novel Two-Stage framework for Structured Pruning (2SSP) for pruning Large Language Models (LLMs), which combines two different strategies of pruning, namely Width and Depth Pruning. The first stage (Width Pruning) removes entire neurons, hence their corresponding rows and columns, aiming to preserve the connectivity among the pruned structures in the intermediate state of the Feed-Forward Networks in each Transformer block. This is done based on an importance score measuring the impact of each neuron over the output magnitude. The second stage (Depth Pruning), instead, removes entire Attention submodules. This is done by applying an iterative process that removes the Attention submodules with the minimum impact on a given metric of interest (in our case, perplexity). We also propose a novel mechanism to balance the sparsity rate of the two stages w.r.t. to the desired global sparsity. We test 2SSP on four LLM families and three sparsity rates (25%, 37.5%, and 50%), measuring the resulting perplexity over three language modeling datasets as well as the performance over six downstream tasks. Our method consistently outperforms five state-of-the-art competitors over three language modeling and six downstream tasks, with an up to two-order-of-magnitude gain in terms of pruning time.

Installation

We recommend setting up the environment using either a Conda environment or a Python virtual environment. Our experiments were conducted using Python 3.11.7.

Clone the 2SSP repository and install its dependencies:

git clone https://github.com/FabrizioSandri/2SSP.git
cd 2SSP
pip install -r requirements.txt

Install Language Model Evaluation Harness (required for downstream task evaluation):

cd lm_harness
pip install -e ./

Usage

usage: main.py [-h] --model MODEL [--seed SEED] [--cache_dir CACHE_DIR] [--dense]
               [--pruning_method {2ssp,window_based,shortgpt,blockpruner,evopress,slicegpt}]
               [--sparsity_rate SPARSITY_RATE] [--main_table_results] [--evaluate_inference] [--evaluate_downstream]
               [--evaluate_perplexity] [--evaluate_qualitative] [--local_datasets] [--ablation]
               [--logging {DEBUG,INFO,WARNING,ERROR,CRITICAL}]

Pruning of transformer models

options:
  -h, --help            show this help message and exit
  --model MODEL         Specify the model's name or path to be pruned
  --seed SEED           Set a seed for reproducibility (default: 0)
  --cache_dir CACHE_DIR
                        Path to a directory in which a downloaded pretrained model should be cached. This option is not
                        supported when --pruning_method=slicegpt
  --dense               Load the original dense model without pruning
  --pruning_method {2ssp,window_based,shortgpt,blockpruner,evopress,slicegpt}
                        Specify the pruning method to apply
  --sparsity_rate SPARSITY_RATE
                        A floating-point value ranging from 0.0 to 1.0 that determines the target sparsity level for
                        pruning. If set to -1, pruning is performed at all sparsity levels from 0.0 to 1.0 with a step size
                        of 1/N. A value of -2 applies pruning at predefined sparsity levels of 25%, 37.5%, and 50%.)
  --main_table_results  Generate results for the main results table in the paper (Table 1)
  --evaluate_inference  Measure the model's inference time
  --evaluate_downstream
                        Perform downstream task evaluation at 37.5% sparsity
  --evaluate_perplexity
                        Evaluates perplexity on Wikitext2 only
  --evaluate_qualitative
                        Qualitative results
  --local_datasets      Use local datasets stored in the './data/' folder
  --ablation            Run the ablation study experiments
  --logging {DEBUG,INFO,WARNING,ERROR,CRITICAL}
                        Set the logging level (default: INFO)

Examples

  • Dense Model Perplexity Evaluation:

    python 2ssp/main.py --model=meta-llama/Llama-2-7b-hf --dense --evaluate_perplexity
  • Pruning at 50% sparsity with 2ssp and evaluating perplexity:

    python 2ssp/main.py --model=meta-llama/Llama-2-7b-hf --pruning_method=2ssp --sparsity_rate=0.5 --evaluate_perplexity
  • Pruning at 50% sparsity with ShortGPT and evaluating perplexity:

    python 2ssp/main.py --model=meta-llama/Llama-2-7b-hf --pruning_method=shortgpt --sparsity_rate=0.5 --evaluate_perplexity
  • Generate main table results for 2SSP at 25%, 37.5% and 50% sparsity on Mistral:

    python 2ssp/main.py --model=meta-llama/Llama-2-7b-hf --pruning_method=2ssp --sparsity_rate=-2 --main_table_results
  • Evaluate downstream tasks at 37.5% sparsity:

    python 2ssp/main.py --model=meta-llama/Llama-2-7b-hf --pruning_method=2ssp --sparsity_rate=0.375 --evaluate_downstream

Supported Pruning Methods


Acknowledgments

Our method includes code sourced from the following repositories:

For more details, refer to the documentation or the associated research paper.

Citation

If you find this work useful, please consider citing:

@article{sandri20252SSP,
      title={2SSP: A Two-Stage Framework for Structured Pruning of LLMs}, 
      author={Fabrizio Sandri and Elia Cunegatti and Giovanni Iacca},
      journal={arXiv preprint arXiv:2501.17771},
      url={https://arxiv.org/abs/2501.17771},
      year={2025}
}

About

2SSP: A Two-Stage Framework for Structured Pruning of LLMs

Topics

Resources

License

Stars

Watchers

Forks