Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions vllm/model_executor/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors import (
compressed_tensors as ct,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
Expand Down Expand Up @@ -829,11 +832,20 @@ def permute_qk_weight_for_rotary(
loaded_weight: torch.Tensor,
) -> tuple[str, torch.Tensor]:
# Helper function to permute the weight's channels
def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool):
def permute(
w: torch.Tensor,
n_heads: int,
is_nvfp4_weight_scale: bool,
is_ct_int8_or_fp8_weight_scale: bool,
):
# Calculate the expected shape of the weight.
# Do not rely on w's shape, as it may be in another layout.
attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size
attn_out = (
self.config.hidden_size
if not is_ct_int8_or_fp8_weight_scale
else w.shape[-1]
)
Comment on lines -836 to +848
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed, because CT transposes scales? Could we just always use the w.shape to decide?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice that the original implementation was: attn_out = self.config.hidden_size and there is a comment above which says Do not rely on w's shape, as it may be in another layout. I didn't want to break this by doing attn_out = w.shape[-1].

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you think this is safe to ignore, and always use attn_out = w.shape[-1], let me know


# If the weight is FP4 packed as uint8, we need to divide attn_out
# by 2.
Expand All @@ -844,7 +856,7 @@ def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool):
# block size, which is currently 16.
elif (
w.dtype == torch.float8_e4m3fn
and is_weight_scale
and is_nvfp4_weight_scale
and w.shape[1] * 16 == attn_out
):
attn_out = attn_out // 16
Expand All @@ -862,19 +874,31 @@ def permute(w: torch.Tensor, n_heads: int, is_weight_scale: bool):
is_nvfp4_weight_scale = (
modules[-1] == "weight_scale" and loaded_weight.dtype == torch.float8_e4m3fn
)

if is_weight or is_nvfp4_weight_scale:
is_ct_int8_or_fp8_weight_scale = False
if modules[-1] == "weight_scale" and isinstance(
self.model.quant_config, ct.CompressedTensorsConfig
):
from compressed_tensors import CompressionFormat

is_ct_int8_or_fp8_weight_scale = self.model.quant_config.quant_format in [
CompressionFormat.int_quantized.value,
CompressionFormat.float_quantized.value,
] and loaded_weight.dtype in [torch.float16, torch.bfloat16, torch.float32]
Comment on lines +877 to +886
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to avoid CT specific logic here. Why don't we apply to all weight scales? Is it because of concern over packed weights? I think they wouldn't work anyway

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added CT specific logic here for safety reasons. I don't know how other frameworks are shipping weight-scales


if is_weight or is_nvfp4_weight_scale or is_ct_int8_or_fp8_weight_scale:
if "wk" in modules or "k_proj" in modules:
loaded_weight = permute(
loaded_weight,
self.config.num_key_value_heads,
is_nvfp4_weight_scale,
is_ct_int8_or_fp8_weight_scale,
)
elif "wq" in modules or "q_proj" in modules:
loaded_weight = permute(
loaded_weight,
self.config.num_attention_heads,
is_nvfp4_weight_scale,
is_ct_int8_or_fp8_weight_scale,
)

return name, loaded_weight