diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 6f3c16fcfe..aff27f6265 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -247,6 +247,15 @@ class DatasetArguments(CustomDatasetArguments): "Default is set to True." }, ) + use_loss_mask: bool = field( + default=False, + metadata={ + "help": "Whether to use a loss mask from the dataset. When True, expects " + "the dataset to contain a 'loss_mask' field that indicates which tokens " + "should be included in loss calculations (e.g., for masking out prompts " + "and only computing loss on generated tokens). Default is False." + }, + ) def is_dataset_provided(self) -> bool: return self.dataset is not None or self.dataset_path is not None diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index cd0cf66280..f4bcb86b00 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -264,6 +264,7 @@ def oneshot( min_tokens_per_module: float | None = None, moe_calibrate_all_experts: bool = True, quantization_aware_calibration: bool = True, + use_loss_mask: bool = False, # Miscellaneous arguments output_dir: str | None = None, log_dir: str | None = None, @@ -339,6 +340,10 @@ def oneshot( calibration in the sequential pipeline. When True, quantization is applied during forward pass in calibration. When False, quantization is disabled during forward pass in calibration. Default is set to True. + :param use_loss_mask: Whether to use a loss mask from the dataset. When True, + expects the dataset to contain a 'loss_mask' field that indicates which + tokens should be included in loss calculations (e.g., for masking out + prompts and only computing loss on generated tokens). Default is False. # Miscellaneous arguments :param output_dir: Path to save the output model after calibration. diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index bc6c35d9b5..572035bb4b 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -40,6 +40,7 @@ from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.observers.base import Observer from llmcompressor.pipelines.cache import IntermediatesCache +from llmcompressor.pipelines.sequential.pipeline import _current_loss_mask from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import ( @@ -162,6 +163,10 @@ class AWQModifier(Modifier, QuantizationMixin): _parent_args_cache: dict[Module, IntermediatesCache] = PrivateAttr( default_factory=dict ) + # Cache loss_mask for each parent module, one mask per batch + _loss_masks: list[torch.Tensor | None] = PrivateAttr( + default_factory=list + ) # Dict[smooth layer name, (activation means, activation counts)] _smooth_activation_means: dict[str, tuple[torch.FloatTensor, int]] = PrivateAttr( default_factory=dict @@ -269,6 +274,7 @@ def on_finalize(self, state: State, **kwargs) -> bool: self.on_end(state, None) self._parent_args_cache.clear() + self._loss_masks = [] self._smooth_activation_means.clear() self._resolved_mappings.clear() @@ -382,15 +388,48 @@ def cache_parent_kwargs_hook( values = inspect.signature(module.forward).bind(*args, **kwargs) self._parent_args_cache[module].append(values.arguments) + loss_mask = _current_loss_mask.get() + self._loss_masks.append(loss_mask) + + def create_cache_smooth_activations_hook_fn(smooth_name): def cache_smooth_activations_hook( _module: Module, args: tuple[torch.Tensor, ...], _output: torch.Tensor, ): + # Get activation tensor: shape [batch, seq_len, hidden_dim] + activations = args[0].abs().detach() + + # Flatten activations: [batch * seq_len, hidden_dim] + flat_activations = activations.flatten(0, -2) + + # Try to get loss_mask from context variable + loss_mask = _current_loss_mask.get() + + if loss_mask is not None: + # loss_mask shape: [batch, seq_len] + # Flatten to [batch * seq_len] + + flat_mask = loss_mask.flatten().to(activations.device) + + # Filter: only keep rows where mask == 1 + valid_mask = flat_mask > 0 + if valid_mask.any(): + flat_activations = flat_activations[valid_mask] + else: + # No valid tokens in this batch, use zeros + logger.warning( + f"No valid tokens (mask==1) found for {smooth_name} in current batch" + ) + flat_activations = torch.zeros( + (1, activations.shape[-1]), + device=activations.device, + dtype=activations.dtype + ) act_mean, count = _accumulate_mean( - args[0].abs().detach().flatten(0, -2), + flat_activations, self._smooth_activation_means.get(smooth_name, None), ) self._smooth_activation_means[smooth_name] = (act_mean.cpu(), count) @@ -415,8 +454,11 @@ def cache_smooth_activations_hook( # input activations to balance layers needed for loss function # storing inputs to first balance layer is sufficient # other balance layers get the same input + + #(zewen): this should be improved + layer_to_hook = mapping.parent.mlp if hasattr(mapping.parent, 'mlp') else mapping.balance_layers[0] self.register_hook( - mapping.balance_layers[0], + layer_to_hook, create_cache_smooth_activations_hook_fn(mapping.smooth_name), "forward", ) @@ -715,11 +757,45 @@ def _compute_loss( num_elements = 0 # Compute the MSE loss for each batch - for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs): + for batch_idx, (fp16_batch, int_w_batch) in enumerate( + zip(fp16_outputs, int_w_outputs) + ): + int_w_batch = int_w_batch.to(fp16_batch.device) + + # Apply mask if available for this batch + if batch_idx < len(self._loss_masks) and self._loss_masks[batch_idx] is not None: + mask = self._loss_masks[batch_idx].to(fp16_batch.device) + + # mask shape: [batch, seq_len] + # output shape: [batch, seq_len, hidden_dim] + # Flatten both to [batch * seq_len, hidden_dim] and [batch * seq_len] + fp16_flat = fp16_batch.flatten(0, -2) # [batch * seq_len, hidden_dim] + int_w_flat = int_w_batch.flatten(0, -2) # [batch * seq_len, hidden_dim] + mask_flat = mask.flatten() # [batch * seq_len] + + # Only compute loss on valid (mask==1) positions + valid_mask = (mask_flat == 1) + if valid_mask.any(): + # Extract only the valid tokens using boolean indexing + fp16_valid = fp16_flat[valid_mask] # [num_valid, hidden_dim] + int_w_valid = int_w_flat[valid_mask] # [num_valid, hidden_dim] + + # Compute MSE loss on valid tokens only + + else: + # No valid tokens, skip this batch + logger.warning( + f"No valid tokens (mask==1) found in batch {batch_idx} " + "during MSE loss computation" + ) + else: + fp16_valid = fp16_batch + int_w_valid = int_w_batch + loss += torch.nn.functional.mse_loss( - fp16_batch, int_w_batch.to(fp16_batch.device), reduction="sum" + fp16_valid, int_w_valid, reduction="sum" ).item() - num_elements += fp16_batch.numel() + num_elements += fp16_valid.numel() # Normalize the loss by the total number of elements loss /= num_elements diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 91516f280d..8a82fd352c 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -1,4 +1,5 @@ import contextlib +from contextvars import ContextVar from typing import TYPE_CHECKING import torch @@ -24,7 +25,12 @@ if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments -__all__ = ["SequentialPipeline"] +__all__ = ["SequentialPipeline", "_current_loss_mask"] + +# Context variable to store the current batch's loss_mask for hooks to access +_current_loss_mask: ContextVar[torch.Tensor | None] = ContextVar( + "_current_loss_mask", default=None +) @CalibrationPipeline.register("sequential") @@ -104,7 +110,17 @@ def __call__( # do a preliminary pass to trigger modifier hooks for batch_idx in tqdm(range(len(dataloader)), desc=calib_desc): inputs = activations.fetch(batch_idx, subgraph.input_names) + + # Set loss_mask in context variable if enabled, so hooks can access it + if dataset_args.use_loss_mask: + loss_mask_dict = activations.fetch(batch_idx, ["loss_mask"]) + _current_loss_mask.set(loss_mask_dict.get("loss_mask")) + subgraph.forward(model, **inputs) + + # Clear the context variable after forward pass + if dataset_args.use_loss_mask: + _current_loss_mask.set(None) LifecycleCallbacks.sequential_epoch_end(subgraph)