diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 0cdb4989ec73..4050bf0453e3 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -44,6 +44,9 @@ RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.compressed_tensors import ( + compressed_tensors as ct, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, @@ -829,11 +832,20 @@ def permute_qk_weight_for_rotary( loaded_weight: torch.Tensor, ) -> tuple[str, torch.Tensor]: # Helper function to permute the weight's channels - def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool): + def permute( + w: torch.Tensor, + n_heads: int, + is_nvfp4_weight_scale: bool, + is_ct_int8_or_fp8_weight_scale: bool, + ): # Calculate the expected shape of the weight. # Do not rely on w's shape, as it may be in another layout. attn_in = self.config.head_dim * n_heads - attn_out = self.config.hidden_size + attn_out = ( + self.config.hidden_size + if not is_ct_int8_or_fp8_weight_scale + else w.shape[-1] + ) # If the weight is FP4 packed as uint8, we need to divide attn_out # by 2. @@ -844,7 +856,7 @@ def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool): # block size, which is currently 16. elif ( w.dtype == torch.float8_e4m3fn - and is_weight_scale + and is_nvfp4_weight_scale and w.shape[1] * 16 == attn_out ): attn_out = attn_out // 16 @@ -862,19 +874,31 @@ def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool): is_nvfp4_weight_scale = ( modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn ) - - if is_weight or is_nvfp4_weight_scale: + is_ct_int8_or_fp8_weight_scale = False + if modules[-1] == "weight_scale" and isinstance( + self.model.quant_config, ct.CompressedTensorsConfig + ): + from compressed_tensors import CompressionFormat + + is_ct_int8_or_fp8_weight_scale = self.model.quant_config.quant_format in [ + CompressionFormat.int_quantized.value, + CompressionFormat.float_quantized.value, + ] and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32] + + if is_weight or is_nvfp4_weight_scale or is_ct_int8_or_fp8_weight_scale: if "wk" in modules or "k_proj" in modules: loaded_weight = permute( loaded_weight, self.config.num_key_value_heads, is_nvfp4_weight_scale, + is_ct_int8_or_fp8_weight_scale, ) elif "wq" in modules or "q_proj" in modules: loaded_weight = permute( loaded_weight, self.config.num_attention_heads, is_nvfp4_weight_scale, + is_ct_int8_or_fp8_weight_scale, ) return name, loaded_weight