Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
CompressedTensorsMxInt4MoE,
CompressedTensorsW4A4Fp4,
CompressedTensorsW4A4Nvfp4MoE,
CompressedTensorsW4A16Fp4,
CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Fp8MoE,
CompressedTensorsW8A8Int8,
Expand Down Expand Up @@ -491,6 +492,17 @@ def _is_fp4a4_nvfp4(
and is_symmetric
)

def _is_nvfp4_w4a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool:
if weight_quant is None or input_quant is not None:
return False
is_fp4 = (
weight_quant.num_bits == 4 and weight_quant.type == QuantizationType.FLOAT
)
is_symmetric = weight_quant.symmetric
is_static = not weight_quant.dynamic
is_tensor_group = weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP
return is_fp4 and is_symmetric and is_static and is_tensor_group

def _is_wNa16_group_channel(
self, weight_quant: BaseModel, input_quant: BaseModel
) -> bool:
Expand Down Expand Up @@ -561,6 +573,17 @@ def _get_scheme_from_parts(
"Other method (CompressedTensorsW4A16Sparse24) is not supported now"
)

if self._is_nvfp4_w4a16(weight_quant, input_quant):
is_supported = self._check_scheme_supported(
CompressedTensorsW4A16Fp4.get_min_capability(), error=False
)
if is_supported:
return CompressedTensorsW4A16Fp4()
else:
raise NotImplementedError(
"Current platform does not support w4a16 nvfp4 quantization."
)

if is_activation_quantization_format(self.quant_format):
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
is_fp4a4_nvfp4_supported = self._check_scheme_supported(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
from .compressed_tensors_w4a4_nvfp4_moe import CompressedTensorsW4A4Nvfp4MoE
from .compressed_tensors_w4a8_int8_moe import NPUCompressedTensorsW4A8Int8DynamicMoE
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_fp8_moe import CompressedTensorsW8A8Fp8MoE
from .compressed_tensors_w8a8_int8 import (
Expand Down Expand Up @@ -38,6 +39,7 @@
"NPUCompressedTensorsW4A16Int4DynamicMoE",
"WNA16_SUPPORTED_BITS",
"CompressedTensorsW4A4Fp4",
"CompressedTensorsW4A16Fp4",
"CompressedTensorsW4A4Nvfp4MoE",
"NPUCompressedTensorsW4A8Int8DynamicMoE",
"CompressedTensorsMxInt4MoE",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
from collections.abc import Callable
from typing import Optional

import torch
from torch.nn.parameter import Parameter

from sglang.srt.layers.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsLinearScheme,
)
from sglang.srt.layers.quantization.fp4_utils import get_fp4_gemm_runner_backend
from sglang.srt.layers.quantization.modelopt_quant import (
enable_flashinfer_fp4_gemm,
fp4_gemm,
fp4_quantize,
)
from sglang.srt.layers.quantization.utils import swizzle_blockscale

logger = logging.getLogger(__name__)

__all__ = ["CompressedTensorsW4A16Fp4"]


class CompressedTensorsW4A16Fp4(CompressedTensorsLinearScheme):
"""weight-only NVFP4 quantization (w4a16)."""

def __init__(self):
self.group_size = 16

@classmethod
def get_min_capability(cls) -> int:
return 100

def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition

weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_packed", weight)

weight_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_global_scale", weight_global_scale)

weight_scale = GroupQuantScaleParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // self.group_size,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)

def process_weights_after_loading(self, layer) -> None:
weight_gs = layer.weight_global_scale.max().to(torch.float32)
input_gs = (1.0 / weight_gs).to(torch.float32)
layer.input_global_scale = Parameter(input_gs, requires_grad=False)
layer.weight_global_scale = Parameter(weight_gs, requires_grad=False)
Comment on lines +87 to +90
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Synthesizing a static activation scale (input_gs) from the weight scale (weight_gs) is mathematically incorrect for w4a16 (weight-only) quantization. Since the checkpoint does not provide activation scales, the model should ideally use dynamic quantization for activations (calculating the scale from the input x at runtime) to maintain accuracy. Using a fixed scale derived from weights will likely result in poor model performance.


if get_fp4_gemm_runner_backend().is_flashinfer_trtllm():
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a

weight = layer.weight_packed.data
weight_scale = layer.weight_scale.data

epilogue_tile_m = 128
weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
weight_scale = (
shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
.reshape(weight_scale.shape)
.view(torch.float8_e4m3fn)
)

layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight_packed = Parameter(weight, requires_grad=False)
Comment on lines +92 to +107
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This section is missing the necessary padding logic for the flashinfer_trtllm backend. The FP4 kernels require specific alignments (e.g., N dimension must be a multiple of 128, and K dimension must be a multiple of 32). Without padding, layers with non-aligned dimensions will cause kernel failures or incorrect results. Please refer to the padding implementation in ModelOptFp4LinearMethod.process_weights_after_loading within modelopt_quant.py and apply similar logic here.

else:
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
layer.weight_packed = Parameter(
layer.weight_packed.data, requires_grad=False
)

layer.alpha = Parameter(
1.0 / (input_gs * weight_gs),
requires_grad=False,
)

def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
output_dtype = x.dtype
w_n, _ = layer.weight_packed.shape
output_shape = [x.shape[0], w_n]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The output_shape calculation assumes a 2D input tensor. If the input x is 3D (e.g., [batch, seq, hidden]), this will lead to an incorrect shape and a runtime error during the view operation. Using x.shape[:-1] ensures compatibility with both 2D and 3D inputs.

Suggested change
output_shape = [x.shape[0], w_n]
output_shape = list(x.shape[:-1]) + [w_n]


x_fp4, x_blockscale = fp4_quantize(x, layer.input_global_scale)

assert x_fp4.dtype == torch.uint8
assert layer.weight_packed.dtype == torch.uint8
assert layer.weight_scale.dtype == torch.float8_e4m3fn
assert layer.alpha.dtype == torch.float32

w = layer.weight_packed
w_blockscale = layer.weight_scale
if (
enable_flashinfer_fp4_gemm
and not get_fp4_gemm_runner_backend().is_cutlass()
):
w = layer.weight_packed.T
w_blockscale = layer.weight_scale.T

out = fp4_gemm(
x_fp4,
w,
x_blockscale,
w_blockscale,
layer.alpha,
output_dtype,
w_n,
)
Comment on lines +146 to +154
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fp4_gemm call is missing activation padding and output slicing. If the weights are padded to meet alignment requirements (as noted in the process_weights_after_loading feedback), the activations must be padded in the K-dimension to match, and the resulting output must be sliced to remove the N-dimension padding. See ModelOptFp4LinearMethod.apply for reference.

if bias is not None:
out = out + bias
return out.view(*output_shape)
Loading