diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 2934fc5ac36..58ae51436d4 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -799,9 +799,18 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_weight[expert_id][start : start + shard_size, :], layer.w13_weight_scale[expert_id][shard_id], ) - layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( - ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) - ) + if _is_cuda: + ( + layer.w13_weight[expert_id][start : start + shard_size, :], + _, + ) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + else: + ( + layer.w13_weight[expert_id][start : start + shard_size, :], + _, + ) = vllm_ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id] + ) start += shard_size layer.w13_weight_scale = torch.nn.Parameter( diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index bc3813e48ba..fa69a1ffad0 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -15,6 +15,13 @@ is_hip, ) +try: + import vllm + + VLLM_AVAILABLE = True +except ImportError: + VLLM_AVAILABLE = False + use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL") _is_hip = is_hip() @@ -27,13 +34,8 @@ from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8 - if use_vllm_cutlass_w8a8_fp8_kernel: - try: - from vllm import _custom_ops as ops - - VLLM_AVAILABLE = True - except ImportError: - VLLM_AVAILABLE = False + if use_vllm_cutlass_w8a8_fp8_kernel and VLLM_AVAILABLE: + from vllm import _custom_ops as ops else: from sgl_kernel import fp8_scaled_mm @@ -253,68 +255,69 @@ def apply_fp8_linear( # torch.scaled_mm supports per tensor weights + activations only # so fallback to naive if per channel or per token - per_tensor_weights = weight_scale.numel() == 1 - per_tensor_activations = x_scale.numel() == 1 - - if per_tensor_weights and per_tensor_activations: - # Fused GEMM_DQ - output = torch._scaled_mm( - qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias, - ) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] + else: + per_tensor_weights = weight_scale.numel() == 1 + per_tensor_activations = x_scale.numel() == 1 + + if per_tensor_weights and per_tensor_activations: + # Fused GEMM_DQ + output = torch._scaled_mm( + qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] - return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) - else: - # Fallback for channelwise case, where we use unfused DQ - # due to limitations with scaled_mm - - # Symmetric quantized GEMM by definition computes the following: - # C = (s_x * X) (s_w * W) + bias - # This is equivalent to dequantizing the weights and activations - # before applying a GEMM. - # - # In order to compute quantized operands, a quantized kernel - # will rewrite the above like so: - # C = s_w * s_x * (X * W) + bias - # - # For the scaled_mm fallback case, we break this down, since it - # does not support s_w being a vector. - - # Making sure the dummy tensor is on the same device as the weight - global TORCH_DEVICE_IDENTITY - if TORCH_DEVICE_IDENTITY.device != weight.device: - TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) - - # GEMM - # This computes C = (X * W). - # Output in fp32 to allow subsequent ops to happen in-place - output = torch._scaled_mm( - qinput, - weight, - scale_a=TORCH_DEVICE_IDENTITY, - scale_b=TORCH_DEVICE_IDENTITY, - out_dtype=torch.float32, - ) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, input_2d.shape[0]) - x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) - - # DQ - # C = sw * sx * (X * W) + bias - output = output * x_scale * weight_scale.t() - if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) + else: + # Fallback for channelwise case, where we use unfused DQ + # due to limitations with scaled_mm + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # Making sure the dummy tensor is on the same device as the weight + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY.device != weight.device: + TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) + + # GEMM + # This computes C = (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output = torch._scaled_mm( + qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32, + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + # Unpad (undo num_token_padding) + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) + + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * weight_scale.t() + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) diff --git a/python/sglang/srt/layers/quantization/gptq.py b/python/sglang/srt/layers/quantization/gptq.py index aecdd4cee82..2e3e8c89c9a 100644 --- a/python/sglang/srt/layers/quantization/gptq.py +++ b/python/sglang/srt/layers/quantization/gptq.py @@ -6,7 +6,6 @@ from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.quantization.utils import scalar_types from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.utils import is_cuda @@ -133,11 +132,16 @@ def get_quant_method( class GPTQMarlinConfig(QuantizationConfig): """Config class for GPTQ Marlin""" - # (num_bits, is_sym) -> quant_type - TYPE_MAP = { - (4, True): scalar_types.uint4b8, - (8, True): scalar_types.uint8b128, - } + if VLLM_AVAILABLE: + from vllm.scalar_type import scalar_types + + # (num_bits, is_sym) -> quant_type + TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + else: + raise ImportError("vllm is not installed") def __init__( self, diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index abe49e80f7f..df7f1f0982b 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -1,15 +1,19 @@ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/scalar_type.py -import functools -import struct -from dataclasses import dataclass -from enum import Enum from types import MappingProxyType -from typing import List, Mapping, Optional, Tuple, Union +from typing import List, Mapping, Tuple, Union import torch +from sglang.srt.utils import is_cuda + +_is_cuda = is_cuda() + +if _is_cuda: + from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant +else: + from vllm import _custom_ops as vllm_ops + def is_layer_skipped( prefix: str, @@ -102,341 +106,12 @@ def requantize_with_max_scale( for idx, logical_width in enumerate(logical_widths): end = start + logical_width weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx]) - weight[start:end, :], _ = ops.scaled_fp8_quant(weight_dq, max_w_scale) + if _is_cuda: + weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale) + else: + weight[start:end, :], _ = vllm_ops.scaled_fp8_quant( + weight_dq, max_w_scale + ) start = end return max_w_scale, weight - - -# Mirrors enum in `core/scalar_type.hpp` -class NanRepr(Enum): - NONE = 0 # nans are not supported - IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s - EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s - - -# This ScalarType class is a parallel implementation of the C++ ScalarType -# class found in csrc/core/scalar_type.hpp. These two classes should be kept -# in sync until the inductor fully supports custom C++ classes. -@dataclass(frozen=True) -class ScalarType: - """ - ScalarType can represent a wide range of floating point and integer - types, in particular it can be used to represent sub-byte data types - (something that torch.dtype currently does not support). It is also - capable of representing types with a bias, i.e.: - `stored_value = value + bias`, - this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias - of 8). The implementation for this class can be found in - csrc/core/scalar_type.hpp, these type signatures should be kept in sync - with that file. - """ - - exponent: int - """ - Number of bits in the exponent if this is a floating point type - (zero if this an integer type) - """ - - mantissa: int - """ - Number of bits in the mantissa if this is a floating point type, - or the number bits representing an integer excluding the sign bit if - this an integer type. - """ - - signed: bool - "If the type is signed (i.e. has a sign bit)" - - bias: int - """ - bias used to encode the values in this scalar type - (value = stored_value - bias, default 0) for example if we store the - type as an unsigned integer with a bias of 128 then the value 0 will be - stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. - """ - - _finite_values_only: bool = False - """ - Private: if infs are supported, used `has_infs()` instead. - """ - - nan_repr: NanRepr = NanRepr.IEEE_754 - """ - How NaNs are represent in this scalar type, returns NanRepr value. - (not applicable for integer types) - """ - - def _floating_point_max_int(self) -> int: - assert ( - self.mantissa <= 52 and self.exponent <= 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - - max_mantissa = (1 << self.mantissa) - 1 - if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: - max_mantissa = max_mantissa - 1 - - max_exponent = (1 << self.exponent) - 2 - if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE: - assert ( - self.exponent < 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - max_exponent = max_exponent + 1 - - # adjust the exponent to match that of a double - # for now we assume the exponent bias is the standard 2^(e-1) -1, (where - # e is the exponent bits), there is some precedent for non-standard - # biases, example `float8_e4m3b11fnuz` here: - # https://github.com/jax-ml/ml_dtypes but to avoid premature over - # complication we are just assuming the standard exponent bias until - # there is a need to support non-standard biases - exponent_bias = (1 << (self.exponent - 1)) - 1 - exponent_bias_double = (1 << 10) - 1 # double e = 11 - - max_exponent_double = max_exponent - exponent_bias + exponent_bias_double - - # shift the mantissa and exponent into the proper positions for an - # IEEE double and bitwise-or them together. - return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52) - - def _floating_point_max(self) -> float: - double_raw = self._floating_point_max_int() - return struct.unpack("!d", struct.pack("!Q", double_raw))[0] - - def _raw_max(self) -> Union[int, float]: - if self.is_floating_point(): - return self._floating_point_max() - else: - assert ( - self.size_bits < 64 or self.size_bits == 64 and self.is_signed() - ), "Cannot represent max as an int" - return (1 << self.mantissa) - 1 - - def _raw_min(self) -> Union[int, float]: - if self.is_floating_point(): - assert ( - self.is_signed() - ), "We currently assume all floating point types are signed" - sign_bit_double = 1 << 63 - - max_raw = self._floating_point_max_int() - min_raw = max_raw | sign_bit_double - return struct.unpack("!d", struct.pack("!Q", min_raw))[0] - else: - assert ( - not self.is_signed() or self.size_bits <= 64 - ), "Cannot represent min as a int64_t" - - if self.is_signed(): - return -(1 << (self.size_bits - 1)) - else: - return 0 - - @functools.cached_property - def id(self) -> int: - """ - Convert the ScalarType to an int which can be passed to pytorch custom - ops. This layout of the int must be kept in sync with the C++ - ScalarType's from_id method. - """ - val = 0 - offset = 0 - - def or_and_advance(member, bit_width): - nonlocal val - nonlocal offset - bit_mask = (1 << bit_width) - 1 - val = val | (int(member) & bit_mask) << offset - offset = offset + bit_width - - or_and_advance(self.exponent, 8) - or_and_advance(self.mantissa, 8) - or_and_advance(self.signed, 1) - or_and_advance(self.bias, 32) - or_and_advance(self._finite_values_only, 1) - or_and_advance(self.nan_repr.value, 8) - - assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64" - - return val - - @property - def size_bits(self) -> int: - return self.exponent + self.mantissa + int(self.signed) - - def min(self) -> Union[int, float]: - """ - Min representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_min() - self.bias - - def max(self) -> Union[int, float]: - """ - Max representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_max() - self.bias - - def is_signed(self) -> bool: - """ - If the type is signed (i.e. has a sign bit), same as `signed` - added for consistency with: - https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html - """ - return self.signed - - def is_floating_point(self) -> bool: - "If the type is a floating point type" - return self.exponent != 0 - - def is_integer(self) -> bool: - "If the type is an integer type" - return self.exponent == 0 - - def has_bias(self) -> bool: - "If the type has a non-zero bias" - return self.bias != 0 - - def has_infs(self) -> bool: - "If the type is floating point and supports infinity" - return not self._finite_values_only - - def has_nans(self) -> bool: - return self.nan_repr != NanRepr.NONE.value - - def is_ieee_754(self) -> bool: - """ - If the type is a floating point type that follows IEEE 754 - conventions - """ - return self.nan_repr == NanRepr.IEEE_754.value and not self._finite_values_only - - def __str__(self) -> str: - """ - naming generally follows: https://github.com/jax-ml/ml_dtypes - for floating point types (leading f) the scheme is: - `float_em[flags]` - flags: - - no-flags: means it follows IEEE 754 conventions - - f: means finite values only (no infinities) - - n: means nans are supported (non-standard encoding) - for integer types the scheme is: - `[u]int[b]` - - if bias is not present it means its zero - """ - if self.is_floating_point(): - ret = ( - "float" - + str(self.size_bits) - + "_e" - + str(self.exponent) - + "m" - + str(self.mantissa) - ) - - if not self.is_ieee_754(): - if self._finite_values_only: - ret = ret + "f" - if self.nan_repr != NanRepr.NONE: - ret = ret + "n" - - return ret - else: - ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) - if self.has_bias(): - ret = ret + "b" + str(self.bias) - return ret - - def __repr__(self) -> str: - return "ScalarType." + self.__str__() - - # __len__ needs to be defined (and has to throw TypeError) for pytorch's - # opcheck to work. - def __len__(self) -> int: - raise TypeError - - # - # Convenience Constructors - # - - @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": - "Create a signed integer scalar type (size_bits includes sign-bit)." - ret = cls(0, size_bits - 1, True, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": - """Create a unsigned integer scalar type.""" - ret = cls(0, size_bits, False, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType": - """ - Create a standard floating point type - (i.e. follows IEEE 754 conventions). - """ - assert mantissa > 0 and exponent > 0 - ret = cls(exponent, mantissa, True, 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_( - cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr - ) -> "ScalarType": - """ - Create a non-standard floating point type - (i.e. does not follow IEEE 754 conventions). - """ - assert mantissa > 0 and exponent > 0 - assert nan_repr != NanRepr.IEEE_754, ( - "use `float_IEEE754` constructor for floating point types that " - "follow IEEE 754 conventions" - ) - ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) - ret.id # noqa B018: make sure the id is cached - return ret - - -# naming generally follows: https://github.com/jax-ml/ml_dtypes -# for floating point types (leading f) the scheme is: -# `float_em[flags]` -# flags: -# - no-flags: means it follows IEEE 754 conventions -# - f: means finite values only (no infinities) -# - n: means nans are supported (non-standard encoding) -# for integer types the scheme is: -# `[u]int[b]` -# - if bias is not present it means its zero - - -class scalar_types: - int4 = ScalarType.int_(4, None) - uint4 = ScalarType.uint(4, None) - int8 = ScalarType.int_(8, None) - uint8 = ScalarType.uint(8, None) - float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) - float8_e5m2 = ScalarType.float_IEEE754(5, 2) - float16_e8m7 = ScalarType.float_IEEE754(8, 7) - float16_e5m10 = ScalarType.float_IEEE754(5, 10) - - # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main - float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) - - # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf - float4_e2m1fn = ScalarType.float_(2, 1, True, NanRepr.NONE) - - # "gptq" types - uint2b2 = ScalarType.uint(2, 2) - uint3b4 = ScalarType.uint(3, 4) - uint4b8 = ScalarType.uint(4, 8) - uint8b128 = ScalarType.uint(8, 128) - - # colloquial names - bfloat16 = float16_e8m7 - float16 = float16_e5m10 diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index 98a569624cd..6933eeddfcf 100755 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -27,3 +27,5 @@ pip install cuda-python nvidia-cuda-nvrtc-cu12 # For DeepSeek-VL2 pip install timm + +pip install sgl-kernel==0.0.5.post3 --force-reinstall