Skip to content

Awq masking#1

Open
ZewenShen-Cohere wants to merge 2 commits intoawq_bugfixfrom
awq_masking
Open

Awq masking#1
ZewenShen-Cohere wants to merge 2 commits intoawq_bugfixfrom
awq_masking

Conversation

@ZewenShen-Cohere
Copy link
Copy Markdown
Owner

SUMMARY:
"please provide a brief summary"

TEST PLAN:
"please outline how the changes were tested"

# Cache loss_mask for each parent module, one mask per batch
_loss_masks: list[torch.Tensor | None] = PrivateAttr(
default_factory=list
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be none if the user isn't using loss masks?

num_elements = 0

# Compute the MSE loss for each batch
for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs):
Copy link
Copy Markdown

@HDCharles HDCharles Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see changes in vllm-project#2188 which will land soon

i suspect it will make more sense to apply the mask in run_samples and the concatenated fp16_output calculation rather than the loss calculation if possible

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll make that change

# Context variable to store the current batch's loss_mask for hooks to access
_current_loss_mask: ContextVar[torch.Tensor | None] = ContextVar(
"_current_loss_mask", default=None
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do the loss masks change for sample to sample or are they largely constant, i think this approach is fine if they change a lot but we could potentially do something different where we just alter the AWQ modifier to take the loss mask into accound directly if they tend to be constant.

I assume the chat template is usually going to be pretty consistent so that may make more sense.

also wondering if the loss mask is usually a step function with a single edge, may make more sense to not store the entire mask and just the edge.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd like to generalize masks to be fully expressive, to include things like padding tokens. This shouldn't be too much memory, just num_samples * seq_len * bool ~= 1mb, or 8b if you don't want to offload and instead keep as a tensor.

# Context variable to store the current batch's loss_mask for hooks to access
_current_loss_mask: ContextVar[torch.Tensor | None] = ContextVar(
"_current_loss_mask", default=None
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd like to generalize masks to be fully expressive, to include things like padding tokens. This shouldn't be too much memory, just num_samples * seq_len * bool ~= 1mb, or 8b if you don't want to offload and instead keep as a tensor.

__all__ = ["SequentialPipeline", "_current_loss_mask"]

# Context variable to store the current batch's loss_mask for hooks to access
_current_loss_mask: ContextVar[torch.Tensor | None] = ContextVar(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, to better work in the LLM Compressor framework, we should store this variable on the State


# Set loss_mask in context variable if enabled, so hooks can access it
if dataset_args.use_loss_mask:
loss_mask_dict = activations.fetch(batch_idx, ["loss_mask"])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did the "loss_mask" argument end up in the activations cache?

It's probably better if we implement a calculate_token_mask which gets called just once per batch with the model inputs (so it can use things like the attention mask). This way, we also don't need to continuously offload/onload the values.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/ZewenShen-Cohere/llm-compressor-fork/pull/1/changes#diff-7fa7c4bb4a7a6087e1af538b307f86d166cff0365bfe9977f9512fd8777df0a4L93

activations = IntermediatesCache.from_dataloader(
dataloader, model_device, offload_device=offload_device
)
will automatically saves all the columns from the dataloader

# mask shape: [batch, seq_len]
# output shape: [batch, seq_len, hidden_dim]
# Flatten both to [batch * seq_len, hidden_dim] and [batch * seq_len]
fp16_flat = fp16_batch.flatten(0, -2) # [batch * seq_len, hidden_dim]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is all this logic actually required? I would assume that you don't need to do any flattening, instead just use something like masked_scatter.

@kylesayrs
Copy link
Copy Markdown

My big points are
Use State rather than a context var, and also integrate with basic pipeline
Calculate loss mask once (as soon as it comes out of the data_loader)
Avoid flattening if possible using something like masked_scatter

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants