Conversation
| # Cache loss_mask for each parent module, one mask per batch | ||
| _loss_masks: list[torch.Tensor | None] = PrivateAttr( | ||
| default_factory=list | ||
| ) |
There was a problem hiding this comment.
should this be none if the user isn't using loss masks?
| num_elements = 0 | ||
|
|
||
| # Compute the MSE loss for each batch | ||
| for fp16_batch, int_w_batch in zip(fp16_outputs, int_w_outputs): |
There was a problem hiding this comment.
see changes in vllm-project#2188 which will land soon
i suspect it will make more sense to apply the mask in run_samples and the concatenated fp16_output calculation rather than the loss calculation if possible
There was a problem hiding this comment.
Sure, I'll make that change
| # Context variable to store the current batch's loss_mask for hooks to access | ||
| _current_loss_mask: ContextVar[torch.Tensor | None] = ContextVar( | ||
| "_current_loss_mask", default=None | ||
| ) |
There was a problem hiding this comment.
do the loss masks change for sample to sample or are they largely constant, i think this approach is fine if they change a lot but we could potentially do something different where we just alter the AWQ modifier to take the loss mask into accound directly if they tend to be constant.
I assume the chat template is usually going to be pretty consistent so that may make more sense.
also wondering if the loss mask is usually a step function with a single edge, may make more sense to not store the entire mask and just the edge.
There was a problem hiding this comment.
I think we'd like to generalize masks to be fully expressive, to include things like padding tokens. This shouldn't be too much memory, just num_samples * seq_len * bool ~= 1mb, or 8b if you don't want to offload and instead keep as a tensor.
| # Context variable to store the current batch's loss_mask for hooks to access | ||
| _current_loss_mask: ContextVar[torch.Tensor | None] = ContextVar( | ||
| "_current_loss_mask", default=None | ||
| ) |
There was a problem hiding this comment.
I think we'd like to generalize masks to be fully expressive, to include things like padding tokens. This shouldn't be too much memory, just num_samples * seq_len * bool ~= 1mb, or 8b if you don't want to offload and instead keep as a tensor.
| __all__ = ["SequentialPipeline", "_current_loss_mask"] | ||
|
|
||
| # Context variable to store the current batch's loss_mask for hooks to access | ||
| _current_loss_mask: ContextVar[torch.Tensor | None] = ContextVar( |
There was a problem hiding this comment.
I think, to better work in the LLM Compressor framework, we should store this variable on the State
|
|
||
| # Set loss_mask in context variable if enabled, so hooks can access it | ||
| if dataset_args.use_loss_mask: | ||
| loss_mask_dict = activations.fetch(batch_idx, ["loss_mask"]) |
There was a problem hiding this comment.
How did the "loss_mask" argument end up in the activations cache?
It's probably better if we implement a calculate_token_mask which gets called just once per batch with the model inputs (so it can use things like the attention mask). This way, we also don't need to continuously offload/onload the values.
There was a problem hiding this comment.
activations = IntermediatesCache.from_dataloader(
dataloader, model_device, offload_device=offload_device
)
will automatically saves all the columns from the dataloader
| # mask shape: [batch, seq_len] | ||
| # output shape: [batch, seq_len, hidden_dim] | ||
| # Flatten both to [batch * seq_len, hidden_dim] and [batch * seq_len] | ||
| fp16_flat = fp16_batch.flatten(0, -2) # [batch * seq_len, hidden_dim] |
There was a problem hiding this comment.
Is all this logic actually required? I would assume that you don't need to do any flattening, instead just use something like masked_scatter.
|
My big points are |
SUMMARY:
"please provide a brief summary"
TEST PLAN:
"please outline how the changes were tested"