Official implementation of Shears: Unstructured Sparsity with Neural Low-rank Adapter Search. 🔥
This repo contains the code for Shears, a practical and novel solution that generates efficient models fine-tuned for downstream-specific tasks for real-world applications. Please refer to our paper for more details.
⚠️ Please see the more recent work SQFT. Its SparsePEFT strategy is not only comparable to Shears, but it can also merge adapters into sparse models without losing sparsity.
- [2024.04.18] Shears V1 paper has been released (link) and accepted by NAACL 2024 (Industry Track). 📚
- [2024.04.11] Release training and inference code for Shears V1. 🎉
We have released several models fine-tuned with Shears. Find them in the Table below:
Efficiency Comparison (use LLaMA-13B as an example):
Method | Sparsity | Non-zero Parameters | Accuracy |
---|---|---|---|
LoRA | - | 13.0B | 51.1 |
Shears | 50% | 6.7B | 50.9 |
By incorporating elastic LoRA adapters into a sparsified base model, Shears can fine-tune a language model without sacrificing the sparsity obtained from the original model weights. This produces sparse models with improvements or minor drops in accuracy and a fraction of the cost compared to other approaches. The increase in sparsity can result in a significant speedup when using runtimes that take advantage of these patterns.
Overall, Shears has a well-designed, simple yet effective, powerful, and general pipeline that allows users to easily extend it to their desired scenarios/tasks, even audio and video. Feel free to try Shears for any downstream task with any model.
Here is an installation script developed from scratch for Shears.
pip install virtualenv
virtualenv shears-env
source shears-env/bin/activate
# install pytorch
pip install torch==2.1.2
# install dependencies
bash install.sh
Note: Please ignore the whitespace issues when applying the patch and running install.sh
.
The following code shows an example of loading our trained Shears model:
from transformers import AutoModelForCausalLM
from peft import PeftModel
base_model = AutoModelForCausalLM.from_pretrained("IntelLabs/shears-mpt-7b-50-base")
model = PeftModel.from_pretrained(base_model, "IntelLabs/shears-mpt-7b-50-gsm8k-heuristic-adapter")
Below is an example of generating the instruction-following responses for some math reasoning samples:
python example_math.py --base_model_path IntelLabs/shears-mpt-7b-50-base --adapter_model_path IntelLabs/shears-mpt-7b-50-gsm8k-heuristic-adapter
Before fine-tuning, Shears employs a simple but effective pruning approach Wanda to sparsify the language model, serving as the base model (frozen) for adapter training. Clone the Wanda repo:
git clone https://github.com/locuslab/wanda.git && cd wanda && git checkout 8e8fc87 && cd ..
Below is an example command for unstructured sparsifying LLaMA-7B with Wanda to achieve unstructured 50% sparsity (takes about five minutes).
python wanda/main.py \
--model yahma/llama-7b-hf \
--prune_method wanda \
--sparsity_ratio 0.5 \
--sparsity_type unstructured \
--save wanda_out \
--save_model <path to sparse model>
--model
: The identifier for the model on the Hugging Face model hub or local path.--sparsity_ratio
: Specifies the percentage of weights to be pruned.--save_model
: Specifies the directory where the pruned language model will be stored.
Further details can be referred to Wanda. You can also skip this step and adopt our released sparsified models (find them in Base Model of Table). It is worth noting that the sparsifying step can be replaced by any other sparse (or even quantization) algorithm. Feel free to try other approaches for the base model.
Taking the unified math reasoning training as an example, please download the 10K instruction-following math reasoning training data (link) from LLM-Adapters.
Example command to train the super-adapter of the pruned LLaMA-7B using Shears:
python run_math.py \
--dataset_path math_10k.json \
--model_name_or_path <path to sparse model> \
--do_train \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 2 \
--num_train_epochs 3 \
--learning_rate 3e-4 \
--warmup_steps 100 \
--optim adamw_torch \
--fp16 \
--output_dir <path to super-adapter> \
--logging_steps 20 \
--save_strategy epoch \
--save_total_limit 2 \
--lora \
--lora_r 32 \
--lora_alpha 64 \
--lora_dropout 0.1 \
--target_modules q_proj,k_proj,v_proj,up_proj,down_proj \
--nncf_config nncf_config/nncf_shears_llama.json
--nncf_config
indicates the NNCF configuration, including the search space for elastic adapters.
To implement the elastic adapter, we apply the BootstrapNAS feature supported in OpenVINO™ NNCF, which provides a suite of compression algorithms for neural network optimization.
The NNCF configuration details are in nncf_config.md.
After training, the trained super-adapter will be saved in ADAPTER_MODEL_PATH
.
All evaluation datasets can be downloaded from LLM-Adapters.
Place them into the directory datasets/
.
git clone https://github.com/AGI-Edgerunners/LLM-Adapters.git
mv LLM-Adapters/dataset/ datasets/
Example command to evaluate the trained super-adapter (heuristic subnetwork):
python run_math.py \
--model_name_or_path <path to sparse model> \
--lora \
--lora_weights <path to super-adapter> \
--do_test \
--output_dir <path to results> \
--nncf_config nncf_config/nncf_shears_llama.json
The above command can also be used to test the released model, for example,
python run_math.py \
--model_name_or_path <path to sparse model> \
--lora \
--lora_weights IntelLabs/shears-llama-7b-50-math-super-adapter \
--do_test \
--output_dir <path to results> \
--nncf_config nncf_config/nncf_shears_llama.json
Note that the torch version we used in our experiments is 1.12.1+cu113
, and the results might vary with different versions.
Please refer to running_commands for all commands related to reproducing the paper's results.
- LLaMA with Math Reasoning tasks
Model | Sparsity | GSM8K | AQuA | MAWPS | SVAMP | Average |
---|---|---|---|---|---|---|
LLaMA-7B-LoRA | - | 37.5 | 18.9 | 79.0 | 52.1 | 46.9 |
LLaMA-7B-Shears | 40% | 36.8 | 19.7 | 83.2 | 47.7 | 46.9 |
LLaMA-7B-Shears | 50% | 36.1 | 22.0 | 78.6 | 44.5 | 45.3 |
LLaMA-13B-LoRA | - | 47.5 | 18.5 | 83.6 | 54.6 | 51.1 |
LLaMA-13B-Shears | 40% | 48.3 | 21.3 | 83.2 | 55.2 | 52.0 |
LLaMA-13B-Shears | 50% | 45.1 | 22.0 | 83.2 | 53.3 | 50.9 |
- LLaMA with Commonsense Reasoning tasks
Model | Sparsity | BoolQ | PIQA | SIQA | HellaSwag | WinoG | ARC-e | ARC-c | OBQA | Average |
---|---|---|---|---|---|---|---|---|---|---|
ChatGPT | - | 73.1 | 85.4 | 68.5 | 78.5 | 66.1 | 89.8 | 79.9 | 74.8 | 77.0 |
LLaMA-7B-LoRA | - | 68.9 | 80.7 | 77.4 | 78.1 | 78.8 | 77.8 | 61.3 | 74.8 | 74.7 |
LLaMA-7B-Shears | 40% | 67.0 | 79.9 | 76.7 | 80.1 | 78.6 | 76.9 | 62.3 | 77.8 | 74.9 |
LLaMA-7B-Shears | 50% | 67.3 | 79.1 | 77.5 | 73.3 | 77.7 | 74.4 | 57.9 | 72.8 | 72.5 |
- MPT with GSM8K
Sparsity | 0% | 40% | 50% | 60% | 70% |
---|---|---|---|---|---|
Accuracy | 36.1 | 35.7 | 33.4 | 30.4 | 22.8 |
To enhance exploration of the super-network trained using the Shears method, we provide an illustrative example search/load_and_explore_supernet.ipynb
. This notebook demonstrates the direct loading of a Shears super-network and the extraction of various subnetworks.
This facilitates users in applying their own search algorithms and evaluation metrics to extract subnetworks tailored to their specific requirements.
After training, we obtain the weights of the super-adapter stored in ADAPTER_MODEL_PATH
and activate different
sub-networks using NNCF. However, once a particular sub-network we need is identified, activating just that one sub-network using NNCF is no longer necessary. Instead, what we need is a clean, pruned, and directly-loadable
sub-network. For this purpose, we provide a function to extract/save any sub-adapter (please refer to the end of
search/load_and_explore_supernet.ipynb
). Below is an example to obtain the heuristic sub-adapter of a trained
super-adapter:
import os
from peft import PeftModel
from transformers import AutoModelForCausalLM
from search.supernet import ShearsSuperNet
from utils.utils import load_nncf_config
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_PATH, trust_remote_code=True)
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_PATH)
nncf_config = load_nncf_config(NNCF_CONFIG, num_hidden_layers=model.config.num_hidden_layers)
supernet = ShearsSuperNet.from_checkpoint(model, nncf_config, supernet_elasticity_path=None, supernet_weights_path=None)
supernet.activate_heuristic_subnet()
supernet.extract_and_save_active_sub_adapter(super_adapter_dir=ADAPTER_MODEL_PATH, sub_adapter_dir=os.path.join(ADAPTER_MODEL_PATH, "heuristic_adapter"))
We released some examples of the extracted heuristic sub-adapters. Refer to them in Table.
If you find our Shears code and paper helpful, please kindly cite:
@inproceedings{munoz-etal-2024-shears,
title = "Shears: Unstructured Sparsity with Neural Low-rank Adapter Search",
author = "Mu{\~n}oz, J. Pablo and
Yuan, Jinjie and
Jain, Nilesh",
editor = "Yang, Yi and
Davani, Aida and
Sil, Avi and
Kumar, Anoop",
booktitle = "Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 6: Industry Track)",
month = jun,
year = "2024",
address = "Mexico City, Mexico",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2024.naacl-industry.34",
doi = "10.18653/v1/2024.naacl-industry.34",
pages = "395--405",
}
This work benefits from the following repositories:
- LLaMA: https://github.com/facebookresearch/llama
- MPT: https://www.mosaicml.com/mpt
- Transformers: https://github.com/huggingface/transformers
- PEFT: https://github.com/huggingface/peft
- LLM-Adapters: ttps://github.com/AGI-Edgerunners/LLM-Adapters
- NNCF: https://github.com/openvinotoolkit/nncf
- BootstrapNAS: https://github.com/IntelLabs/Hardware-Aware-Automated-Machine-Learning/blob/main/BootstrapNAS
- Wanda: https://github.com/locuslab/wanda