diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index c60e09be495..67790db455e 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -33,6 +33,7 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme, CompressedTensorsW8A8Fp8, + CompressedTensorsW8A16Fp8, ) from sglang.srt.layers.quantization.compressed_tensors.utils import ( find_matched_target, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py index fafed717c62..c9457531675 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py @@ -2,8 +2,10 @@ from .compressed_tensors_scheme import CompressedTensorsScheme from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 +from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 __all__ = [ "CompressedTensorsScheme", "CompressedTensorsW8A8Fp8", + "CompressedTensorsW8A16Fp8", ] diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py new file mode 100644 index 00000000000..fa7d77f2832 --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py @@ -0,0 +1,153 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, List, Optional + +import torch +from compressed_tensors.quantization import QuantizationStrategy + +from sglang.srt.layers.parameter import ( + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter, +) +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme, +) +from sglang.srt.layers.quantization.utils import convert_to_channelwise + +try: + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, + prepare_fp8_layer_for_marlin, + ) + + MARLIN_FP8_AVAILABLE = True +except ImportError: + MARLIN_FP8_AVAILABLE = False + + def apply_fp8_marlin_linear(*args, **kwargs): + raise ImportError("vllm is not installed") + + def prepare_fp8_layer_for_marlin(*args, **kwargs): + raise ImportError("vllm is not installed") + + +__all__ = ["CompressedTensorsW8A16Fp8"] + +SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR] + + +class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): + def __init__(self, strategy: str, is_static_input_scheme: bool): + self.strategy = strategy + self.is_static_input_scheme = is_static_input_scheme + + if not MARLIN_FP8_AVAILABLE: + raise ImportError( + "vllm is not installed. To use CompressedTensorsW8A16Fp8, please install vllm" + ) + + @classmethod + def get_min_capability(cls) -> int: + # ampere and up + return 80 + + # W8A8-Fp8 kernels support only per-tensor and per-channel cases. + # So if we have a fused module (QKV, MLP) with per tensor scales, + # we expand each scale to its shard's channels. + def process_weights_after_loading(self, layer) -> None: + if self.strategy == QuantizationStrategy.TENSOR: + ws_channelwise = convert_to_channelwise( + layer.weight_scale, layer.logical_widths + ) + layer.weight_scale = torch.nn.Parameter(ws_channelwise, requires_grad=False) + else: + # required by torch.compile to be torch.nn.Parameter + layer.weight_scale = torch.nn.Parameter( + layer.weight_scale.data, requires_grad=False + ) + + # Weights must be transposed for marlin + layer.weight = torch.nn.Parameter(layer.weight.t(), requires_grad=False) + + if self.is_static_input_scheme: + # required by torch.compile to be torch.nn.Parameter + layer.input_scale = torch.nn.Parameter( + layer.input_scale.data, requires_grad=False + ) + prepare_fp8_layer_for_marlin(layer, strategy="channel") + + def create_weights( + self, + layer: torch.nn.Module, + input_size: int, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, + weight_loader: Callable, + **kwargs, + ): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + elif self.strategy == QuantizationStrategy.TENSOR: + weight_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + else: + raise ValueError( + f"Unsupported weight strategy={self.strategy}, " + f"supported strategies are {SUPPORTED_STRATEGIES}" + ) + + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE (to deal with converted checkpoints) + if self.is_static_input_scheme: + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("input_scale", input_scale) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + )