From c2dff33b92601ae1df662fa5f8020a54c5ebe7e3 Mon Sep 17 00:00:00 2001 From: Motsepe-Jr Date: Sat, 3 Jun 2023 16:03:56 +0200 Subject: [PATCH 1/2] Sparse GPT Algorithm This is the sparseGPT code based on IST-DASLab project. I followed the same coding principles as used in the lit-llama gptq code. I created a file called sparsification which is the algorithm for SparseGPT and a folder called sparsify/sparsegpt.py to run the algorithm on the the model in the checkpoint_path. This is my first contribution to the project, If I missed some household admin I apologise in advance. Key Notes: 1. The source code of SparseGPT consist of the quantization algorithm similar to GPTQ, however I removed this code because we already have GPTQ in the lit-llama source code. 2. I'm still in the waiting list for the Llama weights 7B. --- lit_llama/sparsification.py | 130 ++++++++++++++++++ sparsify/sparsegpt.py | 255 ++++++++++++++++++++++++++++++++++++ 2 files changed, 385 insertions(+) create mode 100644 lit_llama/sparsification.py create mode 100644 sparsify/sparsegpt.py diff --git a/lit_llama/sparsification.py b/lit_llama/sparsification.py new file mode 100644 index 00000000..409d0084 --- /dev/null +++ b/lit_llama/sparsification.py @@ -0,0 +1,130 @@ +# This adapts SparseGPT process: https://github.com/IST-DASLab/sparsegpt +# E. Frantar et al SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot, https://arxiv.org/abs/2301.00774 +# portions copyright by the authors licensed under the Apache License 2.0 + + +import torch +import os +from contextlib import contextmanager +import warnings +import math + + +class SparseGPT: + + def __init__( + self, + linear_module, + sparsity, + prunen=0, + prunem=0, + blocksize=128, + percdamp=.01, + + ): + assert isinstance(linear_module, torch.nn.Linear) + + self.linear_module = linear_module + self.dev = self.linear_module.weight.device + self.rows = linear_module.weight.shape[0] + self.columns = linear_module.weight.shape[1] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + self.block_size = blocksize + self.sparsity = sparsity + self.perdamp = percdamp + self.prunen = prunen + self.prunem = prunem + + def collect_input_stats(self, _1, inp, _2): + inp = inp[0].detach() + self.last_inp = inp + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + inp = math.sqrt(2 / self.nsamples) * inp.float() + self.H += inp.matmul(inp.t()) + + def sparsify(self): + W = self.linear_module.weight.detach().to(dtype=torch.float, copy=True) + + H = self.H + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = self.percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + mask = None + + for i1 in range(0, self.columns, self.blocksize): + i2 = min(i1 + self.blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + + if prunen == 0: + if mask is not None: + mask1 = mask[:, i1:i2] + else: + tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 + thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] + mask1 = tmp <= thresh + else: + mask1 = torch.zeros_like(W1) == 1 + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if prunen != 0 and i % prunem == 0: + tmp = W1[:, i:(i + prunem)] ** 2 / (torch.diag(Hinv1)[i:(i + prunem)].reshape((1, -1))) ** 2 + mask1.scatter_(1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True) + + q = w.clone() + q[mask1[:, i]] = 0 + + + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d ** 2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses += torch.sum(Losses1, 1) / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + pruned_weights = Q.reshape(self.linear_module.weight.shape).to( + self.linear_module.weight.data.dtype + ) + + self.linear_module.weight.data = pruned_weights + del pruned_weights + error = torch.sum(Losses).item() + + return error + diff --git a/sparsify/sparsegpt.py b/sparsify/sparsegpt.py new file mode 100644 index 00000000..f2a20506 --- /dev/null +++ b/sparsify/sparsegpt.py @@ -0,0 +1,255 @@ +# This adapts SparseGPT process: https://github.com/IST-DASLab/sparsegpt +# E. Frantar et al SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot, https://arxiv.org/abs/2301.00774 +# portions copyright by the authors licensed under the Apache License 2.0 + + +import gc +import sys +import time +from pathlib import Path +from typing import Optional + +import torch +from datasets import load_dataset + + +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_llama import LLaMA, Tokenizer +from lit_llama.sparsification import SparseGPT + +from lit_llama.utils import EmptyInitOnDevice, llama_model_lookup + + +def get_sample_data(): + traindata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, + split="train", + ) + # heuristic for the data size? + txt = "\n".join( + traindata[i]["text"] for i in torch.randperm(len(traindata))[:1000].tolist() + ) + return txt + +@torch.no_grad() +def llama_blockwise_sparsification( + model, + sample_inputs, + working_device, + *, + sparsity=0, + prunen=0, + prunem=0, + +): + + print('Getting Inputs for the first block') + model.transformer.wte.to(working_device) + sample_inputs = sample_inputs.to(working_device) + inps = model.transformer.wte(sample_inputs) + model.transformer.wte.to("cpu") + torch.cuda.empty_cache() + + rope_cache = model.build_rope_cache(sample_inputs) + mask_cache = model.build_mask_cache(sample_inputs) + + print('Starting to sparsify block') + outs = torch.zeros_like(inps) + + + submodules_to_process = [ + "attn.c_attn", + "attn.c_proj", + "mlp.c_fc1", + "mlp.c_fc2", + "mlp.c_proj", + ] + + + for i, block in enumerate(model.transformer.h): + + block.to(working_device) + + for name in submodules_to_process: + print(i, name, end=" ") + t0 = time.perf_counter() + print("collecting stats", end=" ") + sys.stdout.flush() + module = block.get_submodule(name) + + sparsegpt = SparseGPT( + module, + sparsity=sparsity, + prunen=prunen, + prunem=prunem, + ) + + handle = model.lm_head.register_forward_hook(sparsegpt.collect_input_stats) + + for j in range(inps.size(0)): + outs[j : j + 1], _ = block( + inps[j : j + 1], + rope=rope_cache, + mask=mask_cache, + max_seq_length=model.config.block_size + ) + + handle.remove() + + error = sparsegpt.sparsify() + + del sparsegpt + gc.collect() + torch.cuda.empty_cache() + t1 = time.perf_counter() + print(f"time {int(t1 - t0 + 0.5)}s sparsification error {error:.1f}") + + + for j in range(inps.size(0)): + outs[j : j + 1], _ = block( + inps[j : j + 1], + rope=rope_cache, + mask=mask_cache, + max_seq_length=model.config.block_size + ) + + block.cpu() + gc.collect() + torch.cuda.empty_cache() + + inps, outs = outs, inps + + model.transformer.ln_f.to(working_device) + for j in range(inps.size(0)): + outs[j : j + 1] = model.transformer.ln_f(inps[j : j + 1]) + model.transformer.ln_f.to("cpu") + + # normalised out will be input to the LM head + inps, outs = outs, inps + + model.lm_head.to(working_device) + sparsegpt = SparseGPT( + model.lm_head, + sparsity=sparsity, + prunen=prunen, + prunem=prunem, + ) + + # During the forward pass, the collect_input_stats function collects input statistics and updates the Hessian matrix. + handle = model.lm_head.register_forward_hook(sparsegpt.collect_input_stats) + for j in range(inps.size(0)): + model.lm_head(inps[j : j + 1]) + handle.remove() + # After the forward pass, the sparsify function can be called to perform the sparsification based on the collected statistics. + error = sparsegpt.sparsify() + model.lm_head.to("cpu") + +def main( + *, + checkpoint_path: Path = Path("checkpoints/lit-llama/7B/lit-llama.pth"), + output_path: Optional[Path] = None, + tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), + n_samples: int = 128, + dtype: str = "float32", + sparsity: int = 0, + prunem: int = 0, + prunen: int = 0 +) -> None: + """ + Generates text samples based on a pre-trained LLaMA model and tokenizer. + + Args: + checkpoint_path: The checkpoint path to load. + output_path: Path to write the sparsified model's state dict to. + tokenizer_path: The tokenizer path to load. + n_samples: Number of example inputs to use for statistics (default: 128) + dtype: The dtype to use to load the model. + sparsity: Target sparsity + prunem: M for N:M pruning. + prunen: N for N:M pruning. + """ + assert checkpoint_path.is_file() + assert tokenizer_path.is_file() + if output_path is None: + output_path = checkpoint_path.parent / "llama-gpt-sparsified.pth" + assert output_path.parent.is_dir() and (not output_path.exists() or output_path.is_file()) + + device = "cuda" + + dt = getattr(torch, dtype, None) + if not isinstance(dt, torch.dtype): + raise ValueError(f"{dtype} is not a valid dtype.") + dtype = dt + + + # we avoid loading the entire model on the GPU and do this block by block + with EmptyInitOnDevice( + device="cpu", + dtype=dtype, + ): + print("Loading model ...", file=sys.stderr) + t0 = time.time() + checkpoint = torch.load(checkpoint_path) + name = llama_model_lookup(checkpoint) + model = LLaMA.from_name(name) + model.load_state_dict(checkpoint) + print(f"Time to load model: {time.time() - t0:.02f} seconds.", file=sys.stderr) + + model.eval() + + tokenizer = Tokenizer(tokenizer_path) + + test_string = get_sample_data() + encoded_text = tokenizer.encode( + test_string, + bos=True, + eos=False, + ) + + block_size = 2048 + # truncate the text and reshape to batch by sequence length + encoded_text = encoded_text[: n_samples * block_size].reshape(n_samples, block_size) + + t0 = time.perf_counter() + llama_blockwise_sparsification( + model=model, + sample_inputs=encoded_text, + working_device=device, + sparsity=sparsity, + prunen=prunen, + prunem=prunem + ) + t = time.perf_counter() - t0 + + print( + f"\n\nTime for sparsification: {t:.02f} sec total", + file=sys.stderr, + ) + print( + f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB", + file=sys.stderr, + ) + + torch.save(model.state_dict(), output_path) + + +if __name__ == "__main__": + from jsonargparse import CLI + + torch.set_float32_matmul_precision("high") + CLI(main) + + + + + + + + + + + From a492ad0a841ba1625a77fc6cd425cda3667dac91 Mon Sep 17 00:00:00 2001 From: Motsepe-Jr Date: Thu, 8 Jun 2023 21:34:16 +0200 Subject: [PATCH 2/2] sparseGPT This is the sparseGPT code based on IST-DASLab project. I followed the same coding principles as used in the lit-llama gptq code. I created a file called sparsification which is the algorithm for SparseGPT and a folder called sparsify/sparsegpt.py to run the algorithm on the model in the checkpoint_path. This is my first contribution to the project, If I missed some household admin I apologize in advance. Assuming you have a model under checkpoints/open-llama/7B you can run this command: python sparsify/sparsegpt.py --checkpoint_path checkpoints/lit-llama/7B/lit-llama.pth Key Notes: 0. I used half n_samples (128-->64) due to memory requirement 1. The SparseGPT paper was evaluated on models trained not using the chinchilla scaling law (Therefore my hypothesis is that some of the weights of those model were not useful, hence they were able to prune 50%). With Llama I only used 0.1 target sparsity.) 2. The source code of SparseGPT consist of a quantization algorithm similar to GPTQ, however, I removed this code because we already have GPTQ in the lit-llama source code. If you would like me to include it, it is also okay I can include GPTQ under sparseGPT code. Before you commit, please also test from your side, and let me know if you want me to solve any bug or integrate a specific feature Thanks --- lit_llama/sparsification.py | 184 ++++++++++++++++++------------------ sparsify/sparsegpt.py | 12 ++- 2 files changed, 98 insertions(+), 98 deletions(-) diff --git a/lit_llama/sparsification.py b/lit_llama/sparsification.py index 409d0084..bea386a0 100644 --- a/lit_llama/sparsification.py +++ b/lit_llama/sparsification.py @@ -19,7 +19,7 @@ def __init__( prunen=0, prunem=0, blocksize=128, - percdamp=.01, + percdamp=0.01, ): assert isinstance(linear_module, torch.nn.Linear) @@ -30,101 +30,99 @@ def __init__( self.columns = linear_module.weight.shape[1] self.H = torch.zeros((self.columns, self.columns), device=self.dev) self.nsamples = 0 - self.block_size = blocksize + self.blocksize = blocksize self.sparsity = sparsity - self.perdamp = percdamp + self.percdamp = percdamp self.prunen = prunen self.prunem = prunem - def collect_input_stats(self, _1, inp, _2): - inp = inp[0].detach() - self.last_inp = inp - if len(inp.shape) == 2: - inp = inp.unsqueeze(0) - tmp = inp.shape[0] - if len(inp.shape) == 3: - inp = inp.reshape((-1, inp.shape[-1])) - inp = inp.t() - self.H *= self.nsamples / (self.nsamples + tmp) - self.nsamples += tmp - inp = math.sqrt(2 / self.nsamples) * inp.float() - self.H += inp.matmul(inp.t()) - - def sparsify(self): - W = self.linear_module.weight.detach().to(dtype=torch.float, copy=True) - - H = self.H - del self.H - dead = torch.diag(H) == 0 - H[dead, dead] = 1 - W[:, dead] = 0 - - - Losses = torch.zeros_like(W) - Q = torch.zeros_like(W) - - damp = self.percdamp * torch.mean(torch.diag(H)) - diag = torch.arange(self.columns, device=self.dev) - H[diag, diag] += damp - H = torch.linalg.cholesky(H) - H = torch.cholesky_inverse(H) - H = torch.linalg.cholesky(H, upper=True) - Hinv = H - - mask = None - - for i1 in range(0, self.columns, self.blocksize): - i2 = min(i1 + self.blocksize, self.columns) - count = i2 - i1 - - W1 = W[:, i1:i2].clone() - Q1 = torch.zeros_like(W1) - Err1 = torch.zeros_like(W1) - Losses1 = torch.zeros_like(W1) - Hinv1 = Hinv[i1:i2, i1:i2] - - - if prunen == 0: - if mask is not None: - mask1 = mask[:, i1:i2] - else: - tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 - thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] - mask1 = tmp <= thresh + def collect_input_stats(self, _1, inp, _2): + inp = inp[0].detach() + self.last_inp = inp + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + inp = math.sqrt(2 / self.nsamples) * inp.float() + self.H += inp.matmul(inp.t()) + + def sparsify(self): + + W = self.linear_module.weight.detach().to(dtype=torch.float, copy=True) + H = self.H + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = self.percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + mask = None + + for i1 in range(0, self.columns, self.blocksize): + i2 = min(i1 + self.blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + + if self.prunen == 0: + if mask is not None: + mask1 = mask[:, i1:i2] else: - mask1 = torch.zeros_like(W1) == 1 - - for i in range(count): - w = W1[:, i] - d = Hinv1[i, i] - - if prunen != 0 and i % prunem == 0: - tmp = W1[:, i:(i + prunem)] ** 2 / (torch.diag(Hinv1)[i:(i + prunem)].reshape((1, -1))) ** 2 - mask1.scatter_(1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True) - - q = w.clone() - q[mask1[:, i]] = 0 - - - Q1[:, i] = q - Losses1[:, i] = (w - q) ** 2 / d ** 2 - - err1 = (w - q) / d - W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - Err1[:, i] = err1 - - Q[:, i1:i2] = Q1 - Losses += torch.sum(Losses1, 1) / 2 - - W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) - - pruned_weights = Q.reshape(self.linear_module.weight.shape).to( - self.linear_module.weight.data.dtype - ) - - self.linear_module.weight.data = pruned_weights - del pruned_weights - error = torch.sum(Losses).item() - - return error + tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 + thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * self.sparsity)] + mask1 = tmp <= thresh + else: + mask1 = torch.zeros_like(W1) == 1 + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if self.prunen != 0 and i % self.prunem == 0: + tmp = W1[:, i:(i + self.prunem)] ** 2 / (torch.diag(Hinv1)[i:(i + self.prunem)].reshape((1, -1))) ** 2 + mask1.scatter_(1, i + torch.topk(tmp, self.prunen, dim=1, largest=False)[1], True) + + q = w.clone() + q[mask1[:, i]] = 0 + + + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d ** 2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + pruned_weights = Q.reshape(self.linear_module.weight.shape).to( + self.linear_module.weight.data.dtype + ) + + # set the linear module weights to pruned weights + self.linear_module.weight.data = pruned_weights + error = torch.sum(Losses).item() + return error diff --git a/sparsify/sparsegpt.py b/sparsify/sparsegpt.py index f2a20506..036c1a80 100644 --- a/sparsify/sparsegpt.py +++ b/sparsify/sparsegpt.py @@ -12,7 +12,6 @@ import torch from datasets import load_dataset - wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) @@ -41,7 +40,7 @@ def llama_blockwise_sparsification( sample_inputs, working_device, *, - sparsity=0, + sparsity, prunen=0, prunem=0, @@ -87,8 +86,8 @@ def llama_blockwise_sparsification( prunen=prunen, prunem=prunem, ) - - handle = model.lm_head.register_forward_hook(sparsegpt.collect_input_stats) + + handle = module.register_forward_hook(sparsegpt.collect_input_stats) for j in range(inps.size(0)): outs[j : j + 1], _ = block( @@ -100,6 +99,9 @@ def llama_blockwise_sparsification( handle.remove() + + print("sparsifying", end=" ") + sys.stdout.flush() error = sparsegpt.sparsify() del sparsegpt @@ -155,7 +157,7 @@ def main( tokenizer_path: Path = Path("checkpoints/lit-llama/tokenizer.model"), n_samples: int = 128, dtype: str = "float32", - sparsity: int = 0, + sparsity: float = 0.1, prunem: int = 0, prunen: int = 0 ) -> None: