Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,15 @@ class DatasetArguments(CustomDatasetArguments):
"Default is set to True."
},
)
use_loss_mask: bool = field(
default=False,
metadata={
"help": "Whether to use a loss mask from the dataset. When True, expects "
"the dataset to contain a 'loss_mask' field that indicates which tokens "
"should be included in loss calculations (e.g., for masking out prompts "
"and only computing loss on generated tokens). Default is False."
},
)

def is_dataset_provided(self) -> bool:
return self.dataset is not None or self.dataset_path is not None
5 changes: 5 additions & 0 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def oneshot(
min_tokens_per_module: float | None = None,
moe_calibrate_all_experts: bool = True,
quantization_aware_calibration: bool = True,
use_loss_mask: bool = False,
# Miscellaneous arguments
output_dir: str | None = None,
log_dir: str | None = None,
Expand Down Expand Up @@ -339,6 +340,10 @@ def oneshot(
calibration in the sequential pipeline. When True, quantization is applied
during forward pass in calibration. When False, quantization is disabled
during forward pass in calibration. Default is set to True.
:param use_loss_mask: Whether to use a loss mask from the dataset. When True,
expects the dataset to contain a 'loss_mask' field that indicates which
tokens should be included in loss calculations (e.g., for masking out
prompts and only computing loss on generated tokens). Default is False.

# Miscellaneous arguments
:param output_dir: Path to save the output model after calibration.
Expand Down
86 changes: 81 additions & 5 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.observers.base import Observer
from llmcompressor.pipelines.cache import IntermediatesCache
from llmcompressor.pipelines.sequential.pipeline import _current_loss_mask
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
from llmcompressor.utils.helpers import calibration_forward_context
from llmcompressor.utils.pytorch.module import (
Expand Down Expand Up @@ -162,6 +163,10 @@ class AWQModifier(Modifier, QuantizationMixin):
_parent_args_cache: dict[Module, IntermediatesCache] = PrivateAttr(
default_factory=dict
)
# Cache loss_mask for each parent module, one mask per batch
_loss_masks: list[torch.Tensor | None] = PrivateAttr(
default_factory=list
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be none if the user isn't using loss masks?

# Dict[smooth layer name, (activation means, activation counts)]
_smooth_activation_means: dict[str, tuple[torch.FloatTensor, int]] = PrivateAttr(
default_factory=dict
Expand Down Expand Up @@ -269,6 +274,7 @@ def on_finalize(self, state: State, **kwargs) -> bool:
self.on_end(state, None)

self._parent_args_cache.clear()
self._loss_masks = []
self._smooth_activation_means.clear()
self._resolved_mappings.clear()

Expand Down Expand Up @@ -382,15 +388,48 @@ def cache_parent_kwargs_hook(
values = inspect.signature(module.forward).bind(*args, **kwargs)
self._parent_args_cache[module].append(values.arguments)

loss_mask = _current_loss_mask.get()
self._loss_masks.append(loss_mask)


def create_cache_smooth_activations_hook_fn(smooth_name):
def cache_smooth_activations_hook(
_module: Module,
args: tuple[torch.Tensor, ...],
_output: torch.Tensor,
):
# Get activation tensor: shape [batch, seq_len, hidden_dim]
activations = args[0].abs().detach()

# Flatten activations: [batch * seq_len, hidden_dim]
flat_activations = activations.flatten(0, -2)

# Try to get loss_mask from context variable
loss_mask = _current_loss_mask.get()

if loss_mask is not None:
# loss_mask shape: [batch, seq_len]
# Flatten to [batch * seq_len]

flat_mask = loss_mask.flatten().to(activations.device)

# Filter: only keep rows where mask == 1
valid_mask = flat_mask > 0
if valid_mask.any():
flat_activations = flat_activations[valid_mask]
else:
# No valid tokens in this batch, use zeros
logger.warning(
f"No valid tokens (mask==1) found for {smooth_name} in current batch"
)
flat_activations = torch.zeros(
(1, activations.shape[-1]),
device=activations.device,
dtype=activations.dtype
)

act_mean, count = _accumulate_mean(
args[0].abs().detach().flatten(0, -2),
flat_activations,
self._smooth_activation_means.get(smooth_name, None),
)
self._smooth_activation_means[smooth_name] = (act_mean.cpu(), count)
Expand All @@ -415,8 +454,11 @@ def cache_smooth_activations_hook(
# input activations to balance layers needed for loss function
# storing inputs to first balance layer is sufficient
# other balance layers get the same input

#(zewen): this should be improved
layer_to_hook = mapping.parent.mlp if hasattr(mapping.parent, 'mlp') else mapping.balance_layers[0]
self.register_hook(
mapping.balance_layers[0],
layer_to_hook,
create_cache_smooth_activations_hook_fn(mapping.smooth_name),
"forward",
)
Expand Down Expand Up @@ -715,11 +757,45 @@ def _compute_loss(
num_elements = 0

# Compute the MSE loss for each batch
for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs):
Copy link
Copy Markdown

@HDCharles HDCharles Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see changes in vllm-project#2188 which will land soon

i suspect it will make more sense to apply the mask in run_samples and the concatenated fp16_output calculation rather than the loss calculation if possible

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll make that change

for batch_idx, (fp16_batch, int_w_batch) in enumerate(
zip(fp16_outputs, int_w_outputs)
):
int_w_batch = int_w_batch.to(fp16_batch.device)

# Apply mask if available for this batch
if batch_idx < len(self._loss_masks) and self._loss_masks[batch_idx] is not None:
mask = self._loss_masks[batch_idx].to(fp16_batch.device)

# mask shape: [batch, seq_len]
# output shape: [batch, seq_len, hidden_dim]
# Flatten both to [batch * seq_len, hidden_dim] and [batch * seq_len]
fp16_flat = fp16_batch.flatten(0, -2) # [batch * seq_len, hidden_dim]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is all this logic actually required? I would assume that you don't need to do any flattening, instead just use something like masked_scatter.

int_w_flat = int_w_batch.flatten(0, -2) # [batch * seq_len, hidden_dim]
mask_flat = mask.flatten() # [batch * seq_len]

# Only compute loss on valid (mask==1) positions
valid_mask = (mask_flat == 1)
if valid_mask.any():
# Extract only the valid tokens using boolean indexing
fp16_valid = fp16_flat[valid_mask] # [num_valid, hidden_dim]
int_w_valid = int_w_flat[valid_mask] # [num_valid, hidden_dim]

# Compute MSE loss on valid tokens only

else:
# No valid tokens, skip this batch
logger.warning(
f"No valid tokens (mask==1) found in batch {batch_idx} "
"during MSE loss computation"
)
else:
fp16_valid = fp16_batch
int_w_valid = int_w_batch

loss += torch.nn.functional.mse_loss(
fp16_batch, int_w_batch.to(fp16_batch.device), reduction="sum"
fp16_valid, int_w_valid, reduction="sum"
).item()
num_elements += fp16_batch.numel()
num_elements += fp16_valid.numel()

# Normalize the loss by the total number of elements
loss /= num_elements
Expand Down
18 changes: 17 additions & 1 deletion src/llmcompressor/pipelines/sequential/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
from contextvars import ContextVar
from typing import TYPE_CHECKING

import torch
Expand All @@ -24,7 +25,12 @@
if TYPE_CHECKING:
from llmcompressor.args.dataset_arguments import DatasetArguments

__all__ = ["SequentialPipeline"]
__all__ = ["SequentialPipeline", "_current_loss_mask"]

# Context variable to store the current batch's loss_mask for hooks to access
_current_loss_mask: ContextVar[torch.Tensor | None] = ContextVar(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, to better work in the LLM Compressor framework, we should store this variable on the State

"_current_loss_mask", default=None
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do the loss masks change for sample to sample or are they largely constant, i think this approach is fine if they change a lot but we could potentially do something different where we just alter the AWQ modifier to take the loss mask into accound directly if they tend to be constant.

I assume the chat template is usually going to be pretty consistent so that may make more sense.

also wondering if the loss mask is usually a step function with a single edge, may make more sense to not store the entire mask and just the edge.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd like to generalize masks to be fully expressive, to include things like padding tokens. This shouldn't be too much memory, just num_samples * seq_len * bool ~= 1mb, or 8b if you don't want to offload and instead keep as a tensor.



@CalibrationPipeline.register("sequential")
Expand Down Expand Up @@ -104,7 +110,17 @@ def __call__(
# do a preliminary pass to trigger modifier hooks
for batch_idx in tqdm(range(len(dataloader)), desc=calib_desc):
inputs = activations.fetch(batch_idx, subgraph.input_names)

# Set loss_mask in context variable if enabled, so hooks can access it
if dataset_args.use_loss_mask:
loss_mask_dict = activations.fetch(batch_idx, ["loss_mask"])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did the "loss_mask" argument end up in the activations cache?

It's probably better if we implement a calculate_token_mask which gets called just once per batch with the model inputs (so it can use things like the attention mask). This way, we also don't need to continuously offload/onload the values.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/ZewenShen-Cohere/llm-compressor-fork/pull/1/changes#diff-7fa7c4bb4a7a6087e1af538b307f86d166cff0365bfe9977f9512fd8777df0a4L93

activations = IntermediatesCache.from_dataloader(
dataloader, model_device, offload_device=offload_device
)
will automatically saves all the columns from the dataloader

_current_loss_mask.set(loss_mask_dict.get("loss_mask"))

subgraph.forward(model, **inputs)

# Clear the context variable after forward pass
if dataset_args.use_loss_mask:
_current_loss_mask.set(None)

LifecycleCallbacks.sequential_epoch_end(subgraph)

Expand Down