From 860f3e38bb5ce3840a739440660aba51aef43ff6 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 18 Feb 2025 17:40:58 +0000 Subject: [PATCH 01/40] cherry picked files from stale PR #181 branch awq-feature-branch Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/__init__.py | 3 + src/llmcompressor/modifiers/awq/base.py | 720 ++++++++++++++++++ src/llmcompressor/pytorch/utils/helpers.py | 481 ++++++++++++ .../transformers/finetune/data/__init__.py | 1 + .../transformers/finetune/data/pile.py | 45 ++ src/llmcompressor/utils/pytorch/module.py | 20 + tests/llmcompressor/modifiers/awq/__init__.py | 0 .../llmcompressor/modifiers/awq/test_base.py | 28 + .../pytorch/utils/test_helpers.py | 260 +++++++ .../finetune/data/test_registry.py | 2 + tests/llmcompressor/utils/pytorch/__init__.py | 0 .../utils/pytorch/test_module.py | 31 + 12 files changed, 1591 insertions(+) create mode 100644 src/llmcompressor/modifiers/awq/__init__.py create mode 100644 src/llmcompressor/modifiers/awq/base.py create mode 100644 src/llmcompressor/transformers/finetune/data/pile.py create mode 100644 tests/llmcompressor/modifiers/awq/__init__.py create mode 100644 tests/llmcompressor/modifiers/awq/test_base.py create mode 100644 tests/llmcompressor/utils/pytorch/__init__.py create mode 100644 tests/llmcompressor/utils/pytorch/test_module.py diff --git a/src/llmcompressor/modifiers/awq/__init__.py b/src/llmcompressor/modifiers/awq/__init__.py new file mode 100644 index 000000000..8bdc93d14 --- /dev/null +++ b/src/llmcompressor/modifiers/awq/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .base import * diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py new file mode 100644 index 000000000..71767b0b4 --- /dev/null +++ b/src/llmcompressor/modifiers/awq/base.py @@ -0,0 +1,720 @@ +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from loguru import logger +from torch.nn import Module +from tqdm import tqdm + +from llmcompressor.core import Event, State +from llmcompressor.modifiers import Modifier +from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward +from llmcompressor.pytorch.utils import ( + clear_memory, + pseudo_quantize_tensor, + tensor_forward_with_input_args, +) +from llmcompressor.utils.fsdp.helpers import get_fsdp_parent +from llmcompressor.utils.pytorch.module import ( + get_layer, + get_layers, + get_matching_layer, + get_parent_by_name, +) + +DEFAULT_AWQ_MAPPINGS = [ + [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"], + [["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"], + [["re:.*down_proj"], "re:.*up_proj"], +] + +__all__ = ["AWQScale", "AWQMapping", "AWQModifier"] + + +@dataclass +class AWQScale: + """ + Dataclass for storing the input activations of a layer to be smoothed + """ + + inps: Union[List[torch.Tensor], torch.Tensor] + + +@dataclass +class AWQMapping: + """ + Dataclass for storing the mapping between an activation layer and the following + weights that must be balanced during smoothing + + :param smooth_name: name of the activation layer + :param smooth_layer: PyTorch module storing the activation layer + :param balance_layers: list of PyTorch modules that smooth_layer feeds into, must be + balanced to offset the smoothing of smooth_layer + :param balance_names: optional list of names of the balance_layers + :param parent: parent module of the balance_layers + :param parent_name: name of the parent module + """ + + smooth_name: str + smooth_layer: Module + balance_layers: List[Module] + balance_names: Optional[List[str]] = None + parent: Optional[Module] = None + parent_name: Optional[str] = None + + +class AWQModifier(Modifier): + """ + Implements the AWQ (Activation-Weighted Quantization) algorithm, + as described in https://arxiv.org/pdf/2306.00978. The algorithm + significantly reduces quantization error by protecting only 1% + of the most salient weight channels. + + Instead of focusing on the weight values directly, AWQ identifies + salient channels based on the activation distribution. + To further minimize quantization error, the algorithm scales up these + salient channels using an equivalent transformation. The scaling factor + is determined offline by collecting activation statistics + + Because this modifier manipulates the weights of the model, it can only be used in + in one-shot and not during training. Activation ranges are determined by running a + small set of calibration data through the model. + + example recipe: + ```yaml + AWQModifier: + bits: 4 + mappings: [ + [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*self_attn_layer_norm"], + [["re:.*fc1"], "re:.*final_layer_norm"] + ] + ignore: ["model.decoder.final_layer_norm"] + ``` + + :param mappings: list activation layers to smooth, and which layers to + scale the output such that activations are smoothed. + Each entry of the mapping list should be a list itself, in which the first + entry is a list of layers who share the same input activation (the one to be + to smoothed) and the second entry is the layer whose output is scaled to + achieve the smoothing. + If regex is used, it matches layers with the largest overlap in module name. + :param ignore: list of layers to ignore, even if they match a regex in mappings. + It should match the name of layers whose outputs are scaled to achieve + smoothing (the second entry of the mappings list). + :param num_calibration_steps: number of samples to use for calibration, or None to + use the whole dataset + :param calibration_function: optional function to use for the forward pass, or None + to use the default tensor_module_forward + :param group_size: number of weights to group together for scaling + :param max_chunk_memory: maximum memory to use for each chunk of input activations + :param bits: number of bits to quantize the weights to + :param symmetric: whether to use symmetric quantization + :param duo_scaling: whether to use duo scaling, which uses both input activations + and weights to determine the scaling factor + :param apply_clip: whether to apply clipping to the weights after scaling + """ + + mappings: List[Tuple] = DEFAULT_AWQ_MAPPINGS + ignore: Optional[List[str]] = None + num_calibration_steps: Optional[int] = None + calibration_function: Optional[Callable] = None + group_size: int = 128 + max_chunk_memory: int = 1024 * 1024 * 1024 + bits: int = 4 + symmetric: bool = True + duo_scaling: bool = True + apply_clip: bool = True + + hooks_: Optional[List] = None + resolved_mappings_: Optional[List] = None + scales_: Optional[Dict] = None + module_kwargs_: Optional[Dict] = None + + def on_initialize_structure(self, state: State, **kwargs): + pass # nothing needed for this modifier + + def on_initialize(self, state: State, **kwargs) -> bool: + """ + Initialize and run AWQ on the given state + + :param state: state to run AWQ on + :return: True on a successful run, False otherwise + """ + if not (self.end is None or self.end == -1): + raise ValueError( + f"{self.__class__.__name__} can only be applied during one-shot. " + f" Expected end to be None or -1, got {self.end}" + ) + if self.start and self.start != -1: + raise ValueError( + f"{self.__class__.__name__} can only be applied during one-shot. " + f"Expected start to be None or -1, got {self.end}" + ) + + self.ignore = [] if not self.ignore else self.ignore + self.resolved_mappings_ = self._resolve_mappings(state.model) + self.scales_ = {} + + calibration_dataloader = state.data.calib + self.hooks_ = [] + + self._get_module_kwargs(state.model, calibration_dataloader) + self._setup_scale_hooks() + self._calibrate(state.model, calibration_dataloader) + self._concat_collected_activations() + self._apply_smoothing(state.model) + + return True + + def on_start(self, state: State, event: Event, **kwargs): + pass + + def on_update(self, state: State, event: Event, **kwargs): + pass + + def on_end(self, state: State, event: Event, **kwargs): + pass + + def on_event(self, state: State, event: Event, **kwargs): + pass + + def on_finalize(self, state: State, **kwargs) -> bool: + """ + Clean up by clearing the scale and mapping data + + :param state: unused + :return: True + """ + if self.scales_ is not None: + self.scales_.clear() + if self.resolved_mappings_ is not None: + self.resolved_mappings_.clear() + + return True + + def _resolve_mappings(self, model: Module) -> List: + """ + Transforms the list of activations to smooth and their corresponding weights + into AWQMapping objects, resolving regular expressions. + + For each activation in the mapping list, we find the corresponding weight to + balance by searching for the longest substring. For instance, if our balance + weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we + would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and + repeat for model.layer.1 and so on + """ + resolved_mappings = [] + for to_balance, to_smooth in self.mappings: + to_smooth_layers = get_layers(to_smooth, model) + for layer_name, smooth_layer in to_smooth_layers.items(): + if layer_name not in self.ignore: + balance_layers, balance_names = [], [] + for balance_suffix in to_balance: + # find the submodule that matches the activation layer + balance_name, balance_layer = get_matching_layer( + balance_suffix, layer_name, model + ) + if balance_layer: + balance_layers.append(balance_layer) + balance_names.append(balance_name) + + # each mapping can contain multiple layers to balance, but only + # one layer to smooth + + if len(balance_layers) == 1: + # for single balance layer, parent is the balance layer + parent_name, parent = balance_name, balance_layer + else: + # for multiple balance layers, + # parent of any balance layer is the parent + parent_name, parent = get_parent_by_name( + layer_name=balance_name, model=model + ) + mapping = AWQMapping( + layer_name, + smooth_layer, + balance_layers, + balance_names=balance_names, + parent=parent, + parent_name=parent_name, + ) + resolved_mappings.append(mapping) + return resolved_mappings + + def _setup_scale_hooks(self): + """ + Attach a forward hook to each activation we want to smooth. This allows us to + calculate the dynamic range during calibration + """ + + def create_hook_fn(layer_name): + def hook_fn(module, inp, out): + inp = inp[0] + inp.cpu().detach() + + if layer_name in self.scales_: + self.scales_[layer_name].inps.append(inp) + else: + self.scales_[layer_name] = AWQScale(inps=[inp]) + + return hook_fn + + for mapping in self.resolved_mappings_: + name = mapping.smooth_name + # storing inps to first balance layer + # is enough, as other balance layers + # get the same input + layer = mapping.balance_layers[0] + self.hooks_.append(layer.register_forward_hook(create_hook_fn(name))) + + @torch.no_grad() + def _calibrate(self, model: Module, calibration_dataloader: List): + """ + Catch the output dynamic ranges of each layer that will be smoothed by running + forward passes with calibration_dataloader + """ + class_name = self.__class__.__name__.replace("PyTorch", "") + logger.info( + f"Running {class_name} calibration with " + f"{len(calibration_dataloader)} samples..." + ) + if not calibration_dataloader: + raise ValueError( + "Calibration data loader not set, must populate the calib_data field of" + " CompressionSession to run the AWQ modifier" + ) + + run_calibration_forward( + model, + calibration_dataloader, + self.num_calibration_steps, + self.calibration_function, + ) + + # remove the hooks now that we are done calibrating + for hook in self.hooks_: + hook.remove() + del self.hooks_ + + def _concat_collected_activations(self): + """ + Concatenate the collected activation values from each forward pass into a single + tensor for each layer + + :postcondition: each layer in self.scales_ will have a single tensor containing + all the activation values seen during calibration + """ + for mapping in self.resolved_mappings_: + name = mapping.smooth_name + self.scales_[name].inps = torch.cat(self.scales_[name].inps, dim=0) + + torch.cuda.empty_cache() + + @torch.no_grad() + def _apply_smoothing(self, model: Module): + """ + Calculate the best scaling factors for each layer to smooth activations and + apply the scaling factors to the weights of the next layer to offset the + smoothing + + :param model: model to apply smoothing to + """ + logger.info("Smoothing activation scales...") + for mapping in tqdm(self.resolved_mappings_): + smooth_layer = mapping.smooth_layer + balance_layers = mapping.balance_layers + balance_names = mapping.balance_names + + activations = self.scales_[mapping.smooth_name].inps + + module2inspect = mapping.parent + + # [STEP 1]: Compute per-channel mean of normalised weights + # All layer weights are concatted together + weight = torch.cat([_m.weight for _m in balance_layers], dim=0) + org_shape = weight.shape + # The weights are reshaped to be organised by quantization group + weight = weight.view(-1, self.group_size) + # Calculates the relative magnitude of the weights within + # each of the quantization groups, and rescales each group + # individually so that each group has weights on a 0-1 scale. + w_scale = weight.abs() / (weight.abs().amax(dim=1, keepdim=True) + 1e-6) + # Resizes the rescaled weight matrix back up to its original dimensions + w_scale = w_scale.view(org_shape) + # Gets the average rescaled magnitude for each output channel + w_mean = w_scale.mean(0) + + # [STEP 2]: Compute per-channel mean of the input activation with chunking + # move inp to cpu to avoid memory leak + inp = activations + inp_flat = inp.cpu().abs().view(-1, inp.shape[-1]) + num_elements = inp_flat.size(0) + num_channels = inp_flat.size(1) + element_size_bytes = inp_flat.element_size() * 2 # multiplied by 2 for FP32 + + # Calculate chunk size dynamically based on max_chunk_memory + chunk_size = int( + self.max_chunk_memory // (element_size_bytes * num_channels) + ) + chunk_size = min(chunk_size, num_elements) + + # Use float32 for sum calculation + x_sum = torch.zeros(num_channels, dtype=torch.float32, device=inp.device) + + for i in range(0, num_elements, chunk_size): + end = min(i + chunk_size, num_elements) + chunk_sum = inp_flat[i:end].to(torch.float32).sum(dim=0) + x_sum += chunk_sum.to(inp.device) + + x_mean = (x_sum / num_elements).to(inp.dtype) + + # [STEP 3]: Compute output of module + with torch.no_grad(): + fp16_output = self._forward_input_with_kwargs( + module=module2inspect, inputs=inp, input_kwargs=self.module_kwargs_ + ) + + # [STEP 4]: Compute loss + best_scales = self._compute_best_scale( + inp, w_mean, x_mean, module2inspect, balance_layers, fp16_output + ) + + scales = best_scales + + @torch.no_grad() + def smooth(module): + if module in balance_layers: + module.weight.mul_(scales.view(1, -1).to(module.weight.device)) + elif module == smooth_layer: + if module.weight.ndim == 1: + module.weight.div_(scales.to(module.weight.device)) + else: + module.weight.div_(scales.view(-1, 1).to(module.weight.device)) + if hasattr(module, "bias") and module.bias is not None: + module.bias.div_(scales.to(module.bias.device)) + + parent = get_fsdp_parent(mapping.smooth_name, model) + if parent is not None: + parent.apply(smooth) + else: + # if we're not running with FSDP we can apply smoothing directly + for layer in balance_layers: + smooth(layer) + smooth(smooth_layer) + + if self.apply_clip: + clip_list = self._search_best_clip( + balance_layers=balance_layers, + balance_names=balance_names, + input_feat=inp, + ) + + _apply_clip(model, clip_list) + + # clear out allocated smoothing scales + torch.cuda.empty_cache() + + def _compute_best_scale( + self, + x: torch.Tensor, + w_mean: torch.Tensor, + x_mean: torch.Tensor, + module2inspect: torch.nn.Module, + linears2scale: List[torch.nn.Linear], + fp16_output: torch.Tensor, + ): + """ + Compute loss and select best scales + + L(s) = || Q(W * s) (s^-1 * X) - W * X || + Q: weight quantization function | pseudo_quantize_tensor(W * s) + X: inputs from calib dataset | X + W: original weights in FP16 | layer + s: per channel scaling factor | s^-1 * X + """ + n_grid = 20 + history = [] + best_ratio = -1 + best_scales = None + best_error = float("inf") + + org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()} + + device = x.device + x_mean = x_mean.view(-1).to(device) + w_mean = w_mean.view(-1).to(device) + + for ratio in range(n_grid): + # create new scales + ratio = ratio / n_grid + + # NOTE: s^-1 * x is fused here, according to paper + if self.duo_scaling: + scales = (x_mean.pow(ratio) / (w_mean.pow(1 - ratio) + 1e-4)).clamp( + min=1e-4 + ) + else: + scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) + scales = scales / (scales.max() * scales.min()).sqrt() + scales_view = scales.view(1, -1).to(device) + + # avoid scaling values that overflow + scales[torch.isinf(scales)] = 1 + scales[torch.isnan(scales)] = 1 + + # Q(W * s) + for fc in linears2scale: + fc.weight.mul_(scales_view) + fc.weight.data = ( + pseudo_quantize_tensor( + w=fc.weight.data, + symmetric=self.symmetric, + bit_width=self.bits, + group_size=self.group_size, + )[0] + / scales_view + ) + + # W * X + int_w_output = self._forward_input_with_kwargs( + module=module2inspect, inputs=x, input_kwargs=self.module_kwargs_ + ) + + # compute mean squared error (L2 norm) + loss = self._compute_loss(fp16_output, int_w_output, device) + + history.append(loss) + if loss < best_error: + best_error = loss + best_ratio = ratio + best_scales = scales.clone() + module2inspect.load_state_dict(org_sd) + + if best_ratio == -1: + logger.debug(history) + raise Exception + + assert torch.isnan(best_scales).sum() == 0, best_scales + + return best_scales.detach().cpu() + + @torch.no_grad() + def _compute_loss( + self, + fp16_output: torch.Tensor, + int_w_output: torch.Tensor, + device: torch.device, + ): + loss = 0.0 + fp16_output_flat = fp16_output.view(-1) + int_w_output_flat = int_w_output.view(-1) + num_elements = fp16_output_flat.size(0) + element_size_bytes = fp16_output.element_size() + + # Calculate chunk size dynamically based on max_chunk_memory + # Divide the max_chunk_memory by twice the element size + chunk_size = self.max_chunk_memory // (element_size_bytes * 2) + chunk_size = min(chunk_size, num_elements) + + # Split the computation into chunks + fp16_chunks = torch.split(fp16_output_flat, chunk_size) + int_w_chunks = torch.split(int_w_output_flat, chunk_size) + + # Compute the loss for each chunk + for fp16_chunk, int_w_chunk in zip(fp16_chunks, int_w_chunks): + chunk_loss = ( + (fp16_chunk.to(device) - int_w_chunk.to(device)) + .float() + .pow(2) + .sum() + .item() + ) + loss += chunk_loss + + # Normalize the loss by the total number of elements + loss /= num_elements + + return loss + + def _get_module_kwargs(self, model, dataloader): + _, modules = next(iter(get_layers("re:.*layers", model).items())) + + samples = [batch["input_ids"] for batch in dataloader] + + samples = torch.cat(samples, dim=0) + + inps = [] + layer_kwargs = {} + + best_device = "cuda" + modules[0] = modules[0].to(best_device) + # self.awq_model.move_embed(self.model, best_device) + + # get input and kwargs to layer 0 + # with_kwargs is only supported in PyTorch 2.0 + # use this Catcher hack for now + class Catcher(torch.nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + # assume first input to forward is hidden states + if len(args) > 0: + hidden_states = args[0] + del args + else: + first_key = list(kwargs.keys())[0] + hidden_states = kwargs.pop(first_key) + + inps.append(hidden_states) + layer_kwargs.update(kwargs) + raise ValueError # early exit to break later inference + + # patch layer 0 to catch input and kwargs + modules[0] = Catcher(modules[0]) + try: + model(samples.to(next(model.parameters()).device)) + except ValueError: # work with early exit + pass + modules[0] = modules[0].module # restore + + # Update the layer kwargs with `prepare_inputs_for_generation` method + # that takes care of everything to avoid unexpected errors. + layer_kwargs = model.prepare_inputs_for_generation(samples, **layer_kwargs) + # Pop the input_ids as they are not needed at all. + layer_kwargs.pop("input_ids") + + del samples + inps = inps[0] + + torch.cuda.empty_cache() + + if layer_kwargs.get("attention_mask") is not None: + layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to( + best_device + ) + + self.module_kwargs_ = layer_kwargs + + def _forward_input_with_kwargs( + self, + module: Module, + inputs: torch.Tensor, + input_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Forward pass with input arguments + + :param module: module to run forward pass on + :param inputs: input tensor to pass to the module + :param input_kwargs: additional arguments to pass to the module + :return: the first output tensor from the forward pass + """ + kwargs = input_kwargs or self.module_kwargs_ or {} + return tensor_forward_with_input_args( + module=module, + inputs=inputs, + input_kwargs=kwargs, + )[0] + + @torch.no_grad() + def _search_best_clip(self, balance_layers, balance_names, input_feat): + clip_list = [] + avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"] + + for name, layer in zip(balance_names, balance_layers): + # due to qk bmm, it is hard to clip precisely + if any([_ in name for _ in avoid_clipping]): + continue + + max_val = self._compute_best_clip(layer.weight, input_feat) + clip_list.append((name, max_val)) + + return clip_list + + @torch.no_grad() + def _compute_best_clip( + self, + w: torch.Tensor, + input_feat: torch.Tensor, + n_grid=20, + max_shrink=0.5, + n_sample_token=512, + ): + assert w.dim() == 2 + org_w_shape = w.shape + # w [co, ci] -> [co, 1, n_group, group size] + # input_feat [n_token, ci] -> [1, n_token, n_group, group size] + group_size = self.group_size if self.group_size > 0 else org_w_shape[1] + input_feat = input_feat.view(-1, input_feat.shape[-1]) + input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size) + + # Compute input feature step size (minimum 1) + step_size = max(1, input_feat.shape[1] // n_sample_token) + input_feat = input_feat[:, ::step_size] + + w = w.reshape(org_w_shape[0], 1, -1, group_size) + + oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64 # prevent OOM + assert org_w_shape[0] % oc_batch_size == 0 + w_all = w + best_max_val_all = [] + + for i_b in range(org_w_shape[0] // oc_batch_size): + w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size] + + org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1 + + best_max_val = org_max_val.clone() + min_errs = torch.ones_like(org_max_val) * 1e9 + input_feat = input_feat.to(w.device) + org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group + + for i_s in range(int(max_shrink * n_grid)): + max_val = org_max_val * (1 - i_s / n_grid) + min_val = -max_val + cur_w = torch.clamp(w, min_val, max_val) + q_w = pseudo_quantize_tensor( + w=cur_w, + symmetric=self.symmetric, + group_size=group_size, + bit_width=self.bits, + )[0] + cur_out = (input_feat * q_w).sum(dim=-1) + + # co, 1, n_group, 1 + err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape) + del cur_w + del cur_out + cur_best_idx = err < min_errs + min_errs[cur_best_idx] = err[cur_best_idx] + best_max_val[cur_best_idx] = max_val[cur_best_idx] + best_max_val_all.append(best_max_val) + + best_max_val = torch.cat(best_max_val_all, dim=0) + + clear_memory(input_feat) + clear_memory(org_out) + + return best_max_val.squeeze(1) + + +@torch.no_grad() +def _apply_clip(module, clip_list: Tuple[str, torch.Tensor]): + """ + Apply clipping to the weights of the given module + + :post-condition: the weights of the module are clipped to the given maximum values + :param module: module to apply clipping to + :param clip_list: list of tuples containing the name of the layer and the maximum + value to clip the weights to + """ + for name, max_val in clip_list: + _, layer = get_layer(target=name, module=module) + assert isinstance(layer, torch.nn.Linear) + max_val = max_val.to(layer.weight.device) + org_shape = layer.weight.shape + layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1) + layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val) + layer.weight.data = layer.weight.data.reshape(org_shape) diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index d0e497766..6abe47707 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -2,6 +2,10 @@ Utility / helper functions """ +import functools +import gc +import inspect +import os import random from typing import Any, Dict, Iterable, List, Mapping, OrderedDict, Tuple, Union @@ -27,6 +31,18 @@ "get_linear_layers", "get_quantized_layers", "set_deterministic_seeds", + "torch_distributed_zero_first", + "thin_model_from_checkpoint", + "MEMORY_BOUNDED", + "memory_aware_threshold", + "detach", + "adjust_quantization_for_onnx_export", + "get_dependency_order", + "pseudo_quantize_tensor", + "pseudo_dequantize_linear", + "tensor_forward_with_input_args", + "sanitize_kwargs_for_module", + "clear_memory", ] @@ -200,6 +216,147 @@ def tensor_sparsity( return zeros.float() / float(total) +def tensor_density(tens: Tensor, dim: Union[None, int, Iterable[int]] = None) -> Tensor: + """ + :param tens: the tensor to calculate the density for + :param dim: the dimension(s) to split the calculations over; ex, can split over + batch, channels, or combos + :return: the density of the input tens, ie the fraction of numbers that are non zero + """ + density = (tensor_sparsity(tens, dim) - 1.0) * -1.0 + + return density + + +def tensor_sample( + tens: Tensor, + sample_size: int, + dim: Union[None, int, List[int], Tuple[int, ...]] = None, +) -> Tensor: + """ + :param tens: the tensor to grab samples from + :param sample_size: the number of samples to grab overall if dim is not supplied + or per each dim if it is + :param dim: the dimension(s) to split the samples over; + ex, can split over batch, channels, or combos + :return: the sampled tensor + """ + if sample_size < 1: + raise ValueError("improper sample size given of {}".format(sample_size)) + + if dim is None: + indices = tens.new_zeros((sample_size,)).long().random_(0, tens.numel()) + samples = tens.view(-1)[indices] + + return samples + + if isinstance(dim, int): + dim = [dim] + + if max(dim) >= len(tens.shape): + raise ValueError( + "Unsupported dim given of {} in {} for tensor shape {}".format( + max(dim), dim, tens.shape + ) + ) + + if dim != [ind for ind in range(len(dim))]: + # put the desired dimension(s) at the front to sample from + tens = tens.permute( + *dim, *[ind for ind in range(len(tens.shape)) if ind not in dim] + ) + dim = [ind for ind in range(len(dim))] + + if not tens.is_contiguous(): + tens = tens.contiguous() + + num_indices = int(numpy.prod([tens.shape[ind] for ind in range(len(dim))])) + elem_per_ind = int( + numpy.prod([tens.shape[ind] for ind in range(len(dim), len(tens.shape))]) + ) + # create a new tensor with offsets set for each of our elements that we are indexing + indices = tens.new_tensor( + [ind * elem_per_ind for ind in range(num_indices)], dtype=torch.long + ).unsqueeze(1) + # now broadcast it across to the total number of elements we should end with + indices = indices * tens.new_ones((num_indices, sample_size), dtype=torch.long) + # finally add in a random number within the available range per index + indices += tens.new_zeros((num_indices, sample_size), dtype=torch.long).random_( + 0, elem_per_ind + ) + # get our samples + samples = tens.view(-1)[indices.view(-1)] + # reshape for the proper dimension + samples = samples.view(*(tens.shape[ind] for ind in dim), sample_size) + + return samples + + +def tensor_list_sparsity(tensors: List[Tensor]) -> float: + """ + :param tensors: the list of tensors to calculate the sparsity for + :return: the total sparsity of all tensors in the list + """ + zeros = 0 + numel = 0 + for tensor in tensors: + zeros += (tensor == 0).sum().item() + numel += tensor.numel() + return float(zeros) / float(numel) + + +def mask_difference(old_mask: Tensor, new_mask: Tensor) -> Tensor: + """ + :param old_mask: the old mask to compare against for calculating the difference + :param new_mask: the new mask to compare with for calculating the difference + :return: a tensor representing the change from the old_mask to the new_mask + specifically values returned as 1.0 are newly unmasked (0.0 => 1.0) + values returned as -1.0 are newly masked (1.0 => 0.0) + values returned as 0.0 had no change in (0.0 => 0.0 or 1.0 => 1.0) + """ + newly_masked = ((old_mask != new_mask) & (new_mask == 0.0)).type(old_mask.type()) + newly_unmasked = ((old_mask != new_mask) & (new_mask == 1.0)).type(old_mask.type()) + + return -1.0 * newly_masked + newly_unmasked + + +def sanitize_kwargs_for_module( + kwargs: Dict[str, Any], module: Module +) -> Dict[str, Any]: + """ + Sanitize the kwargs for a Module by removing any keys that are not + in the signature of the forward method. + :param kwargs: the kwargs to sanitize + :param module: the Module to sanitize the kwargs for + :return: the sanitized kwargs for the callable object + """ + if not isinstance(kwargs, dict): + raise TypeError(f"Expected a dictionary as kwargs, but got {kwargs}") + + allowed_params = inspect.signature(module.forward).parameters + return {key: value for key, value in kwargs.items() if key in allowed_params} + + +def tensor_forward_with_input_args( + module: Module, inputs: Tensor, input_kwargs: Dict[str, Any] +) -> Tensor: + """ + Forward the given inputs through the given module with the given input_kwargs. + This function is a wrapper around tensors_module_forward that ensures that the + input_kwargs are sanitized and passed to the module as keyword arguments during + the forward pass. + :param module: the module to forward the inputs through + :param inputs: the inputs to forward through the module + :param input_kwargs: the keyword arguments to pass to the + module during the forward pass + :return: the output of the module after forwarding the inputs through it + """ + inputs = inputs.to(next(module.parameters()).device) + input_kwargs = sanitize_kwargs_for_module(input_kwargs, module) + + return tensors_module_forward(inputs, functools.partial(module, **input_kwargs)) + + ############################## # # pytorch module helper functions @@ -244,3 +401,327 @@ def set_deterministic_seeds(seed: int = 0): random.seed(seed) torch.manual_seed(seed) torch.backends.cudnn.deterministic = True + + +@contextmanager +def torch_distributed_zero_first(local_rank: Optional[int]): + """ + Decorator to make all processes in distributed training wait for each + local 0 ranked process to do something. + :param local_rank: the local rank of this process + """ + if local_rank is not None and local_rank not in [-1, 0]: + torch.distributed.barrier() + yield + if local_rank == 0: + torch.distributed.barrier() + + +def thin_model_from_checkpoint(model: Module, state_dict: Dict[str, Any]): + """ + Updates any Linear/Conv/BN layers in the given model to match their + respective shapes in the given state dict. Purpose of compatibility + when loading weight for a model from a checkpoint of the same architecture + but with potentially structured thinning applied. Note that this function + has no guarantees on accuracy, will only resize model parameters for + loading compatibility. All adjustments done in place + + :param model: model to potentially adjust parameter shapes of + :param state_dict: state dict to infer parameter shapes from + """ + first_thinned = True + for param_name, checkpoint_tens in state_dict.items(): + if not param_name.endswith(".weight"): + continue # only deal with weight params of modules + layer_name = param_name[:-7] + layer = get_layer(layer_name, model) + + if not hasattr(layer, "weight") or ( + layer.weight.shape == checkpoint_tens.shape + ): + continue # skip if there is no update to shape + + # quick check that target layer is some flavor of FC/Conv/BN + layer_type = layer.__class__.__name__ + if not ( + "Linear" not in layer_type + or "Conv" not in layer_type + or ("BatchNorm" not in layer_type) + ): + continue + + orig_shape = layer.weight.shape + target_shape = checkpoint_tens.shape + + # update weight param + grad + if len(target_shape) > 1: + layer.weight.data = layer.weight.data[ + : target_shape[0], : target_shape[1], ... + ] + if layer.weight.grad is not None: + layer.weight.grad = layer.weight.grad[ + : target_shape[0], : target_shape[1], ... + ] + else: + layer.weight.data = layer.weight.data[: target_shape[0]] + if layer.weight.grad is not None: + layer.weight.grad = layer.weight.grad[: target_shape[0]] + + # update bias param + grad + if hasattr(layer, "bias") and layer.bias is not None: + # target output channels should be the first dim of target shape + layer.bias.data = layer.bias.data[: target_shape[0]] + if layer.bias.grad is not None: + layer.bias.grad = layer.bias.grad[: target_shape[0]] + + # update layer attributes + if "BatchNorm" in layer_type: + if hasattr(layer, "num_features"): + layer.num_features = layer.weight.size(0) + # BN running mean and var are not stored as Parameters + if hasattr(layer, "running_mean"): + layer.running_mean = torch.zeros_like(layer.running_mean)[ + : target_shape[0] + ] + if hasattr(layer, "running_var"): + layer.running_var = torch.zeros_like(layer.running_var)[ + : target_shape[0] + ] + + if "Linear" in layer_type: + if hasattr(layer, "out_features"): + layer.out_features = layer.weight.shape[0] + if hasattr(layer, "in_features"): + layer.in_features = layer.weight.shape[1] + + if "Conv" in layer_type: + if hasattr(layer, "out_channels"): + layer.out_channels = layer.weight.shape[0] + if hasattr(layer, "in_channels"): + layer.in_channels = layer.weight.shape[1] + if hasattr(layer, "groups") and layer.groups > 1: + layer.groups = layer.weight.shape[0] // layer.weight.shape[1] + + if first_thinned: + logger.info( + "Thinning module layers for compatibility with given state dict:" + ) + first_thinned = False + logger.info( + f"Thinned layer {layer_name} from shape {orig_shape} to " + f"{layer.weight.shape}" + ) + + +############################## +# +# misc pytorch helper functions +# +############################## + + +MEMORY_BOUNDED = "MEMORY_BOUNDED" + + +def memory_aware_threshold(tensor: torch.Tensor, idx: int) -> Tensor: + """ + Finds a threshold at the lookup idx in the most efficient way with available + resources. Will be phased out when GPU-memory overhead of torch.sort reduces, + or when torch.kthvalue becomes faster than torch.sort. + + :param tensor: A tensor to find a k-th smallest value in, where k=idx+1 + :param idx: A lookup index + :return: k-th smallest value from the given tensor, where k=idx+1 + """ + try: + if ( + MEMORY_BOUNDED in os.environ + and os.environ[MEMORY_BOUNDED].lower() == "true" + ): + return torch.kthvalue(tensor.reshape(-1), idx + 1)[0] + else: + return torch.sort(tensor.reshape(-1))[0][idx] + except RuntimeError: + logger.warning( + "Finding threshold from sparsity failed due to lack of memory, " + "will attempt to recover. Consider setting env variable " + f"{MEMORY_BOUNDED}=True in future runs." + ) + torch.cuda.empty_cache() + os.environ[MEMORY_BOUNDED] = "True" + return torch.kthvalue(tensor.view(-1), idx + 1)[0] + + +def detach(x: Union[torch.Tensor, List, Tuple]): + if isinstance(x, torch.Tensor): + return x.detach() + elif isinstance(x, List): + return [detach(e) for e in x] + elif isinstance(x, Tuple): + return tuple([detach(e) for e in x]) + else: + raise ValueError("Unexpected type to detach") + + +def adjust_quantization_for_onnx_export(module: torch.nn.Module) -> torch.nn.Module: + # supported pytorch ranges are int8 or uint8 + allowed_ranges = [(0, 127), (0, 255), (-128, 127)] + fake_quant_modules = [ + m for m in module.modules() if m.__class__.__name__ == "FakeQuantize" + ] + + if _PARSED_TORCH_VERSION >= version.parse("1.12"): + for quant in fake_quant_modules: + # original ranges preserved in quant.quant_min and quant.quant_max + quant_range = ( + quant.activation_post_process.quant_min, + quant.activation_post_process.quant_max, + ) + if quant_range not in allowed_ranges: + if quant_range[0] < 0: # convert signed range to int8 + quant.activation_post_process.quant_min = -128 + quant.activation_post_process.quant_max = 127 + else: # convert unsigned range to uint8 + quant.activation_post_process.quant_min = 0 + quant.activation_post_process.quant_max = 255 + # don't update observer since ranges are artificially modified + quant.observer_enabled[0] = 0 + + else: # backwards compatibility for torch <= 1.11 + for quant in fake_quant_modules: + quant_range = (quant.quant_min, quant.quant_max) + if quant_range not in allowed_ranges: + if quant_range[0] < 0: # convert signed range to int8 + quant.quant_min = -128 + quant.quant_max = 127 + else: # convert unsigned range to uint8 + quant.quant_min = 0 + quant.quant_max = 255 + # don't update observer since ranges are artificially modified + quant.observer_enabled[0] = 0 + + +def get_dependency_order( + layer: Module, subset: Dict, an_input: Tensor, **kwargs +) -> List[str]: + """ + Get a list of a subset of modules in layer ordered by execution order, which honors + the dependencies in the graph + + :param layer: pytorch module to calculate dependencies for + :param subset: subset of modules in the layer to include in the ordering + :param an_input: example input to pass through the layer forward pass, used to + determine execution order + + :return: list of module names in execution order + """ + order = [] + + def exe_input(name): + def _exe_input(_, inp, out): + if name in subset: + order.append(name) + + return _exe_input + + # register a hook for each module of interest, will be triggered in exeuction order + handles = [subset[name].register_forward_hook(exe_input(name)) for name in subset] + layer(an_input, **kwargs) + for h in handles: + h.remove() + return order + + +def swap_modules( + module: torch.nn.Module, submodule_name: str, submodule_to_replace: torch.nn.Module +) -> torch.nn.Module: + """ + Iteratively unfold the submodules of the module according to the submodule_name + to eventually replace the leaf submodule (accessed from the module through the + submodule_name) with the submodule_to_replace. + + E.g + ``` + swap_modules(module=Model, + module_name="layers.0.sublayer", + module_to_replace=ReplaceModule + ) + ``` + this will iteratively traverse through the submodules + 'layers' -> '0' -> to eventually replace 'sublayer' with ReplaceModule + + :param module: the module to replace with the module_to_replace + :param submodule_name: the name of the module to replace + :param submodule_to_replace: the module to replace the module with + :return: the replaced module + """ + parent = module + sections = submodule_name.split(".") + + for sec in sections[:-1]: + parent = parent.__getattr__(sec) + + cur = parent.__getattr__(sections[-1]) + parent.__setattr__(sections[-1], submodule_to_replace) + + return cur + + +def pseudo_quantize_tensor( + w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 +): + org_w_shape = w.shape + if group_size > 0: + assert org_w_shape[-1] % group_size == 0 + w = w.reshape(-1, group_size) + assert w.dim() == 2 + assert torch.isnan(w).sum() == 0 + + if not symmetric: + max_val = w.amax(dim=1, keepdim=True) + min_val = w.amin(dim=1, keepdim=True) + max_int = 2**bit_width - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-5) / max_int + zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) + w = ( + torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros + ) * scales + zeros = zeros.view(org_w_shape[0], -1) + else: + max_val = w.abs().amax(dim=1, keepdim=True) + max_val = max_val.clamp(min=1e-5) + max_int = 2 ** (bit_width - 1) - 1 + min_int = -(2 ** (bit_width - 1)) + scales = max_val / max_int + zeros = None + w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales + + assert torch.isnan(scales).sum() == 0 + assert torch.isnan(w).sum() == 0 + + scales = scales.view(org_w_shape[0], -1) + w = w.reshape(org_w_shape) + + return w, scales, zeros + + +def pseudo_dequantize_linear( + w: torch.Tensor, + scales: torch.Tensor, + zeros: Optional[torch.Tensor] = None, + symmetric: bool = False, +): + # get repeated count + repeat_count = w.weight.data.shape[-1] // scales.shape[-1] + scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape) + + # dequantize + if not symmetric: + zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape) + w = (w.weight.data - zeros) * scales + else: + w = w.weight.data * scales + + return w + diff --git a/src/llmcompressor/transformers/finetune/data/__init__.py b/src/llmcompressor/transformers/finetune/data/__init__.py index 3754908fa..655a85df6 100644 --- a/src/llmcompressor/transformers/finetune/data/__init__.py +++ b/src/llmcompressor/transformers/finetune/data/__init__.py @@ -9,6 +9,7 @@ from .gsm8k import GSM8KDataset from .open_platypus import OpenPlatypusDataset from .peoples_speech import PeoplesSpeech +from .pile import PileEvalDataset from .ptb import PtbDataset from .ultrachat_200k import UltraChatDataset from .wikitext import WikiTextDataset diff --git a/src/llmcompressor/transformers/finetune/data/pile.py b/src/llmcompressor/transformers/finetune/data/pile.py new file mode 100644 index 000000000..b3a99ea0e --- /dev/null +++ b/src/llmcompressor/transformers/finetune/data/pile.py @@ -0,0 +1,45 @@ +from copy import deepcopy +from typing import Optional + +from llmcompressor.transformers.finetune.data import TextGenerationDataset + + +@TextGenerationDataset.register(name="pile_eval") +class PileEvalDataset(TextGenerationDataset): + """ + Child text generation class for the PileEval dataset + :param data_args: configuration settings for dataset loading + :param split: split from dataset to load, for instance `test` or `train[:5%]` + :param tokenizer: tokenizer to use on dataset + """ + + def __init__(self, data_args, split, tokenizer): + data_args = deepcopy(data_args) + data_args.dataset = "mit-han-lab/pile-val-backup" + super().__init__( + text_column="text", data_args=data_args, split=split, tokenizer=tokenizer + ) + + def get_raw_dataset(self, cache_dir: Optional[str] = None): + """ + Load the raw dataset from Hugging Face, using cached copy if available. + Additionally reformats the entries to fit the template. + :param cache_dir: disk location to search for cached dataset + :return: the requested dataset + """ + raw_dataset = super().get_raw_dataset(cache_dir=cache_dir) + + def restructure_fn(sample): + sample["text"] = sample["text"].strip() + return sample + + raw_dataset = self.map( + raw_dataset, + function=restructure_fn, + batched=False, + remove_columns=["meta"], + num_proc=self.data_args.preprocessing_num_workers, + load_from_cache_file=not self.data_args.overwrite_cache, + desc="Restructuring Pile Dataset", + ) + return raw_dataset diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index 8f7eadb53..c980f00c8 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -60,6 +60,7 @@ "get_layers_params", "get_matching_layer", "get_no_split_params", + "get_parent_by_name", ] @@ -338,3 +339,22 @@ def get_no_split_params(module: Module) -> Union[str, List[str]]: if hasattr(model, "_no_split_modules"): return model._no_split_modules return ALL_TARGET + + +def get_parent_by_name(layer_name: str, model: Module) -> Tuple[str, Module]: + """ + Get the parent layer of a layer by name. + :param layer_name: Name of the layer to find the parent of. + :param model: Model to search for the parent layer. + :return: Tuple containing the name of the parent layer + and the parent layer itself. + """ + if not any(layer_name == name for name, _ in model.named_modules()): + raise ValueError(f"Layer '{layer_name}' not found in model") + + parent_name_parts = layer_name.split(".")[:-1] + if not parent_name_parts: + return "", model + + parent_name = ".".join(parent_name_parts) + return get_layer(parent_name, model) diff --git a/tests/llmcompressor/modifiers/awq/__init__.py b/tests/llmcompressor/modifiers/awq/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py new file mode 100644 index 000000000..918238718 --- /dev/null +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -0,0 +1,28 @@ +import unittest + +import pytest + +from llmcompressor.modifiers.awq import AWQModifier +from llmcompressor.modifiers.factory import ModifierFactory +from tests.llmcompressor.modifiers.conf import setup_modifier_factory + + +@pytest.mark.unit +class TestAWQIsRegistered(unittest.TestCase): + def setUp(self): + self.kwargs = {} + setup_modifier_factory() + + def test_awq_is_registered(self): + modifier = ModifierFactory.create( + type_="AWQModifier", + allow_experimental=False, + allow_registered=True, + **self.kwargs, + ) + + self.assertIsInstance( + modifier, + AWQModifier, + "PyTorch AWQModifier not registered", + ) diff --git a/tests/llmcompressor/pytorch/utils/test_helpers.py b/tests/llmcompressor/pytorch/utils/test_helpers.py index cc4edfdda..7d844cb00 100644 --- a/tests/llmcompressor/pytorch/utils/test_helpers.py +++ b/tests/llmcompressor/pytorch/utils/test_helpers.py @@ -7,6 +7,17 @@ from torch.nn import Linear, Module, ReLU, Sequential from llmcompressor.pytorch.utils import ( + MEMORY_BOUNDED, + default_device, + get_optim_learning_rate, + mask_difference, + memory_aware_threshold, + sanitize_kwargs_for_module, + set_optim_learning_rate, + tensor_density, + tensor_export, + tensor_forward_with_input_args, + tensor_sample, tensor_sparsity, tensors_module_forward, tensors_to_device, @@ -494,3 +505,252 @@ def test_tensor_sparsity_cuda(tensor, dim, expected_sparsity): sparsity = tensor_sparsity(tensor, dim) assert expected_sparsity.shape == sparsity.shape assert torch.sum((sparsity.detach().cpu() - expected_sparsity).abs()) < 0.001 + + +@pytest.mark.flaky(reruns=2, min_passes=1) +@pytest.mark.skipif( + os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False), + reason="Skipping pytorch tests", +) +@pytest.mark.parametrize( + "tensor,dim,expected_density", + [ + (torch.zeros(8, 16), None, torch.tensor(0.0)), + (torch.zeros(8, 16), 0, torch.zeros(8)), + (torch.zeros(8, 16), 1, torch.zeros(16)), + (torch.zeros(8, 16), [0, 1], torch.zeros(8, 16)), + (torch.zeros(8, 16), [1, 0], torch.zeros(16, 8)), + (torch.zeros(8, 16, 32, 8), [3, 1, 2], torch.zeros(8, 16, 32)), + (torch.ones(8, 16), None, torch.tensor(1.0)), + (torch.ones(8, 16), 0, torch.ones(8)), + (torch.ones(8, 16), 1, torch.ones(16)), + (torch.ones(8, 16), [0, 1], torch.ones(8, 16)), + (torch.ones(8, 16), [1, 0], torch.ones(16, 8)), + (torch.ones(8, 16, 32, 8), [3, 1, 2], torch.ones(8, 16, 32)), + (torch.randn(8, 16), None, torch.tensor(1.0)), + (torch.randn(8, 16), 0, torch.ones(8)), + (torch.randn(8, 16), 1, torch.ones(16)), + (torch.randn(8, 16), [0, 1], torch.ones(8, 16)), + (torch.randn(8, 16), [1, 0], torch.ones(16, 8)), + (torch.randn(8, 16, 32, 8), [3, 1, 2], torch.ones(8, 16, 32)), + ( + torch.tensor([10.0, 0.0, 1.0, 3.0, 2.0, 0.0, 8.0, 0.0, 5.0, 0.0]), + None, + torch.tensor(0.6), + ), + ], +) +def test_tensor_density(tensor, dim, expected_density): + density = tensor_density(tensor, dim) + assert expected_density.shape == density.shape + assert torch.sum((density - expected_density).abs()) < 0.001 + + +@pytest.mark.flaky(reruns=2, min_passes=1) +@pytest.mark.skipif( + os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False), + reason="Skipping pytorch tests", +) +@pytest.mark.parametrize( + "tensor,dim,expected_density", + [ + (torch.zeros(8, 16), None, torch.tensor(0.0)), + (torch.zeros(8, 16, 32, 8), [3, 1, 2], torch.zeros(8, 16, 32)), + (torch.ones(8, 16), None, torch.tensor(1.0)), + (torch.ones(8, 16, 32, 8), [3, 1, 2], torch.ones(8, 16, 32)), + (torch.randn(8, 16), None, torch.tensor(1.0)), + ( + torch.tensor([10.0, 0.0, 1.0, 3.0, 2.0, 0.0, 8.0, 0.0, 5.0, 0.0]), + None, + torch.tensor(0.6), + ), + ], +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda availability") +def test_tensor_density_cuda(tensor, dim, expected_density): + tensor = tensor.to("cuda") + density = tensor_density(tensor, dim) + assert expected_density.shape == density.shape + assert torch.sum((density.detach().cpu() - expected_density).abs()) < 0.001 + + +@pytest.mark.skipif( + os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False), + reason="Skipping pytorch tests", +) +@pytest.mark.parametrize( + "tensor,size,dim,expected_shape", + [ + (torch.randn(8, 16), 100, None, [100]), + (torch.randn(8, 16), 100, 0, [8, 100]), + (torch.randn(8, 16), 100, 1, [16, 100]), + (torch.randn(8, 16), 10, [0, 1], [8, 16, 10]), + (torch.randn(8, 16), 10, [1, 0], [16, 8, 10]), + (torch.randn(64, 12, 32, 16), 10, 2, [32, 10]), + (torch.randn(64, 12, 32, 16), 10, [3, 2], [16, 32, 10]), + (torch.randn(64, 12, 32, 16), 10, 1, [12, 10]), + (torch.randn(64, 12, 32, 16), 10, [0, 1], [64, 12, 10]), + ], +) +def test_tensor_sample(tensor, size, dim, expected_shape): + sample = tensor_sample(tensor, size, dim) + assert len(sample.shape) == len(expected_shape) + for s1, s2 in zip(sample.shape, expected_shape): + assert s1 == s2 + + +@pytest.mark.skipif( + os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False), + reason="Skipping pytorch tests", +) +@pytest.mark.parametrize( + "tensor,size,dim,expected_shape", + [ + (torch.randn(8, 16), 100, None, [100]), + (torch.randn(8, 16), 100, 0, [8, 100]), + (torch.randn(8, 16), 100, 1, [16, 100]), + (torch.randn(8, 16), 10, [0, 1], [8, 16, 10]), + (torch.randn(8, 16), 10, [1, 0], [16, 8, 10]), + (torch.randn(64, 12, 32, 16), 10, 2, [32, 10]), + (torch.randn(64, 12, 32, 16), 10, [3, 2], [16, 32, 10]), + (torch.randn(64, 12, 32, 16), 10, 1, [12, 10]), + (torch.randn(64, 12, 32, 16), 10, [0, 1], [64, 12, 10]), + ], +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda availability") +def test_tensor_sample_cuda(tensor, size, dim, expected_shape): + tensor = tensor.to("cuda") + sample = tensor_sample(tensor, size, dim) + assert len(sample.shape) == len(expected_shape) + for s1, s2 in zip(sample.shape, expected_shape): + assert s1 == s2 + + +@pytest.mark.skipif( + os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False), + reason="Skipping pytorch tests", +) +@pytest.mark.parametrize( + "old_mask,new_mask,expected_diff", + [ + (torch.zeros(8, 8), torch.zeros(8, 8), torch.zeros(8, 8)), + (torch.zeros(8, 8), torch.ones(8, 8), torch.ones(8, 8)), + (torch.ones(8, 8), torch.zeros(8, 8), -1.0 * torch.ones(8, 8)), + (torch.ones(8, 8), torch.ones(8, 8), torch.zeros(8, 8)), + ( + torch.tensor([0.0, 0.0, 1.0, 0.0, 1.0, 1.0]), + torch.tensor([0.0, 1.0, 0.0, 0.0, 0.0, 1.0]), + torch.tensor([0.0, 1.0, -1.0, 0.0, -1.0, 0.0]), + ), + ], +) +def test_mask_difference(old_mask, new_mask, expected_diff): + diff = mask_difference(old_mask, new_mask) + assert torch.sum((diff - expected_diff).abs()) < sys.float_info.epsilon + + +@pytest.mark.skipif( + os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False), + reason="Skipping pytorch tests", +) +@pytest.mark.parametrize( + "model,state_dict,test_input", + [ + ( + Sequential(Conv2d(3, 16, (1, 1)), BatchNorm2d(16), Conv2d(16, 16, (1, 1))), + { + "0.weight": torch.randn(8, 3, 1, 1), + "0.bias": torch.randn(8), + "1.weight": torch.randn(8), + "1.bias": torch.randn(8), + "1.running_mean": torch.randn(8), + "1.running_var": torch.randn(8), + "2.weight": torch.randn(12, 8, 1, 1), + "2.bias": torch.randn(12), + }, + torch.randn(2, 3, 16, 16), + ), + ( + Sequential(Linear(8, 12), Linear(12, 16)), + { + "0.weight": torch.randn(7, 8), + "0.bias": torch.randn(7), + "1.weight": torch.randn(9, 7), + "1.bias": torch.randn(9), + }, + torch.randn(5, 8), + ), + ], +) +def test_thin_model_from_checkpoint(model, state_dict, test_input): + with pytest.raises(RuntimeError): + model.load_state_dict(state_dict) + + thin_model_from_checkpoint(model, state_dict) + model.load_state_dict(state_dict, strict=True) + assert isinstance(model(test_input), Tensor) + + +@pytest.mark.parametrize( + "tensor,idx", + [ + (torch.rand(1), 0), + (torch.rand(1_000), 123), + (torch.rand(10_000), 4321), + (torch.rand(100_000), 12345), + ], +) +def test_memory_aware_threshold(tensor, idx): + prior_state = os.getenv(MEMORY_BOUNDED) + + dev = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + tensor = tensor.to(dev) + + os.environ[MEMORY_BOUNDED] = "True" + t1 = memory_aware_threshold(tensor, idx) + os.environ[MEMORY_BOUNDED] = "False" + t2 = memory_aware_threshold(tensor, idx) + assert abs(t1 - t2) < 1e-3 + + if prior_state is not None: + os.environ[MEMORY_BOUNDED] = prior_state + + +class TestSanitizeKwargsForModule: + @pytest.fixture + def module(self): + return Linear(10, 20) + + def test_sanitize_kwargs_for_module_not_dict(self, module): + # Test with kwargs that are not a dictionary + with pytest.raises(TypeError): + sanitize_kwargs_for_module("not a dictionary", module) + + def test_sanitize_kwargs_for_module_not_in_signature(self, module): + # Test with kwargs that are not in the signature of the forward method + kwargs = {"not_in_signature": 123} + sanitized_kwargs = sanitize_kwargs_for_module(kwargs, module) + assert sanitized_kwargs == {} + + def test_sanitize_kwargs_for_module_in_signature(self, module): + # Test with kwargs that are in the signature of the forward method + kwargs = {"input": torch.randn(1, 10)} + sanitized_kwargs = sanitize_kwargs_for_module(kwargs, module) + assert sanitized_kwargs == kwargs + + +class TestTensorForwardWithInputArgs: + @pytest.fixture + def module(self): + return Linear(10, 20) + + def test_tensor_forward_with_input_args(self, module): + # Test with valid inputs and input_kwargs + inputs = torch.randn(1, 10) + input_kwargs = {} + output = tensor_forward_with_input_args(module, inputs, input_kwargs) + assert output.shape == (1, 20) + + # Test with input_kwargs that are not in the signature of the forward method + input_kwargs = {"not_in_signature": 123} + tensor_forward_with_input_args(module, inputs, input_kwargs) diff --git a/tests/llmcompressor/transformers/finetune/data/test_registry.py b/tests/llmcompressor/transformers/finetune/data/test_registry.py index 29895b4a4..c2fa5a1b2 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_registry.py +++ b/tests/llmcompressor/transformers/finetune/data/test_registry.py @@ -4,6 +4,7 @@ from llmcompressor.transformers.finetune.data import ( C4Dataset, OpenPlatypusDataset, + PileEvalDataset, TextGenerationDataset, WikiTextDataset, ) @@ -57,3 +58,4 @@ def test_open_platypus_initializes(tiny_llama_tokenizer): assert op_manager.dataset_args.text_column == "text" assert not op_manager.padding assert op_manager.max_seq_length == dataset_args.max_seq_length + diff --git a/tests/llmcompressor/utils/pytorch/__init__.py b/tests/llmcompressor/utils/pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/llmcompressor/utils/pytorch/test_module.py b/tests/llmcompressor/utils/pytorch/test_module.py new file mode 100644 index 000000000..4600377fa --- /dev/null +++ b/tests/llmcompressor/utils/pytorch/test_module.py @@ -0,0 +1,31 @@ +import unittest + +import torch.nn as nn + +from llmcompressor.utils.pytorch import get_parent_by_name + + +class TestGetParentByName(unittest.TestCase): + def setUp(self): + self.model = nn.Sequential( + nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10), nn.Softmax(dim=1) + ) + + def test_get_parent_by_name(self): + # Test getting the parent of a non-existent layer + with self.assertRaises(ValueError): + get_parent_by_name("non_existent_layer", self.model) + + # Test getting the parent of the first layer + name, parent = get_parent_by_name("0", self.model) + self.assertEqual(parent, self.model) + + # Test getting the parent of a nested layer + nested_model = nn.Sequential( + nn.Linear(10, 20), + nn.Sequential(nn.ReLU(), nn.Linear(20, 10)), + nn.Softmax(dim=1), + ) + name, parent = get_parent_by_name("1.1", nested_model) + self.assertEqual(parent, nested_model[1]) + self.assertEqual(name, "1") From 34fa92a2c68c13ce5ed65b34f9b249dcc35aedb3 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 18 Feb 2025 20:32:53 +0000 Subject: [PATCH 02/40] updated to be compatible with latest, unit tests passing Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 26 ++++-------- src/llmcompressor/pytorch/utils/helpers.py | 1 - .../transformers/finetune/data/pile.py | 42 +++++++------------ .../finetune/data/test_registry.py | 1 - 4 files changed, 21 insertions(+), 49 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 71767b0b4..07ad3093f 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -10,7 +10,6 @@ from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.pytorch.utils import ( - clear_memory, pseudo_quantize_tensor, tensor_forward_with_input_args, ) @@ -83,12 +82,12 @@ class AWQModifier(Modifier): example recipe: ```yaml AWQModifier: - bits: 4 - mappings: [ + bits: 4 + mappings: [ [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*self_attn_layer_norm"], [["re:.*fc1"], "re:.*final_layer_norm"] - ] - ignore: ["model.decoder.final_layer_norm"] + ] + ignore: ["model.decoder.final_layer_norm"] ``` :param mappings: list activation layers to smooth, and which layers to @@ -166,18 +165,6 @@ def on_initialize(self, state: State, **kwargs) -> bool: return True - def on_start(self, state: State, event: Event, **kwargs): - pass - - def on_update(self, state: State, event: Event, **kwargs): - pass - - def on_end(self, state: State, event: Event, **kwargs): - pass - - def on_event(self, state: State, event: Event, **kwargs): - pass - def on_finalize(self, state: State, **kwargs) -> bool: """ Clean up by clearing the scale and mapping data @@ -694,8 +681,9 @@ def _compute_best_clip( best_max_val = torch.cat(best_max_val_all, dim=0) - clear_memory(input_feat) - clear_memory(org_out) + #TODO this appears unneeded, clear_memory removed + # clear_memory(input_feat) + # clear_memory(org_out) return best_max_val.squeeze(1) diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index 6abe47707..eacd11350 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -42,7 +42,6 @@ "pseudo_dequantize_linear", "tensor_forward_with_input_args", "sanitize_kwargs_for_module", - "clear_memory", ] diff --git a/src/llmcompressor/transformers/finetune/data/pile.py b/src/llmcompressor/transformers/finetune/data/pile.py index b3a99ea0e..4eef5f7eb 100644 --- a/src/llmcompressor/transformers/finetune/data/pile.py +++ b/src/llmcompressor/transformers/finetune/data/pile.py @@ -1,7 +1,11 @@ from copy import deepcopy -from typing import Optional +from typing import TYPE_CHECKING from llmcompressor.transformers.finetune.data import TextGenerationDataset +from llmcompressor.typing import Processor + +if TYPE_CHECKING: + from llmcompressor.args import DatasetArguments @TextGenerationDataset.register(name="pile_eval") @@ -13,33 +17,15 @@ class PileEvalDataset(TextGenerationDataset): :param tokenizer: tokenizer to use on dataset """ - def __init__(self, data_args, split, tokenizer): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) + data_args.text_column = "text" data_args.dataset = "mit-han-lab/pile-val-backup" - super().__init__( - text_column="text", data_args=data_args, split=split, tokenizer=tokenizer - ) - - def get_raw_dataset(self, cache_dir: Optional[str] = None): - """ - Load the raw dataset from Hugging Face, using cached copy if available. - Additionally reformats the entries to fit the template. - :param cache_dir: disk location to search for cached dataset - :return: the requested dataset - """ - raw_dataset = super().get_raw_dataset(cache_dir=cache_dir) - - def restructure_fn(sample): - sample["text"] = sample["text"].strip() - return sample + super().__init__(data_args=data_args, split=split, processor=processor) - raw_dataset = self.map( - raw_dataset, - function=restructure_fn, - batched=False, - remove_columns=["meta"], - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, - desc="Restructuring Pile Dataset", - ) - return raw_dataset + def dataset_template(self, sample): + return { + "text": self.processor.apply_chat_template( + sample["text"].strip(), + ), + } diff --git a/tests/llmcompressor/transformers/finetune/data/test_registry.py b/tests/llmcompressor/transformers/finetune/data/test_registry.py index c2fa5a1b2..ce872fba9 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_registry.py +++ b/tests/llmcompressor/transformers/finetune/data/test_registry.py @@ -58,4 +58,3 @@ def test_open_platypus_initializes(tiny_llama_tokenizer): assert op_manager.dataset_args.text_column == "text" assert not op_manager.padding assert op_manager.max_seq_length == dataset_args.max_seq_length - From f67b386552becda49b162ed115bba8a5b17bcbf9 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 18 Feb 2025 21:20:02 +0000 Subject: [PATCH 03/40] switch to using HooksMixin api Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 44 +++++++++++-------- .../modifiers/smoothquant/base.py | 4 +- src/llmcompressor/pytorch/utils/helpers.py | 2 - 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 07ad3093f..9a9155e7b 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -2,11 +2,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch +from compressed_tensors.utils.offload import is_module_offloaded from loguru import logger from torch.nn import Module from tqdm import tqdm -from llmcompressor.core import Event, State +from llmcompressor.core import State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.pytorch.utils import ( @@ -14,6 +15,7 @@ tensor_forward_with_input_args, ) from llmcompressor.utils.fsdp.helpers import get_fsdp_parent +from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import ( get_layer, get_layers, @@ -124,14 +126,10 @@ class AWQModifier(Modifier): duo_scaling: bool = True apply_clip: bool = True - hooks_: Optional[List] = None - resolved_mappings_: Optional[List] = None + resolved_mappings_: Optional[List[AWQMapping]] = None scales_: Optional[Dict] = None module_kwargs_: Optional[Dict] = None - def on_initialize_structure(self, state: State, **kwargs): - pass # nothing needed for this modifier - def on_initialize(self, state: State, **kwargs) -> bool: """ Initialize and run AWQ on the given state @@ -155,7 +153,6 @@ def on_initialize(self, state: State, **kwargs) -> bool: self.scales_ = {} calibration_dataloader = state.data.calib - self.hooks_ = [] self._get_module_kwargs(state.model, calibration_dataloader) self._setup_scale_hooks() @@ -179,7 +176,7 @@ def on_finalize(self, state: State, **kwargs) -> bool: return True - def _resolve_mappings(self, model: Module) -> List: + def _resolve_mappings(self, model: Module) -> List[AWQMapping]: """ Transforms the list of activations to smooth and their corresponding weights into AWQMapping objects, resolving regular expressions. @@ -252,7 +249,7 @@ def hook_fn(module, inp, out): # is enough, as other balance layers # get the same input layer = mapping.balance_layers[0] - self.hooks_.append(layer.register_forward_hook(create_hook_fn(name))) + self.register_hook(layer, create_hook_fn(name), "forward") @torch.no_grad() def _calibrate(self, model: Module, calibration_dataloader: List): @@ -271,17 +268,16 @@ def _calibrate(self, model: Module, calibration_dataloader: List): " CompressionSession to run the AWQ modifier" ) - run_calibration_forward( - model, - calibration_dataloader, - self.num_calibration_steps, - self.calibration_function, - ) + with calibration_forward_context(model): + run_calibration_forward( + model, + calibration_dataloader, + self.num_calibration_steps, + self.calibration_function, + ) # remove the hooks now that we are done calibrating - for hook in self.hooks_: - hook.remove() - del self.hooks_ + self.remove_hooks() def _concat_collected_activations(self): """ @@ -370,6 +366,13 @@ def _apply_smoothing(self, model: Module): @torch.no_grad() def smooth(module): + # TODO calls to module._hf_hook.pre_forward(module) and + # module._hf_hook.post_forward(module, None) appear a couple places + # in SmoothQuantModifier, do we need them anywhere else? + offloaded = is_module_offloaded(module) + if offloaded: + module._hf_hook.pre_forward(module) + if module in balance_layers: module.weight.mul_(scales.view(1, -1).to(module.weight.device)) elif module == smooth_layer: @@ -380,6 +383,9 @@ def smooth(module): if hasattr(module, "bias") and module.bias is not None: module.bias.div_(scales.to(module.bias.device)) + if offloaded: + module._hf_hook.post_forward(module, None) + parent = get_fsdp_parent(mapping.smooth_name, model) if parent is not None: parent.apply(smooth) @@ -681,7 +687,7 @@ def _compute_best_clip( best_max_val = torch.cat(best_max_val_all, dim=0) - #TODO this appears unneeded, clear_memory removed + # TODO this appears unneeded, clear_memory removed # clear_memory(input_feat) # clear_memory(org_out) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index aa3317198..71b9bd9f6 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -105,7 +105,7 @@ class SmoothQuantModifier(Modifier): num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None - resolved_mappings_: Optional[List] = None + resolved_mappings_: Optional[List[SmoothQuantMapping]] = None scales_: Optional[Dict] = None def on_initialize(self, state: State, **kwargs) -> bool: @@ -166,7 +166,7 @@ def _infer_mappings_from_model( ) @handle_mapping_resolution_errors - def _resolve_mappings(self, model: Module) -> List: + def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]: """ Transforms the list of activations to smooth and their corresponding weights into SmoothQuantMapping objects, resolving regular expressions. diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index eacd11350..3de4f82e4 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -3,7 +3,6 @@ """ import functools -import gc import inspect import os import random @@ -723,4 +722,3 @@ def pseudo_dequantize_linear( w = w.weight.data * scales return w - From f341dc07156a1eac5cfd98eccbe14ac7df15f901 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 18 Feb 2025 22:56:14 +0000 Subject: [PATCH 04/40] pydantic serialization issue fix Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 4 ++++ src/llmcompressor/modifiers/smoothquant/base.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 9a9155e7b..a0608b573 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -4,6 +4,7 @@ import torch from compressed_tensors.utils.offload import is_module_offloaded from loguru import logger +from pydantic import ConfigDict from torch.nn import Module from tqdm import tqdm @@ -115,6 +116,9 @@ class AWQModifier(Modifier): :param apply_clip: whether to apply clipping to the weights after scaling """ + # Allow arbitrary types because AWQMapping has field of type torch.nn.Module + model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) + mappings: List[Tuple] = DEFAULT_AWQ_MAPPINGS ignore: Optional[List[str]] = None num_calibration_steps: Optional[int] = None diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 71b9bd9f6..845798f07 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -4,6 +4,7 @@ import torch from compressed_tensors.utils.offload import is_module_offloaded from loguru import logger +from pydantic import ConfigDict from torch.nn import Module from llmcompressor.core import State @@ -99,6 +100,9 @@ class SmoothQuantModifier(Modifier): to use the default tensor_module_forward """ + # Allow arbitrary types because AWQMapping has field of type torch.nn.Module + model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) + smoothing_strength: float = 0.5 mappings: Optional[List[Union[Tuple, List]]] = None ignore: Optional[List[str]] = None From ee767524cf5c1e3c30bdf3c7682044dfc596d961 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 19 Feb 2025 18:00:07 +0000 Subject: [PATCH 05/40] switch to accelerate with align_module_device Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 68 +++++++++---------- .../modifiers/smoothquant/base.py | 40 ++++------- 2 files changed, 45 insertions(+), 63 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index a0608b573..ebd5b0936 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from compressed_tensors.utils.offload import is_module_offloaded +from accelerate.utils import align_module_device from loguru import logger from pydantic import ConfigDict from torch.nn import Module @@ -318,7 +318,7 @@ def _apply_smoothing(self, model: Module): # [STEP 1]: Compute per-channel mean of normalised weights # All layer weights are concatted together - weight = torch.cat([_m.weight for _m in balance_layers], dim=0) + weight = torch.cat([bl.weight for bl in balance_layers], dim=0) org_shape = weight.shape # The weights are reshaped to be organised by quantization group weight = weight.view(-1, self.group_size) @@ -373,22 +373,18 @@ def smooth(module): # TODO calls to module._hf_hook.pre_forward(module) and # module._hf_hook.post_forward(module, None) appear a couple places # in SmoothQuantModifier, do we need them anywhere else? - offloaded = is_module_offloaded(module) - if offloaded: - module._hf_hook.pre_forward(module) - - if module in balance_layers: - module.weight.mul_(scales.view(1, -1).to(module.weight.device)) - elif module == smooth_layer: - if module.weight.ndim == 1: - module.weight.div_(scales.to(module.weight.device)) - else: - module.weight.div_(scales.view(-1, 1).to(module.weight.device)) - if hasattr(module, "bias") and module.bias is not None: - module.bias.div_(scales.to(module.bias.device)) - - if offloaded: - module._hf_hook.post_forward(module, None) + with align_module_device(module): + if module in balance_layers: + module.weight.mul_(scales.view(1, -1).to(module.weight.device)) + elif module == smooth_layer: + if module.weight.ndim == 1: + module.weight.div_(scales.to(module.weight.device)) + else: + module.weight.div_( + scales.view(-1, 1).to(module.weight.device) + ) + if hasattr(module, "bias") and module.bias is not None: + module.bias.div_(scales.to(module.bias.device)) parent = get_fsdp_parent(mapping.smooth_name, model) if parent is not None: @@ -461,16 +457,17 @@ def _compute_best_scale( # Q(W * s) for fc in linears2scale: - fc.weight.mul_(scales_view) - fc.weight.data = ( - pseudo_quantize_tensor( - w=fc.weight.data, - symmetric=self.symmetric, - bit_width=self.bits, - group_size=self.group_size, - )[0] - / scales_view - ) + with align_module_device(fc): + fc.weight.mul_(scales_view) + fc.weight.data = ( + pseudo_quantize_tensor( + w=fc.weight.data, + symmetric=self.symmetric, + bit_width=self.bits, + group_size=self.group_size, + )[0] + / scales_view + ) # W * X int_w_output = self._forward_input_with_kwargs( @@ -691,10 +688,6 @@ def _compute_best_clip( best_max_val = torch.cat(best_max_val_all, dim=0) - # TODO this appears unneeded, clear_memory removed - # clear_memory(input_feat) - # clear_memory(org_out) - return best_max_val.squeeze(1) @@ -711,8 +704,9 @@ def _apply_clip(module, clip_list: Tuple[str, torch.Tensor]): for name, max_val in clip_list: _, layer = get_layer(target=name, module=module) assert isinstance(layer, torch.nn.Linear) - max_val = max_val.to(layer.weight.device) - org_shape = layer.weight.shape - layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1) - layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val) - layer.weight.data = layer.weight.data.reshape(org_shape) + with align_module_device(layer): + max_val = max_val.to(layer.weight.device) + org_shape = layer.weight.shape + layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1) + layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val) + layer.weight.data = layer.weight.data.reshape(org_shape) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 845798f07..1b1e0aee6 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -2,7 +2,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import torch -from compressed_tensors.utils.offload import is_module_offloaded +from accelerate.utils import align_module_device from loguru import logger from pydantic import ConfigDict from torch.nn import Module @@ -293,22 +293,16 @@ def _apply_smoothing(self, model: Module): @torch.no_grad() def smooth(module): - offloaded = is_module_offloaded(module) - if offloaded: - module._hf_hook.pre_forward(module) - - if module in balance_layers: - module.weight.mul_(scales.view(1, -1)) - elif module == smooth_layer: - if module.weight.ndim == 1: - module.weight.div_(scales) - else: - module.weight.div_(scales.view(-1, 1)) - if hasattr(module, "bias") and module.bias is not None: - module.bias.div_(scales) - - if offloaded: - module._hf_hook.post_forward(module, None) + with align_module_device(module): + if module in balance_layers: + module.weight.mul_(scales.view(1, -1)) + elif module == smooth_layer: + if module.weight.ndim == 1: + module.weight.div_(scales) + else: + module.weight.div_(scales.view(-1, 1)) + if hasattr(module, "bias") and module.bias is not None: + module.bias.div_(scales) parent = get_fsdp_parent(mapping.smooth_name, model) if parent is not None: @@ -333,15 +327,9 @@ def _calculate_smoothing_scales( # get the channel-wise dynamic range for each layer to be balanced weight_scales = [] for layer in balance_layers: - offloaded = is_module_offloaded(layer) - if offloaded: - layer._hf_hook.pre_forward(layer) - - scale = layer.weight.abs().max(dim=0, keepdim=True)[0] - weight_scales.append(scale) - - if offloaded: - layer._hf_hook.post_forward(layer, None) + with align_module_device(layer): + scale = layer.weight.abs().max(dim=0, keepdim=True)[0] + weight_scales.append(scale) weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0] From 9e415f28693c524bc9846098a3253b643d649484 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 19 Feb 2025 23:11:09 +0000 Subject: [PATCH 06/40] AWQ running but OOMs unless NUM_CALIBRATION_SAMPLES and MAX_SEQUENCE_LENGTH are very low Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 6 +++--- src/llmcompressor/transformers/finetune/data/__init__.py | 1 - src/llmcompressor/transformers/finetune/data/pile.py | 8 +++++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index ebd5b0936..6675291bc 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -158,7 +158,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: calibration_dataloader = state.data.calib - self._get_module_kwargs(state.model, calibration_dataloader) + self._set_module_kwargs(state.model, calibration_dataloader) self._setup_scale_hooks() self._calibrate(state.model, calibration_dataloader) self._concat_collected_activations() @@ -530,7 +530,7 @@ def _compute_loss( return loss - def _get_module_kwargs(self, model, dataloader): + def _set_module_kwargs(self, model, dataloader) -> None: _, modules = next(iter(get_layers("re:.*layers", model).items())) samples = [batch["input_ids"] for batch in dataloader] @@ -575,7 +575,7 @@ def forward(self, *args, **kwargs): # Update the layer kwargs with `prepare_inputs_for_generation` method # that takes care of everything to avoid unexpected errors. - layer_kwargs = model.prepare_inputs_for_generation(samples, **layer_kwargs) + layer_kwargs |= model.prepare_inputs_for_generation(samples, **layer_kwargs) # Pop the input_ids as they are not needed at all. layer_kwargs.pop("input_ids") diff --git a/src/llmcompressor/transformers/finetune/data/__init__.py b/src/llmcompressor/transformers/finetune/data/__init__.py index 655a85df6..3754908fa 100644 --- a/src/llmcompressor/transformers/finetune/data/__init__.py +++ b/src/llmcompressor/transformers/finetune/data/__init__.py @@ -9,7 +9,6 @@ from .gsm8k import GSM8KDataset from .open_platypus import OpenPlatypusDataset from .peoples_speech import PeoplesSpeech -from .pile import PileEvalDataset from .ptb import PtbDataset from .ultrachat_200k import UltraChatDataset from .wikitext import WikiTextDataset diff --git a/src/llmcompressor/transformers/finetune/data/pile.py b/src/llmcompressor/transformers/finetune/data/pile.py index 4eef5f7eb..ccdb92056 100644 --- a/src/llmcompressor/transformers/finetune/data/pile.py +++ b/src/llmcompressor/transformers/finetune/data/pile.py @@ -1,6 +1,8 @@ from copy import deepcopy from typing import TYPE_CHECKING +from loguru import logger + from llmcompressor.transformers.finetune.data import TextGenerationDataset from llmcompressor.typing import Processor @@ -8,10 +10,10 @@ from llmcompressor.args import DatasetArguments -@TextGenerationDataset.register(name="pile_eval") -class PileEvalDataset(TextGenerationDataset): +@TextGenerationDataset.register(name="mit-han-lab/pile-val-backup", alias="pile_val") +class PileValDataset(TextGenerationDataset): """ - Child text generation class for the PileEval dataset + Child text generation class for "The Pile" dataset :param data_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param tokenizer: tokenizer to use on dataset From db767b74a7343518884e228105c7e686f7bb68c4 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 20 Feb 2025 17:15:54 +0000 Subject: [PATCH 07/40] working with larger num_calibration_samples Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 46 ++++++++++++++----------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 6675291bc..5e985559b 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from accelerate.utils import align_module_device +from compressed_tensors.utils import align_module_device, update_offload_parameter from loguru import logger from pydantic import ConfigDict from torch.nn import Module @@ -158,11 +158,14 @@ def on_initialize(self, state: State, **kwargs) -> bool: calibration_dataloader = state.data.calib - self._set_module_kwargs(state.model, calibration_dataloader) - self._setup_scale_hooks() - self._calibrate(state.model, calibration_dataloader) - self._concat_collected_activations() - self._apply_smoothing(state.model) + # TODO is it ok to wrap the whole model in this context? + # I don't think we ever want gradients or to use kv cache + with calibration_forward_context(state.model): + self._set_module_kwargs(state.model, calibration_dataloader) + self._setup_scale_hooks() + self._calibrate(state.model, calibration_dataloader) + self._concat_collected_activations() + self._apply_smoothing(state.model) return True @@ -272,13 +275,13 @@ def _calibrate(self, model: Module, calibration_dataloader: List): " CompressionSession to run the AWQ modifier" ) - with calibration_forward_context(model): - run_calibration_forward( - model, - calibration_dataloader, - self.num_calibration_steps, - self.calibration_function, - ) + # with calibration_forward_context(model): + run_calibration_forward( + model, + calibration_dataloader, + self.num_calibration_steps, + self.calibration_function, + ) # remove the hooks now that we are done calibrating self.remove_hooks() @@ -356,10 +359,9 @@ def _apply_smoothing(self, model: Module): x_mean = (x_sum / num_elements).to(inp.dtype) # [STEP 3]: Compute output of module - with torch.no_grad(): - fp16_output = self._forward_input_with_kwargs( - module=module2inspect, inputs=inp, input_kwargs=self.module_kwargs_ - ) + fp16_output = self._forward_input_with_kwargs( + module=module2inspect, inputs=inp, input_kwargs=self.module_kwargs_ + ) # [STEP 4]: Compute loss best_scales = self._compute_best_scale( @@ -459,14 +461,16 @@ def _compute_best_scale( for fc in linears2scale: with align_module_device(fc): fc.weight.mul_(scales_view) - fc.weight.data = ( + update_offload_parameter( + fc, + "weight", pseudo_quantize_tensor( w=fc.weight.data, symmetric=self.symmetric, bit_width=self.bits, group_size=self.group_size, )[0] - / scales_view + / scales_view, ) # W * X @@ -488,7 +492,9 @@ def _compute_best_scale( logger.debug(history) raise Exception - assert torch.isnan(best_scales).sum() == 0, best_scales + assert ( + torch.isnan(best_scales).sum() == 0 + ), f"Nan found in scales: {best_scales}" return best_scales.detach().cpu() From 15a0b16e03a92db0a473e4836793a0c8ab5b3f83 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 20 Feb 2025 18:37:18 +0000 Subject: [PATCH 08/40] fix pile dataset issue Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 5 ++++- src/llmcompressor/transformers/finetune/data/pile.py | 8 +------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 5e985559b..160f2106a 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -25,7 +25,10 @@ ) DEFAULT_AWQ_MAPPINGS = [ - [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"], + [ + ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj", "re:.*o_proj"], + "re:.*input_layernorm", + ], [["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"], [["re:.*down_proj"], "re:.*up_proj"], ] diff --git a/src/llmcompressor/transformers/finetune/data/pile.py b/src/llmcompressor/transformers/finetune/data/pile.py index ccdb92056..f420ba2a5 100644 --- a/src/llmcompressor/transformers/finetune/data/pile.py +++ b/src/llmcompressor/transformers/finetune/data/pile.py @@ -1,8 +1,6 @@ from copy import deepcopy from typing import TYPE_CHECKING -from loguru import logger - from llmcompressor.transformers.finetune.data import TextGenerationDataset from llmcompressor.typing import Processor @@ -26,8 +24,4 @@ def __init__(self, data_args: "DatasetArguments", split: str, processor: Process super().__init__(data_args=data_args, split=split, processor=processor) def dataset_template(self, sample): - return { - "text": self.processor.apply_chat_template( - sample["text"].strip(), - ), - } + return {"text": sample["text"].strip()} From 91ad7fc72def1c13e2a5683672301d7e11b9e79c Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 24 Feb 2025 15:02:07 -0500 Subject: [PATCH 09/40] updated config dataclasses Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 94 ++++++++++++++++--------- 1 file changed, 60 insertions(+), 34 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 160f2106a..73a6abb51 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -24,15 +24,6 @@ get_parent_by_name, ) -DEFAULT_AWQ_MAPPINGS = [ - [ - ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj", "re:.*o_proj"], - "re:.*input_layernorm", - ], - [["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"], - [["re:.*down_proj"], "re:.*up_proj"], -] - __all__ = ["AWQScale", "AWQMapping", "AWQModifier"] @@ -48,8 +39,42 @@ class AWQScale: @dataclass class AWQMapping: """ - Dataclass for storing the mapping between an activation layer and the following - weights that must be balanced during smoothing + Dataclass storing config of activation mappings to smooth + The output activations of smooth_layer are input activations + into the balance_layers + + `AWQMapping`s are resolved into `ResolvedMapping`s, which + retain pointers to the actual `torch.nn.Module`s and additional + metadata at runtime + """ + + smooth_layer: str + balance_layers: list[str] + + +DEFAULT_AWQ_MAPPINGS: list[AWQMapping] = [ + AWQMapping( + "re:.*input_layernorm", + ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], + ), + AWQMapping( + "re:.*post_attention_layernorm", + ["re:.*gate_proj", "re:.*up_proj"], + ), + AWQMapping( + "re:.*up_proj", + ["re:.*down_proj"], + ), + # TODO check with this uncommented + # AWQMapping("re:.*v_proj", ["re:.*o_proj"]), +] + + +@dataclass +class ResolvedMapping: + """ + Dataclass for storing the resolved mappings between an activation layer + and the following weights that must be balanced during smoothing :param smooth_name: name of the activation layer :param smooth_layer: PyTorch module storing the activation layer @@ -89,9 +114,11 @@ class AWQModifier(Modifier): ```yaml AWQModifier: bits: 4 - mappings: [ - [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*self_attn_layer_norm"], - [["re:.*fc1"], "re:.*final_layer_norm"] + mappings: + - smooth_layer: "re:.*self_attn_layer_norm" + balance_layers: ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"] + - smooth_layer: "re:.*final_layer_norm" + balance_layers: ["re:.*fc1"] ] ignore: ["model.decoder.final_layer_norm"] ``` @@ -119,10 +146,10 @@ class AWQModifier(Modifier): :param apply_clip: whether to apply clipping to the weights after scaling """ - # Allow arbitrary types because AWQMapping has field of type torch.nn.Module + # Allow arbitrary types because AWQMapping has fields of type torch.nn.Module model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) - mappings: List[Tuple] = DEFAULT_AWQ_MAPPINGS + mappings: List[AWQMapping] = DEFAULT_AWQ_MAPPINGS ignore: Optional[List[str]] = None num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None @@ -133,7 +160,7 @@ class AWQModifier(Modifier): duo_scaling: bool = True apply_clip: bool = True - resolved_mappings_: Optional[List[AWQMapping]] = None + resolved_mappings_: Optional[List[ResolvedMapping]] = None scales_: Optional[Dict] = None module_kwargs_: Optional[Dict] = None @@ -156,13 +183,11 @@ def on_initialize(self, state: State, **kwargs) -> bool: ) self.ignore = [] if not self.ignore else self.ignore - self.resolved_mappings_ = self._resolve_mappings(state.model) + self.resolved_mappings_ = self._get_resolved_mappings(state.model) self.scales_ = {} calibration_dataloader = state.data.calib - # TODO is it ok to wrap the whole model in this context? - # I don't think we ever want gradients or to use kv cache with calibration_forward_context(state.model): self._set_module_kwargs(state.model, calibration_dataloader) self._setup_scale_hooks() @@ -186,10 +211,10 @@ def on_finalize(self, state: State, **kwargs) -> bool: return True - def _resolve_mappings(self, model: Module) -> List[AWQMapping]: + def _get_resolved_mappings(self, model: Module) -> List[ResolvedMapping]: """ Transforms the list of activations to smooth and their corresponding weights - into AWQMapping objects, resolving regular expressions. + into ResolvedMapping objects, resolving regular expressions. For each activation in the mapping list, we find the corresponding weight to balance by searching for the longest substring. For instance, if our balance @@ -197,13 +222,13 @@ def _resolve_mappings(self, model: Module) -> List[AWQMapping]: would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and repeat for model.layer.1 and so on """ - resolved_mappings = [] - for to_balance, to_smooth in self.mappings: - to_smooth_layers = get_layers(to_smooth, model) + resolved_mappings: list[ResolvedMapping] = [] + for mapping in self.mappings: + to_smooth_layers = get_layers(mapping.smooth_layer, model) for layer_name, smooth_layer in to_smooth_layers.items(): if layer_name not in self.ignore: balance_layers, balance_names = [], [] - for balance_suffix in to_balance: + for balance_suffix in mapping.balance_layers: # find the submodule that matches the activation layer balance_name, balance_layer = get_matching_layer( balance_suffix, layer_name, model @@ -224,15 +249,16 @@ def _resolve_mappings(self, model: Module) -> List[AWQMapping]: parent_name, parent = get_parent_by_name( layer_name=balance_name, model=model ) - mapping = AWQMapping( - layer_name, - smooth_layer, - balance_layers, - balance_names=balance_names, - parent=parent, - parent_name=parent_name, + resolved_mappings.append( + ResolvedMapping( + layer_name, + smooth_layer, + balance_layers, + balance_names=balance_names, + parent=parent, + parent_name=parent_name, + ) ) - resolved_mappings.append(mapping) return resolved_mappings def _setup_scale_hooks(self): From c1c6a6ced9143f3e2fde5b47809a8af689ed1414 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 25 Feb 2025 00:24:09 +0000 Subject: [PATCH 10/40] OOM error resolved Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 73a6abb51..0dbcbc811 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -65,7 +65,7 @@ class AWQMapping: "re:.*up_proj", ["re:.*down_proj"], ), - # TODO check with this uncommented + # TODO this generally results in higher perplexity for llama 2 7B on wikitext # AWQMapping("re:.*v_proj", ["re:.*o_proj"]), ] @@ -269,8 +269,7 @@ def _setup_scale_hooks(self): def create_hook_fn(layer_name): def hook_fn(module, inp, out): - inp = inp[0] - inp.cpu().detach() + inp = inp[0].cpu().detach() if layer_name in self.scales_: self.scales_[layer_name].inps.append(inp) @@ -365,8 +364,8 @@ def _apply_smoothing(self, model: Module): # [STEP 2]: Compute per-channel mean of the input activation with chunking # move inp to cpu to avoid memory leak - inp = activations - inp_flat = inp.cpu().abs().view(-1, inp.shape[-1]) + inp = activations.to(weight.device) + inp_flat = activations.cpu().abs().view(-1, inp.shape[-1]) num_elements = inp_flat.size(0) num_channels = inp_flat.size(1) element_size_bytes = inp_flat.element_size() * 2 # multiplied by 2 for FP32 From eb320546fef09b54345ba7308b27c5226a1be93a Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 25 Feb 2025 20:19:50 +0000 Subject: [PATCH 11/40] codereview updates Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 42 +++++++------------------ 1 file changed, 11 insertions(+), 31 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 0dbcbc811..f71feff6f 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -4,7 +4,7 @@ import torch from compressed_tensors.utils import align_module_device, update_offload_parameter from loguru import logger -from pydantic import ConfigDict +from pydantic import ConfigDict, Field from torch.nn import Module from tqdm import tqdm @@ -24,16 +24,7 @@ get_parent_by_name, ) -__all__ = ["AWQScale", "AWQMapping", "AWQModifier"] - - -@dataclass -class AWQScale: - """ - Dataclass for storing the input activations of a layer to be smoothed - """ - - inps: Union[List[torch.Tensor], torch.Tensor] +__all__ = ["AWQMapping", "AWQModifier"] @dataclass @@ -161,8 +152,8 @@ class AWQModifier(Modifier): apply_clip: bool = True resolved_mappings_: Optional[List[ResolvedMapping]] = None - scales_: Optional[Dict] = None - module_kwargs_: Optional[Dict] = None + scales_: Dict[str, torch.Tensor | List[torch.Tensor]] = Field(default_factory=dict) + module_kwargs_: Dict = Field(default_factory=dict) def on_initialize(self, state: State, **kwargs) -> bool: """ @@ -171,20 +162,9 @@ def on_initialize(self, state: State, **kwargs) -> bool: :param state: state to run AWQ on :return: True on a successful run, False otherwise """ - if not (self.end is None or self.end == -1): - raise ValueError( - f"{self.__class__.__name__} can only be applied during one-shot. " - f" Expected end to be None or -1, got {self.end}" - ) - if self.start and self.start != -1: - raise ValueError( - f"{self.__class__.__name__} can only be applied during one-shot. " - f"Expected start to be None or -1, got {self.end}" - ) - + self.ignore = [] if not self.ignore else self.ignore self.resolved_mappings_ = self._get_resolved_mappings(state.model) - self.scales_ = {} calibration_dataloader = state.data.calib @@ -272,9 +252,9 @@ def hook_fn(module, inp, out): inp = inp[0].cpu().detach() if layer_name in self.scales_: - self.scales_[layer_name].inps.append(inp) + self.scales_[layer_name].append(inp) else: - self.scales_[layer_name] = AWQScale(inps=[inp]) + self.scales_[layer_name] = [inp] return hook_fn @@ -324,7 +304,7 @@ def _concat_collected_activations(self): """ for mapping in self.resolved_mappings_: name = mapping.smooth_name - self.scales_[name].inps = torch.cat(self.scales_[name].inps, dim=0) + self.scales_[name] = torch.cat(self.scales_[name], dim=0) torch.cuda.empty_cache() @@ -343,7 +323,7 @@ def _apply_smoothing(self, model: Module): balance_layers = mapping.balance_layers balance_names = mapping.balance_names - activations = self.scales_[mapping.smooth_name].inps + activations = self.scales_[mapping.smooth_name] module2inspect = mapping.parent @@ -445,7 +425,7 @@ def _compute_best_scale( module2inspect: torch.nn.Module, linears2scale: List[torch.nn.Linear], fp16_output: torch.Tensor, - ): + ) -> torch.Tensor: """ Compute loss and select best scales @@ -639,7 +619,7 @@ def _forward_input_with_kwargs( :param input_kwargs: additional arguments to pass to the module :return: the first output tensor from the forward pass """ - kwargs = input_kwargs or self.module_kwargs_ or {} + kwargs = input_kwargs or self.module_kwargs_ return tensor_forward_with_input_args( module=module, inputs=inputs, From c7be2776fe3da539a1b7790665d75306040f6068 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 25 Feb 2025 23:02:26 +0000 Subject: [PATCH 12/40] minor touchups Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 46 ++++++++++++++++------ src/llmcompressor/pytorch/utils/helpers.py | 2 +- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index f71feff6f..e773b191d 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,3 +1,4 @@ +import inspect from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -57,7 +58,7 @@ class AWQMapping: ["re:.*down_proj"], ), # TODO this generally results in higher perplexity for llama 2 7B on wikitext - # AWQMapping("re:.*v_proj", ["re:.*o_proj"]), + AWQMapping("re:.*v_proj", ["re:.*o_proj"]), ] @@ -141,7 +142,7 @@ class AWQModifier(Modifier): model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) mappings: List[AWQMapping] = DEFAULT_AWQ_MAPPINGS - ignore: Optional[List[str]] = None + ignore: List[str] = [] num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None group_size: int = 128 @@ -151,9 +152,9 @@ class AWQModifier(Modifier): duo_scaling: bool = True apply_clip: bool = True - resolved_mappings_: Optional[List[ResolvedMapping]] = None - scales_: Dict[str, torch.Tensor | List[torch.Tensor]] = Field(default_factory=dict) - module_kwargs_: Dict = Field(default_factory=dict) + resolved_mappings_: List[ResolvedMapping] = [] + scales_: Dict[str, torch.Tensor | List[torch.Tensor]] = {} + module_kwargs_: Dict = {} def on_initialize(self, state: State, **kwargs) -> bool: """ @@ -162,8 +163,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: :param state: state to run AWQ on :return: True on a successful run, False otherwise """ - - self.ignore = [] if not self.ignore else self.ignore + self.resolved_mappings_ = self._get_resolved_mappings(state.model) calibration_dataloader = state.data.calib @@ -368,7 +368,12 @@ def _apply_smoothing(self, model: Module): # [STEP 3]: Compute output of module fp16_output = self._forward_input_with_kwargs( - module=module2inspect, inputs=inp, input_kwargs=self.module_kwargs_ + module=module2inspect, + inputs=inp, + input_kwargs=self._sanitize_kwargs(self.module_kwargs_, module2inspect), + ) + fp16_output = fp16_output.clip( + torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max ) # [STEP 4]: Compute loss @@ -380,9 +385,6 @@ def _apply_smoothing(self, model: Module): @torch.no_grad() def smooth(module): - # TODO calls to module._hf_hook.pre_forward(module) and - # module._hf_hook.post_forward(module, None) appear a couple places - # in SmoothQuantModifier, do we need them anywhere else? with align_module_device(module): if module in balance_layers: module.weight.mul_(scales.view(1, -1).to(module.weight.device)) @@ -589,7 +591,7 @@ def forward(self, *args, **kwargs): # Update the layer kwargs with `prepare_inputs_for_generation` method # that takes care of everything to avoid unexpected errors. - layer_kwargs |= model.prepare_inputs_for_generation(samples, **layer_kwargs) + layer_kwargs = model.prepare_inputs_for_generation(samples, **layer_kwargs) # Pop the input_ids as they are not needed at all. layer_kwargs.pop("input_ids") @@ -620,6 +622,7 @@ def _forward_input_with_kwargs( :return: the first output tensor from the forward pass """ kwargs = input_kwargs or self.module_kwargs_ + kwargs = self._sanitize_kwargs(kwargs, module) return tensor_forward_with_input_args( module=module, inputs=inputs, @@ -704,6 +707,25 @@ def _compute_best_clip( return best_max_val.squeeze(1) + def _sanitize_kwargs(self, inputs_kwargs, module): + """ + Remove the arguments that are not supported in the module's + forward pass to avoid breaking behaviour between different versions + of transformers. + + Args: + inputs_kwargs (`dict`): + The input dictionary to pass to the model layer + module (`torch.nn.Module`): + Target module to quantize. + """ + module_signature = inspect.signature(module.forward).parameters + sanitized_kwargs = {} + for k, v in inputs_kwargs.items(): + if k in module_signature and k != "use_cache": + sanitized_kwargs[k] = v + return sanitized_kwargs + @torch.no_grad() def _apply_clip(module, clip_list: Tuple[str, torch.Tensor]): diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index 3de4f82e4..f2678c366 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -675,7 +675,7 @@ def pseudo_quantize_tensor( assert w.dim() == 2 assert torch.isnan(w).sum() == 0 - if not symmetric: + if symmetric: max_val = w.amax(dim=1, keepdim=True) min_val = w.amin(dim=1, keepdim=True) max_int = 2**bit_width - 1 From ab32f21e68fce6a3bbe2c4fc17fde1ff1c4157ad Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 3 Mar 2025 21:56:00 +0000 Subject: [PATCH 13/40] updates from debugging Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 11 ++-- src/llmcompressor/observers/__init__.py | 1 + src/llmcompressor/observers/rtn.py | 58 ++++++++++++++++++++++ src/llmcompressor/pytorch/utils/helpers.py | 7 +-- 4 files changed, 69 insertions(+), 8 deletions(-) create mode 100644 src/llmcompressor/observers/rtn.py diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index e773b191d..796ad65cf 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -49,6 +49,8 @@ class AWQMapping: "re:.*input_layernorm", ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], ), + # TODO this generally results in higher perplexity for llama 2 7B on wikitext + AWQMapping("re:.*v_proj", ["re:.*o_proj"]), AWQMapping( "re:.*post_attention_layernorm", ["re:.*gate_proj", "re:.*up_proj"], @@ -57,8 +59,6 @@ class AWQMapping: "re:.*up_proj", ["re:.*down_proj"], ), - # TODO this generally results in higher perplexity for llama 2 7B on wikitext - AWQMapping("re:.*v_proj", ["re:.*o_proj"]), ] @@ -148,7 +148,7 @@ class AWQModifier(Modifier): group_size: int = 128 max_chunk_memory: int = 1024 * 1024 * 1024 bits: int = 4 - symmetric: bool = True + symmetric: bool = False duo_scaling: bool = True apply_clip: bool = True @@ -487,6 +487,9 @@ def _compute_best_scale( int_w_output = self._forward_input_with_kwargs( module=module2inspect, inputs=x, input_kwargs=self.module_kwargs_ ) + int_w_output = int_w_output.clip( + torch.finfo(int_w_output.dtype).min, torch.finfo(int_w_output.dtype).max + ) # compute mean squared error (L2 norm) loss = self._compute_loss(fp16_output, int_w_output, device) @@ -598,8 +601,6 @@ def forward(self, *args, **kwargs): del samples inps = inps[0] - torch.cuda.empty_cache() - if layer_kwargs.get("attention_mask") is not None: layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to( best_device diff --git a/src/llmcompressor/observers/__init__.py b/src/llmcompressor/observers/__init__.py index 4c3ee5a88..e16d9d93b 100644 --- a/src/llmcompressor/observers/__init__.py +++ b/src/llmcompressor/observers/__init__.py @@ -5,3 +5,4 @@ from .base import * from .min_max import * from .mse import * +from .rtn import * diff --git a/src/llmcompressor/observers/rtn.py b/src/llmcompressor/observers/rtn.py new file mode 100644 index 000000000..47b6846b6 --- /dev/null +++ b/src/llmcompressor/observers/rtn.py @@ -0,0 +1,58 @@ +from typing import Any, Optional, Tuple + +import torch +from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.utils import calculate_qparams +from compressed_tensors.utils import deprecated + +from llmcompressor.observers.base import Observer +from llmcompressor.pytorch.utils import pseudo_quantize_tensor + +__all__ = ["RoundToNearestObserver"] + + +@Observer.register("rtn") +class RoundToNearestObserver(Observer): + """ + Implements a quantization observer that calculates scale and zero point based on the + minimum and maximum values of the tensor being observed. If averaging_constant is + specified, then the scales are updated using a moving average + """ + + def calculate_qparams( + self, + observed: torch.Tensor, + reduce_dims: Optional[Tuple[int]] = None, + tensor_id: Optional[Any] = None, + ) -> Tuple[torch.FloatTensor, torch.IntTensor]: + """ + Updates the observed min and max using a moving average smoothed by the + averaging_constant. Set the averaging_constant to 1.0 to disable averaging. + + :param observed: observed tensor to calculate quantization parameters for + :param reduce_dims: optional tuple of dimensions to reduce along, + returned scale and zero point will be shaped (1,) along the + reduced dimensions + :param tensor_id: Optional id if different ranges of observed tensors are + passed, useful for sharding tensors by group_size + :return: tuple of scale and zero point derived from the observed tensor + """ + + _, scales, zp = pseudo_quantize_tensor( + observed, + symmetric=self.quantization_args.symmetric, + bit_width=self.quantization_args.num_bits, + group_size=-1, #self.quantization_args.group_size, + ) + return (scales, zp) + + def get_qparams_along_dim( + self, observed: torch.Tensor, dim: int, tensor_id: Optional[Any] = None + ): + """ + Calculate quantization parameters along the specified dimension + """ + reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) + return self.calculate_qparams( + observed, reduce_dims=reduce_dims, tensor_id=tensor_id + ) diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index f2678c366..973993585 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -670,12 +670,13 @@ def pseudo_quantize_tensor( ): org_w_shape = w.shape if group_size > 0: - assert org_w_shape[-1] % group_size == 0 + assert org_w_shape[-1] % group_size == 0, f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({group_size})!" w = w.reshape(-1, group_size) assert w.dim() == 2 assert torch.isnan(w).sum() == 0 - if symmetric: + # zero point quantization + if not symmetric: max_val = w.amax(dim=1, keepdim=True) min_val = w.amin(dim=1, keepdim=True) max_int = 2**bit_width - 1 @@ -685,7 +686,7 @@ def pseudo_quantize_tensor( w = ( torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros ) * scales - zeros = zeros.view(org_w_shape[0], -1) + zeros = (zeros- 2**(bit_width-1)).view(org_w_shape[0], -1) else: max_val = w.abs().amax(dim=1, keepdim=True) max_val = max_val.clamp(min=1e-5) From ff857e5b13905586a2bce8bb0b0a31296a8f0e4e Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 4 Mar 2025 18:17:55 +0000 Subject: [PATCH 14/40] styling Signed-off-by: Brian Dellabetta --- src/llmcompressor/pytorch/utils/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index 973993585..97754ce92 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -686,7 +686,7 @@ def pseudo_quantize_tensor( w = ( torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros ) * scales - zeros = (zeros- 2**(bit_width-1)).view(org_w_shape[0], -1) + zeros = (zeros - 2**(bit_width-1)).view(org_w_shape[0], -1) else: max_val = w.abs().amax(dim=1, keepdim=True) max_val = max_val.clamp(min=1e-5) From 3e79d377e4e60828b547f6cb3e05a488fdefd39b Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 5 Mar 2025 17:30:50 +0000 Subject: [PATCH 15/40] slightly improved rtn calculate_qparams logic Signed-off-by: Brian Dellabetta --- src/llmcompressor/observers/rtn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/observers/rtn.py b/src/llmcompressor/observers/rtn.py index 47b6846b6..889b03318 100644 --- a/src/llmcompressor/observers/rtn.py +++ b/src/llmcompressor/observers/rtn.py @@ -42,7 +42,7 @@ def calculate_qparams( observed, symmetric=self.quantization_args.symmetric, bit_width=self.quantization_args.num_bits, - group_size=-1, #self.quantization_args.group_size, + group_size=self.quantization_args.group_size or -1, ) return (scales, zp) From 80767ab79e58fd7ebd06360b32b35bd82ca504ce Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 10 Mar 2025 19:37:17 +0000 Subject: [PATCH 16/40] code cleanup Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 212 ++++++------------ src/llmcompressor/observers/__init__.py | 1 - src/llmcompressor/observers/rtn.py | 58 ----- src/llmcompressor/pytorch/utils/helpers.py | 61 ----- .../llmcompressor/modifiers/awq/test_base.py | 2 +- 5 files changed, 71 insertions(+), 263 deletions(-) delete mode 100644 src/llmcompressor/observers/rtn.py diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 796ad65cf..bdf3f8628 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -13,7 +13,6 @@ from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward from llmcompressor.pytorch.utils import ( - pseudo_quantize_tensor, tensor_forward_with_input_args, ) from llmcompressor.utils.fsdp.helpers import get_fsdp_parent @@ -49,7 +48,7 @@ class AWQMapping: "re:.*input_layernorm", ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], ), - # TODO this generally results in higher perplexity for llama 2 7B on wikitext + # TODO this should only be added if v_proj/o_proj shapes match up, should we check during validation and skip if this is not the case? AWQMapping("re:.*v_proj", ["re:.*o_proj"]), AWQMapping( "re:.*post_attention_layernorm", @@ -127,8 +126,6 @@ class AWQModifier(Modifier): smoothing (the second entry of the mappings list). :param num_calibration_steps: number of samples to use for calibration, or None to use the whole dataset - :param calibration_function: optional function to use for the forward pass, or None - to use the default tensor_module_forward :param group_size: number of weights to group together for scaling :param max_chunk_memory: maximum memory to use for each chunk of input activations :param bits: number of bits to quantize the weights to @@ -144,17 +141,15 @@ class AWQModifier(Modifier): mappings: List[AWQMapping] = DEFAULT_AWQ_MAPPINGS ignore: List[str] = [] num_calibration_steps: Optional[int] = None - calibration_function: Optional[Callable] = None group_size: int = 128 max_chunk_memory: int = 1024 * 1024 * 1024 bits: int = 4 symmetric: bool = False duo_scaling: bool = True - apply_clip: bool = True - resolved_mappings_: List[ResolvedMapping] = [] - scales_: Dict[str, torch.Tensor | List[torch.Tensor]] = {} - module_kwargs_: Dict = {} + _resolved_mappings: List[ResolvedMapping] = [] + _scales: Dict[str, torch.Tensor | List[torch.Tensor]] = {} + _module_kwargs: Dict = {} def on_initialize(self, state: State, **kwargs) -> bool: """ @@ -164,7 +159,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: :return: True on a successful run, False otherwise """ - self.resolved_mappings_ = self._get_resolved_mappings(state.model) + self._set_resolved_mappings(state.model) calibration_dataloader = state.data.calib @@ -184,17 +179,18 @@ def on_finalize(self, state: State, **kwargs) -> bool: :param state: unused :return: True """ - if self.scales_ is not None: - self.scales_.clear() - if self.resolved_mappings_ is not None: - self.resolved_mappings_.clear() + if self._scales is not None: + self._scales.clear() + if self._resolved_mappings is not None: + self._resolved_mappings.clear() return True - def _get_resolved_mappings(self, model: Module) -> List[ResolvedMapping]: + def _set_resolved_mappings(self, model: Module) -> None: """ Transforms the list of activations to smooth and their corresponding weights - into ResolvedMapping objects, resolving regular expressions. + into ResolvedMapping objects, resolving regular expressions. + Result is stored in _resolved_mappings. For each activation in the mapping list, we find the corresponding weight to balance by searching for the longest substring. For instance, if our balance @@ -239,7 +235,8 @@ def _get_resolved_mappings(self, model: Module) -> List[ResolvedMapping]: parent_name=parent_name, ) ) - return resolved_mappings + self._resolved_mappings = resolved_mappings + return def _setup_scale_hooks(self): """ @@ -251,14 +248,14 @@ def create_hook_fn(layer_name): def hook_fn(module, inp, out): inp = inp[0].cpu().detach() - if layer_name in self.scales_: - self.scales_[layer_name].append(inp) + if layer_name in self._scales: + self._scales[layer_name].append(inp) else: - self.scales_[layer_name] = [inp] + self._scales[layer_name] = [inp] return hook_fn - for mapping in self.resolved_mappings_: + for mapping in self._resolved_mappings: name = mapping.smooth_name # storing inps to first balance layer # is enough, as other balance layers @@ -288,7 +285,6 @@ def _calibrate(self, model: Module, calibration_dataloader: List): model, calibration_dataloader, self.num_calibration_steps, - self.calibration_function, ) # remove the hooks now that we are done calibrating @@ -299,12 +295,12 @@ def _concat_collected_activations(self): Concatenate the collected activation values from each forward pass into a single tensor for each layer - :postcondition: each layer in self.scales_ will have a single tensor containing + :postcondition: each layer in self._scales will have a single tensor containing all the activation values seen during calibration """ - for mapping in self.resolved_mappings_: + for mapping in self._resolved_mappings: name = mapping.smooth_name - self.scales_[name] = torch.cat(self.scales_[name], dim=0) + self._scales[name] = torch.cat(self._scales[name], dim=0) torch.cuda.empty_cache() @@ -318,12 +314,11 @@ def _apply_smoothing(self, model: Module): :param model: model to apply smoothing to """ logger.info("Smoothing activation scales...") - for mapping in tqdm(self.resolved_mappings_): + for mapping in tqdm(self._resolved_mappings): smooth_layer = mapping.smooth_layer balance_layers = mapping.balance_layers - balance_names = mapping.balance_names - activations = self.scales_[mapping.smooth_name] + activations = self._scales[mapping.smooth_name] module2inspect = mapping.parent @@ -370,7 +365,7 @@ def _apply_smoothing(self, model: Module): fp16_output = self._forward_input_with_kwargs( module=module2inspect, inputs=inp, - input_kwargs=self._sanitize_kwargs(self.module_kwargs_, module2inspect), + input_kwargs=self._sanitize_kwargs(self._module_kwargs, module2inspect), ) fp16_output = fp16_output.clip( torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max @@ -407,15 +402,6 @@ def smooth(module): smooth(layer) smooth(smooth_layer) - if self.apply_clip: - clip_list = self._search_best_clip( - balance_layers=balance_layers, - balance_names=balance_names, - input_feat=inp, - ) - - _apply_clip(model, clip_list) - # clear out allocated smoothing scales torch.cuda.empty_cache() @@ -432,7 +418,7 @@ def _compute_best_scale( Compute loss and select best scales L(s) = || Q(W * s) (s^-1 * X) - W * X || - Q: weight quantization function | pseudo_quantize_tensor(W * s) + Q: weight quantization function | _pseudo_quantize_tensor(W * s) X: inputs from calib dataset | X W: original weights in FP16 | layer s: per channel scaling factor | s^-1 * X @@ -461,7 +447,7 @@ def _compute_best_scale( else: scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) scales = scales / (scales.max() * scales.min()).sqrt() - scales_view = scales.view(1, -1).to(device) + _scalesview = scales.view(1, -1).to(device) # avoid scaling values that overflow scales[torch.isinf(scales)] = 1 @@ -470,22 +456,22 @@ def _compute_best_scale( # Q(W * s) for fc in linears2scale: with align_module_device(fc): - fc.weight.mul_(scales_view) + fc.weight.mul_(_scalesview) update_offload_parameter( fc, "weight", - pseudo_quantize_tensor( + _pseudo_quantize_tensor( w=fc.weight.data, symmetric=self.symmetric, bit_width=self.bits, group_size=self.group_size, )[0] - / scales_view, + / _scalesview, ) # W * X int_w_output = self._forward_input_with_kwargs( - module=module2inspect, inputs=x, input_kwargs=self.module_kwargs_ + module=module2inspect, inputs=x, input_kwargs=self._module_kwargs ) int_w_output = int_w_output.clip( torch.finfo(int_w_output.dtype).min, torch.finfo(int_w_output.dtype).max @@ -606,7 +592,7 @@ def forward(self, *args, **kwargs): best_device ) - self.module_kwargs_ = layer_kwargs + self._module_kwargs = layer_kwargs def _forward_input_with_kwargs( self, @@ -622,7 +608,7 @@ def _forward_input_with_kwargs( :param input_kwargs: additional arguments to pass to the module :return: the first output tensor from the forward pass """ - kwargs = input_kwargs or self.module_kwargs_ + kwargs = input_kwargs or self._module_kwargs kwargs = self._sanitize_kwargs(kwargs, module) return tensor_forward_with_input_args( module=module, @@ -630,84 +616,6 @@ def _forward_input_with_kwargs( input_kwargs=kwargs, )[0] - @torch.no_grad() - def _search_best_clip(self, balance_layers, balance_names, input_feat): - clip_list = [] - avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"] - - for name, layer in zip(balance_names, balance_layers): - # due to qk bmm, it is hard to clip precisely - if any([_ in name for _ in avoid_clipping]): - continue - - max_val = self._compute_best_clip(layer.weight, input_feat) - clip_list.append((name, max_val)) - - return clip_list - - @torch.no_grad() - def _compute_best_clip( - self, - w: torch.Tensor, - input_feat: torch.Tensor, - n_grid=20, - max_shrink=0.5, - n_sample_token=512, - ): - assert w.dim() == 2 - org_w_shape = w.shape - # w [co, ci] -> [co, 1, n_group, group size] - # input_feat [n_token, ci] -> [1, n_token, n_group, group size] - group_size = self.group_size if self.group_size > 0 else org_w_shape[1] - input_feat = input_feat.view(-1, input_feat.shape[-1]) - input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size) - - # Compute input feature step size (minimum 1) - step_size = max(1, input_feat.shape[1] // n_sample_token) - input_feat = input_feat[:, ::step_size] - - w = w.reshape(org_w_shape[0], 1, -1, group_size) - - oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64 # prevent OOM - assert org_w_shape[0] % oc_batch_size == 0 - w_all = w - best_max_val_all = [] - - for i_b in range(org_w_shape[0] // oc_batch_size): - w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size] - - org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1 - - best_max_val = org_max_val.clone() - min_errs = torch.ones_like(org_max_val) * 1e9 - input_feat = input_feat.to(w.device) - org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group - - for i_s in range(int(max_shrink * n_grid)): - max_val = org_max_val * (1 - i_s / n_grid) - min_val = -max_val - cur_w = torch.clamp(w, min_val, max_val) - q_w = pseudo_quantize_tensor( - w=cur_w, - symmetric=self.symmetric, - group_size=group_size, - bit_width=self.bits, - )[0] - cur_out = (input_feat * q_w).sum(dim=-1) - - # co, 1, n_group, 1 - err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape) - del cur_w - del cur_out - cur_best_idx = err < min_errs - min_errs[cur_best_idx] = err[cur_best_idx] - best_max_val[cur_best_idx] = max_val[cur_best_idx] - best_max_val_all.append(best_max_val) - - best_max_val = torch.cat(best_max_val_all, dim=0) - - return best_max_val.squeeze(1) - def _sanitize_kwargs(self, inputs_kwargs, module): """ Remove the arguments that are not supported in the module's @@ -728,22 +636,42 @@ def _sanitize_kwargs(self, inputs_kwargs, module): return sanitized_kwargs -@torch.no_grad() -def _apply_clip(module, clip_list: Tuple[str, torch.Tensor]): - """ - Apply clipping to the weights of the given module - :post-condition: the weights of the module are clipped to the given maximum values - :param module: module to apply clipping to - :param clip_list: list of tuples containing the name of the layer and the maximum - value to clip the weights to - """ - for name, max_val in clip_list: - _, layer = get_layer(target=name, module=module) - assert isinstance(layer, torch.nn.Linear) - with align_module_device(layer): - max_val = max_val.to(layer.weight.device) - org_shape = layer.weight.shape - layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1) - layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val) - layer.weight.data = layer.weight.data.reshape(org_shape) +def _pseudo_quantize_tensor( + w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 +): + org_w_shape = w.shape + if group_size > 0: + assert org_w_shape[-1] % group_size == 0, f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({group_size})!" + w = w.reshape(-1, group_size) + assert w.dim() == 2 + assert torch.isnan(w).sum() == 0 + + # zero point quantization + if not symmetric: + max_val = w.amax(dim=1, keepdim=True) + min_val = w.amin(dim=1, keepdim=True) + max_int = 2**bit_width - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-5) / max_int + zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) + w = ( + torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros + ) * scales + zeros = (zeros - 2**(bit_width-1)).view(org_w_shape[0], -1) + else: + max_val = w.abs().amax(dim=1, keepdim=True) + max_val = max_val.clamp(min=1e-5) + max_int = 2 ** (bit_width - 1) - 1 + min_int = -(2 ** (bit_width - 1)) + scales = max_val / max_int + zeros = None + w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales + + assert torch.isnan(scales).sum() == 0 + assert torch.isnan(w).sum() == 0 + + scales = scales.view(org_w_shape[0], -1) + w = w.reshape(org_w_shape) + + return w, scales, zeros diff --git a/src/llmcompressor/observers/__init__.py b/src/llmcompressor/observers/__init__.py index e16d9d93b..4c3ee5a88 100644 --- a/src/llmcompressor/observers/__init__.py +++ b/src/llmcompressor/observers/__init__.py @@ -5,4 +5,3 @@ from .base import * from .min_max import * from .mse import * -from .rtn import * diff --git a/src/llmcompressor/observers/rtn.py b/src/llmcompressor/observers/rtn.py deleted file mode 100644 index 889b03318..000000000 --- a/src/llmcompressor/observers/rtn.py +++ /dev/null @@ -1,58 +0,0 @@ -from typing import Any, Optional, Tuple - -import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs -from compressed_tensors.quantization.utils import calculate_qparams -from compressed_tensors.utils import deprecated - -from llmcompressor.observers.base import Observer -from llmcompressor.pytorch.utils import pseudo_quantize_tensor - -__all__ = ["RoundToNearestObserver"] - - -@Observer.register("rtn") -class RoundToNearestObserver(Observer): - """ - Implements a quantization observer that calculates scale and zero point based on the - minimum and maximum values of the tensor being observed. If averaging_constant is - specified, then the scales are updated using a moving average - """ - - def calculate_qparams( - self, - observed: torch.Tensor, - reduce_dims: Optional[Tuple[int]] = None, - tensor_id: Optional[Any] = None, - ) -> Tuple[torch.FloatTensor, torch.IntTensor]: - """ - Updates the observed min and max using a moving average smoothed by the - averaging_constant. Set the averaging_constant to 1.0 to disable averaging. - - :param observed: observed tensor to calculate quantization parameters for - :param reduce_dims: optional tuple of dimensions to reduce along, - returned scale and zero point will be shaped (1,) along the - reduced dimensions - :param tensor_id: Optional id if different ranges of observed tensors are - passed, useful for sharding tensors by group_size - :return: tuple of scale and zero point derived from the observed tensor - """ - - _, scales, zp = pseudo_quantize_tensor( - observed, - symmetric=self.quantization_args.symmetric, - bit_width=self.quantization_args.num_bits, - group_size=self.quantization_args.group_size or -1, - ) - return (scales, zp) - - def get_qparams_along_dim( - self, observed: torch.Tensor, dim: int, tensor_id: Optional[Any] = None - ): - """ - Calculate quantization parameters along the specified dimension - """ - reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim) - return self.calculate_qparams( - observed, reduce_dims=reduce_dims, tensor_id=tensor_id - ) diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index 97754ce92..4e3fb6a5a 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -37,8 +37,6 @@ "detach", "adjust_quantization_for_onnx_export", "get_dependency_order", - "pseudo_quantize_tensor", - "pseudo_dequantize_linear", "tensor_forward_with_input_args", "sanitize_kwargs_for_module", ] @@ -664,62 +662,3 @@ def swap_modules( return cur - -def pseudo_quantize_tensor( - w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 -): - org_w_shape = w.shape - if group_size > 0: - assert org_w_shape[-1] % group_size == 0, f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({group_size})!" - w = w.reshape(-1, group_size) - assert w.dim() == 2 - assert torch.isnan(w).sum() == 0 - - # zero point quantization - if not symmetric: - max_val = w.amax(dim=1, keepdim=True) - min_val = w.amin(dim=1, keepdim=True) - max_int = 2**bit_width - 1 - min_int = 0 - scales = (max_val - min_val).clamp(min=1e-5) / max_int - zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) - w = ( - torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros - ) * scales - zeros = (zeros - 2**(bit_width-1)).view(org_w_shape[0], -1) - else: - max_val = w.abs().amax(dim=1, keepdim=True) - max_val = max_val.clamp(min=1e-5) - max_int = 2 ** (bit_width - 1) - 1 - min_int = -(2 ** (bit_width - 1)) - scales = max_val / max_int - zeros = None - w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales - - assert torch.isnan(scales).sum() == 0 - assert torch.isnan(w).sum() == 0 - - scales = scales.view(org_w_shape[0], -1) - w = w.reshape(org_w_shape) - - return w, scales, zeros - - -def pseudo_dequantize_linear( - w: torch.Tensor, - scales: torch.Tensor, - zeros: Optional[torch.Tensor] = None, - symmetric: bool = False, -): - # get repeated count - repeat_count = w.weight.data.shape[-1] // scales.shape[-1] - scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape) - - # dequantize - if not symmetric: - zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape) - w = (w.weight.data - zeros) * scales - else: - w = w.weight.data * scales - - return w diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 918238718..3fff33e23 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -24,5 +24,5 @@ def test_awq_is_registered(self): self.assertIsInstance( modifier, AWQModifier, - "PyTorch AWQModifier not registered", + "AWQModifier not registered", ) From d352bcf2f4b970b9433eab9580a083074a7b8ac9 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 10 Mar 2025 19:39:30 +0000 Subject: [PATCH 17/40] rename smoothquant private vars Signed-off-by: Brian Dellabetta --- .../modifiers/smoothquant/base.py | 36 +++++++++---------- .../logarithmic_equalization/test_pytorch.py | 6 ++-- .../modifiers/smoothquant/test_pytorch.py | 6 ++-- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 1b1e0aee6..037fe1219 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -109,8 +109,8 @@ class SmoothQuantModifier(Modifier): num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None - resolved_mappings_: Optional[List[SmoothQuantMapping]] = None - scales_: Optional[Dict] = None + _resolved_mappings: Optional[List[SmoothQuantMapping]] = None + _scales: Optional[Dict] = None def on_initialize(self, state: State, **kwargs) -> bool: """ @@ -132,8 +132,8 @@ def on_initialize(self, state: State, **kwargs) -> bool: self.ignore = [] if not self.ignore else self.ignore self.mappings = self._infer_mappings_from_model(state.model) - self.resolved_mappings_ = self._resolve_mappings(state.model) - self.scales_ = {} + self._resolved_mappings = self._resolve_mappings(state.model) + self._scales = {} calibration_dataloader = state.data.calib @@ -150,10 +150,10 @@ def on_finalize(self, state: State, **kwargs) -> bool: :param state: unused :return: True """ - if self.scales_ is not None: - self.scales_.clear() - if self.resolved_mappings_ is not None: - self.resolved_mappings_.clear() + if self._scales is not None: + self._scales.clear() + if self._resolved_mappings is not None: + self._resolved_mappings.clear() return True @@ -219,21 +219,21 @@ def hook_fn(module, inp, out): latest_mins = torch.min(out, dim=0)[0] latest_maxes = torch.max(out, dim=0)[0] - if layer_name in self.scales_: - self.scales_[layer_name].min_channel_vals = torch.minimum( - self.scales_[layer_name].min_channel_vals, latest_mins + if layer_name in self._scales: + self._scales[layer_name].min_channel_vals = torch.minimum( + self._scales[layer_name].min_channel_vals, latest_mins ) - self.scales_[layer_name].max_channel_vals = torch.maximum( - self.scales_[layer_name].max_channel_vals, latest_maxes + self._scales[layer_name].max_channel_vals = torch.maximum( + self._scales[layer_name].max_channel_vals, latest_maxes ) else: - self.scales_[layer_name] = SmoothQuantScale( + self._scales[layer_name] = SmoothQuantScale( min_channel_vals=latest_mins, max_channel_vals=latest_maxes ) return hook_fn - for mapping in self.resolved_mappings_: + for mapping in self._resolved_mappings: name = mapping.smooth_name layer = mapping.smooth_layer self.register_hook(layer, create_hook_fn(name), "forward") @@ -278,10 +278,10 @@ def _apply_smoothing(self, model: Module): This modifies the weights of the model in-place. """ logger.info("Smoothing activation scales...") - for mapping in self.resolved_mappings_: + for mapping in self._resolved_mappings: activation_scales = ( # get dynamic range for each activation channel - self.scales_[mapping.smooth_name].max_channel_vals - - self.scales_[mapping.smooth_name].min_channel_vals + self._scales[mapping.smooth_name].max_channel_vals + - self._scales[mapping.smooth_name].min_channel_vals ) smooth_layer = mapping.smooth_layer balance_layers = mapping.balance_layers diff --git a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py index d485c0637..e84f66e83 100644 --- a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py @@ -21,11 +21,11 @@ def test_successful_map(self): modifier = LogarithmicEqualizationModifier(mappings=mappings) modifier.ignore = [] - modifier.resolved_mappings_ = modifier._resolve_mappings(self.state.model) + modifier._resolved_mappings = modifier._resolve_mappings(self.state.model) - self.assertEqual(len(modifier.resolved_mappings_), len(mappings)) + self.assertEqual(len(modifier._resolved_mappings), len(mappings)) - mapping = modifier.resolved_mappings_[0] + mapping = modifier._resolved_mappings[0] self.assertEqual(mapping.smooth_name, mappings[0][1]) self.assertIsInstance(mapping.smooth_layer, Linear) self.assertIsInstance(mapping.balance_layers[0], Linear) diff --git a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py index 7977c4546..cbb60f030 100644 --- a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py @@ -19,11 +19,11 @@ def test_successful_map(self): modifier = SmoothQuantModifier(mappings=mappings) modifier.ignore = [] - modifier.resolved_mappings_ = modifier._resolve_mappings(self.state.model) + modifier._resolved_mappings = modifier._resolve_mappings(self.state.model) - self.assertEqual(len(modifier.resolved_mappings_), len(mappings)) + self.assertEqual(len(modifier._resolved_mappings), len(mappings)) - mapping = modifier.resolved_mappings_[0] + mapping = modifier._resolved_mappings[0] self.assertEqual(mapping.smooth_name, mappings[0][1]) self.assertIsInstance(mapping.smooth_layer, Linear) self.assertIsInstance(mapping.balance_layers[0], Linear) From 7ed2e723c3b66ed00e0777b4b2915eb5817b40a0 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 19 Mar 2025 09:35:50 -0500 Subject: [PATCH 18/40] squashed codereview updates for rebase Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 44 ++++++++++++------- src/llmcompressor/pytorch/utils/helpers.py | 1 - .../transformers/finetune/data/pile.py | 27 ------------ .../finetune/data/test_registry.py | 1 - 4 files changed, 28 insertions(+), 45 deletions(-) delete mode 100644 src/llmcompressor/transformers/finetune/data/pile.py diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index bdf3f8628..18d2c4541 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,24 +1,21 @@ import inspect from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Union import torch from compressed_tensors.utils import align_module_device, update_offload_parameter from loguru import logger -from pydantic import ConfigDict, Field +from pydantic import ConfigDict from torch.nn import Module from tqdm import tqdm from llmcompressor.core import State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.utils.pytorch_helpers import run_calibration_forward -from llmcompressor.pytorch.utils import ( - tensor_forward_with_input_args, -) +from llmcompressor.pytorch.utils import tensor_forward_with_input_args from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import ( - get_layer, get_layers, get_matching_layer, get_parent_by_name, @@ -48,7 +45,8 @@ class AWQMapping: "re:.*input_layernorm", ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], ), - # TODO this should only be added if v_proj/o_proj shapes match up, should we check during validation and skip if this is not the case? + # TODO this should only be added if v_proj/o_proj shapes match up + # should we check during validation and skip if this is not the case? AWQMapping("re:.*v_proj", ["re:.*o_proj"]), AWQMapping( "re:.*post_attention_layernorm", @@ -148,7 +146,7 @@ class AWQModifier(Modifier): duo_scaling: bool = True _resolved_mappings: List[ResolvedMapping] = [] - _scales: Dict[str, torch.Tensor | List[torch.Tensor]] = {} + _scales: Dict[str, Union[torch.Tensor, List[torch.Tensor]]] = {} _module_kwargs: Dict = {} def on_initialize(self, state: State, **kwargs) -> bool: @@ -189,7 +187,7 @@ def on_finalize(self, state: State, **kwargs) -> bool: def _set_resolved_mappings(self, model: Module) -> None: """ Transforms the list of activations to smooth and their corresponding weights - into ResolvedMapping objects, resolving regular expressions. + into ResolvedMapping objects, resolving regular expressions. Result is stored in _resolved_mappings. For each activation in the mapping list, we find the corresponding weight to @@ -385,13 +383,25 @@ def smooth(module): module.weight.mul_(scales.view(1, -1).to(module.weight.device)) elif module == smooth_layer: if module.weight.ndim == 1: - module.weight.div_(scales.to(module.weight.device)) + update_offload_parameter( + module, + "weight", + module.weight.div(scales.to(module.weight.device)), + ) else: - module.weight.div_( - scales.view(-1, 1).to(module.weight.device) + update_offload_parameter( + module, + "weight", + module.weight.div( + scales.view(-1, 1).to(module.weight.device) + ), ) if hasattr(module, "bias") and module.bias is not None: - module.bias.div_(scales.to(module.bias.device)) + update_offload_parameter( + module, + "bias", + module.bias.div(scales.to(module.bias.device)), + ) parent = get_fsdp_parent(mapping.smooth_name, model) if parent is not None: @@ -636,13 +646,15 @@ def _sanitize_kwargs(self, inputs_kwargs, module): return sanitized_kwargs - def _pseudo_quantize_tensor( w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 ): org_w_shape = w.shape if group_size > 0: - assert org_w_shape[-1] % group_size == 0, f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({group_size})!" + assert org_w_shape[-1] % group_size == 0, ( + f"org_w_shape ({org_w_shape[-1]}) must be a multiple " + + f"of group_size ({group_size})!" + ) w = w.reshape(-1, group_size) assert w.dim() == 2 assert torch.isnan(w).sum() == 0 @@ -658,7 +670,7 @@ def _pseudo_quantize_tensor( w = ( torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros ) * scales - zeros = (zeros - 2**(bit_width-1)).view(org_w_shape[0], -1) + zeros = (zeros - 2 ** (bit_width - 1)).view(org_w_shape[0], -1) else: max_val = w.abs().amax(dim=1, keepdim=True) max_val = max_val.clamp(min=1e-5) diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index 4e3fb6a5a..f56a1ceab 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -661,4 +661,3 @@ def swap_modules( parent.__setattr__(sections[-1], submodule_to_replace) return cur - diff --git a/src/llmcompressor/transformers/finetune/data/pile.py b/src/llmcompressor/transformers/finetune/data/pile.py deleted file mode 100644 index f420ba2a5..000000000 --- a/src/llmcompressor/transformers/finetune/data/pile.py +++ /dev/null @@ -1,27 +0,0 @@ -from copy import deepcopy -from typing import TYPE_CHECKING - -from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.typing import Processor - -if TYPE_CHECKING: - from llmcompressor.args import DatasetArguments - - -@TextGenerationDataset.register(name="mit-han-lab/pile-val-backup", alias="pile_val") -class PileValDataset(TextGenerationDataset): - """ - Child text generation class for "The Pile" dataset - :param data_args: configuration settings for dataset loading - :param split: split from dataset to load, for instance `test` or `train[:5%]` - :param tokenizer: tokenizer to use on dataset - """ - - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.text_column = "text" - data_args.dataset = "mit-han-lab/pile-val-backup" - super().__init__(data_args=data_args, split=split, processor=processor) - - def dataset_template(self, sample): - return {"text": sample["text"].strip()} diff --git a/tests/llmcompressor/transformers/finetune/data/test_registry.py b/tests/llmcompressor/transformers/finetune/data/test_registry.py index ce872fba9..29895b4a4 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_registry.py +++ b/tests/llmcompressor/transformers/finetune/data/test_registry.py @@ -4,7 +4,6 @@ from llmcompressor.transformers.finetune.data import ( C4Dataset, OpenPlatypusDataset, - PileEvalDataset, TextGenerationDataset, WikiTextDataset, ) From ea41fe5942c2a4356f1e50069505cc51f180a50a Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 19 Mar 2025 09:56:54 -0500 Subject: [PATCH 19/40] cleanup fixes from rebase Signed-off-by: Brian Dellabetta --- src/llmcompressor/pytorch/utils/helpers.py | 380 +----------------- .../pytorch/utils/test_helpers.py | 218 ---------- 2 files changed, 2 insertions(+), 596 deletions(-) diff --git a/src/llmcompressor/pytorch/utils/helpers.py b/src/llmcompressor/pytorch/utils/helpers.py index f56a1ceab..1eee9c034 100644 --- a/src/llmcompressor/pytorch/utils/helpers.py +++ b/src/llmcompressor/pytorch/utils/helpers.py @@ -4,7 +4,6 @@ import functools import inspect -import os import random from typing import Any, Dict, Iterable, List, Mapping, OrderedDict, Tuple, Union @@ -27,18 +26,11 @@ "tensors_to_precision", "tensors_module_forward", "tensor_sparsity", + "tensor_forward_with_input_args", + "sanitize_kwargs_for_module", "get_linear_layers", "get_quantized_layers", "set_deterministic_seeds", - "torch_distributed_zero_first", - "thin_model_from_checkpoint", - "MEMORY_BOUNDED", - "memory_aware_threshold", - "detach", - "adjust_quantization_for_onnx_export", - "get_dependency_order", - "tensor_forward_with_input_args", - "sanitize_kwargs_for_module", ] @@ -212,110 +204,6 @@ def tensor_sparsity( return zeros.float() / float(total) -def tensor_density(tens: Tensor, dim: Union[None, int, Iterable[int]] = None) -> Tensor: - """ - :param tens: the tensor to calculate the density for - :param dim: the dimension(s) to split the calculations over; ex, can split over - batch, channels, or combos - :return: the density of the input tens, ie the fraction of numbers that are non zero - """ - density = (tensor_sparsity(tens, dim) - 1.0) * -1.0 - - return density - - -def tensor_sample( - tens: Tensor, - sample_size: int, - dim: Union[None, int, List[int], Tuple[int, ...]] = None, -) -> Tensor: - """ - :param tens: the tensor to grab samples from - :param sample_size: the number of samples to grab overall if dim is not supplied - or per each dim if it is - :param dim: the dimension(s) to split the samples over; - ex, can split over batch, channels, or combos - :return: the sampled tensor - """ - if sample_size < 1: - raise ValueError("improper sample size given of {}".format(sample_size)) - - if dim is None: - indices = tens.new_zeros((sample_size,)).long().random_(0, tens.numel()) - samples = tens.view(-1)[indices] - - return samples - - if isinstance(dim, int): - dim = [dim] - - if max(dim) >= len(tens.shape): - raise ValueError( - "Unsupported dim given of {} in {} for tensor shape {}".format( - max(dim), dim, tens.shape - ) - ) - - if dim != [ind for ind in range(len(dim))]: - # put the desired dimension(s) at the front to sample from - tens = tens.permute( - *dim, *[ind for ind in range(len(tens.shape)) if ind not in dim] - ) - dim = [ind for ind in range(len(dim))] - - if not tens.is_contiguous(): - tens = tens.contiguous() - - num_indices = int(numpy.prod([tens.shape[ind] for ind in range(len(dim))])) - elem_per_ind = int( - numpy.prod([tens.shape[ind] for ind in range(len(dim), len(tens.shape))]) - ) - # create a new tensor with offsets set for each of our elements that we are indexing - indices = tens.new_tensor( - [ind * elem_per_ind for ind in range(num_indices)], dtype=torch.long - ).unsqueeze(1) - # now broadcast it across to the total number of elements we should end with - indices = indices * tens.new_ones((num_indices, sample_size), dtype=torch.long) - # finally add in a random number within the available range per index - indices += tens.new_zeros((num_indices, sample_size), dtype=torch.long).random_( - 0, elem_per_ind - ) - # get our samples - samples = tens.view(-1)[indices.view(-1)] - # reshape for the proper dimension - samples = samples.view(*(tens.shape[ind] for ind in dim), sample_size) - - return samples - - -def tensor_list_sparsity(tensors: List[Tensor]) -> float: - """ - :param tensors: the list of tensors to calculate the sparsity for - :return: the total sparsity of all tensors in the list - """ - zeros = 0 - numel = 0 - for tensor in tensors: - zeros += (tensor == 0).sum().item() - numel += tensor.numel() - return float(zeros) / float(numel) - - -def mask_difference(old_mask: Tensor, new_mask: Tensor) -> Tensor: - """ - :param old_mask: the old mask to compare against for calculating the difference - :param new_mask: the new mask to compare with for calculating the difference - :return: a tensor representing the change from the old_mask to the new_mask - specifically values returned as 1.0 are newly unmasked (0.0 => 1.0) - values returned as -1.0 are newly masked (1.0 => 0.0) - values returned as 0.0 had no change in (0.0 => 0.0 or 1.0 => 1.0) - """ - newly_masked = ((old_mask != new_mask) & (new_mask == 0.0)).type(old_mask.type()) - newly_unmasked = ((old_mask != new_mask) & (new_mask == 1.0)).type(old_mask.type()) - - return -1.0 * newly_masked + newly_unmasked - - def sanitize_kwargs_for_module( kwargs: Dict[str, Any], module: Module ) -> Dict[str, Any]: @@ -397,267 +285,3 @@ def set_deterministic_seeds(seed: int = 0): random.seed(seed) torch.manual_seed(seed) torch.backends.cudnn.deterministic = True - - -@contextmanager -def torch_distributed_zero_first(local_rank: Optional[int]): - """ - Decorator to make all processes in distributed training wait for each - local 0 ranked process to do something. - :param local_rank: the local rank of this process - """ - if local_rank is not None and local_rank not in [-1, 0]: - torch.distributed.barrier() - yield - if local_rank == 0: - torch.distributed.barrier() - - -def thin_model_from_checkpoint(model: Module, state_dict: Dict[str, Any]): - """ - Updates any Linear/Conv/BN layers in the given model to match their - respective shapes in the given state dict. Purpose of compatibility - when loading weight for a model from a checkpoint of the same architecture - but with potentially structured thinning applied. Note that this function - has no guarantees on accuracy, will only resize model parameters for - loading compatibility. All adjustments done in place - - :param model: model to potentially adjust parameter shapes of - :param state_dict: state dict to infer parameter shapes from - """ - first_thinned = True - for param_name, checkpoint_tens in state_dict.items(): - if not param_name.endswith(".weight"): - continue # only deal with weight params of modules - layer_name = param_name[:-7] - layer = get_layer(layer_name, model) - - if not hasattr(layer, "weight") or ( - layer.weight.shape == checkpoint_tens.shape - ): - continue # skip if there is no update to shape - - # quick check that target layer is some flavor of FC/Conv/BN - layer_type = layer.__class__.__name__ - if not ( - "Linear" not in layer_type - or "Conv" not in layer_type - or ("BatchNorm" not in layer_type) - ): - continue - - orig_shape = layer.weight.shape - target_shape = checkpoint_tens.shape - - # update weight param + grad - if len(target_shape) > 1: - layer.weight.data = layer.weight.data[ - : target_shape[0], : target_shape[1], ... - ] - if layer.weight.grad is not None: - layer.weight.grad = layer.weight.grad[ - : target_shape[0], : target_shape[1], ... - ] - else: - layer.weight.data = layer.weight.data[: target_shape[0]] - if layer.weight.grad is not None: - layer.weight.grad = layer.weight.grad[: target_shape[0]] - - # update bias param + grad - if hasattr(layer, "bias") and layer.bias is not None: - # target output channels should be the first dim of target shape - layer.bias.data = layer.bias.data[: target_shape[0]] - if layer.bias.grad is not None: - layer.bias.grad = layer.bias.grad[: target_shape[0]] - - # update layer attributes - if "BatchNorm" in layer_type: - if hasattr(layer, "num_features"): - layer.num_features = layer.weight.size(0) - # BN running mean and var are not stored as Parameters - if hasattr(layer, "running_mean"): - layer.running_mean = torch.zeros_like(layer.running_mean)[ - : target_shape[0] - ] - if hasattr(layer, "running_var"): - layer.running_var = torch.zeros_like(layer.running_var)[ - : target_shape[0] - ] - - if "Linear" in layer_type: - if hasattr(layer, "out_features"): - layer.out_features = layer.weight.shape[0] - if hasattr(layer, "in_features"): - layer.in_features = layer.weight.shape[1] - - if "Conv" in layer_type: - if hasattr(layer, "out_channels"): - layer.out_channels = layer.weight.shape[0] - if hasattr(layer, "in_channels"): - layer.in_channels = layer.weight.shape[1] - if hasattr(layer, "groups") and layer.groups > 1: - layer.groups = layer.weight.shape[0] // layer.weight.shape[1] - - if first_thinned: - logger.info( - "Thinning module layers for compatibility with given state dict:" - ) - first_thinned = False - logger.info( - f"Thinned layer {layer_name} from shape {orig_shape} to " - f"{layer.weight.shape}" - ) - - -############################## -# -# misc pytorch helper functions -# -############################## - - -MEMORY_BOUNDED = "MEMORY_BOUNDED" - - -def memory_aware_threshold(tensor: torch.Tensor, idx: int) -> Tensor: - """ - Finds a threshold at the lookup idx in the most efficient way with available - resources. Will be phased out when GPU-memory overhead of torch.sort reduces, - or when torch.kthvalue becomes faster than torch.sort. - - :param tensor: A tensor to find a k-th smallest value in, where k=idx+1 - :param idx: A lookup index - :return: k-th smallest value from the given tensor, where k=idx+1 - """ - try: - if ( - MEMORY_BOUNDED in os.environ - and os.environ[MEMORY_BOUNDED].lower() == "true" - ): - return torch.kthvalue(tensor.reshape(-1), idx + 1)[0] - else: - return torch.sort(tensor.reshape(-1))[0][idx] - except RuntimeError: - logger.warning( - "Finding threshold from sparsity failed due to lack of memory, " - "will attempt to recover. Consider setting env variable " - f"{MEMORY_BOUNDED}=True in future runs." - ) - torch.cuda.empty_cache() - os.environ[MEMORY_BOUNDED] = "True" - return torch.kthvalue(tensor.view(-1), idx + 1)[0] - - -def detach(x: Union[torch.Tensor, List, Tuple]): - if isinstance(x, torch.Tensor): - return x.detach() - elif isinstance(x, List): - return [detach(e) for e in x] - elif isinstance(x, Tuple): - return tuple([detach(e) for e in x]) - else: - raise ValueError("Unexpected type to detach") - - -def adjust_quantization_for_onnx_export(module: torch.nn.Module) -> torch.nn.Module: - # supported pytorch ranges are int8 or uint8 - allowed_ranges = [(0, 127), (0, 255), (-128, 127)] - fake_quant_modules = [ - m for m in module.modules() if m.__class__.__name__ == "FakeQuantize" - ] - - if _PARSED_TORCH_VERSION >= version.parse("1.12"): - for quant in fake_quant_modules: - # original ranges preserved in quant.quant_min and quant.quant_max - quant_range = ( - quant.activation_post_process.quant_min, - quant.activation_post_process.quant_max, - ) - if quant_range not in allowed_ranges: - if quant_range[0] < 0: # convert signed range to int8 - quant.activation_post_process.quant_min = -128 - quant.activation_post_process.quant_max = 127 - else: # convert unsigned range to uint8 - quant.activation_post_process.quant_min = 0 - quant.activation_post_process.quant_max = 255 - # don't update observer since ranges are artificially modified - quant.observer_enabled[0] = 0 - - else: # backwards compatibility for torch <= 1.11 - for quant in fake_quant_modules: - quant_range = (quant.quant_min, quant.quant_max) - if quant_range not in allowed_ranges: - if quant_range[0] < 0: # convert signed range to int8 - quant.quant_min = -128 - quant.quant_max = 127 - else: # convert unsigned range to uint8 - quant.quant_min = 0 - quant.quant_max = 255 - # don't update observer since ranges are artificially modified - quant.observer_enabled[0] = 0 - - -def get_dependency_order( - layer: Module, subset: Dict, an_input: Tensor, **kwargs -) -> List[str]: - """ - Get a list of a subset of modules in layer ordered by execution order, which honors - the dependencies in the graph - - :param layer: pytorch module to calculate dependencies for - :param subset: subset of modules in the layer to include in the ordering - :param an_input: example input to pass through the layer forward pass, used to - determine execution order - - :return: list of module names in execution order - """ - order = [] - - def exe_input(name): - def _exe_input(_, inp, out): - if name in subset: - order.append(name) - - return _exe_input - - # register a hook for each module of interest, will be triggered in exeuction order - handles = [subset[name].register_forward_hook(exe_input(name)) for name in subset] - layer(an_input, **kwargs) - for h in handles: - h.remove() - return order - - -def swap_modules( - module: torch.nn.Module, submodule_name: str, submodule_to_replace: torch.nn.Module -) -> torch.nn.Module: - """ - Iteratively unfold the submodules of the module according to the submodule_name - to eventually replace the leaf submodule (accessed from the module through the - submodule_name) with the submodule_to_replace. - - E.g - ``` - swap_modules(module=Model, - module_name="layers.0.sublayer", - module_to_replace=ReplaceModule - ) - ``` - this will iteratively traverse through the submodules - 'layers' -> '0' -> to eventually replace 'sublayer' with ReplaceModule - - :param module: the module to replace with the module_to_replace - :param submodule_name: the name of the module to replace - :param submodule_to_replace: the module to replace the module with - :return: the replaced module - """ - parent = module - sections = submodule_name.split(".") - - for sec in sections[:-1]: - parent = parent.__getattr__(sec) - - cur = parent.__getattr__(sections[-1]) - parent.__setattr__(sections[-1], submodule_to_replace) - - return cur diff --git a/tests/llmcompressor/pytorch/utils/test_helpers.py b/tests/llmcompressor/pytorch/utils/test_helpers.py index 7d844cb00..9e2bb373f 100644 --- a/tests/llmcompressor/pytorch/utils/test_helpers.py +++ b/tests/llmcompressor/pytorch/utils/test_helpers.py @@ -7,17 +7,8 @@ from torch.nn import Linear, Module, ReLU, Sequential from llmcompressor.pytorch.utils import ( - MEMORY_BOUNDED, - default_device, - get_optim_learning_rate, - mask_difference, - memory_aware_threshold, sanitize_kwargs_for_module, - set_optim_learning_rate, - tensor_density, - tensor_export, tensor_forward_with_input_args, - tensor_sample, tensor_sparsity, tensors_module_forward, tensors_to_device, @@ -507,215 +498,6 @@ def test_tensor_sparsity_cuda(tensor, dim, expected_sparsity): assert torch.sum((sparsity.detach().cpu() - expected_sparsity).abs()) < 0.001 -@pytest.mark.flaky(reruns=2, min_passes=1) -@pytest.mark.skipif( - os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False), - reason="Skipping pytorch tests", -) -@pytest.mark.parametrize( - "tensor,dim,expected_density", - [ - (torch.zeros(8, 16), None, torch.tensor(0.0)), - (torch.zeros(8, 16), 0, torch.zeros(8)), - (torch.zeros(8, 16), 1, torch.zeros(16)), - (torch.zeros(8, 16), [0, 1], torch.zeros(8, 16)), - (torch.zeros(8, 16), [1, 0], torch.zeros(16, 8)), - (torch.zeros(8, 16, 32, 8), [3, 1, 2], torch.zeros(8, 16, 32)), - (torch.ones(8, 16), None, torch.tensor(1.0)), - (torch.ones(8, 16), 0, torch.ones(8)), - (torch.ones(8, 16), 1, torch.ones(16)), - (torch.ones(8, 16), [0, 1], torch.ones(8, 16)), - (torch.ones(8, 16), [1, 0], torch.ones(16, 8)), - (torch.ones(8, 16, 32, 8), [3, 1, 2], torch.ones(8, 16, 32)), - (torch.randn(8, 16), None, torch.tensor(1.0)), - (torch.randn(8, 16), 0, torch.ones(8)), - (torch.randn(8, 16), 1, torch.ones(16)), - (torch.randn(8, 16), [0, 1], torch.ones(8, 16)), - (torch.randn(8, 16), [1, 0], torch.ones(16, 8)), - (torch.randn(8, 16, 32, 8), [3, 1, 2], torch.ones(8, 16, 32)), - ( - torch.tensor([10.0, 0.0, 1.0, 3.0, 2.0, 0.0, 8.0, 0.0, 5.0, 0.0]), - None, - torch.tensor(0.6), - ), - ], -) -def test_tensor_density(tensor, dim, expected_density): - density = tensor_density(tensor, dim) - assert expected_density.shape == density.shape - assert torch.sum((density - expected_density).abs()) < 0.001 - - -@pytest.mark.flaky(reruns=2, min_passes=1) -@pytest.mark.skipif( - os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False), - reason="Skipping pytorch tests", -) -@pytest.mark.parametrize( - "tensor,dim,expected_density", - [ - (torch.zeros(8, 16), None, torch.tensor(0.0)), - (torch.zeros(8, 16, 32, 8), [3, 1, 2], torch.zeros(8, 16, 32)), - (torch.ones(8, 16), None, torch.tensor(1.0)), - (torch.ones(8, 16, 32, 8), [3, 1, 2], torch.ones(8, 16, 32)), - (torch.randn(8, 16), None, torch.tensor(1.0)), - ( - torch.tensor([10.0, 0.0, 1.0, 3.0, 2.0, 0.0, 8.0, 0.0, 5.0, 0.0]), - None, - torch.tensor(0.6), - ), - ], -) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda availability") -def test_tensor_density_cuda(tensor, dim, expected_density): - tensor = tensor.to("cuda") - density = tensor_density(tensor, dim) - assert expected_density.shape == density.shape - assert torch.sum((density.detach().cpu() - expected_density).abs()) < 0.001 - - -@pytest.mark.skipif( - os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False), - reason="Skipping pytorch tests", -) -@pytest.mark.parametrize( - "tensor,size,dim,expected_shape", - [ - (torch.randn(8, 16), 100, None, [100]), - (torch.randn(8, 16), 100, 0, [8, 100]), - (torch.randn(8, 16), 100, 1, [16, 100]), - (torch.randn(8, 16), 10, [0, 1], [8, 16, 10]), - (torch.randn(8, 16), 10, [1, 0], [16, 8, 10]), - (torch.randn(64, 12, 32, 16), 10, 2, [32, 10]), - (torch.randn(64, 12, 32, 16), 10, [3, 2], [16, 32, 10]), - (torch.randn(64, 12, 32, 16), 10, 1, [12, 10]), - (torch.randn(64, 12, 32, 16), 10, [0, 1], [64, 12, 10]), - ], -) -def test_tensor_sample(tensor, size, dim, expected_shape): - sample = tensor_sample(tensor, size, dim) - assert len(sample.shape) == len(expected_shape) - for s1, s2 in zip(sample.shape, expected_shape): - assert s1 == s2 - - -@pytest.mark.skipif( - os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False), - reason="Skipping pytorch tests", -) -@pytest.mark.parametrize( - "tensor,size,dim,expected_shape", - [ - (torch.randn(8, 16), 100, None, [100]), - (torch.randn(8, 16), 100, 0, [8, 100]), - (torch.randn(8, 16), 100, 1, [16, 100]), - (torch.randn(8, 16), 10, [0, 1], [8, 16, 10]), - (torch.randn(8, 16), 10, [1, 0], [16, 8, 10]), - (torch.randn(64, 12, 32, 16), 10, 2, [32, 10]), - (torch.randn(64, 12, 32, 16), 10, [3, 2], [16, 32, 10]), - (torch.randn(64, 12, 32, 16), 10, 1, [12, 10]), - (torch.randn(64, 12, 32, 16), 10, [0, 1], [64, 12, 10]), - ], -) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda availability") -def test_tensor_sample_cuda(tensor, size, dim, expected_shape): - tensor = tensor.to("cuda") - sample = tensor_sample(tensor, size, dim) - assert len(sample.shape) == len(expected_shape) - for s1, s2 in zip(sample.shape, expected_shape): - assert s1 == s2 - - -@pytest.mark.skipif( - os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False), - reason="Skipping pytorch tests", -) -@pytest.mark.parametrize( - "old_mask,new_mask,expected_diff", - [ - (torch.zeros(8, 8), torch.zeros(8, 8), torch.zeros(8, 8)), - (torch.zeros(8, 8), torch.ones(8, 8), torch.ones(8, 8)), - (torch.ones(8, 8), torch.zeros(8, 8), -1.0 * torch.ones(8, 8)), - (torch.ones(8, 8), torch.ones(8, 8), torch.zeros(8, 8)), - ( - torch.tensor([0.0, 0.0, 1.0, 0.0, 1.0, 1.0]), - torch.tensor([0.0, 1.0, 0.0, 0.0, 0.0, 1.0]), - torch.tensor([0.0, 1.0, -1.0, 0.0, -1.0, 0.0]), - ), - ], -) -def test_mask_difference(old_mask, new_mask, expected_diff): - diff = mask_difference(old_mask, new_mask) - assert torch.sum((diff - expected_diff).abs()) < sys.float_info.epsilon - - -@pytest.mark.skipif( - os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False), - reason="Skipping pytorch tests", -) -@pytest.mark.parametrize( - "model,state_dict,test_input", - [ - ( - Sequential(Conv2d(3, 16, (1, 1)), BatchNorm2d(16), Conv2d(16, 16, (1, 1))), - { - "0.weight": torch.randn(8, 3, 1, 1), - "0.bias": torch.randn(8), - "1.weight": torch.randn(8), - "1.bias": torch.randn(8), - "1.running_mean": torch.randn(8), - "1.running_var": torch.randn(8), - "2.weight": torch.randn(12, 8, 1, 1), - "2.bias": torch.randn(12), - }, - torch.randn(2, 3, 16, 16), - ), - ( - Sequential(Linear(8, 12), Linear(12, 16)), - { - "0.weight": torch.randn(7, 8), - "0.bias": torch.randn(7), - "1.weight": torch.randn(9, 7), - "1.bias": torch.randn(9), - }, - torch.randn(5, 8), - ), - ], -) -def test_thin_model_from_checkpoint(model, state_dict, test_input): - with pytest.raises(RuntimeError): - model.load_state_dict(state_dict) - - thin_model_from_checkpoint(model, state_dict) - model.load_state_dict(state_dict, strict=True) - assert isinstance(model(test_input), Tensor) - - -@pytest.mark.parametrize( - "tensor,idx", - [ - (torch.rand(1), 0), - (torch.rand(1_000), 123), - (torch.rand(10_000), 4321), - (torch.rand(100_000), 12345), - ], -) -def test_memory_aware_threshold(tensor, idx): - prior_state = os.getenv(MEMORY_BOUNDED) - - dev = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") - tensor = tensor.to(dev) - - os.environ[MEMORY_BOUNDED] = "True" - t1 = memory_aware_threshold(tensor, idx) - os.environ[MEMORY_BOUNDED] = "False" - t2 = memory_aware_threshold(tensor, idx) - assert abs(t1 - t2) < 1e-3 - - if prior_state is not None: - os.environ[MEMORY_BOUNDED] = prior_state - - class TestSanitizeKwargsForModule: @pytest.fixture def module(self): From 433bb2b330c05b26715a89af3e4c1a020bf26b9d Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 19 Mar 2025 10:12:19 -0500 Subject: [PATCH 20/40] awq mappings registry Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/__init__.py | 1 + src/llmcompressor/modifiers/awq/base.py | 63 +------------------ src/llmcompressor/modifiers/awq/mappings.py | 69 +++++++++++++++++++++ 3 files changed, 73 insertions(+), 60 deletions(-) create mode 100644 src/llmcompressor/modifiers/awq/mappings.py diff --git a/src/llmcompressor/modifiers/awq/__init__.py b/src/llmcompressor/modifiers/awq/__init__.py index 8bdc93d14..c68517377 100644 --- a/src/llmcompressor/modifiers/awq/__init__.py +++ b/src/llmcompressor/modifiers/awq/__init__.py @@ -1,3 +1,4 @@ # flake8: noqa from .base import * +from .mappings import * diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 18d2c4541..157f30307 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,5 +1,4 @@ import inspect -from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union import torch @@ -21,65 +20,9 @@ get_parent_by_name, ) -__all__ = ["AWQMapping", "AWQModifier"] +from .mappings import AWQ_MAPPING_REGISTRY, AWQMapping, ResolvedMapping - -@dataclass -class AWQMapping: - """ - Dataclass storing config of activation mappings to smooth - The output activations of smooth_layer are input activations - into the balance_layers - - `AWQMapping`s are resolved into `ResolvedMapping`s, which - retain pointers to the actual `torch.nn.Module`s and additional - metadata at runtime - """ - - smooth_layer: str - balance_layers: list[str] - - -DEFAULT_AWQ_MAPPINGS: list[AWQMapping] = [ - AWQMapping( - "re:.*input_layernorm", - ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], - ), - # TODO this should only be added if v_proj/o_proj shapes match up - # should we check during validation and skip if this is not the case? - AWQMapping("re:.*v_proj", ["re:.*o_proj"]), - AWQMapping( - "re:.*post_attention_layernorm", - ["re:.*gate_proj", "re:.*up_proj"], - ), - AWQMapping( - "re:.*up_proj", - ["re:.*down_proj"], - ), -] - - -@dataclass -class ResolvedMapping: - """ - Dataclass for storing the resolved mappings between an activation layer - and the following weights that must be balanced during smoothing - - :param smooth_name: name of the activation layer - :param smooth_layer: PyTorch module storing the activation layer - :param balance_layers: list of PyTorch modules that smooth_layer feeds into, must be - balanced to offset the smoothing of smooth_layer - :param balance_names: optional list of names of the balance_layers - :param parent: parent module of the balance_layers - :param parent_name: name of the parent module - """ - - smooth_name: str - smooth_layer: Module - balance_layers: List[Module] - balance_names: Optional[List[str]] = None - parent: Optional[Module] = None - parent_name: Optional[str] = None +__all__ = ["AWQModifier"] class AWQModifier(Modifier): @@ -136,7 +79,7 @@ class AWQModifier(Modifier): # Allow arbitrary types because AWQMapping has fields of type torch.nn.Module model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) - mappings: List[AWQMapping] = DEFAULT_AWQ_MAPPINGS + mappings: List[AWQMapping] = AWQ_MAPPING_REGISTRY["Llama"] ignore: List[str] = [] num_calibration_steps: Optional[int] = None group_size: int = 128 diff --git a/src/llmcompressor/modifiers/awq/mappings.py b/src/llmcompressor/modifiers/awq/mappings.py new file mode 100644 index 000000000..9ff615a6b --- /dev/null +++ b/src/llmcompressor/modifiers/awq/mappings.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional + +from torch.nn import Module + +__all__ = ["AWQMapping", "AWQ_MAPPING_REGISTRY"] + + +@dataclass +class AWQMapping: + """ + Dataclass storing config of activation mappings to smooth + The output activations of smooth_layer are input activations + into the balance_layers + + `AWQMapping`s are resolved into `ResolvedMapping`s, which + retain pointers to the actual `torch.nn.Module`s and additional + metadata at runtime + """ + + smooth_layer: str + balance_layers: list[str] + + +AWQ_MAPPING_REGISTRY: Dict[str, list[AWQMapping]] = { + "Llama": [ + AWQMapping( + "re:.*input_layernorm", + ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], + ), + # TODO this should only be added if v_proj/o_proj shapes match up + # should we check during validation and skip if this is not the case? + AWQMapping("re:.*v_proj", ["re:.*o_proj"]), + AWQMapping( + "re:.*post_attention_layernorm", + ["re:.*gate_proj", "re:.*up_proj"], + ), + AWQMapping( + "re:.*up_proj", + ["re:.*down_proj"], + ), + ], + "Qwen": [ + # TODO add Qwen mappings + ], +} + + +@dataclass +class ResolvedMapping: + """ + Dataclass for storing the resolved mappings between an activation layer + and the following weights that must be balanced during smoothing + + :param smooth_name: name of the activation layer + :param smooth_layer: PyTorch module storing the activation layer + :param balance_layers: list of PyTorch modules that smooth_layer feeds into, must be + balanced to offset the smoothing of smooth_layer + :param balance_names: optional list of names of the balance_layers + :param parent: parent module of the balance_layers + :param parent_name: name of the parent module + """ + + smooth_name: str + smooth_layer: Module + balance_layers: List[Module] + balance_names: Optional[List[str]] = None + parent: Optional[Module] = None + parent_name: Optional[str] = None From 2519643c316aa89e6cdebecc9d3a24ffb0534161 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 19 Mar 2025 17:02:24 +0000 Subject: [PATCH 21/40] remove empty_cache calls Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 157f30307..48b04a380 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -243,8 +243,6 @@ def _concat_collected_activations(self): name = mapping.smooth_name self._scales[name] = torch.cat(self._scales[name], dim=0) - torch.cuda.empty_cache() - @torch.no_grad() def _apply_smoothing(self, model: Module): """ @@ -355,9 +353,6 @@ def smooth(module): smooth(layer) smooth(smooth_layer) - # clear out allocated smoothing scales - torch.cuda.empty_cache() - def _compute_best_scale( self, x: torch.Tensor, From 3b9b8139fadb540834054b0d76951edfb854b82a Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 19 Mar 2025 22:09:34 +0000 Subject: [PATCH 22/40] resolve attention module forward missing attention_mask input Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 48b04a380..302494aa8 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -576,11 +576,23 @@ def _sanitize_kwargs(self, inputs_kwargs, module): module (`torch.nn.Module`): Target module to quantize. """ - module_signature = inspect.signature(module.forward).parameters + params = inspect.signature(module.forward).parameters sanitized_kwargs = {} for k, v in inputs_kwargs.items(): - if k in module_signature and k != "use_cache": + if k in params and k != "use_cache": sanitized_kwargs[k] = v + # In case forward pass has optional dependencies that don't default to None. + # This is the case for `LlamaAttention.forward` which has input + # `attention_mask: Optional[torch.Tensor],` (with no `= None` default) + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L269 + for k, v in params.items(): + if ( + getattr(v.annotation, "_name", "") == "Optional" + and k not in sanitized_kwargs + and k != "use_cache" + ): + sanitized_kwargs[k] = None + return sanitized_kwargs From 698b0578709633e311cf89ee3310b6b8dc19e2bc Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 20 Mar 2025 14:14:41 +0000 Subject: [PATCH 23/40] improve order of check for optional kwargs setting to None Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 302494aa8..d2713ff89 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -587,9 +587,9 @@ def _sanitize_kwargs(self, inputs_kwargs, module): # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L269 for k, v in params.items(): if ( - getattr(v.annotation, "_name", "") == "Optional" - and k not in sanitized_kwargs + k not in sanitized_kwargs and k != "use_cache" + and getattr(v.annotation, "_name", "") == "Optional" ): sanitized_kwargs[k] = None From 7b9d85e79aed2dc417e0ece0b389768d111e9665 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Fri, 21 Mar 2025 20:43:06 +0000 Subject: [PATCH 24/40] run awq one shot example Signed-off-by: Brian Dellabetta --- examples/awq/awq_one_shot.py | 167 +++++++++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 examples/awq/awq_one_shot.py diff --git a/examples/awq/awq_one_shot.py b/examples/awq/awq_one_shot.py new file mode 100644 index 000000000..958ae5ec7 --- /dev/null +++ b/examples/awq/awq_one_shot.py @@ -0,0 +1,167 @@ +# local/awq/AWQ/scripts/awq_one_shot.py + +from transformers import AutoTokenizer, AutoModelForCausalLM +import lm_eval + +# MODEL_ID= "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +MODEL_ID = "meta-llama/Llama-2-7b-hf" +# MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" + +# TODO add Qwen mappings +# MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" + +DATASET_ID = "mit-han-lab/pile-val-backup" +DATASET_SPLIT = "validation" +NUM_CALIBRATION_SAMPLES = 256 +MAX_SEQUENCE_LENGTH = 512 + +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) + +### +### 1) LLMCOMPRESSOR quantize +### + + +def run_llmc_awq() -> AutoModelForCausalLM: + OUTPUT_DIR = MODEL_ID.split("/")[-1] + f"-llmc-awq-{NUM_CALIBRATION_SAMPLES}" + from llmcompressor.modifiers.awq import AWQModifier + from llmcompressor.modifiers.smoothquant import SmoothQuantModifier + from llmcompressor.modifiers.quantization import QuantizationModifier + from llmcompressor import oneshot + from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, + QuantizationStrategy, + QuantizationType, + ) + + recipe = [ + AWQModifier(bits=4, apply_clip=False, symmetric=True), + QuantizationModifier( + ignore=["lm_head"], + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.INT, + dynamic=False, + symmetric=True, + # strategy=QuantizationStrategy.CHANNEL, + strategy=QuantizationStrategy.GROUP, + group_size=128, + ), + ) + }, + ), + ] + + model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, device_map="auto", torch_dtype="auto" + ) + + def get_calib_dataset_manual(tokenizer): + + from datasets import load_dataset + + ds = load_dataset( + DATASET_ID, + split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*100}]", + ) + + def preprocess(example): + return { + "input_ids": tokenizer.encode(example["text"].strip())[ + :MAX_SEQUENCE_LENGTH + ] + } + + ds = ( + ds.shuffle(seed=42) + .map(preprocess, remove_columns=ds.column_names) + .filter(lambda example: len(example["input_ids"]) >= MAX_SEQUENCE_LENGTH) + .select(range(NUM_CALIBRATION_SAMPLES)) + ) + + return ds + + oneshot( + model=model, + # dataset = get_calib_dataset(tokenizer=tokenizer), + dataset=get_calib_dataset_manual(tokenizer=tokenizer), + recipe=recipe, + # save_compressed=True, + output_dir=OUTPUT_DIR, + # overwrite_output_dir=True, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + ) + + print("Done! model saved to", OUTPUT_DIR) + + return model, OUTPUT_DIR + + +### +### 2) AUTOAWQ quantize +### + + +def run_auto_awq() -> AutoModelForCausalLM: + OUTPUT_DIR = ( + MODEL_ID.split("/")[-1] + f"-auto-awq-{NUM_CALIBRATION_SAMPLES}-quant-only" + ) + from awq import AutoAWQForCausalLM + + # Load model + model = AutoAWQForCausalLM.from_pretrained(MODEL_ID, device_map="cuda:0") + + # Quantize + model.quantize( + tokenizer, + apply_clip=False, + quant_config={ + "zero_point": True, + "q_group_size": 128, + "w_bit": 4, + "version": "GEMM", + }, + ) + model = model.model.to("cuda:0") + + # Save quantized model + # model.save_quantized(OUTPUT_DIR) + # model=AutoAWQForCausalLM.from_pretrained(OUTPUT_DIR, device_map="cuda:0").model + + return model, OUTPUT_DIR + + +### +### EVAL +### + +import os + +# lm_eval --model vllm is failing for me if using V1 +os.environ["VLLM_USE_V1"] = "0" + +# print("RUNNING AUTOAWQ") +# model, OUTPUT_DIR = run_auto_awq() + +print("RUNNING LLMCAWQ") +model, OUTPUT_DIR = run_llmc_awq() + +results = lm_eval.simple_evaluate( + model="vllm", + model_args={ + # "pretrained": MODEL_ID, + "pretrained": OUTPUT_DIR, + "add_bos_token": True, + "dtype": "bfloat16", + "gpu_memory_utilization": 0.5, + }, + tasks=["wikitext", "gsm8k"], + num_fewshot=5, + batch_size=8, +) +print("DONE", results["results"]) From ab962ce0b386ed2b668bfb6e03d08bc0230732c6 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 26 Mar 2025 21:52:45 +0000 Subject: [PATCH 25/40] clean up awq_one_shot example Signed-off-by: Brian Dellabetta --- examples/awq/awq_one_shot.py | 63 ++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 36 deletions(-) diff --git a/examples/awq/awq_one_shot.py b/examples/awq/awq_one_shot.py index 958ae5ec7..ffdf80164 100644 --- a/examples/awq/awq_one_shot.py +++ b/examples/awq/awq_one_shot.py @@ -1,15 +1,12 @@ -# local/awq/AWQ/scripts/awq_one_shot.py +import os -from transformers import AutoTokenizer, AutoModelForCausalLM import lm_eval +from transformers import AutoModelForCausalLM, AutoTokenizer # MODEL_ID= "TinyLlama/TinyLlama-1.1B-Chat-v1.0" MODEL_ID = "meta-llama/Llama-2-7b-hf" # MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" -# TODO add Qwen mappings -# MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" - DATASET_ID = "mit-han-lab/pile-val-backup" DATASET_SPLIT = "validation" NUM_CALIBRATION_SAMPLES = 256 @@ -17,17 +14,13 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) -### -### 1) LLMCOMPRESSOR quantize -### +# +# 1) LLMCOMPRESSOR quantize +# def run_llmc_awq() -> AutoModelForCausalLM: OUTPUT_DIR = MODEL_ID.split("/")[-1] + f"-llmc-awq-{NUM_CALIBRATION_SAMPLES}" - from llmcompressor.modifiers.awq import AWQModifier - from llmcompressor.modifiers.smoothquant import SmoothQuantModifier - from llmcompressor.modifiers.quantization import QuantizationModifier - from llmcompressor import oneshot from compressed_tensors.quantization import ( QuantizationArgs, QuantizationScheme, @@ -35,8 +28,12 @@ def run_llmc_awq() -> AutoModelForCausalLM: QuantizationType, ) + from llmcompressor import oneshot + from llmcompressor.modifiers.awq import AWQModifier + from llmcompressor.modifiers.quantization import QuantizationModifier + recipe = [ - AWQModifier(bits=4, apply_clip=False, symmetric=True), + AWQModifier(bits=4, apply_clip=False, symmetric=False), QuantizationModifier( ignore=["lm_head"], config_groups={ @@ -46,8 +43,7 @@ def run_llmc_awq() -> AutoModelForCausalLM: num_bits=4, type=QuantizationType.INT, dynamic=False, - symmetric=True, - # strategy=QuantizationStrategy.CHANNEL, + symmetric=False, strategy=QuantizationStrategy.GROUP, group_size=128, ), @@ -60,8 +56,7 @@ def run_llmc_awq() -> AutoModelForCausalLM: MODEL_ID, device_map="auto", torch_dtype="auto" ) - def get_calib_dataset_manual(tokenizer): - + def get_calib_dataset(tokenizer): from datasets import load_dataset ds = load_dataset( @@ -87,12 +82,9 @@ def preprocess(example): oneshot( model=model, - # dataset = get_calib_dataset(tokenizer=tokenizer), - dataset=get_calib_dataset_manual(tokenizer=tokenizer), + dataset=get_calib_dataset(tokenizer=tokenizer), recipe=recipe, - # save_compressed=True, output_dir=OUTPUT_DIR, - # overwrite_output_dir=True, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) @@ -102,9 +94,9 @@ def preprocess(example): return model, OUTPUT_DIR -### -### 2) AUTOAWQ quantize -### +# +# 2) AUTOAWQ quantize +# def run_auto_awq() -> AutoModelForCausalLM: @@ -130,31 +122,30 @@ def run_auto_awq() -> AutoModelForCausalLM: model = model.model.to("cuda:0") # Save quantized model - # model.save_quantized(OUTPUT_DIR) - # model=AutoAWQForCausalLM.from_pretrained(OUTPUT_DIR, device_map="cuda:0").model + model.save_quantized(OUTPUT_DIR) return model, OUTPUT_DIR -### -### EVAL -### - -import os - -# lm_eval --model vllm is failing for me if using V1 -os.environ["VLLM_USE_V1"] = "0" +# +# RUN +# # print("RUNNING AUTOAWQ") # model, OUTPUT_DIR = run_auto_awq() - print("RUNNING LLMCAWQ") model, OUTPUT_DIR = run_llmc_awq() +# +# EVAL +# + +# NOTE: lm_eval --model vllm is failing with vllm==0.8.1 if using V1 +os.environ["VLLM_USE_V1"] = "0" + results = lm_eval.simple_evaluate( model="vllm", model_args={ - # "pretrained": MODEL_ID, "pretrained": OUTPUT_DIR, "add_bos_token": True, "dtype": "bfloat16", From da16def93a5a5aa327461edd4f5e18a84630bda4 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 31 Mar 2025 16:48:08 +0000 Subject: [PATCH 26/40] rename bits to num_bits Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index d2713ff89..af7d02799 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -84,7 +84,7 @@ class AWQModifier(Modifier): num_calibration_steps: Optional[int] = None group_size: int = 128 max_chunk_memory: int = 1024 * 1024 * 1024 - bits: int = 4 + num_bits: int = 4 symmetric: bool = False duo_scaling: bool = True @@ -411,7 +411,7 @@ def _compute_best_scale( _pseudo_quantize_tensor( w=fc.weight.data, symmetric=self.symmetric, - bit_width=self.bits, + bit_width=self.num_bits, group_size=self.group_size, )[0] / _scalesview, From c7e274f0627dcb76844bba5bc143c62a4f048626 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 1 Apr 2025 20:18:55 +0000 Subject: [PATCH 27/40] added TODOs Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 1 + src/llmcompressor/modifiers/awq/mappings.py | 9 ++++----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index af7d02799..070748620 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -25,6 +25,7 @@ __all__ = ["AWQModifier"] +# TODO (Brian INFERENG-531) Add support for offloaded models class AWQModifier(Modifier): """ Implements the AWQ (Activation-Weighted Quantization) algorithm, diff --git a/src/llmcompressor/modifiers/awq/mappings.py b/src/llmcompressor/modifiers/awq/mappings.py index 9ff615a6b..022707019 100644 --- a/src/llmcompressor/modifiers/awq/mappings.py +++ b/src/llmcompressor/modifiers/awq/mappings.py @@ -28,8 +28,8 @@ class AWQMapping: "re:.*input_layernorm", ["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], ), - # TODO this should only be added if v_proj/o_proj shapes match up - # should we check during validation and skip if this is not the case? + # TODO (Brian INFERENG-530) when resolving, only add + # if v_proj/o_proj shapes match up AWQMapping("re:.*v_proj", ["re:.*o_proj"]), AWQMapping( "re:.*post_attention_layernorm", @@ -40,9 +40,8 @@ class AWQMapping: ["re:.*down_proj"], ), ], - "Qwen": [ - # TODO add Qwen mappings - ], + # TODO (Brian INFERENG-529) Add Qwen mappings + # "Qwen": [ ], } From 90c266c3a054a0a08be7ac7a56648da78f06cd85 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 10 Apr 2025 21:11:23 +0000 Subject: [PATCH 28/40] update example file Signed-off-by: Brian Dellabetta --- examples/awq/awq_one_shot.py | 211 ++++++++++++++++------------------- 1 file changed, 95 insertions(+), 116 deletions(-) diff --git a/examples/awq/awq_one_shot.py b/examples/awq/awq_one_shot.py index ffdf80164..a37d5e494 100644 --- a/examples/awq/awq_one_shot.py +++ b/examples/awq/awq_one_shot.py @@ -1,148 +1,127 @@ -import os - import lm_eval +from lm_eval.utils import make_table from transformers import AutoModelForCausalLM, AutoTokenizer +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationScheme, + QuantizationStrategy, + QuantizationType, +) +from llmcompressor import oneshot +from llmcompressor.modifiers.awq import AWQModifier +from llmcompressor.modifiers.quantization import QuantizationModifier -# MODEL_ID= "TinyLlama/TinyLlama-1.1B-Chat-v1.0" -MODEL_ID = "meta-llama/Llama-2-7b-hf" -# MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" +# This example demonstrates how to: +# 1) Run the `llm-compressor` implementation of AWQ +# 2) Compare it against the original AutoAWQ implementation available +# at https://github.com/casper-hansen/AutoAWQ +# 3) Evaluate the compressed model with the lm_eval framework +MODEL_ID = "meta-llama/Llama-2-7b-hf" DATASET_ID = "mit-han-lab/pile-val-backup" DATASET_SPLIT = "validation" NUM_CALIBRATION_SAMPLES = 256 MAX_SEQUENCE_LENGTH = 512 - -tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) +OUTPUT_DIR = MODEL_ID.split("/")[-1] + f"-awq-{NUM_CALIBRATION_SAMPLES}" # -# 1) LLMCOMPRESSOR quantize +# 1) Run LLM Compressor AWQ implementation # +recipe = [ + AWQModifier(bits=4, apply_clip=False, symmetric=False), + QuantizationModifier( + ignore=["lm_head"], + config_groups={ + "group_0": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.INT, + dynamic=False, + symmetric=False, + strategy=QuantizationStrategy.GROUP, + group_size=128, + ), + ) + }, + ), +] -def run_llmc_awq() -> AutoModelForCausalLM: - OUTPUT_DIR = MODEL_ID.split("/")[-1] + f"-llmc-awq-{NUM_CALIBRATION_SAMPLES}" - from compressed_tensors.quantization import ( - QuantizationArgs, - QuantizationScheme, - QuantizationStrategy, - QuantizationType, - ) - - from llmcompressor import oneshot - from llmcompressor.modifiers.awq import AWQModifier - from llmcompressor.modifiers.quantization import QuantizationModifier - - recipe = [ - AWQModifier(bits=4, apply_clip=False, symmetric=False), - QuantizationModifier( - ignore=["lm_head"], - config_groups={ - "group_0": QuantizationScheme( - targets=["Linear"], - weights=QuantizationArgs( - num_bits=4, - type=QuantizationType.INT, - dynamic=False, - symmetric=False, - strategy=QuantizationStrategy.GROUP, - group_size=128, - ), - ) - }, - ), - ] - - model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, device_map="auto", torch_dtype="auto" - ) +model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, device_map="auto", torch_dtype="auto" +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) - def get_calib_dataset(tokenizer): - from datasets import load_dataset - - ds = load_dataset( - DATASET_ID, - split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*100}]", - ) - - def preprocess(example): - return { - "input_ids": tokenizer.encode(example["text"].strip())[ - :MAX_SEQUENCE_LENGTH - ] - } - - ds = ( - ds.shuffle(seed=42) - .map(preprocess, remove_columns=ds.column_names) - .filter(lambda example: len(example["input_ids"]) >= MAX_SEQUENCE_LENGTH) - .select(range(NUM_CALIBRATION_SAMPLES)) - ) - - return ds - - oneshot( - model=model, - dataset=get_calib_dataset(tokenizer=tokenizer), - recipe=recipe, - output_dir=OUTPUT_DIR, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - ) - print("Done! model saved to", OUTPUT_DIR) +def get_calib_dataset(tokenizer): + from datasets import load_dataset - return model, OUTPUT_DIR + ds = load_dataset( + DATASET_ID, + split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES*100}]", + ) + def preprocess(example): + return { + "input_ids": tokenizer.encode(example["text"].strip())[:MAX_SEQUENCE_LENGTH] + } -# -# 2) AUTOAWQ quantize -# + ds = ( + ds.shuffle(seed=42) + .map(preprocess, remove_columns=ds.column_names) + .filter(lambda example: len(example["input_ids"]) >= MAX_SEQUENCE_LENGTH) + .select(range(NUM_CALIBRATION_SAMPLES)) + ) + return ds -def run_auto_awq() -> AutoModelForCausalLM: - OUTPUT_DIR = ( - MODEL_ID.split("/")[-1] + f"-auto-awq-{NUM_CALIBRATION_SAMPLES}-quant-only" - ) - from awq import AutoAWQForCausalLM - - # Load model - model = AutoAWQForCausalLM.from_pretrained(MODEL_ID, device_map="cuda:0") - - # Quantize - model.quantize( - tokenizer, - apply_clip=False, - quant_config={ - "zero_point": True, - "q_group_size": 128, - "w_bit": 4, - "version": "GEMM", - }, - ) - model = model.model.to("cuda:0") - # Save quantized model - model.save_quantized(OUTPUT_DIR) +oneshot( + model=model, + dataset=get_calib_dataset(tokenizer=tokenizer), + recipe=recipe, + output_dir=OUTPUT_DIR, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, +) - return model, OUTPUT_DIR +print("Done! model saved to", OUTPUT_DIR) # -# RUN +# 2) Or run original AutoAWQ implementation (requires `pip install autoawq`) # +# OUTPUT_DIR = ( +# MODEL_ID.split("/")[-1] + f"-auto-awq-{NUM_CALIBRATION_SAMPLES}-quant-only" +# ) +# from awq import AutoAWQForCausalLM + +# # Load model +# model = AutoAWQForCausalLM.from_pretrained(MODEL_ID, device_map="cuda:0") +# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) + +# # Quantize +# model.quantize( +# tokenizer, +# apply_clip=False, +# quant_config={ +# "zero_point": True, +# "q_group_size": 128, +# "w_bit": 4, +# "version": "GEMM", +# }, +# ) +# model = model.model.to("cuda:0") + +# # Save quantized model +# model.save_quantized(OUTPUT_DIR) -# print("RUNNING AUTOAWQ") -# model, OUTPUT_DIR = run_auto_awq() -print("RUNNING LLMCAWQ") -model, OUTPUT_DIR = run_llmc_awq() # -# EVAL +# 3) Evaluate model on wikitext perplexity # -# NOTE: lm_eval --model vllm is failing with vllm==0.8.1 if using V1 -os.environ["VLLM_USE_V1"] = "0" - results = lm_eval.simple_evaluate( model="vllm", model_args={ @@ -151,8 +130,8 @@ def run_auto_awq() -> AutoModelForCausalLM: "dtype": "bfloat16", "gpu_memory_utilization": 0.5, }, - tasks=["wikitext", "gsm8k"], + tasks=["wikitext"], num_fewshot=5, batch_size=8, ) -print("DONE", results["results"]) +print(make_table(results)) From 7306de76f3994f3f6e71ef4fa924c6fc13eb58ee Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 10 Apr 2025 21:51:45 +0000 Subject: [PATCH 29/40] revise get_parent_by_name test Signed-off-by: Brian Dellabetta --- .../utils/pytorch/test_module.py | 63 +++++++++++-------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/tests/llmcompressor/utils/pytorch/test_module.py b/tests/llmcompressor/utils/pytorch/test_module.py index 4600377fa..b91906cd4 100644 --- a/tests/llmcompressor/utils/pytorch/test_module.py +++ b/tests/llmcompressor/utils/pytorch/test_module.py @@ -1,31 +1,44 @@ -import unittest +import pytest import torch.nn as nn from llmcompressor.utils.pytorch import get_parent_by_name -class TestGetParentByName(unittest.TestCase): - def setUp(self): - self.model = nn.Sequential( - nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 10), nn.Softmax(dim=1) - ) - - def test_get_parent_by_name(self): - # Test getting the parent of a non-existent layer - with self.assertRaises(ValueError): - get_parent_by_name("non_existent_layer", self.model) - - # Test getting the parent of the first layer - name, parent = get_parent_by_name("0", self.model) - self.assertEqual(parent, self.model) - - # Test getting the parent of a nested layer - nested_model = nn.Sequential( - nn.Linear(10, 20), - nn.Sequential(nn.ReLU(), nn.Linear(20, 10)), - nn.Softmax(dim=1), - ) - name, parent = get_parent_by_name("1.1", nested_model) - self.assertEqual(parent, nested_model[1]) - self.assertEqual(name, "1") +@pytest.fixture +def example_nested_module() -> str: + return nn.Sequential( + nn.Linear(10, 20), + nn.Sequential(nn.ReLU(), nn.Linear(20, 10)), + nn.Sequential(nn.SiLU(), nn.Linear(20, 10)), + nn.Softmax(dim=1), + ) + + +@pytest.mark.unit +def test_get_parent_by_name(example_nested_module): + + # Test getting the parent of the first layer + name, parent = get_parent_by_name("0", example_nested_module) + assert parent == example_nested_module + + # Test getting the parent of a nested layer + name, parent = get_parent_by_name("1.0", example_nested_module) + assert parent == example_nested_module[1] + assert name == "1" + + name, parent = get_parent_by_name("1.1", example_nested_module) + assert parent == example_nested_module[1] + assert name == "1" + + name, parent = get_parent_by_name("2.0", example_nested_module) + assert parent == example_nested_module[2] + assert name == "2" + + name, parent = get_parent_by_name("2.1", example_nested_module) + assert parent == example_nested_module[2] + assert name == "2" + + # Test getting the parent of a non-existent layer + with pytest.raises(ValueError): + get_parent_by_name("non_existent_layer", example_nested_module) From 99cf589afdc01a0760109032a62157a99ac3a07b Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 10 Apr 2025 22:23:19 +0000 Subject: [PATCH 30/40] revert smoothquant changes Signed-off-by: Brian Dellabetta --- .../modifiers/smoothquant/base.py | 82 ++++++++++--------- 1 file changed, 45 insertions(+), 37 deletions(-) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index 037fe1219..aa3317198 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -2,9 +2,8 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import torch -from accelerate.utils import align_module_device +from compressed_tensors.utils.offload import is_module_offloaded from loguru import logger -from pydantic import ConfigDict from torch.nn import Module from llmcompressor.core import State @@ -100,17 +99,14 @@ class SmoothQuantModifier(Modifier): to use the default tensor_module_forward """ - # Allow arbitrary types because AWQMapping has field of type torch.nn.Module - model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) - smoothing_strength: float = 0.5 mappings: Optional[List[Union[Tuple, List]]] = None ignore: Optional[List[str]] = None num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None - _resolved_mappings: Optional[List[SmoothQuantMapping]] = None - _scales: Optional[Dict] = None + resolved_mappings_: Optional[List] = None + scales_: Optional[Dict] = None def on_initialize(self, state: State, **kwargs) -> bool: """ @@ -132,8 +128,8 @@ def on_initialize(self, state: State, **kwargs) -> bool: self.ignore = [] if not self.ignore else self.ignore self.mappings = self._infer_mappings_from_model(state.model) - self._resolved_mappings = self._resolve_mappings(state.model) - self._scales = {} + self.resolved_mappings_ = self._resolve_mappings(state.model) + self.scales_ = {} calibration_dataloader = state.data.calib @@ -150,10 +146,10 @@ def on_finalize(self, state: State, **kwargs) -> bool: :param state: unused :return: True """ - if self._scales is not None: - self._scales.clear() - if self._resolved_mappings is not None: - self._resolved_mappings.clear() + if self.scales_ is not None: + self.scales_.clear() + if self.resolved_mappings_ is not None: + self.resolved_mappings_.clear() return True @@ -170,7 +166,7 @@ def _infer_mappings_from_model( ) @handle_mapping_resolution_errors - def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]: + def _resolve_mappings(self, model: Module) -> List: """ Transforms the list of activations to smooth and their corresponding weights into SmoothQuantMapping objects, resolving regular expressions. @@ -219,21 +215,21 @@ def hook_fn(module, inp, out): latest_mins = torch.min(out, dim=0)[0] latest_maxes = torch.max(out, dim=0)[0] - if layer_name in self._scales: - self._scales[layer_name].min_channel_vals = torch.minimum( - self._scales[layer_name].min_channel_vals, latest_mins + if layer_name in self.scales_: + self.scales_[layer_name].min_channel_vals = torch.minimum( + self.scales_[layer_name].min_channel_vals, latest_mins ) - self._scales[layer_name].max_channel_vals = torch.maximum( - self._scales[layer_name].max_channel_vals, latest_maxes + self.scales_[layer_name].max_channel_vals = torch.maximum( + self.scales_[layer_name].max_channel_vals, latest_maxes ) else: - self._scales[layer_name] = SmoothQuantScale( + self.scales_[layer_name] = SmoothQuantScale( min_channel_vals=latest_mins, max_channel_vals=latest_maxes ) return hook_fn - for mapping in self._resolved_mappings: + for mapping in self.resolved_mappings_: name = mapping.smooth_name layer = mapping.smooth_layer self.register_hook(layer, create_hook_fn(name), "forward") @@ -278,10 +274,10 @@ def _apply_smoothing(self, model: Module): This modifies the weights of the model in-place. """ logger.info("Smoothing activation scales...") - for mapping in self._resolved_mappings: + for mapping in self.resolved_mappings_: activation_scales = ( # get dynamic range for each activation channel - self._scales[mapping.smooth_name].max_channel_vals - - self._scales[mapping.smooth_name].min_channel_vals + self.scales_[mapping.smooth_name].max_channel_vals + - self.scales_[mapping.smooth_name].min_channel_vals ) smooth_layer = mapping.smooth_layer balance_layers = mapping.balance_layers @@ -293,16 +289,22 @@ def _apply_smoothing(self, model: Module): @torch.no_grad() def smooth(module): - with align_module_device(module): - if module in balance_layers: - module.weight.mul_(scales.view(1, -1)) - elif module == smooth_layer: - if module.weight.ndim == 1: - module.weight.div_(scales) - else: - module.weight.div_(scales.view(-1, 1)) - if hasattr(module, "bias") and module.bias is not None: - module.bias.div_(scales) + offloaded = is_module_offloaded(module) + if offloaded: + module._hf_hook.pre_forward(module) + + if module in balance_layers: + module.weight.mul_(scales.view(1, -1)) + elif module == smooth_layer: + if module.weight.ndim == 1: + module.weight.div_(scales) + else: + module.weight.div_(scales.view(-1, 1)) + if hasattr(module, "bias") and module.bias is not None: + module.bias.div_(scales) + + if offloaded: + module._hf_hook.post_forward(module, None) parent = get_fsdp_parent(mapping.smooth_name, model) if parent is not None: @@ -327,9 +329,15 @@ def _calculate_smoothing_scales( # get the channel-wise dynamic range for each layer to be balanced weight_scales = [] for layer in balance_layers: - with align_module_device(layer): - scale = layer.weight.abs().max(dim=0, keepdim=True)[0] - weight_scales.append(scale) + offloaded = is_module_offloaded(layer) + if offloaded: + layer._hf_hook.pre_forward(layer) + + scale = layer.weight.abs().max(dim=0, keepdim=True)[0] + weight_scales.append(scale) + + if offloaded: + layer._hf_hook.post_forward(layer, None) weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0] From 64690a8f4dc943ce99ed7d5d0854c3265ce76175 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 10 Apr 2025 22:34:12 +0000 Subject: [PATCH 31/40] revert smoothquant changes Signed-off-by: Brian Dellabetta --- .../modifiers/logarithmic_equalization/test_pytorch.py | 6 +++--- .../pytorch/modifiers/smoothquant/test_pytorch.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py index e84f66e83..d485c0637 100644 --- a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py @@ -21,11 +21,11 @@ def test_successful_map(self): modifier = LogarithmicEqualizationModifier(mappings=mappings) modifier.ignore = [] - modifier._resolved_mappings = modifier._resolve_mappings(self.state.model) + modifier.resolved_mappings_ = modifier._resolve_mappings(self.state.model) - self.assertEqual(len(modifier._resolved_mappings), len(mappings)) + self.assertEqual(len(modifier.resolved_mappings_), len(mappings)) - mapping = modifier._resolved_mappings[0] + mapping = modifier.resolved_mappings_[0] self.assertEqual(mapping.smooth_name, mappings[0][1]) self.assertIsInstance(mapping.smooth_layer, Linear) self.assertIsInstance(mapping.balance_layers[0], Linear) diff --git a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py index cbb60f030..7977c4546 100644 --- a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py @@ -19,11 +19,11 @@ def test_successful_map(self): modifier = SmoothQuantModifier(mappings=mappings) modifier.ignore = [] - modifier._resolved_mappings = modifier._resolve_mappings(self.state.model) + modifier.resolved_mappings_ = modifier._resolve_mappings(self.state.model) - self.assertEqual(len(modifier._resolved_mappings), len(mappings)) + self.assertEqual(len(modifier.resolved_mappings_), len(mappings)) - mapping = modifier._resolved_mappings[0] + mapping = modifier.resolved_mappings_[0] self.assertEqual(mapping.smooth_name, mappings[0][1]) self.assertIsInstance(mapping.smooth_layer, Linear) self.assertIsInstance(mapping.balance_layers[0], Linear) From cb6f840fbd7ab4d5abe1c60674263bd7b73b357c Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 10 Apr 2025 22:42:48 +0000 Subject: [PATCH 32/40] sanitize_kwargs cleanup Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 65 +++++++++++++------------ 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 070748620..f303fa546 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -305,7 +305,7 @@ def _apply_smoothing(self, model: Module): fp16_output = self._forward_input_with_kwargs( module=module2inspect, inputs=inp, - input_kwargs=self._sanitize_kwargs(self._module_kwargs, module2inspect), + input_kwargs=_sanitize_kwargs(self._module_kwargs, module2inspect), ) fp16_output = fp16_output.clip( torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max @@ -558,43 +558,44 @@ def _forward_input_with_kwargs( :return: the first output tensor from the forward pass """ kwargs = input_kwargs or self._module_kwargs - kwargs = self._sanitize_kwargs(kwargs, module) + kwargs = _sanitize_kwargs(kwargs, module) return tensor_forward_with_input_args( module=module, inputs=inputs, input_kwargs=kwargs, )[0] - def _sanitize_kwargs(self, inputs_kwargs, module): - """ - Remove the arguments that are not supported in the module's - forward pass to avoid breaking behaviour between different versions - of transformers. - - Args: - inputs_kwargs (`dict`): - The input dictionary to pass to the model layer - module (`torch.nn.Module`): - Target module to quantize. - """ - params = inspect.signature(module.forward).parameters - sanitized_kwargs = {} - for k, v in inputs_kwargs.items(): - if k in params and k != "use_cache": - sanitized_kwargs[k] = v - # In case forward pass has optional dependencies that don't default to None. - # This is the case for `LlamaAttention.forward` which has input - # `attention_mask: Optional[torch.Tensor],` (with no `= None` default) - # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L269 - for k, v in params.items(): - if ( - k not in sanitized_kwargs - and k != "use_cache" - and getattr(v.annotation, "_name", "") == "Optional" - ): - sanitized_kwargs[k] = None - - return sanitized_kwargs + +def _sanitize_kwargs(inputs_kwargs, module): + """ + Remove the arguments that are not supported in the module's + forward pass to avoid breaking behaviour between different versions + of transformers. + + Args: + inputs_kwargs (`dict`): + The input dictionary to pass to the model layer + module (`torch.nn.Module`): + Target module to quantize. + """ + params = inspect.signature(module.forward).parameters + sanitized_kwargs = {} + for k, v in inputs_kwargs.items(): + if k in params and k != "use_cache": + sanitized_kwargs[k] = v + # In case forward pass has optional dependencies that don't default to None. + # This is the case for `LlamaAttention.forward` which has input + # `attention_mask: Optional[torch.Tensor],` (with no `= None` default) + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L269 + for k, v in params.items(): + if ( + k not in sanitized_kwargs + and k != "use_cache" + and getattr(v.annotation, "_name", "") == "Optional" + ): + sanitized_kwargs[k] = None + + return sanitized_kwargs def _pseudo_quantize_tensor( From 018b255a6d39adc1b94d54eb6e536c44b8fb16bc Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 14 Apr 2025 19:57:20 +0000 Subject: [PATCH 33/40] remove deprecated AWQModifier apply_clip Signed-off-by: Brian Dellabetta --- examples/awq/awq_one_shot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/awq/awq_one_shot.py b/examples/awq/awq_one_shot.py index a37d5e494..e8b380401 100644 --- a/examples/awq/awq_one_shot.py +++ b/examples/awq/awq_one_shot.py @@ -29,7 +29,7 @@ # recipe = [ - AWQModifier(bits=4, apply_clip=False, symmetric=False), + AWQModifier(bits=4, symmetric=False), QuantizationModifier( ignore=["lm_head"], config_groups={ From 39a4745de02e2e97f8a301e333a65fa40abb2556 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 14 Apr 2025 20:17:52 +0000 Subject: [PATCH 34/40] PR revision Signed-off-by: Brian Dellabetta --- .../llmcompressor/modifiers/awq/test_base.py | 28 +++++++------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 3fff33e23..682c2fc83 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -1,5 +1,3 @@ -import unittest - import pytest from llmcompressor.modifiers.awq import AWQModifier @@ -8,21 +6,15 @@ @pytest.mark.unit -class TestAWQIsRegistered(unittest.TestCase): - def setUp(self): - self.kwargs = {} - setup_modifier_factory() +class test_awq_is_registered: + """Ensure AWQModifier is registered in ModifierFactory""" + + setup_modifier_factory() - def test_awq_is_registered(self): - modifier = ModifierFactory.create( - type_="AWQModifier", - allow_experimental=False, - allow_registered=True, - **self.kwargs, - ) + modifier = ModifierFactory.create( + type_="AWQModifier", + allow_experimental=False, + allow_registered=True, + ) - self.assertIsInstance( - modifier, - AWQModifier, - "AWQModifier not registered", - ) + assert isinstance(modifier, AWQModifier), "AWQModifier not registered" From 58c396880ee83a8f1eec50bb39a213919e2b37e3 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Mon, 14 Apr 2025 22:48:10 +0000 Subject: [PATCH 35/40] add lifecycle to docstring Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index f303fa546..2eb70a15f 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -56,6 +56,21 @@ class AWQModifier(Modifier): ignore: ["model.decoder.final_layer_norm"] ``` + Lifecycle: + - on_initialize + - resolve mappings + - capture input activations to balance layers + - register hook to capture inputs and offload to cpu + - run calibration dataset through, to capture inputs + - clear hooks + - concatenate activations across all batches + - apply smooothing + - find best smoothing scale for each smoothing layer + - apply + - move to next smoothing layer + - on_finalize + - clear resolved mappings and captured activations + :param mappings: list activation layers to smooth, and which layers to scale the output such that activations are smoothed. Each entry of the mapping list should be a list itself, in which the first @@ -66,15 +81,12 @@ class AWQModifier(Modifier): :param ignore: list of layers to ignore, even if they match a regex in mappings. It should match the name of layers whose outputs are scaled to achieve smoothing (the second entry of the mappings list). - :param num_calibration_steps: number of samples to use for calibration, or None to - use the whole dataset :param group_size: number of weights to group together for scaling :param max_chunk_memory: maximum memory to use for each chunk of input activations :param bits: number of bits to quantize the weights to :param symmetric: whether to use symmetric quantization :param duo_scaling: whether to use duo scaling, which uses both input activations and weights to determine the scaling factor - :param apply_clip: whether to apply clipping to the weights after scaling """ # Allow arbitrary types because AWQMapping has fields of type torch.nn.Module @@ -82,7 +94,6 @@ class AWQModifier(Modifier): mappings: List[AWQMapping] = AWQ_MAPPING_REGISTRY["Llama"] ignore: List[str] = [] - num_calibration_steps: Optional[int] = None group_size: int = 128 max_chunk_memory: int = 1024 * 1024 * 1024 num_bits: int = 4 From 2980c057177a910aaa2294539aa5a3cb0a19039b Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 15 Apr 2025 13:38:07 +0000 Subject: [PATCH 36/40] update docstring Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 2eb70a15f..dba4c5f78 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -33,11 +33,11 @@ class AWQModifier(Modifier): significantly reduces quantization error by protecting only 1% of the most salient weight channels. - Instead of focusing on the weight values directly, AWQ identifies - salient channels based on the activation distribution. - To further minimize quantization error, the algorithm scales up these - salient channels using an equivalent transformation. The scaling factor - is determined offline by collecting activation statistics + Instead of relying on raw weight values, AWQ identifies important channels by + analyzing activation patterns, focusing on the channels in the weight tensor that + are most responsive to the input. To reduce quantization error, it scales these + channels in a way that preserves the model's original behavior, using scaling + factors computed offline from activation statistics. Because this modifier manipulates the weights of the model, it can only be used in in one-shot and not during training. Activation ranges are determined by running a From f7ece22b0b036289aa0d2c88b9e8f3d12b80ced7 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 15 Apr 2025 13:52:02 +0000 Subject: [PATCH 37/40] remove comment Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index dba4c5f78..b79231784 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -507,7 +507,6 @@ def _set_module_kwargs(self, model, dataloader) -> None: best_device = "cuda" modules[0] = modules[0].to(best_device) - # self.awq_model.move_embed(self.model, best_device) # get input and kwargs to layer 0 # with_kwargs is only supported in PyTorch 2.0 From 6e7468b55c0bd849be66e0a0a9af15d1b448d799 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Wed, 16 Apr 2025 17:05:19 +0000 Subject: [PATCH 38/40] rearrange so it's clear when hooks are removed Signed-off-by: Brian Dellabetta --- src/llmcompressor/modifiers/awq/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index b79231784..d5ae65b64 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -120,6 +120,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: self._set_module_kwargs(state.model, calibration_dataloader) self._setup_scale_hooks() self._calibrate(state.model, calibration_dataloader) + self.remove_hooks() self._concat_collected_activations() self._apply_smoothing(state.model) @@ -240,9 +241,6 @@ def _calibrate(self, model: Module, calibration_dataloader: List): self.num_calibration_steps, ) - # remove the hooks now that we are done calibrating - self.remove_hooks() - def _concat_collected_activations(self): """ Concatenate the collected activation values from each forward pass into a single From 3beba8906d471ac4728d82866d873681126d3834 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Thu, 17 Apr 2025 20:39:18 +0000 Subject: [PATCH 39/40] style fixes Signed-off-by: Brian Dellabetta --- examples/awq/awq_one_shot.py | 5 +++-- src/llmcompressor/utils/pytorch/module.py | 1 + tests/llmcompressor/utils/pytorch/test_module.py | 2 -- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/awq/awq_one_shot.py b/examples/awq/awq_one_shot.py index e8b380401..e66287fb3 100644 --- a/examples/awq/awq_one_shot.py +++ b/examples/awq/awq_one_shot.py @@ -1,12 +1,13 @@ import lm_eval -from lm_eval.utils import make_table -from transformers import AutoModelForCausalLM, AutoTokenizer from compressed_tensors.quantization import ( QuantizationArgs, QuantizationScheme, QuantizationStrategy, QuantizationType, ) +from lm_eval.utils import make_table +from transformers import AutoModelForCausalLM, AutoTokenizer + from llmcompressor import oneshot from llmcompressor.modifiers.awq import AWQModifier from llmcompressor.modifiers.quantization import QuantizationModifier diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index a70a847b1..1bb3e3f70 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -343,6 +343,7 @@ def get_no_split_params(model: PreTrainedModel) -> Union[str, List[str]]: return no_split_modules + def get_parent_by_name(layer_name: str, model: Module) -> Tuple[str, Module]: """ Get the parent layer of a layer by name. diff --git a/tests/llmcompressor/utils/pytorch/test_module.py b/tests/llmcompressor/utils/pytorch/test_module.py index b91906cd4..22763aba9 100644 --- a/tests/llmcompressor/utils/pytorch/test_module.py +++ b/tests/llmcompressor/utils/pytorch/test_module.py @@ -1,5 +1,4 @@ import pytest - import torch.nn as nn from llmcompressor.utils.pytorch import get_parent_by_name @@ -17,7 +16,6 @@ def example_nested_module() -> str: @pytest.mark.unit def test_get_parent_by_name(example_nested_module): - # Test getting the parent of the first layer name, parent = get_parent_by_name("0", example_nested_module) assert parent == example_nested_module From d1d37663b6c33944c50345575ac763449fd55fa9 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Fri, 18 Apr 2025 19:02:41 +0000 Subject: [PATCH 40/40] revisions from codereview Signed-off-by: Brian Dellabetta --- examples/awq/awq_one_shot.py | 42 +++---------------------- src/llmcompressor/modifiers/awq/base.py | 17 +++++----- 2 files changed, 14 insertions(+), 45 deletions(-) diff --git a/examples/awq/awq_one_shot.py b/examples/awq/awq_one_shot.py index e66287fb3..cbe2d77d1 100644 --- a/examples/awq/awq_one_shot.py +++ b/examples/awq/awq_one_shot.py @@ -14,16 +14,14 @@ # This example demonstrates how to: # 1) Run the `llm-compressor` implementation of AWQ -# 2) Compare it against the original AutoAWQ implementation available -# at https://github.com/casper-hansen/AutoAWQ -# 3) Evaluate the compressed model with the lm_eval framework +# 2) Evaluate the compressed model with the lm_eval framework -MODEL_ID = "meta-llama/Llama-2-7b-hf" +MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct" DATASET_ID = "mit-han-lab/pile-val-backup" DATASET_SPLIT = "validation" NUM_CALIBRATION_SAMPLES = 256 MAX_SEQUENCE_LENGTH = 512 -OUTPUT_DIR = MODEL_ID.split("/")[-1] + f"-awq-{NUM_CALIBRATION_SAMPLES}" +OUTPUT_DIR = MODEL_ID.split("/")[-1] + "-awq-asym" # # 1) Run LLM Compressor AWQ implementation @@ -89,38 +87,8 @@ def preprocess(example): print("Done! model saved to", OUTPUT_DIR) - -# -# 2) Or run original AutoAWQ implementation (requires `pip install autoawq`) -# -# OUTPUT_DIR = ( -# MODEL_ID.split("/")[-1] + f"-auto-awq-{NUM_CALIBRATION_SAMPLES}-quant-only" -# ) -# from awq import AutoAWQForCausalLM - -# # Load model -# model = AutoAWQForCausalLM.from_pretrained(MODEL_ID, device_map="cuda:0") -# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) - -# # Quantize -# model.quantize( -# tokenizer, -# apply_clip=False, -# quant_config={ -# "zero_point": True, -# "q_group_size": 128, -# "w_bit": 4, -# "version": "GEMM", -# }, -# ) -# model = model.model.to("cuda:0") - -# # Save quantized model -# model.save_quantized(OUTPUT_DIR) - - # -# 3) Evaluate model on wikitext perplexity +# 2) Evaluate model on wikitext perplexity # results = lm_eval.simple_evaluate( @@ -133,6 +101,6 @@ def preprocess(example): }, tasks=["wikitext"], num_fewshot=5, - batch_size=8, + batch_size="auto", ) print(make_table(results)) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index d5ae65b64..f196d9358 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -59,6 +59,7 @@ class AWQModifier(Modifier): Lifecycle: - on_initialize - resolve mappings + - capture kwargs needed for forward passes into modules - capture input activations to balance layers - register hook to capture inputs and offload to cpu - run calibration dataset through, to capture inputs @@ -114,14 +115,16 @@ def on_initialize(self, state: State, **kwargs) -> bool: self._set_resolved_mappings(state.model) - calibration_dataloader = state.data.calib + with calibration_forward_context(state.model): + self._set_module_kwargs(state.model, state.data.calib) + + self._setup_scale_hooks() + with calibration_forward_context(state.model): + self._calibrate(state.model, state.data.calib) + self.remove_hooks() + self._concat_collected_activations() with calibration_forward_context(state.model): - self._set_module_kwargs(state.model, calibration_dataloader) - self._setup_scale_hooks() - self._calibrate(state.model, calibration_dataloader) - self.remove_hooks() - self._concat_collected_activations() self._apply_smoothing(state.model) return True @@ -234,11 +237,9 @@ def _calibrate(self, model: Module, calibration_dataloader: List): " CompressionSession to run the AWQ modifier" ) - # with calibration_forward_context(model): run_calibration_forward( model, calibration_dataloader, - self.num_calibration_steps, ) def _concat_collected_activations(self):