-
-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[Bugfix] Enable attn quantization of Llama-4 by correctly permuting scales for rope (int8, fp8) #34243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] Enable attn quantization of Llama-4 by correctly permuting scales for rope (int8, fp8) #34243
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,6 +44,9 @@ | |
| RowParallelLinear, | ||
| ) | ||
| from vllm.model_executor.layers.quantization import QuantizationConfig | ||
| from vllm.model_executor.layers.quantization.compressed_tensors import ( | ||
| compressed_tensors as ct, | ||
| ) | ||
| from vllm.model_executor.layers.rotary_embedding import get_rope | ||
| from vllm.model_executor.model_loader.weight_utils import ( | ||
| default_weight_loader, | ||
|
|
@@ -829,11 +832,20 @@ def permute_qk_weight_for_rotary( | |
| loaded_weight: torch.Tensor, | ||
| ) -> tuple[str, torch.Tensor]: | ||
| # Helper function to permute the weight's channels | ||
| def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool): | ||
| def permute( | ||
| w: torch.Tensor, | ||
| n_heads: int, | ||
| is_nvfp4_weight_scale: bool, | ||
| is_ct_int8_or_fp8_weight_scale: bool, | ||
| ): | ||
| # Calculate the expected shape of the weight. | ||
| # Do not rely on w's shape, as it may be in another layout. | ||
| attn_in = self.config.head_dim * n_heads | ||
| attn_out = self.config.hidden_size | ||
| attn_out = ( | ||
| self.config.hidden_size | ||
| if not is_ct_int8_or_fp8_weight_scale | ||
| else w.shape[-1] | ||
| ) | ||
|
|
||
| # If the weight is FP4 packed as uint8, we need to divide attn_out | ||
| # by 2. | ||
|
|
@@ -844,7 +856,7 @@ def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool): | |
| # block size, which is currently 16. | ||
| elif ( | ||
| w.dtype == torch.float8_e4m3fn | ||
| and is_weight_scale | ||
| and is_nvfp4_weight_scale | ||
| and w.shape[1] * 16 == attn_out | ||
| ): | ||
| attn_out = attn_out // 16 | ||
|
|
@@ -862,19 +874,31 @@ def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool): | |
| is_nvfp4_weight_scale = ( | ||
| modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn | ||
| ) | ||
|
|
||
| if is_weight or is_nvfp4_weight_scale: | ||
| is_ct_int8_or_fp8_weight_scale = False | ||
| if modules[-1] == "weight_scale" and isinstance( | ||
| self.model.quant_config, ct.CompressedTensorsConfig | ||
| ): | ||
| from compressed_tensors import CompressionFormat | ||
|
|
||
| is_ct_int8_or_fp8_weight_scale = self.model.quant_config.quant_format in [ | ||
| CompressionFormat.int_quantized.value, | ||
| CompressionFormat.float_quantized.value, | ||
| ] and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32] | ||
|
Comment on lines
+877
to
+886
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would like to avoid CT specific logic here. Why don't we apply to all weight scales? Is it because of concern over packed weights? I think they wouldn't work anyway
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added CT specific logic here for safety reasons. I don't know how other frameworks are shipping weight-scales |
||
|
|
||
| if is_weight or is_nvfp4_weight_scale or is_ct_int8_or_fp8_weight_scale: | ||
| if "wk" in modules or "k_proj" in modules: | ||
| loaded_weight = permute( | ||
| loaded_weight, | ||
| self.config.num_key_value_heads, | ||
| is_nvfp4_weight_scale, | ||
| is_ct_int8_or_fp8_weight_scale, | ||
| ) | ||
| elif "wq" in modules or "q_proj" in modules: | ||
| loaded_weight = permute( | ||
| loaded_weight, | ||
| self.config.num_attention_heads, | ||
| is_nvfp4_weight_scale, | ||
| is_ct_int8_or_fp8_weight_scale, | ||
| ) | ||
|
|
||
| return name, loaded_weight | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this needed, because CT transposes scales? Could we just always use the w.shape to decide?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Notice that the original implementation was:
attn_out = self.config.hidden_sizeand there is a comment above which saysDo not rely on w's shape, as it may be in another layout. I didn't want to break this by doingattn_out = w.shape[-1].There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you think this is safe to ignore, and always use
attn_out = w.shape[-1], let me know