From b9bea3c7bd08480c6d38b5a7f8908cfc268f203d Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 7 Feb 2025 19:13:13 +0000 Subject: [PATCH 01/13] WIP Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/obcq/base.py | 5 +- .../modifiers/obcq/sgpt_mixin.py | 77 +++++++++++-------- .../modifiers/pruning/wanda/base.py | 7 +- 3 files changed, 54 insertions(+), 35 deletions(-) diff --git a/src/llmcompressor/modifiers/obcq/base.py b/src/llmcompressor/modifiers/obcq/base.py index 3941563069..1329f17de8 100644 --- a/src/llmcompressor/modifiers/obcq/base.py +++ b/src/llmcompressor/modifiers/obcq/base.py @@ -8,7 +8,7 @@ update_offload_parameter, ) from loguru import logger -from pydantic import PrivateAttr +from pydantic import Field, PrivateAttr from llmcompressor.core import State from llmcompressor.modifiers import Modifier @@ -85,7 +85,8 @@ class SparseGPTModifier(SparsityModifierMixin, Modifier): # data pipeline arguments sequential_update: Optional[bool] = False # deprecated sequential_targets: Union[str, List[str], None] = None - targets: Union[str, List[str], None] = None # alias sequential_targets + targets: Union[str, List[str]] = ["Linear"] + ignore: List[str] = Field(default_factory=list) # private variables _prune_n: Optional[int] = PrivateAttr(default=None) diff --git a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py index c3cf585fc7..721ef3fd4a 100644 --- a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py +++ b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py @@ -1,4 +1,5 @@ import warnings +from abc import abstractmethod from collections import defaultdict from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union @@ -6,7 +7,7 @@ import numpy import torch from loguru import logger -from pydantic import Field, field_validator, model_validator +from pydantic import Field, PrivateAttr, field_validator, model_validator from llmcompressor.core import State from llmcompressor.modifiers import Modifier @@ -19,7 +20,7 @@ from llmcompressor.utils.pytorch.module import ( get_layers, get_no_split_params, - get_prunable_layers, + match_layers_params, ) @@ -34,9 +35,15 @@ class SparsityModifierMixin(HooksMixin): # data pipeline arguments sequential_update: Optional[bool] = False # deprecated sequential_targets: Union[str, List[str], None] = None - targets: Union[str, List[str], None] = None # alias sequential_targets + targets: Union[str, List[str]] = ["Linear"] ignore: List[str] = Field(default_factory=list) + # private variables + _prune_n: Optional[int] = PrivateAttr(default=None) + _prune_m: Optional[int] = PrivateAttr(default=None) + _module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) + _module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) + @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: if not value: @@ -62,14 +69,12 @@ def validate_sparsity_profile(cls, value: Optional[str]) -> bool: return value @model_validator(mode="after") - def validate_model_after(model: "Modifier") -> "Modifier": + def validate_model_after(model: "SparsityModifierMixin") -> "Modifier": sparsity = model.sparsity profile = model.sparsity_profile owl_m = model.owl_m owl_lmbda = model.owl_lmbda mask_structure = model.mask_structure - targets = model.targets - sequential_targets = model.sequential_targets if profile == "owl" and ((owl_m is not None) ^ (owl_lmbda is not None)): raise ValueError("Must provide both `owl_m` and `owl_lmbda` or neither") @@ -80,27 +85,31 @@ def validate_model_after(model: "Modifier") -> "Modifier": if owl_m is not None and sparsity is not None: raise ValueError("Cannot provide both sparsity and owl parameters") - if targets is not None: - if sequential_targets is not None: - raise ValueError("Cannot use both `targets` and `sequential_targets`") - model.sequential_targets = targets - model.targets = None - model._prune_n, model._prune_m = model._split_mask_structure(mask_structure) return model + @abstractmethod + def calibrate_module( + self, + module: torch.nn.Module, + args: Tuple[torch.Tensor, ...], + _output: torch.Tensor, + ): + raise NotImplementedError() + def on_initialize(self, state: "State", **kwargs) -> bool: """ Initialize and run the OBCQ algorithm on the current state :param state: session state storing input model and calibration data """ - model = state.model - dataloader = state.data.calib + model: torch.nn.Module = state.model + dataloader: torch.utils.data.DataLoader = state.data.calib # infer module and sequential targets self.sequential_targets = self._infer_sequential_targets(model) + layers = get_layers(self.sequential_targets, model) # infer layer sparsities if self.sparsity_profile == "owl": @@ -108,10 +117,9 @@ def on_initialize(self, state: "State", **kwargs) -> bool: "Using OWL to infer target layer-wise sparsities from " f"{len(dataloader) if dataloader else 0} calibration samples..." ) - self.sparsity = self._infer_owl_layer_sparsity() + self.sparsity = self._infer_owl_layer_sparsity(model, layers, dataloader) # get layers and validate sparsity - layers = get_layers(self.sequential_targets, model) if isinstance(self.sparsity, (list, dict)) and len(layers) != len( self.sparsity ): @@ -121,18 +129,21 @@ def on_initialize(self, state: "State", **kwargs) -> bool: ) # register hooks - for index, (name, layer) in enumerate(layers.items()): + #target_modules = match_layers_params(self.targets, model) + for index, (layer_name, layer) in enumerate(layers.items()): if isinstance(self.sparsity, dict): - layer_sparsity = self.sparsity[name] + layer_sparsity = self.sparsity[layer_name] elif isinstance(self.sparsity, list): layer_sparsity = self.sparsity[index] else: layer_sparsity = self.sparsity - for name, module in get_prunable_layers(layer).items(): - self._module_names[module] = name - self._module_sparsities[module] = layer_sparsity - self.register_hook(module, self.calibrate_module, "forward") + # TODO: match module or param + for name, module in layer.named_modules(prefix=layer_name): + if module in target_modules.values(): + self._module_names[module] = name + self._module_sparsities[module] = layer_sparsity + self.register_hook(module, self.calibrate_module, "forward") # infer and run pipeline model_name = state.model.__class__.__name__ @@ -177,8 +188,6 @@ def on_initialize(self, state: "State", **kwargs) -> bool: run_basic(state.model, state.data.calib, self) return True - return True - def _infer_sequential_targets( self, model: torch.nn.Module ) -> Union[str, List[str]]: @@ -188,15 +197,23 @@ def _infer_sequential_targets( return [self.sequential_targets] return self.sequential_targets - def _infer_owl_layer_sparsity(self, activations): + def _infer_owl_layer_sparsity( + self, + model: torch.nn.Module, + layers: Dict[str, torch.nn.Module], + dataloader: torch.utils.data.DataLoader, + ) -> Dict[str, float]: + activations = self._get_activations(model, dataloader) groups = {} - for name, layer in self.compressible_layers_.items(): - prunable_layers = get_prunable_layers(layer) + + target_modules = match_layers_params(self.targets, model) + for layer_name, layer in layers.items(): z = [ - m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0) - for n, m in prunable_layers.items() + module.weight.abs() * activations[f"{layer_name}.{name}"].unsqueeze(0) + for name, module in layer.named_modules.items() + if module in target_modules.values() ] - groups[name] = torch.cat([item.flatten().cpu() for item in z]) + groups[layer_name] = torch.cat([item.flatten().cpu() for item in z]) del activations diff --git a/src/llmcompressor/modifiers/pruning/wanda/base.py b/src/llmcompressor/modifiers/pruning/wanda/base.py index fb3696933e..291e5ae481 100644 --- a/src/llmcompressor/modifiers/pruning/wanda/base.py +++ b/src/llmcompressor/modifiers/pruning/wanda/base.py @@ -7,7 +7,7 @@ update_offload_parameter, ) from loguru import logger -from pydantic import PrivateAttr +from pydantic import PrivateAttr, Field from llmcompressor.core import State from llmcompressor.modifiers import Modifier @@ -69,8 +69,9 @@ class WandaPruningModifier(SparsityModifierMixin, Modifier): # data pipeline arguments sequential_update: Optional[bool] = False # deprecated - sequential_targets: Union[str, List[str], None] = None - targets: Union[str, List[str], None] = None # alias sequential_targets + sequential_targets: Union[str, List[str]] = None + targets: Union[str, List[str], None] = ["Linear"] + ignore: List[str] = Field(default_factory=list) # private variables _prune_n: Optional[int] = PrivateAttr(default=None) From 440212113905eac9a5ed36530740b0c0a204b4ec Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 7 Feb 2025 20:19:08 +0000 Subject: [PATCH 02/13] implement sequential pipeline hack Signed-off-by: Kyle Sayers --- .../modifiers/obcq/sgpt_mixin.py | 7 +++--- .../modifiers/pruning/wanda/base.py | 2 +- src/llmcompressor/pipelines/basic/pipeline.py | 6 ++--- .../pipelines/layer_sequential/helpers.py | 23 ++++++++++++++++++- .../pipelines/layer_sequential/pipeline.py | 7 +++++- .../oneshot_configs/recipes/recipe.yaml | 1 - .../oneshot_configs/tiny_stories_conf1.yaml | 1 - 7 files changed, 36 insertions(+), 11 deletions(-) diff --git a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py index 721ef3fd4a..6d002cb295 100644 --- a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py +++ b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py @@ -20,6 +20,7 @@ from llmcompressor.utils.pytorch.module import ( get_layers, get_no_split_params, + get_prunable_layers, match_layers_params, ) @@ -129,7 +130,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: ) # register hooks - #target_modules = match_layers_params(self.targets, model) + target_modules = match_layers_params(self.targets, model) for index, (layer_name, layer) in enumerate(layers.items()): if isinstance(self.sparsity, dict): layer_sparsity = self.sparsity[layer_name] @@ -139,9 +140,9 @@ def on_initialize(self, state: "State", **kwargs) -> bool: layer_sparsity = self.sparsity # TODO: match module or param - for name, module in layer.named_modules(prefix=layer_name): + for name, module in get_prunable_layers(layer).items(): if module in target_modules.values(): - self._module_names[module] = name + self._module_names[module] = f"{layer_name}.{name}" self._module_sparsities[module] = layer_sparsity self.register_hook(module, self.calibrate_module, "forward") diff --git a/src/llmcompressor/modifiers/pruning/wanda/base.py b/src/llmcompressor/modifiers/pruning/wanda/base.py index 291e5ae481..f4ef5e224c 100644 --- a/src/llmcompressor/modifiers/pruning/wanda/base.py +++ b/src/llmcompressor/modifiers/pruning/wanda/base.py @@ -7,7 +7,7 @@ update_offload_parameter, ) from loguru import logger -from pydantic import PrivateAttr, Field +from pydantic import Field, PrivateAttr from llmcompressor.core import State from llmcompressor.modifiers import Modifier diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index 61d6e28ce9..13a1c9454c 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -40,6 +40,6 @@ def run_pipeline( batch = tensors_to_device(batch, model_device) model(**batch) - # TODO: replace with a lifecycle event - if callback_modifier: - callback_modifier.on_sequential_batch_end() + # TODO: replace with a lifecycle event + if callback_modifier: + callback_modifier.on_sequential_batch_end() diff --git a/src/llmcompressor/pipelines/layer_sequential/helpers.py b/src/llmcompressor/pipelines/layer_sequential/helpers.py index 06e7e5b3ba..2bf943fcc1 100644 --- a/src/llmcompressor/pipelines/layer_sequential/helpers.py +++ b/src/llmcompressor/pipelines/layer_sequential/helpers.py @@ -15,7 +15,12 @@ from llmcompressor.pytorch.utils.helpers import tensors_to_device from llmcompressor.utils.helpers import calibration_forward_context -__all__ = ["match_modules", "capture_first_layer_intermediates", "to_next_layer_kwargs"] +__all__ = [ + "match_modules", + "capture_first_layer_intermediates", + "to_next_layer_kwargs", + "maybe_inject_pos_embeddings", +] def match_modules(model: Module, target_names: List[str]) -> List[Module]: @@ -126,3 +131,19 @@ class EarlyStopException(Exception): _args: Tuple[Any, ...] _kwargs: Dict[str, Any] + + +def maybe_inject_pos_embeddings( + output: Dict[str, Any], + next_layer: Module, + inputs: Dict[str, Any], +) -> Dict[str, Any]: + signature = inspect.signature(next_layer.forward) + if ( + "position_embeddings" in signature.parameters.keys() + and "position_embeddings" in inputs + and "position_embeddings" not in output + ): + output["position_embeddings"] = inputs["position_embeddings"] + + return output diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index cef100e2f8..c0ae0b620a 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -9,6 +9,7 @@ from llmcompressor.pipelines.layer_sequential.helpers import ( capture_first_layer_intermediates, match_modules, + maybe_inject_pos_embeddings, to_next_layer_kwargs, ) from llmcompressor.utils.helpers import calibration_forward_context @@ -79,6 +80,10 @@ def run_pipeline( output = layer(**inputs) if layer_index < num_layers - 1: - output = to_next_layer_kwargs(output, layers[layer_index + 1]) + next_layer = layers[layer_index + 1] + output = to_next_layer_kwargs(output, next_layer) + # HACK: accommodate models which pass position embeddings + output = maybe_inject_pos_embeddings(output, next_layer, inputs) + intermediates.delete(batch_index) intermediates.update(batch_index, output) diff --git a/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml b/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml index c5bf782d49..54239b3b46 100644 --- a/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml +++ b/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml @@ -3,7 +3,6 @@ test_stage: SparseGPTModifier: sparsity: 0.5 block_size: 128 - sequential_update: False targets: [ 're:model.layers.3.mlp.gate_proj.weight' ] \ No newline at end of file diff --git a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml index 39f9d65762..7b795ba8e7 100644 --- a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml +++ b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml @@ -9,7 +9,6 @@ recipe: | SparseGPTModifier: sparsity: 0.5 block_size: 128 - sequential_update: False targets: [ 're:model.layers.3.mlp.gate_proj.weight' ] \ No newline at end of file From 1479dc1982ef43604ac0ab78f630a4f9d3ca78e5 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 7 Feb 2025 18:15:14 -0500 Subject: [PATCH 03/13] unwrap wrapping decorators Signed-off-by: Kyle Sayers --- .../pipelines/sequential/helpers.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 4945ba01e4..9fcf80eda0 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -1,7 +1,7 @@ import inspect from collections import deque from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Set +from typing import Any, Callable, Dict, List, Set, Union from compressed_tensors import has_offloaded_params from compressed_tensors.quantization import find_name_or_class_matches @@ -125,13 +125,24 @@ def create_arg(self, a: Any) -> Argument: return self.create_node("call_function", a.__class__, (), kwargs) else: - return super().create_arg(a) + return super(HFTracer, self).create_arg(a) def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: - return module in skip_trace_modules or super().is_leaf_module( + return module in skip_trace_modules or super(HFTracer, self).is_leaf_module( module, module_qualified_name ) + def trace(self, root: Union[Module, Callable], *args, **kwargs) -> Graph: + if isinstance(root, Module): + root = root.forward + + # unwrap any decorators that may have altered the function signature, + # for example `deprecate_kwarg` added by transformers + while hasattr(root, "__wrapped__"): + root = root.__wrapped__ + + return super(HFTracer, self).trace(root, *args, **kwargs) + return SequentialTracer() From 4f5cb617147237ceb54abc03132471af85580a10 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 7 Feb 2025 18:37:14 -0500 Subject: [PATCH 04/13] docstrings and comments Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/layer_sequential/helpers.py | 8 ++++++++ src/llmcompressor/pipelines/layer_sequential/pipeline.py | 1 - src/llmcompressor/pipelines/sequential/helpers.py | 8 ++++---- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/pipelines/layer_sequential/helpers.py b/src/llmcompressor/pipelines/layer_sequential/helpers.py index 2bf943fcc1..3760ee99d9 100644 --- a/src/llmcompressor/pipelines/layer_sequential/helpers.py +++ b/src/llmcompressor/pipelines/layer_sequential/helpers.py @@ -138,6 +138,14 @@ def maybe_inject_pos_embeddings( next_layer: Module, inputs: Dict[str, Any], ) -> Dict[str, Any]: + """ + As of https://github.com/huggingface/transformers/pull/34858, positional embeddings + must be passed into each decoder call as kwargs + + :param output: output of the previous layer + :param next_layer: next layer to call + :param inputs: inputs to next layer + """ signature = inspect.signature(next_layer.forward) if ( "position_embeddings" in signature.parameters.keys() diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index c0ae0b620a..9f8adbce4f 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -82,7 +82,6 @@ def run_pipeline( if layer_index < num_layers - 1: next_layer = layers[layer_index + 1] output = to_next_layer_kwargs(output, next_layer) - # HACK: accommodate models which pass position embeddings output = maybe_inject_pos_embeddings(output, next_layer, inputs) intermediates.delete(batch_index) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 9fcf80eda0..189a74db39 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -136,10 +136,10 @@ def trace(self, root: Union[Module, Callable], *args, **kwargs) -> Graph: if isinstance(root, Module): root = root.forward - # unwrap any decorators that may have altered the function signature, - # for example `deprecate_kwarg` added by transformers - while hasattr(root, "__wrapped__"): - root = root.__wrapped__ + # due to a bug in Tracer.create_args_for_root (likely _patch_function args), + # must unwrap function wrappers prior to tracing, for example + # `deprecate_kwarg` decorator added by transformers + root = inspect.unwrap(root) return super(HFTracer, self).trace(root, *args, **kwargs) From 9c33a066e07760d61b10018a4fd3d8e1fb5bf14f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 7 Feb 2025 20:39:02 -0500 Subject: [PATCH 05/13] cleanup Signed-off-by: Kyle Sayers --- .../pipelines/sequential/helpers.py | 24 ++++++++++++------- src/llmcompressor/utils/helpers.py | 10 ++++++++ 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 189a74db39..629990be9a 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -13,7 +13,7 @@ from transformers.utils.fx import HFTracer from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.utils.helpers import calibration_forward_context +from llmcompressor.utils.helpers import calibration_forward_context, preserve_attr __all__ = ["trace_subgraphs", "Subgraph"] @@ -114,6 +114,7 @@ def get_tracer( :param sequential_targets: modules which are sequential targets :param ignore: modules which are ignored """ + # TODO: redefine skip_trace_modules to all non-ancestors of sequential_targets offloaded_modules = set(m for m in model.modules() if has_offloaded_params(m)) skip_trace_modules = sequential_targets | offloaded_modules | ignore @@ -125,23 +126,28 @@ def create_arg(self, a: Any) -> Argument: return self.create_node("call_function", a.__class__, (), kwargs) else: - return super(HFTracer, self).create_arg(a) + return super().create_arg(a) def is_leaf_module(self, module: Module, module_qualified_name: str) -> bool: - return module in skip_trace_modules or super(HFTracer, self).is_leaf_module( + return module in skip_trace_modules or super().is_leaf_module( module, module_qualified_name ) def trace(self, root: Union[Module, Callable], *args, **kwargs) -> Graph: if isinstance(root, Module): - root = root.forward + with preserve_attr(type(root), "forward"): + # due to a bug in Tracer.create_args_for_root (_patch_function), + # we must unwrap function wrappers prior to tracing, for example + # the `deprecate_kwarg` by transformers which wraps forward - # due to a bug in Tracer.create_args_for_root (likely _patch_function args), - # must unwrap function wrappers prior to tracing, for example - # `deprecate_kwarg` decorator added by transformers - root = inspect.unwrap(root) + # we override the class method because the + # class method is the one being traced + type(root).forward = inspect.unwrap(type(root).forward) - return super(HFTracer, self).trace(root, *args, **kwargs) + return super().trace(root, *args, **kwargs) + + else: + return super().trace(root, *args, **kwargs) return SequentialTracer() diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index ad4d884b24..097bcf1cbd 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -65,6 +65,7 @@ "DisableKVCache", "DisableQuantization", "calibration_forward_context", + "preserve_attr", ] @@ -1115,3 +1116,12 @@ def calibration_forward_context(model: torch.nn.Module): DisableQuantization(model), ): yield + + +@contextlib.contextmanager +def preserve_attr(base: object, attr: str): + value = getattr(base, attr) + try: + yield + finally: + setattr(base, attr, value) From 2b2c01a54131f262b9e40bb2496dd7d6d7b5ac74 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 7 Feb 2025 20:57:45 -0500 Subject: [PATCH 06/13] fix matching Signed-off-by: Kyle Sayers --- .../modifiers/obcq/sgpt_mixin.py | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py index 6d002cb295..6c534cc852 100644 --- a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py +++ b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py @@ -2,7 +2,7 @@ from abc import abstractmethod from collections import defaultdict from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union import numpy import torch @@ -130,7 +130,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: ) # register hooks - target_modules = match_layers_params(self.targets, model) + target_modules = self._get_target_modules(model) for index, (layer_name, layer) in enumerate(layers.items()): if isinstance(self.sparsity, dict): layer_sparsity = self.sparsity[layer_name] @@ -139,9 +139,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool: else: layer_sparsity = self.sparsity - # TODO: match module or param for name, module in get_prunable_layers(layer).items(): - if module in target_modules.values(): + if module in target_modules: self._module_names[module] = f"{layer_name}.{name}" self._module_sparsities[module] = layer_sparsity self.register_hook(module, self.calibrate_module, "forward") @@ -198,6 +197,15 @@ def _infer_sequential_targets( return [self.sequential_targets] return self.sequential_targets + def _get_target_modules(self, model: torch.nn.Module) -> Set[torch.nn.Module]: + target_layers = match_layers_params(self.targets, model) + return set().union( + *( + get_prunable_layers(target_layer).values() + for target_layer in target_layers.values() + ) + ) + def _infer_owl_layer_sparsity( self, model: torch.nn.Module, @@ -205,16 +213,15 @@ def _infer_owl_layer_sparsity( dataloader: torch.utils.data.DataLoader, ) -> Dict[str, float]: activations = self._get_activations(model, dataloader) - groups = {} - target_modules = match_layers_params(self.targets, model) - for layer_name, layer in layers.items(): + groups = {} + for name, layer in layers.items(): + prunable_layers = get_prunable_layers(layer) z = [ - module.weight.abs() * activations[f"{layer_name}.{name}"].unsqueeze(0) - for name, module in layer.named_modules.items() - if module in target_modules.values() + m.weight.abs() * activations[f"{name}.{n}"].unsqueeze(0) + for n, m in prunable_layers.items() ] - groups[layer_name] = torch.cat([item.flatten().cpu() for item in z]) + groups[name] = torch.cat([item.flatten().cpu() for item in z]) del activations From 3c3da1b3d1a2a834fa53a8cfa5f177c2bcece6f1 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 7 Feb 2025 21:14:49 -0500 Subject: [PATCH 07/13] closer match to original behavior Signed-off-by: Kyle Sayers --- .../modifiers/obcq/sgpt_mixin.py | 23 +++++-------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py index 6c534cc852..d9816d61e5 100644 --- a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py +++ b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py @@ -2,7 +2,7 @@ from abc import abstractmethod from collections import defaultdict from functools import partial -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy import torch @@ -21,7 +21,6 @@ get_layers, get_no_split_params, get_prunable_layers, - match_layers_params, ) @@ -111,6 +110,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: # infer module and sequential targets self.sequential_targets = self._infer_sequential_targets(model) layers = get_layers(self.sequential_targets, model) + target_layers = get_layers(self.targets, model) # layers containing targets # infer layer sparsities if self.sparsity_profile == "owl": @@ -130,8 +130,7 @@ def on_initialize(self, state: "State", **kwargs) -> bool: ) # register hooks - target_modules = self._get_target_modules(model) - for index, (layer_name, layer) in enumerate(layers.items()): + for index, (layer_name, layer) in enumerate(target_layers.items()): if isinstance(self.sparsity, dict): layer_sparsity = self.sparsity[layer_name] elif isinstance(self.sparsity, list): @@ -140,10 +139,9 @@ def on_initialize(self, state: "State", **kwargs) -> bool: layer_sparsity = self.sparsity for name, module in get_prunable_layers(layer).items(): - if module in target_modules: - self._module_names[module] = f"{layer_name}.{name}" - self._module_sparsities[module] = layer_sparsity - self.register_hook(module, self.calibrate_module, "forward") + self._module_names[module] = f"{layer_name}.{name}" + self._module_sparsities[module] = layer_sparsity + self.register_hook(module, self.calibrate_module, "forward") # infer and run pipeline model_name = state.model.__class__.__name__ @@ -197,15 +195,6 @@ def _infer_sequential_targets( return [self.sequential_targets] return self.sequential_targets - def _get_target_modules(self, model: torch.nn.Module) -> Set[torch.nn.Module]: - target_layers = match_layers_params(self.targets, model) - return set().union( - *( - get_prunable_layers(target_layer).values() - for target_layer in target_layers.values() - ) - ) - def _infer_owl_layer_sparsity( self, model: torch.nn.Module, From f6a96b3c231fffac4f4f7c34e5ce0bcbd1ddf65f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 7 Feb 2025 21:17:45 -0500 Subject: [PATCH 08/13] add ignore Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/obcq/sgpt_mixin.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py index d9816d61e5..ad55c852c1 100644 --- a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py +++ b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py @@ -21,6 +21,7 @@ get_layers, get_no_split_params, get_prunable_layers, + match_targets, ) @@ -139,9 +140,11 @@ def on_initialize(self, state: "State", **kwargs) -> bool: layer_sparsity = self.sparsity for name, module in get_prunable_layers(layer).items(): - self._module_names[module] = f"{layer_name}.{name}" - self._module_sparsities[module] = layer_sparsity - self.register_hook(module, self.calibrate_module, "forward") + name = f"{layer_name}.{name}" + if not match_targets(name, self.ignore)[0]: + self._module_names[module] = name + self._module_sparsities[module] = layer_sparsity + self.register_hook(module, self.calibrate_module, "forward") # infer and run pipeline model_name = state.model.__class__.__name__ From 83fb4376d47ef9acd644e56d05c9ed9c597ce35f Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 7 Feb 2025 21:40:30 -0500 Subject: [PATCH 09/13] add owl test Signed-off-by: Kyle Sayers --- .../modifiers/obcq/sgpt_mixin.py | 6 +--- .../transformers/obcq/test_obcq_owl.py | 36 +++++++++++++++++++ .../test_compress_tensor_utils.py | 4 ++- 3 files changed, 40 insertions(+), 6 deletions(-) create mode 100644 tests/llmcompressor/transformers/obcq/test_obcq_owl.py diff --git a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py index ad55c852c1..8f1df7b800 100644 --- a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py +++ b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py @@ -27,7 +27,7 @@ class SparsityModifierMixin(HooksMixin): # modifier arguments - sparsity: Optional[Union[float, List[float]]] = None + sparsity: Optional[Union[float, List[float]]] sparsity_profile: Optional[str] = None mask_structure: str = "0:0" owl_m: Optional[int] = None @@ -71,7 +71,6 @@ def validate_sparsity_profile(cls, value: Optional[str]) -> bool: @model_validator(mode="after") def validate_model_after(model: "SparsityModifierMixin") -> "Modifier": - sparsity = model.sparsity profile = model.sparsity_profile owl_m = model.owl_m owl_lmbda = model.owl_lmbda @@ -83,9 +82,6 @@ def validate_model_after(model: "SparsityModifierMixin") -> "Modifier": if profile != "owl" and (owl_m is not None or owl_lmbda is not None): raise ValueError("Must provide both `owl_m` and `owl_lmbda`") - if owl_m is not None and sparsity is not None: - raise ValueError("Cannot provide both sparsity and owl parameters") - model._prune_n, model._prune_m = model._split_mask_structure(mask_structure) return model diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_owl.py b/tests/llmcompressor/transformers/obcq/test_obcq_owl.py new file mode 100644 index 0000000000..4948c6da34 --- /dev/null +++ b/tests/llmcompressor/transformers/obcq/test_obcq_owl.py @@ -0,0 +1,36 @@ +import pytest +import torch +from datasets import Dataset +from transformers import AutoModelForCausalLM + +from llmcompressor.modifiers.obcq import SparseGPTModifier +from llmcompressor.transformers.finetune.data.data_helpers import ( + format_calibration_data, +) +from llmcompressor.utils.pytorch.module import get_layers + + +@pytest.mark.integration +def test_infer_owl_layer_sparsity(): + target_sparsity = 0.7 + vocab_size = 512 + seq_len = 2048 + ds_size = 16 + + modifier = SparseGPTModifier( + sparsity=0.7, sparsity_profile="owl", owl_m=5, owl_lmbda=0.05 + ) + model = AutoModelForCausalLM.from_pretrained("Xenova/llama2.c-stories15M") + + dataset = Dataset.from_dict( + {"input_ids": torch.randint(0, vocab_size, (ds_size, seq_len))} + ) + dataloader = format_calibration_data(dataset) + + sequential_targets = modifier._infer_sequential_targets(model) + layers = get_layers(sequential_targets, model) + sparsities = modifier._infer_owl_layer_sparsity(model, layers, dataloader) + assert sparsities.keys() == layers.keys() + + for sparsity in sparsities.values(): + assert sparsity == pytest.approx(target_sparsity, abs=0.1) diff --git a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py index 58eda03ee5..eeb6e95ae5 100644 --- a/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py +++ b/tests/llmcompressor/transformers/sparsification/test_compress_tensor_utils.py @@ -672,7 +672,9 @@ def test_correct_compressor_inferred( if is_24: weights = _make_24_sparse(weights) else: - weights[0, :] = torch.ones(4, ) # guarantee not 24 sparse + weights[0, :] = torch.ones( + 4, + ) # guarantee not 24 sparse quantization_config = _quantization_config_from_string(quant_style, quant_type) quantization_args = quantization_config.config_groups["group_0"].weights From cacd87c24ee6f223ee89cddb9a8ad49c6cd429ce Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 7 Feb 2025 21:42:37 -0500 Subject: [PATCH 10/13] reduce diff Signed-off-by: Kyle Sayers --- .../transformers/oneshot/oneshot_configs/recipes/recipe.yaml | 1 + .../transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml b/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml index 54239b3b46..c5bf782d49 100644 --- a/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml +++ b/tests/llmcompressor/transformers/oneshot/oneshot_configs/recipes/recipe.yaml @@ -3,6 +3,7 @@ test_stage: SparseGPTModifier: sparsity: 0.5 block_size: 128 + sequential_update: False targets: [ 're:model.layers.3.mlp.gate_proj.weight' ] \ No newline at end of file diff --git a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml index 7b795ba8e7..39f9d65762 100644 --- a/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml +++ b/tests/llmcompressor/transformers/oneshot/oneshot_configs/tiny_stories_conf1.yaml @@ -9,6 +9,7 @@ recipe: | SparseGPTModifier: sparsity: 0.5 block_size: 128 + sequential_update: False targets: [ 're:model.layers.3.mlp.gate_proj.weight' ] \ No newline at end of file From d1fff02068c03429199ec73a102ad18947280929 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sun, 9 Feb 2025 12:31:19 -0500 Subject: [PATCH 11/13] revert targets functionality in tests Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/obcq/sgpt_mixin.py | 4 ++-- .../pytorch/modifiers/pruning/sparsegpt/test_pytorch.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py index 8f1df7b800..4a21f1b163 100644 --- a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py +++ b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py @@ -118,12 +118,12 @@ def on_initialize(self, state: "State", **kwargs) -> bool: self.sparsity = self._infer_owl_layer_sparsity(model, layers, dataloader) # get layers and validate sparsity - if isinstance(self.sparsity, (list, dict)) and len(layers) != len( + if isinstance(self.sparsity, (list, dict)) and len(target_layers) != len( self.sparsity ): raise ValueError( f"{self.__repr_name__} was initialized with {len(self.sparsity)} " - f"sparsities values, but model only has {len(layers)} layers" + f"sparsities values, but model has {len(layers)} target layers" ) # register hooks diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index b5adc45ce1..c1f0cb4253 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -32,7 +32,7 @@ def test_invalid_layerwise_recipes_raise_exceptions(self, sparsity, targets): modifier = SparseGPTModifier( sparsity=sparsity, block_size=128, - sequential_targets=targets, + targets=targets, ) testing_harness = LifecyleTestingHarness(model=LinearNet(), start=-1) @@ -48,9 +48,9 @@ def setUp(self): def test_successful_layerwise_recipe(self): sparsities = [0.5, 0.2] - sequential_targets = ["seq.fc1", "seq.fc2"] + targets = ["seq.fc1", "seq.fc2"] modifier = SparseGPTModifier( - sparsity=sparsities, block_size=128, sequential_targets=sequential_targets + sparsity=sparsities, block_size=128, targets=targets ) testing_harness = LifecyleTestingHarness(model=LinearNet(), start=-1) modifier.initialize(testing_harness.get_state()) From f02ec131a74eaf8aa9cb8d1db585b805d13d3b95 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sun, 9 Feb 2025 12:37:48 -0500 Subject: [PATCH 12/13] remove replicated arguments Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/obcq/base.py | 23 ++++--------------- .../modifiers/pruning/wanda/base.py | 23 ++++--------------- 2 files changed, 9 insertions(+), 37 deletions(-) diff --git a/src/llmcompressor/modifiers/obcq/base.py b/src/llmcompressor/modifiers/obcq/base.py index 1329f17de8..4b0c9e5026 100644 --- a/src/llmcompressor/modifiers/obcq/base.py +++ b/src/llmcompressor/modifiers/obcq/base.py @@ -1,5 +1,5 @@ import contextlib -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Tuple import torch from compressed_tensors.utils import ( @@ -8,7 +8,7 @@ update_offload_parameter, ) from loguru import logger -from pydantic import Field, PrivateAttr +from pydantic import PrivateAttr from llmcompressor.core import State from llmcompressor.modifiers import Modifier @@ -69,32 +69,19 @@ class SparseGPTModifier(SparsityModifierMixin, Modifier): to compress every layer in the model. Alias for `targets` :param targets: list of layer names to compress during OBCQ, or '__ALL__' to compress every layer in the model. Alias for `sequential_targets` + :param ignore: optional list of module class names or submodule names to not + quantize even if they match a target. Defaults to empty list. """ # modifier arguments - sparsity: Optional[Union[float, List[float]]] = None - sparsity_profile: Optional[str] = None - mask_structure: str = "0:0" - owl_m: Optional[int] = None - owl_lmbda: Optional[float] = None block_size: int = 128 dampening_frac: Optional[float] = 0.01 preserve_sparsity_mask: bool = False offload_hessians: bool = False - # data pipeline arguments - sequential_update: Optional[bool] = False # deprecated - sequential_targets: Union[str, List[str], None] = None - targets: Union[str, List[str]] = ["Linear"] - ignore: List[str] = Field(default_factory=list) - # private variables - _prune_n: Optional[int] = PrivateAttr(default=None) - _prune_m: Optional[int] = PrivateAttr(default=None) - _hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict) _num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict) - _module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) - _module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) + _hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict) def calibrate_module( self, diff --git a/src/llmcompressor/modifiers/pruning/wanda/base.py b/src/llmcompressor/modifiers/pruning/wanda/base.py index f4ef5e224c..3b0eb9f584 100644 --- a/src/llmcompressor/modifiers/pruning/wanda/base.py +++ b/src/llmcompressor/modifiers/pruning/wanda/base.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Tuple import torch from compressed_tensors.utils import ( @@ -7,7 +7,7 @@ update_offload_parameter, ) from loguru import logger -from pydantic import Field, PrivateAttr +from pydantic import PrivateAttr from llmcompressor.core import State from llmcompressor.modifiers import Modifier @@ -58,30 +58,15 @@ class WandaPruningModifier(SparsityModifierMixin, Modifier): to compress every layer in the model. Alias for `targets` :param targets: list of layer names to compress during OBCQ, or '__ALL__' to compress every layer in the model. Alias for `sequential_targets` + :param ignore: optional list of module class names or submodule names to not + quantize even if they match a target. Defaults to empty list. """ - # sparsity arguments - sparsity: Optional[Union[float, List[float]]] = None - sparsity_profile: Optional[str] = None - mask_structure: str = "0:0" - owl_m: Optional[int] = None - owl_lmbda: Optional[float] = None - - # data pipeline arguments - sequential_update: Optional[bool] = False # deprecated - sequential_targets: Union[str, List[str]] = None - targets: Union[str, List[str], None] = ["Linear"] - ignore: List[str] = Field(default_factory=list) - # private variables - _prune_n: Optional[int] = PrivateAttr(default=None) - _prune_m: Optional[int] = PrivateAttr(default=None) _row_scalars: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr( default_factory=dict ) _num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict) - _module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) - _module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) def calibrate_module( self, From 3b577b8e038cbea72a2f4d03f30d3ac21ae0339e Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 10 Feb 2025 09:16:34 -0500 Subject: [PATCH 13/13] clearer owl validation Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/obcq/sgpt_mixin.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py index 4a21f1b163..ff4b9af2ba 100644 --- a/src/llmcompressor/modifiers/obcq/sgpt_mixin.py +++ b/src/llmcompressor/modifiers/obcq/sgpt_mixin.py @@ -10,7 +10,6 @@ from pydantic import Field, PrivateAttr, field_validator, model_validator from llmcompressor.core import State -from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.basic import run_pipeline as run_basic from llmcompressor.pipelines.layer_sequential import ( @@ -70,17 +69,20 @@ def validate_sparsity_profile(cls, value: Optional[str]) -> bool: return value @model_validator(mode="after") - def validate_model_after(model: "SparsityModifierMixin") -> "Modifier": + def validate_model_after(model: "SparsityModifierMixin") -> "SparsityModifierMixin": profile = model.sparsity_profile owl_m = model.owl_m owl_lmbda = model.owl_lmbda mask_structure = model.mask_structure - if profile == "owl" and ((owl_m is not None) ^ (owl_lmbda is not None)): - raise ValueError("Must provide both `owl_m` and `owl_lmbda` or neither") - - if profile != "owl" and (owl_m is not None or owl_lmbda is not None): - raise ValueError("Must provide both `owl_m` and `owl_lmbda`") + has_owl_m = owl_m is not None + has_owl_lmbda = owl_lmbda is not None + has_owl = profile == "owl" + owl_args = (has_owl_m, has_owl_lmbda, has_owl) + if any(owl_args) and not all(owl_args): + raise ValueError( + 'Must provide all of `profile="owl"`, `owl_m` and `owl_lmbda` or none' + ) model._prune_n, model._prune_m = model._split_mask_structure(mask_structure)