diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index 31f47b88bc2f..14f4e5f86904 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -1,4 +1,4 @@ -# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations @@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( WNA16_SUPPORTED_BITS, CompressedTensorsScheme, + CompressedTensorsW4A4Fp4, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, @@ -376,6 +377,35 @@ def _is_fp8_w8a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool # All conditions satisfied. return True + def _is_fp4a4_nvfp4( + self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs + ): + if weight_quant is None or input_quant is None: + return False + + is_tensor_group_quant = ( + weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value + and input_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value + ) + is_symmetric = weight_quant.symmetric and input_quant.symmetric + + is_group_size_16 = ( + weight_quant.group_size == 16 and input_quant.group_size == 16 + ) + is_float_type = ( + weight_quant.type == QuantizationType.FLOAT + and input_quant.type == QuantizationType.FLOAT + ) + is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4 + + return ( + is_tensor_group_quant + and is_float_type + and is_4_bits + and is_group_size_16 + and is_symmetric + ) + def _is_wNa16_group_channel( self, weight_quant: BaseModel, input_quant: BaseModel ) -> bool: @@ -411,6 +441,17 @@ def _get_scheme_from_parts( ) if is_activation_quantization_format(self.quant_format): + if self._is_fp4a4_nvfp4(weight_quant, input_quant): + is_fp4a4_nvfp4_supported = self._check_scheme_supported( + CompressedTensorsW4A4Fp4.get_min_capability(), error=False + ) + if is_fp4a4_nvfp4_supported: + return CompressedTensorsW4A4Fp4() + else: + raise NotImplementedError( + "Current platform does not support w4a4 nvfp4 quantization." + ) + if self._is_fp8_w8a8(weight_quant, input_quant): is_fp8_w8a8_supported = self._check_scheme_supported( CompressedTensorsW8A8Fp8.get_min_capability(), error=False diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 4620d4506516..b5e3964c85f4 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1,4 +1,4 @@ -# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations @@ -13,19 +13,24 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig +from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase from sglang.srt.layers.quantization.compressed_tensors.schemes import ( WNA16_SUPPORTED_BITS, ) from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant -from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz +from sglang.srt.layers.quantization.fp8_utils import ( + is_blackwell_supported, + normalize_e4m3fn_to_e4m3fnuz, +) from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales from sglang.srt.layers.quantization.utils import ( all_close_1d, per_tensor_dequantize, replace_parameter, + swizzle_blockscale, ) from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, set_weight_attrs @@ -60,6 +65,7 @@ class GPTQMarlinState(Enum): __all__ = [ "CompressedTensorsMoEMethod", + "CompressedTensorsW4A4Nvfp4MoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", "CompressedTensorsWNA16MoEMethod", ] @@ -86,7 +92,11 @@ def get_moe_method( if quant_config._is_wNa16_group_channel(weight_quant, input_quant): logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") return CompressedTensorsWNA16MoEMethod(quant_config) + elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): + logger.info_once("Using CompressedTensorsW4A4Nvfp4MoEMethod") + return CompressedTensorsW4A4Nvfp4MoEMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): + logger.info_once("Using CompressedTensorsW8A8Fp8MoEMethod") return CompressedTensorsW8A8Fp8MoEMethod(quant_config) else: raise RuntimeError( @@ -94,6 +104,239 @@ def get_moe_method( ) +class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): + + def __init__(self, quant_config: CompressedTensorsConfig): + if not is_blackwell_supported(): + raise ValueError( + "Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above." + ) + self.quant_config = quant_config + self.group_size = 16 + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + layer.num_experts = num_experts + layer.params_dtype = params_dtype + + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + requires_grad=False, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Weight Scales + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // self.group_size, + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # Weight Global Scales + w13_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w13_weight_scale_2, extra_weight_attrs) + + w2_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) + + # Input Global Scales + w13_input_scale = torch.nn.Parameter( + torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_input_global_scale", w13_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_input_global_scale", w2_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # From packed to weight + layer.w13_weight = torch.nn.Parameter( + layer.w13_weight_packed.data, requires_grad=False + ) + delattr(layer, "w13_weight_packed") + + layer.w2_weight = torch.nn.Parameter( + layer.w2_weight_packed.data, requires_grad=False + ) + delattr(layer, "w2_weight_packed") + + if not torch.allclose( + layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1] + ): + logger.warning_once( + "w1_weight_global_scale must match w3_weight_global_scale. " + "Accuracy may be affected." + ) + + # Take inverse of global scale saved to disk + layer.w13_weight_scale_2 = torch.nn.Parameter( + 1 / layer.w13_weight_global_scale[:, 0], requires_grad=False + ) + + layer.w2_weight_scale_2 = torch.nn.Parameter( + 1 / layer.w2_weight_global_scale.data, requires_grad=False + ) + + # w13 + w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to( + torch.float32 + ) + layer.g1_alphas = torch.nn.Parameter( + ((1 / w13_input_global_scale) * layer.w13_weight_scale_2), + requires_grad=False, + ) + + layer.w13_input_scale_quant = torch.nn.Parameter( + (w13_input_global_scale), requires_grad=False + ) + + # w2 + w2_input_global_scale = layer.w2_input_global_scale + + layer.g2_alphas = torch.nn.Parameter( + ((1 / w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32), + requires_grad=False, + ) + + layer.w2_input_scale_quant = torch.nn.Parameter( + (w2_input_global_scale), requires_grad=False + ) + + # swizzle weight scales + layer.w13_weight_scale = torch.nn.Parameter( + swizzle_blockscale(layer.w13_weight_scale), requires_grad=False + ) + + layer.w2_weight_scale = torch.nn.Parameter( + swizzle_blockscale(layer.w2_weight_scale), requires_grad=False + ) + + layer.cutlass_moe_params = CutlassMoEParams( + CutlassMoEType.BlockscaledFP4, + layer.w13_weight.device, + num_experts=layer.num_experts, + intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, + hidden_size=layer.w13_weight.shape[2] * 2, + ) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids + + output = cutlass_moe_fp4( + a=x, + a1_gscale=layer.w13_input_scale_quant, + w1_fp4=layer.w13_weight, + w1_blockscale=layer.w13_weight_scale, + w1_alphas=layer.g1_alphas, + a2_gscale=layer.w2_input_scale_quant, + w2_fp4=layer.w2_weight, + w2_blockscale=layer.w2_weight_scale, + w2_alphas=layer.g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + params=layer.cutlass_moe_params, + apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input, + ).to(x.dtype) + + return StandardCombineInput(hidden_states=output) + + class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): def __init__(self, quant_config: CompressedTensorsConfig): diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py index 6d9871917bbb..d5f4146e606f 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from .compressed_tensors_scheme import CompressedTensorsScheme +from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 @@ -13,4 +14,5 @@ "CompressedTensorsW8A8Int8", "CompressedTensorsWNA16", "WNA16_SUPPORTED_BITS", + "CompressedTensorsW4A4Fp4", ] diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py new file mode 100644 index 000000000000..a155b160422e --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -0,0 +1,168 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors +# SPDX-License-Identifier: Apache-2.0 +import logging +from collections.abc import Callable +from typing import Optional + +import torch +from torch.nn.parameter import Parameter + +from sglang.srt.layers.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme, +) +from sglang.srt.layers.quantization.modelopt_quant import ( + FLASHINFER_FP4_GEMM_BACKEND, + _sglang_fp4_gemm, + enable_flashinfer_fp4_gemm, + fp4_quantize, +) +from sglang.srt.layers.quantization.utils import swizzle_blockscale + +logger = logging.getLogger(__name__) + +__all__ = ["CompressedTensorsW4A4Fp4"] + + +class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): + def __init__(self): + self.group_size = 16 + + @classmethod + def get_min_capability(cls) -> int: + return 100 + + def create_weights( + self, + layer: torch.nn.Module, + output_partition_sizes: list[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + # Weight + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_packed", weight) + + # Global Weight Scale + weight_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_global_scale", weight_global_scale) + + # Per Group Weight Scale + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("weight_scale", weight_scale) + + input_global_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("input_global_scale", input_global_scale) + + def process_weights_after_loading(self, layer) -> None: + global_input_scale = layer.input_global_scale.max().to(torch.float32) + layer.input_global_scale = Parameter(global_input_scale, requires_grad=False) + + layer.weight_global_scale = Parameter( + layer.weight_global_scale.max().to(torch.float32), requires_grad=False + ) + + if FLASHINFER_FP4_GEMM_BACKEND == "trtllm": + # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. + # FlashInfer provides nvfp4_quantize to quantize + shuffle the + # layout but we use our own quantization so we have to call + # shuffles ourselves. + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a + + weight = layer.weight_packed.data + weight_scale = layer.weight_scale.data + + epilogue_tile_m = 128 + weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) + weight_scale = ( + shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m) + .reshape(weight_scale.shape) + .view(torch.float8_e4m3fn) + ) + + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.weight_packed = Parameter(weight, requires_grad=False) + else: + swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) + layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) + layer.weight_packed = Parameter( + layer.weight_packed.data, requires_grad=False + ) + + layer.alpha = Parameter( + 1 / (layer.input_global_scale * layer.weight_global_scale), + requires_grad=False, + ) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + output_dtype = x.dtype + w_n, _ = layer.weight_packed.shape + output_shape = [x.shape[0], w_n] + + # quantize BF16 or FP16 to (FP4 and interleaved block scale) + x_fp4, x_blockscale = fp4_quantize(x, layer.input_global_scale) + + assert x_fp4.dtype == torch.uint8 + assert layer.weight_packed.dtype == torch.uint8 + assert layer.weight_scale.dtype == torch.float8_e4m3fn + assert layer.alpha.dtype == torch.float32 + + w = layer.weight_packed + w_blockscale = layer.weight_scale + if enable_flashinfer_fp4_gemm: + w = layer.weight_packed.T + w_blockscale = layer.weight_scale.T + + out = _sglang_fp4_gemm( + x_fp4, + w, + x_blockscale, + w_blockscale, + layer.alpha, + output_dtype, + w_n, + ) + if bias is not None: + out = out + bias + return out.view(*output_shape) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/utils.py b/python/sglang/srt/layers/quantization/compressed_tensors/utils.py index ab01c94d51d0..33d0d05b237b 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/utils.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/utils.py @@ -14,6 +14,7 @@ def is_activation_quantization_format(format: str) -> bool: CompressionFormat.naive_quantized.value, CompressionFormat.int_quantized.value, CompressionFormat.float_quantized.value, + CompressionFormat.nvfp4_pack_quantized.value, ] return format in _ACTIVATION_QUANTIZATION_FORMATS diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 28f7ad7d707b..84406d947796 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -42,6 +42,7 @@ is_layer_skipped, per_tensor_dequantize, requantize_with_max_scale, + swizzle_blockscale, ) from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.utils.common import ( @@ -1340,7 +1341,7 @@ def create_weights( # Only use `swizzle_blockscale` for shapes, not for real content layer.w13_blockscale_swizzled = Parameter( - self.swizzle_blockscale(layer.w13_weight_scale), requires_grad=False + swizzle_blockscale(layer.w13_weight_scale), requires_grad=False ) w2_weight_scale = ModelWeightParameter( @@ -1357,7 +1358,7 @@ def create_weights( layer.register_parameter("w2_weight_scale", w2_weight_scale) layer.w2_blockscale_swizzled = Parameter( - self.swizzle_blockscale(layer.w2_weight_scale), requires_grad=False + swizzle_blockscale(layer.w2_weight_scale), requires_grad=False ) from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported @@ -1402,31 +1403,6 @@ def create_weights( w2_input_scale._sglang_require_global_experts = True layer.register_parameter("w2_input_scale", w2_input_scale) - def swizzle_blockscale(self, scale: torch.Tensor): - assert scale.dtype == torch.float8_e4m3fn - # Pad and blockwise interleave weight_scale - scale_ndim = scale.ndim - if scale.ndim == 2: - scale = scale.unsqueeze(0) - assert scale.ndim == 3 - B, M, K = scale.shape - round_up_multiple = lambda x, m: (x + m - 1) // m * m - M_padded = round_up_multiple(M, 128) - K_padded = round_up_multiple(K, 4) - padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) - padded_scale[:B, :M, :K] = scale - batches, rows, cols = padded_scale.shape - assert rows % 128 == 0 - assert cols % 4 == 0 - padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, cols // 4, 4) - swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) - swizzled_scale = swizzled_scale.contiguous().cuda() - return ( - swizzled_scale.reshape(M_padded, K_padded) - if scale_ndim == 2 - else swizzled_scale.reshape(B, M_padded, K_padded) - ) - def prepare_static_weights_for_kernel( self, # args_dequant, @@ -1701,7 +1677,7 @@ def _slice_scale(w): # CUTLASS processing - handle w13 and w2 separately # Process w13 weights - w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale) + w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale) del layer.w13_weight_scale layer.w13_blockscale_swizzled.data.copy_(w13_blockscale_swizzled) @@ -1734,13 +1710,13 @@ def _slice_scale(w): requires_grad=False, ) layer.w2_blockscale_swizzled = Parameter( - self.swizzle_blockscale(layer.w2_weight_scale), requires_grad=False + swizzle_blockscale(layer.w2_weight_scale), requires_grad=False ) layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) # Process w2 weights - w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale) + w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) del layer.w2_weight_scale layer.w2_blockscale_swizzled.data.copy_(w2_blockscale_swizzled) diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index fc81f3140660..a2da44d00b69 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -563,3 +563,32 @@ def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): g_idx.to(device=orig_device), sort_indices.to(device=orig_device), ) + + +def swizzle_blockscale(scale: torch.Tensor): + """ + Swizzle the scale tensor into a blockwise interleaved format for NVFP4 quantization. + """ + assert scale.dtype == torch.float8_e4m3fn + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return ( + swizzled_scale.reshape(M_padded, K_padded) + if scale_ndim == 2 + else swizzled_scale.reshape(B, M_padded, K_padded) + ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index aa600834225e..6400a7c87f4c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -484,6 +484,10 @@ def __init__( tp_rank=tp_rank, tp_size=tp_size, ) + if not hasattr(self.gate_up_proj, "weight"): + self.gate_up_proj.weight = getattr(self.gate_up_proj, "weight_packed") + if not hasattr(self.down_proj, "weight"): + self.down_proj.weight = getattr(self.down_proj, "weight_packed") if hidden_act != "silu": raise ValueError( f"Unsupported activation: {hidden_act}. " diff --git a/python/sglang/srt/models/mistral_large_3.py b/python/sglang/srt/models/mistral_large_3.py index fd60ef61f7c0..cbbf893b6e9d 100644 --- a/python/sglang/srt/models/mistral_large_3.py +++ b/python/sglang/srt/models/mistral_large_3.py @@ -1,5 +1,5 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/mistral_large_3.py # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable import regex as re diff --git a/python/sglang/srt/models/mistral_large_3_eagle.py b/python/sglang/srt/models/mistral_large_3_eagle.py index 08f7271fde6c..a5ce7b6aabb6 100644 --- a/python/sglang/srt/models/mistral_large_3_eagle.py +++ b/python/sglang/srt/models/mistral_large_3_eagle.py @@ -1,3 +1,5 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/mistral_large_3_eagle.py +# SPDX-License-Identifier: Apache-2.0 from typing import Optional import torch diff --git a/python/sglang/srt/utils/mistral_utils.py b/python/sglang/srt/utils/mistral_utils.py index ecce3042df02..7be1ba5bc6f0 100644 --- a/python/sglang/srt/utils/mistral_utils.py +++ b/python/sglang/srt/utils/mistral_utils.py @@ -1,5 +1,5 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/mistral.py # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json from pathlib import Path from typing import Any @@ -11,7 +11,7 @@ def adapt_config_dict( config_dict: dict[str, Any], model: str, **kwargs -) -> PretrainedConfig: +) -> tuple[dict, PretrainedConfig]: config_dict.update(kwargs) config_dict = _remap_general_mistral_args(config_dict)