-
Notifications
You must be signed in to change notification settings - Fork 0
Awq masking #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: awq_bugfix
Are you sure you want to change the base?
Awq masking #1
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,6 +40,7 @@ | |
| from llmcompressor.modifiers.utils.hooks import HooksMixin | ||
| from llmcompressor.observers.base import Observer | ||
| from llmcompressor.pipelines.cache import IntermediatesCache | ||
| from llmcompressor.pipelines.sequential.pipeline import _current_loss_mask | ||
| from llmcompressor.utils.fsdp.helpers import get_fsdp_parent | ||
| from llmcompressor.utils.helpers import calibration_forward_context | ||
| from llmcompressor.utils.pytorch.module import ( | ||
|
|
@@ -162,6 +163,10 @@ class AWQModifier(Modifier, QuantizationMixin): | |
| _parent_args_cache: dict[Module, IntermediatesCache] = PrivateAttr( | ||
| default_factory=dict | ||
| ) | ||
| # Cache loss_mask for each parent module, one mask per batch | ||
| _loss_masks: list[torch.Tensor | None] = PrivateAttr( | ||
| default_factory=list | ||
| ) | ||
| # Dict[smooth layer name, (activation means, activation counts)] | ||
| _smooth_activation_means: dict[str, tuple[torch.FloatTensor, int]] = PrivateAttr( | ||
| default_factory=dict | ||
|
|
@@ -269,6 +274,7 @@ def on_finalize(self, state: State, **kwargs) -> bool: | |
| self.on_end(state, None) | ||
|
|
||
| self._parent_args_cache.clear() | ||
| self._loss_masks = [] | ||
| self._smooth_activation_means.clear() | ||
| self._resolved_mappings.clear() | ||
|
|
||
|
|
@@ -382,15 +388,48 @@ def cache_parent_kwargs_hook( | |
| values = inspect.signature(module.forward).bind(*args, **kwargs) | ||
| self._parent_args_cache[module].append(values.arguments) | ||
|
|
||
| loss_mask = _current_loss_mask.get() | ||
| self._loss_masks.append(loss_mask) | ||
|
|
||
|
|
||
| def create_cache_smooth_activations_hook_fn(smooth_name): | ||
| def cache_smooth_activations_hook( | ||
| _module: Module, | ||
| args: tuple[torch.Tensor, ...], | ||
| _output: torch.Tensor, | ||
| ): | ||
| # Get activation tensor: shape [batch, seq_len, hidden_dim] | ||
| activations = args[0].abs().detach() | ||
|
|
||
| # Flatten activations: [batch * seq_len, hidden_dim] | ||
| flat_activations = activations.flatten(0, -2) | ||
|
|
||
| # Try to get loss_mask from context variable | ||
| loss_mask = _current_loss_mask.get() | ||
|
|
||
| if loss_mask is not None: | ||
| # loss_mask shape: [batch, seq_len] | ||
| # Flatten to [batch * seq_len] | ||
|
|
||
| flat_mask = loss_mask.flatten().to(activations.device) | ||
|
|
||
| # Filter: only keep rows where mask == 1 | ||
| valid_mask = flat_mask > 0 | ||
| if valid_mask.any(): | ||
| flat_activations = flat_activations[valid_mask] | ||
| else: | ||
| # No valid tokens in this batch, use zeros | ||
| logger.warning( | ||
| f"No valid tokens (mask==1) found for {smooth_name} in current batch" | ||
| ) | ||
| flat_activations = torch.zeros( | ||
| (1, activations.shape[-1]), | ||
| device=activations.device, | ||
| dtype=activations.dtype | ||
| ) | ||
|
|
||
| act_mean, count = _accumulate_mean( | ||
| args[0].abs().detach().flatten(0, -2), | ||
| flat_activations, | ||
| self._smooth_activation_means.get(smooth_name, None), | ||
| ) | ||
| self._smooth_activation_means[smooth_name] = (act_mean.cpu(), count) | ||
|
|
@@ -415,8 +454,11 @@ def cache_smooth_activations_hook( | |
| # input activations to balance layers needed for loss function | ||
| # storing inputs to first balance layer is sufficient | ||
| # other balance layers get the same input | ||
|
|
||
| #(zewen): this should be improved | ||
| layer_to_hook = mapping.parent.mlp if hasattr(mapping.parent, 'mlp') else mapping.balance_layers[0] | ||
| self.register_hook( | ||
| mapping.balance_layers[0], | ||
| layer_to_hook, | ||
| create_cache_smooth_activations_hook_fn(mapping.smooth_name), | ||
| "forward", | ||
| ) | ||
|
|
@@ -715,11 +757,45 @@ def _compute_loss( | |
| num_elements = 0 | ||
|
|
||
| # Compute the MSE loss for each batch | ||
| for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see changes in vllm-project#2188 which will land soon i suspect it will make more sense to apply the mask in run_samples and the concatenated fp16_output calculation rather than the loss calculation if possible
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I'll make that change |
||
| for batch_idx, (fp16_batch, int_w_batch) in enumerate( | ||
| zip(fp16_outputs, int_w_outputs) | ||
| ): | ||
| int_w_batch = int_w_batch.to(fp16_batch.device) | ||
|
|
||
| # Apply mask if available for this batch | ||
| if batch_idx < len(self._loss_masks) and self._loss_masks[batch_idx] is not None: | ||
| mask = self._loss_masks[batch_idx].to(fp16_batch.device) | ||
|
|
||
| # mask shape: [batch, seq_len] | ||
| # output shape: [batch, seq_len, hidden_dim] | ||
| # Flatten both to [batch * seq_len, hidden_dim] and [batch * seq_len] | ||
| fp16_flat = fp16_batch.flatten(0, -2) # [batch * seq_len, hidden_dim] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is all this logic actually required? I would assume that you don't need to do any flattening, instead just use something like |
||
| int_w_flat = int_w_batch.flatten(0, -2) # [batch * seq_len, hidden_dim] | ||
| mask_flat = mask.flatten() # [batch * seq_len] | ||
|
|
||
| # Only compute loss on valid (mask==1) positions | ||
| valid_mask = (mask_flat == 1) | ||
| if valid_mask.any(): | ||
| # Extract only the valid tokens using boolean indexing | ||
| fp16_valid = fp16_flat[valid_mask] # [num_valid, hidden_dim] | ||
| int_w_valid = int_w_flat[valid_mask] # [num_valid, hidden_dim] | ||
|
|
||
| # Compute MSE loss on valid tokens only | ||
|
|
||
| else: | ||
| # No valid tokens, skip this batch | ||
| logger.warning( | ||
| f"No valid tokens (mask==1) found in batch {batch_idx} " | ||
| "during MSE loss computation" | ||
| ) | ||
| else: | ||
| fp16_valid = fp16_batch | ||
| int_w_valid = int_w_batch | ||
|
|
||
| loss += torch.nn.functional.mse_loss( | ||
| fp16_batch, int_w_batch.to(fp16_batch.device), reduction="sum" | ||
| fp16_valid, int_w_valid, reduction="sum" | ||
| ).item() | ||
| num_elements += fp16_batch.numel() | ||
| num_elements += fp16_valid.numel() | ||
|
|
||
| # Normalize the loss by the total number of elements | ||
| loss /= num_elements | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import contextlib | ||
| from contextvars import ContextVar | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| import torch | ||
|
|
@@ -24,7 +25,12 @@ | |
| if TYPE_CHECKING: | ||
| from llmcompressor.args.dataset_arguments import DatasetArguments | ||
|
|
||
| __all__ = ["SequentialPipeline"] | ||
| __all__ = ["SequentialPipeline", "_current_loss_mask"] | ||
|
|
||
| # Context variable to store the current batch's loss_mask for hooks to access | ||
| _current_loss_mask: ContextVar[torch.Tensor | None] = ContextVar( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think, to better work in the LLM Compressor framework, we should store this variable on the |
||
| "_current_loss_mask", default=None | ||
| ) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do the loss masks change for sample to sample or are they largely constant, i think this approach is fine if they change a lot but we could potentially do something different where we just alter the AWQ modifier to take the loss mask into accound directly if they tend to be constant. I assume the chat template is usually going to be pretty consistent so that may make more sense. also wondering if the loss mask is usually a step function with a single edge, may make more sense to not store the entire mask and just the edge. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we'd like to generalize masks to be fully expressive, to include things like padding tokens. This shouldn't be too much memory, just num_samples * seq_len * bool ~= 1mb, or 8b if you don't want to offload and instead keep as a tensor. |
||
|
|
||
|
|
||
| @CalibrationPipeline.register("sequential") | ||
|
|
@@ -104,7 +110,17 @@ def __call__( | |
| # do a preliminary pass to trigger modifier hooks | ||
| for batch_idx in tqdm(range(len(dataloader)), desc=calib_desc): | ||
| inputs = activations.fetch(batch_idx, subgraph.input_names) | ||
|
|
||
| # Set loss_mask in context variable if enabled, so hooks can access it | ||
| if dataset_args.use_loss_mask: | ||
| loss_mask_dict = activations.fetch(batch_idx, ["loss_mask"]) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How did the "loss_mask" argument end up in the activations cache? It's probably better if we implement a
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. activations = IntermediatesCache.from_dataloader( |
||
| _current_loss_mask.set(loss_mask_dict.get("loss_mask")) | ||
|
|
||
| subgraph.forward(model, **inputs) | ||
|
|
||
| # Clear the context variable after forward pass | ||
| if dataset_args.use_loss_mask: | ||
| _current_loss_mask.set(None) | ||
|
|
||
| LifecycleCallbacks.sequential_epoch_end(subgraph) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be none if the user isn't using loss masks?