From e1a57c0183f51914e065d00e87d693099221a3c0 Mon Sep 17 00:00:00 2001 From: mgoin Date: Mon, 12 May 2025 15:14:16 +0000 Subject: [PATCH 1/2] Use NVFP4 Marlin for CompressedTensorsW4A16Fp4 Signed-off-by: mgoin --- .../schemes/compressed_tensors_w4a16_nvfp4.py | 68 +++++++------------ 1 file changed, 26 insertions(+), 42 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py index f192a8164515..db8a16a53deb 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py @@ -2,13 +2,12 @@ from typing import Callable, List, Optional import torch -import torch.nn.functional as F from torch.nn.parameter import Parameter from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) -from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501 - dequantize_to_dtype) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + apply_fp4_marlin_linear, prepare_fp4_layer_for_marlin) from vllm.model_executor.parameter import (GroupQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) @@ -31,6 +30,10 @@ def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition # Weight weight = ModelWeightParameter(data=torch.empty( @@ -60,48 +63,29 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("weight_scale", weight_scale) - def swizzle_blockscale(self, scale: torch.tensor): - assert (scale.dtype == torch.float8_e4m3fn) - # Pad and blockwise interleave weight_scale - scale_ndim = scale.ndim - if scale.ndim == 2: - scale = scale.unsqueeze(0) - assert scale.ndim == 3 - B, M, K = scale.shape - round_up_multiple = lambda x, m: (x + m - 1) // m * m - M_padded = round_up_multiple(M, 128) - K_padded = round_up_multiple(K, 4) - padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) - padded_scale[:B, :M, :K] = scale - batches, rows, cols = padded_scale.shape - assert rows % 128 == 0 - assert cols % 4 == 0 - padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, - cols // 4, 4) - swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) - swizzled_scale = swizzled_scale.contiguous().cuda() - return (swizzled_scale.reshape(M, K) - if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) - def process_weights_after_loading(self, layer) -> None: - layer.weight_global_scale = Parameter( - layer.weight_global_scale.max().to(torch.float32), - requires_grad=False) - # Note: a post weight loading step but not required for the emulation - swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) - layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, - requires_grad=False) + # Process parameters for marlin repacking + + # Rename weight_packed to weight that marlin expects + layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) + del layer.weight_packed + # Rename weight_global_scale to weight_scale_2 that marlin expects + layer.weight_scale_2 = Parameter(layer.weight_global_scale.max().to( + torch.float32), + requires_grad=False) + del layer.weight_global_scale + + prepare_fp4_layer_for_marlin(layer) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - - w_fp4 = layer.weight_packed.data - w_global_scale = layer.weight_global_scale - w_blockscale = layer.weight_scale_swizzled.data - w_dq = dequantize_to_dtype(w_fp4, w_blockscale, w_global_scale, - x.dtype, x.device, self.group_size) - out = F.linear(x, w_dq) - del w_dq, w_fp4, w_global_scale, w_blockscale - return out + return apply_fp4_marlin_linear(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + weight_scale_2=layer.weight_scale_2, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) From 2154bf0cef4be02ea83dcb86ca4f730a36071593 Mon Sep 17 00:00:00 2001 From: Dipika Date: Mon, 12 May 2025 13:11:35 -0400 Subject: [PATCH 2/2] update ct global scale processing Signed-off-by: Dipika --- .../schemes/compressed_tensors_w4a16_nvfp4.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py index db8a16a53deb..caa4fe89c621 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py @@ -70,9 +70,10 @@ def process_weights_after_loading(self, layer) -> None: layer.weight = Parameter(layer.weight_packed.data, requires_grad=False) del layer.weight_packed # Rename weight_global_scale to weight_scale_2 that marlin expects - layer.weight_scale_2 = Parameter(layer.weight_global_scale.max().to( - torch.float32), - requires_grad=False) + # Note: ct stores the inverse of what is expected by the marlin kernel + layer.weight_scale_2 = Parameter( + 1 / layer.weight_global_scale.max().to(torch.float32), + requires_grad=False) del layer.weight_global_scale prepare_fp4_layer_for_marlin(layer)