diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 742262216e05..2280f4535ee7 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -27,8 +27,8 @@ from sglang.srt.layers.moe.token_dispatcher.moriep import MoriEPNormalCombineInput from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import ( - NPUCompressedTensorsW4A16Int4DynamicMoEMethod, +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + NPUCompressedTensorsW4A16Int4DynamicMoE, ) from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz @@ -377,7 +377,7 @@ def forward_npu( else: input_quant = get_bool_env_var("DEEP_NORMAL_MODE_USE_INT8_QUANT") if not input_quant and not isinstance( - self.quant_method, NPUCompressedTensorsW4A16Int4DynamicMoEMethod + self.quant_method, NPUCompressedTensorsW4A16Int4DynamicMoE ): hidden_states, hidden_states_scale = torch_npu.npu_dynamic_quant( hidden_states diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 1609b6dbb346..de8a07ab3000 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -54,8 +54,8 @@ FusedMoEMethodBase, QuantizationConfig, ) -from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import ( - CompressedTensorsMxInt4MoEMethod, +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsMxInt4MoE, ) from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod @@ -682,6 +682,8 @@ def _weight_loader_impl( # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality method = self.quant_method + if hasattr(self, "scheme"): + method = self.scheme if method.__class__.__name__ == "KTEPWrapperMethod": method = method.gpu_method @@ -690,9 +692,9 @@ def _weight_loader_impl( if ( method.__class__.__name__ in [ - "CompressedTensorsWNA16MarlinMoEMethod", - "CompressedTensorsWNA16MoEMethod", - "CompressedTensorsWNA16TritonMoEMethod", + "CompressedTensorsWNA16MarlinMoE", + "CompressedTensorsWNA16MoE", + "CompressedTensorsWNA16TritonMoE", ] ) else loaded_weight @@ -703,10 +705,10 @@ def _weight_loader_impl( # Flashinfer assumes w31 format for w13_weight. Same for the scales. if self.use_flashinfer_trtllm_moe and ( - isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) - or isinstance(self.quant_method, Fp8MoEMethod) - or isinstance(self.quant_method, UnquantizedFusedMoEMethod) - or isinstance(self.quant_method, CompressedTensorsMxInt4MoEMethod) + isinstance(method, ModelOptNvFp4FusedMoEMethod) + or isinstance(method, Fp8MoEMethod) + or isinstance(method, UnquantizedFusedMoEMethod) + or isinstance(method, CompressedTensorsMxInt4MoE) ): shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id] @@ -739,7 +741,7 @@ def _weight_loader_impl( if ( ( - "compressed" in self.quant_method.__class__.__name__.lower() + "compressed" in method.__class__.__name__.lower() or "w4afp8" in self.quant_config.get_name() ) and (param.data[expert_id] != 1).any() @@ -767,9 +769,9 @@ def _weight_loader_impl( ) return - if "ModelOpt" in self.quant_method.__class__.__name__: + if "ModelOpt" in method.__class__.__name__: # Determine per-tensor weight scale patterns based on variant - is_fp4_variant = isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) + is_fp4_variant = isinstance(method, ModelOptNvFp4FusedMoEMethod) # FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor per_tensor_conditions = ( @@ -902,13 +904,16 @@ def weight_loader_fused( # compressed-tensors checkpoints with packed weights are stored flipped # TODO: check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality + method = self.quant_method + if hasattr(self, "scheme"): + method = self.scheme loaded_weight = ( loaded_weight.t().contiguous() if ( - self.quant_method.__class__.__name__ + method.__class__.__name__ in [ - "CompressedTensorsWNA16MoEMethod", - "CompressedTensorsWNA16TritonMoEMethod", + "CompressedTensorsWNA16MoE", + "CompressedTensorsWNA16TritonMoE", ] ) else loaded_weight diff --git a/python/sglang/srt/layers/moe/kt_ep_wrapper.py b/python/sglang/srt/layers/moe/kt_ep_wrapper.py index 3d8901e8f701..128853931805 100644 --- a/python/sglang/srt/layers/moe/kt_ep_wrapper.py +++ b/python/sglang/srt/layers/moe/kt_ep_wrapper.py @@ -128,7 +128,7 @@ class KTEPWrapperMethod(FusedMoEMethodBase): Example: # Wrap any GPU method with AMX/AVX CPU expert support - gpu_method = CompressedTensorsWNA16MoEMethod(quant_config, prefix) + gpu_method = CompressedTensorsWNA16MoE(quant_config, prefix) kt_config = KTConfig(layer_idx=0, num_gpu_experts=4, ...) method = KTEPWrapperMethod(gpu_method, kt_config) """ diff --git a/python/sglang/srt/layers/quantization/base_scheme.py b/python/sglang/srt/layers/quantization/base_scheme.py new file mode 100644 index 000000000000..ee55caa3ceae --- /dev/null +++ b/python/sglang/srt/layers/quantization/base_scheme.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Optional + +import torch + +from sglang.srt.layers.moe import MoeRunnerConfig + +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import StandardDispatchOutput + +__all__ = ["BaseLinearScheme", "BaseMoEScheme"] + + +class BaseLinearScheme(ABC): + """ + Abstract class used to describe the weight creation and forward pass + of different quantization schemes. + """ + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module): + """ + Called after weight loading is complete for any cleanup that + needs to occur. + """ + raise NotImplementedError + + @abstractmethod + def apply_weights( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] + ): + """ + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. + + :param layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + :param x: input to the layer + :param bias: bias parameter + + """ + raise NotImplementedError + + +class BaseMoEScheme(ABC): + """ + Abstract class used to describe the weight creation and forward pass + of different quantization schemes. + """ + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + raise NotImplementedError + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module): + """ + Called after weight loading is complete for any cleanup that + needs to occur. + """ + raise NotImplementedError + + @abstractmethod + def apply_weights( + self, + layer: torch.nn.Module, + dispatch_output: "StandardDispatchOutput", + ): + """ + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. + + :param layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + :param x: input to the layer + :param bias: bias parameter + + """ + raise NotImplementedError diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py index 69e994acb484..4cbfed6f90e7 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py @@ -29,23 +29,31 @@ ) from pydantic import BaseModel +from sglang.srt.layers.moe import MoeRunnerConfig, get_moe_runner_backend from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 - CompressedTensorsMoEMethod, -) from sglang.srt.layers.quantization.compressed_tensors.schemes import ( WNA16_SUPPORTED_BITS, - CompressedTensorsScheme, + CompressedTensorsLinearScheme, + CompressedTensorsMoEScheme, + CompressedTensorsMxInt4MoE, CompressedTensorsW4A4Fp4, + CompressedTensorsW4A4Nvfp4MoE, CompressedTensorsW8A8Fp8, + CompressedTensorsW8A8Fp8MoE, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, CompressedTensorsWNA16, + CompressedTensorsWNA16MoE, + CompressedTensorsWNA16TritonMoE, + NPUCompressedTensorsW4A8Int8DynamicMoE, + NPUCompressedTensorsW4A16Int4DynamicMoE, NPUCompressedTensorsW8A8Int8, + NPUCompressedTensorsW8A8Int8DynamicMoE, ) from sglang.srt.layers.quantization.compressed_tensors.utils import ( find_matched_target, @@ -53,13 +61,21 @@ should_ignore_layer, ) from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod -from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod -from sglang.srt.utils import is_cuda, is_npu +from sglang.srt.layers.quantization.unquant import ( + UnquantizedFusedMoEMethod, + UnquantizedLinearMethod, +) +from sglang.srt.utils import is_cuda, is_hip, is_npu _is_cuda = is_cuda() _is_npu = is_npu() +_is_hip = is_hip() if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) from sglang.srt.models.utils import WeightsMapper logger = logging.getLogger(__name__) @@ -152,7 +168,7 @@ def get_quant_method( # This allows mixed quantization: experts with int4, linear layers with fp8 if self.linear_fp8_config is not None: return Fp8LinearMethod(self.linear_fp8_config) - scheme = self.get_scheme(layer=layer, layer_name=prefix) + scheme = self.get_linear_scheme(layer=layer, layer_name=prefix) if scheme is None: return UnquantizedLinearMethod() layer.scheme = scheme @@ -160,7 +176,16 @@ def get_quant_method( from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, FusedMoE): - return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix) + layer.scheme = self.get_moe_scheme(layer=layer, layer_name=prefix) + if layer.scheme is None: # ignored layer + use_triton_kernels = get_moe_runner_backend().is_triton_kernels() + use_flashinfer_trtllm_moe = ( + get_moe_runner_backend().is_flashinfer_trtllm() + ) + return UnquantizedFusedMoEMethod( + use_triton_kernels, use_flashinfer_trtllm_moe + ) + return CompressedTensorsFusedMoEMethod(self) return None def _add_fused_moe_to_target_scheme_map(self): @@ -514,7 +539,7 @@ def _is_dynamic_token_w4( def _get_scheme_from_parts( self, weight_quant: BaseModel, input_quant: BaseModel - ) -> CompressedTensorsScheme: + ) -> CompressedTensorsLinearScheme: # Detect If Mixed Precision if self._is_wNa16_group_channel(weight_quant, input_quant): @@ -602,9 +627,99 @@ def _get_scheme_from_parts( raise NotImplementedError("No compressed-tensors compatible scheme was found.") - def get_scheme( + def get_moe_scheme( self, layer: torch.nn.Module, layer_name: Optional[str] = None - ) -> Optional[CompressedTensorsScheme]: + ) -> Optional[CompressedTensorsMoEScheme]: + """ + compressed-tensors supports non uniform in the following way: + + targets of config_groups: There can be N config_groups which each + have a quantization scheme. Each config_group has a list of targets + which can be a full layer_name, a regex for a layer_name, or + an nn.Module name. + + Detect whether a layer_name is found in any target and + use the quantization scheme corresponding to the matched target + to select the CompressedTensorsMoEScheme used for infernece. + """ + + # FusedMoE was made by combining multiple Linears so need to + # make sure quantization config for Linear can target it + self._add_fused_moe_to_target_scheme_map() + unfused_names = [ + layer_name + proj_name + for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"] + ] + # TODO: refactor this to use expert_mapping and check all layer numbers + all_scheme_dicts = [self.get_scheme_dict(layer, name) for name in unfused_names] + scheme_dict = all_scheme_dicts[0] if all_scheme_dicts else None + + # multiple schemes found + if not all(d == scheme_dict for d in all_scheme_dicts): + raise ValueError( + "All MoE projections need to have same " + "quantization scheme but found multiple" + ) + + if scheme_dict is None: # ignored layer + return None + + weight_quant = scheme_dict.get("weights") + input_quant = scheme_dict.get("input_activations") + + if self._is_wNa16_group_channel(weight_quant, input_quant): + if not _is_npu: + if ( + self._is_mxint4a16(weight_quant, input_quant) + and get_moe_runner_backend().is_flashinfer_trtllm() + ): + logger.info_once( + "Using CompressedTensorsMxInt4MoE with flashinfer_trtllm backend" + ) + return CompressedTensorsMxInt4MoE(self) + elif _is_hip: + logger.info_once("Using CompressedTensorsWNA16TritonMoE (ROCm)") + return CompressedTensorsWNA16TritonMoE(self) + else: + logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") + return CompressedTensorsWNA16MoE(self) + else: + if ( + self._is_dynamic_token_w4(weight_quant, input_quant) + and input_quant is None + ): + logger.info_once("Using NPUCompressedTensorsW4A16Int4DynamicMoE") + return NPUCompressedTensorsW4A16Int4DynamicMoE(self) + elif self._is_fp4a4_nvfp4(weight_quant, input_quant): + logger.info_once("Using CompressedTensorsW4A4Nvfp4MoE") + return CompressedTensorsW4A4Nvfp4MoE() + elif self._is_fp8_w8a8(weight_quant, input_quant): + logger.info_once("Using CompressedTensorsW8A8Fp8MoE") + return CompressedTensorsW8A8Fp8MoE(weight_quant, input_quant) + elif self._is_dynamic_token_w8a8(weight_quant, input_quant): + if _is_npu: + logger.info_once("Using NPUCompressedTensorsW8A8Int8DynamicMoE") + return NPUCompressedTensorsW8A8Int8DynamicMoE(weight_quant, input_quant) + else: + raise NotImplementedError( + f"The W8A8Int8 Fused MoE scheme is implemented only for NPU for now." + ) + elif self._is_dynamic_token_w4a8(weight_quant, input_quant): + if _is_npu: + logger.info_once("Using NPUCompressedTensorsW4A8Int8DynamicMoE") + return NPUCompressedTensorsW4A8Int8DynamicMoE(self) + else: + raise NotImplementedError( + f"The W4A8Int8 Fused MoE scheme is implemented only for NPU for now." + ) + else: + raise RuntimeError( + f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}" + ) + + def get_linear_scheme( + self, layer: torch.nn.Module, layer_name: Optional[str] = None + ) -> Optional[CompressedTensorsLinearScheme]: """ compressed-tensors supports non uniform in the following way: @@ -842,3 +957,86 @@ def apply( if scheme is None: raise ValueError("A scheme must be defined for each layer") return scheme.apply_weights(layer, x, bias=bias) + + +class CompressedTensorsFusedMoEMethod(FusedMoEMethodBase): + + def __init__(self, quantization_config: CompressedTensorsConfig): + self.quantization_config = quantization_config + self.quant_config = quantization_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.scheme.process_weights_after_loading(layer) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """ + Use the CompressedTensorsScheme associated with each layer to create + the necessary parameters for the layer. See LinearMethodBase for param + details + """ + layer.scheme.create_weights( + layer=layer, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size_per_partition=intermediate_size_per_partition, + params_dtype=params_dtype, + **extra_weight_attrs, + ) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + return layer.scheme.create_moe_runner(layer, moe_runner_config) + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + """ + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. See LinearMethodBase for param details + + """ + + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, dispatch_output) + + def apply_weights_with_router_logits( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> torch.Tensor: + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights_with_router_logits(layer, dispatch_output) + + def apply_without_routing_weights( + self, + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ): + return layer.scheme.apply_without_routing_weights( + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py deleted file mode 100644 index 12184a4b6caf..000000000000 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ /dev/null @@ -1,2190 +0,0 @@ -# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors -# SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations - -import enum -import logging -from enum import Enum -from typing import TYPE_CHECKING - -import torch -from compressed_tensors import CompressionFormat -from compressed_tensors.quantization import QuantizationStrategy - -from sglang.srt.distributed import ( - get_moe_expert_parallel_rank, - get_tensor_model_parallel_world_size, - get_tp_group, -) -from sglang.srt.distributed.device_communicators.pynccl_allocator import ( - use_symmetric_memory, -) -from sglang.srt.hardware_backend.npu.quantization.fused_moe_method_npu import ( - NPUW4A8Int8DynamicMoEMethod, - NPUW4A16Int4DynamicMoEMethod, - NPUW8A8Int8DynamicMoEMethod, -) -from sglang.srt.layers.dp_attention import is_allocation_symmetric -from sglang.srt.layers.moe import ( - MoeRunner, - MoeRunnerBackend, - MoeRunnerConfig, - get_moe_runner_backend, -) -from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType -from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo -from sglang.srt.layers.moe.utils import RoutingMethodType -from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase -from sglang.srt.layers.quantization.compressed_tensors.schemes import ( - WNA16_SUPPORTED_BITS, -) -from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant -from sglang.srt.layers.quantization.fp8_utils import ( - is_blackwell_supported, - normalize_e4m3fn_to_e4m3fnuz, -) -from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack -from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales -from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod -from sglang.srt.layers.quantization.utils import ( - all_close_1d, - per_tensor_dequantize, - prepare_static_weights_for_trtllm_fp4_moe, - reorder_w1w3_to_w3w1, - replace_parameter, - swizzle_blockscale, -) -from sglang.srt.utils import ( - get_bool_env_var, - is_cuda, - is_flashinfer_available, - is_hip, - is_npu, - next_power_of_2, - set_weight_attrs, -) - -if TYPE_CHECKING: - from sglang.srt.layers.moe.fused_moe_triton import FusedMoE - from sglang.srt.layers.moe.token_dispatcher import ( - CombineInput, - StandardDispatchOutput, - ) - from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( - CompressedTensorsConfig, - ) - -_is_hip = is_hip() -_is_npu = is_npu() -_is_cuda = is_cuda() - -_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip - -if _use_aiter: - from aiter import ActivationType, QuantType - from aiter.fused_moe import fused_moe - from aiter.ops.shuffle import shuffle_weight - -if is_flashinfer_available(): - from flashinfer.fp4_quantization import block_scale_interleave - from flashinfer.fused_moe import ( - convert_to_block_layout, - trtllm_mxint4_block_scale_moe, - ) - from flashinfer.fused_moe.core import ( - _maybe_get_cached_w3_w1_permute_indices, - get_w2_permute_indices_with_cache, - ) - - -logger = logging.getLogger(__name__) - - -class GPTQMarlinState(Enum): - REPACK = enum.auto() - READY = enum.auto() - - -__all__ = [ - "CompressedTensorsMoEMethod", - "CompressedTensorsW4A4Nvfp4MoEMethod", - "NPUCompressedTensorsW4A8Int8DynamicMoEMethod", - "CompressedTensorsW8A8Fp8MoEMethod", - "NPUCompressedTensorsW8A8Int8MoEMethod", - "CompressedTensorsWNA16MoEMethod", - "CompressedTensorsMxInt4MoEMethod", - "NPUCompressedTensorsW4A16Int4DynamicMoEMethod", -] - - -class CompressedTensorsMoEMethod(FusedMoEMethodBase): - def __new__(cls, *args, **kwargs): - if cls is CompressedTensorsMoEMethod: - return super().__new__(cls) - return super().__new__(cls) - - @staticmethod - def get_moe_method( - quant_config: CompressedTensorsConfig, - layer: torch.nn.Module, - prefix: str, - ) -> "CompressedTensorsMoEMethod": - # TODO: @dsikka: refactor this to use schemes as other kernels - # are supported + check if the layer is being ignored. - - # FusedMoE was made by combining multiple Linears so need to - # make sure quantization config for Linear can target it - quant_config._add_fused_moe_to_target_scheme_map() - unfused_names = [ - prefix + proj_name - for proj_name in [".0.gate_proj", ".0.up_proj", ".0.down_proj"] - ] - # TODO: refactor this to use expert_mapping and check all layer numbers - all_scheme_dicts = [ - quant_config.get_scheme_dict(layer, name) for name in unfused_names - ] - scheme_dict = all_scheme_dicts[0] if all_scheme_dicts else None - - # multiple schemes found - if not all(d == scheme_dict for d in all_scheme_dicts): - raise ValueError( - "All MoE projections need to have same " - "quantization scheme but found multiple" - ) - - use_triton_kernels = get_moe_runner_backend().is_triton_kernels() - use_flashinfer_trtllm_moe = get_moe_runner_backend().is_flashinfer_trtllm() - if scheme_dict is None: # ignored layer - return UnquantizedFusedMoEMethod( - use_triton_kernels, use_flashinfer_trtllm_moe - ) - - weight_quant = scheme_dict.get("weights") - input_quant = scheme_dict.get("input_activations") - - if quant_config._is_wNa16_group_channel(weight_quant, input_quant): - if not _is_npu: - if ( - quant_config._is_mxint4a16(weight_quant, input_quant) - and get_moe_runner_backend().is_flashinfer_trtllm() - ): - logger.info_once( - "Using CompressedTensorsMxInt4MoEMethod with flashinfer_trtllm backend" - ) - return CompressedTensorsMxInt4MoEMethod(quant_config) - elif _is_hip: - logger.info_once( - "Using CompressedTensorsWNA16TritonMoEMethod (ROCm)" - ) - return CompressedTensorsWNA16TritonMoEMethod(quant_config) - else: - logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") - return CompressedTensorsWNA16MoEMethod(quant_config) - else: - if ( - quant_config._is_dynamic_token_w4(weight_quant, input_quant) - and input_quant is None - ): - logger.info_once( - "Using NPUCompressedTensorsW4A16Int4DynamicMoEMethod" - ) - return NPUCompressedTensorsW4A16Int4DynamicMoEMethod(quant_config) - elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): - logger.info_once("Using CompressedTensorsW4A4Nvfp4MoEMethod") - return CompressedTensorsW4A4Nvfp4MoEMethod(quant_config) - elif quant_config._is_fp8_w8a8(weight_quant, input_quant): - logger.info_once("Using CompressedTensorsW8A8Fp8MoEMethod") - return CompressedTensorsW8A8Fp8MoEMethod(quant_config) - elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): - if _is_npu: - logger.info_once("Using NPUCompressedTensorsW8A8Int8DynamicMoEMethod") - return NPUCompressedTensorsW8A8Int8DynamicMoEMethod(quant_config) - else: - raise NotImplementedError( - f"The W8A8Int8 Fused MoE scheme is implemented only for NPU for now." - ) - elif quant_config._is_dynamic_token_w4a8(weight_quant, input_quant): - if _is_npu: - logger.info_once("Using NPUCompressedTensorsW4A8Int8DynamicMoEMethod") - return NPUCompressedTensorsW4A8Int8DynamicMoEMethod(quant_config) - else: - raise NotImplementedError( - f"The W4A8Int8 Fused MoE scheme is implemented only for NPU for now." - ) - else: - raise RuntimeError( - f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}" - ) - - -class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): - - def __init__(self, quant_config: CompressedTensorsConfig): - if not is_blackwell_supported(): - raise ValueError( - "Current platform does not support NVFP4" - " quantization. Please use Blackwell and" - " above." - ) - self.quant_config = quant_config - self.group_size = 16 - self.use_flashinfer_trtllm = get_moe_runner_backend().is_flashinfer_trtllm() - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported - - layer.params_dtype = params_dtype - - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - # 2 fp4 items are packed in the input dimension - hidden_size // 2, - requires_grad=False, - dtype=torch.uint8, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_packed", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - # 2 fp4 items are packed in the input dimension - intermediate_size_per_partition // 2, - dtype=torch.uint8, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_packed", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # Weight Scales - w13_weight_scale = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - # 2 fp4 items are packed in the input dimension - hidden_size // self.group_size, - dtype=torch.float8_e4m3fn, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} - ) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - - w2_weight_scale = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - # 2 fp4 items are packed in the input dimension - intermediate_size_per_partition // self.group_size, - dtype=torch.float8_e4m3fn, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} - ) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - # Weight Global Scales - w13_weight_scale_2 = torch.nn.Parameter( - torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} - ) - set_weight_attrs(w13_weight_scale_2, extra_weight_attrs) - - w2_weight_scale_2 = torch.nn.Parameter( - torch.empty(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} - ) - set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) - - # Input Global Scales - w13_input_scale = torch.nn.Parameter( - torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w13_input_global_scale", w13_input_scale) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} - ) - set_weight_attrs(w13_input_scale, extra_weight_attrs) - - w2_input_scale = torch.nn.Parameter( - torch.empty(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w2_input_global_scale", w2_input_scale) - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} - ) - set_weight_attrs(w2_input_scale, extra_weight_attrs) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # From packed to weight - layer.w13_weight = torch.nn.Parameter( - layer.w13_weight_packed.data, requires_grad=False - ) - delattr(layer, "w13_weight_packed") - - layer.w2_weight = torch.nn.Parameter( - layer.w2_weight_packed.data, requires_grad=False - ) - delattr(layer, "w2_weight_packed") - - if self.use_flashinfer_trtllm: - w, s = reorder_w1w3_to_w3w1( - layer.w13_weight.data, layer.w13_weight_scale.data, dim=-2 - ) - layer.w13_weight = torch.nn.Parameter(w, requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter(s, requires_grad=False) - - if not torch.allclose( - layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1] - ): - logger.warning_once( - "w1_weight_global_scale must match w3_weight_global_scale. " - "Accuracy may be affected." - ) - - # Take inverse of global scale saved to disk - layer.w13_weight_scale_2 = torch.nn.Parameter( - 1 / layer.w13_weight_global_scale[:, 0], requires_grad=False - ) - - layer.w2_weight_scale_2 = torch.nn.Parameter( - 1 / layer.w2_weight_global_scale.data, requires_grad=False - ) - - # w13 - if self.use_flashinfer_trtllm: - w13_input_global_scale = ( - layer.w13_input_global_scale.min() - .to(torch.float32) - .expand(layer.num_local_experts) - ) - else: - w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to( - torch.float32 - ) - layer.g1_alphas = torch.nn.Parameter( - ((1 / w13_input_global_scale) * layer.w13_weight_scale_2), - requires_grad=False, - ) - - layer.w13_input_scale_quant = torch.nn.Parameter( - (w13_input_global_scale), requires_grad=False - ) - - # w2 - if self.use_flashinfer_trtllm: - w2_input_global_scale = ( - layer.w2_input_global_scale.min() - .to(torch.float32) - .expand(layer.num_local_experts) - ) - else: - w2_input_global_scale = layer.w2_input_global_scale - - layer.g2_alphas = torch.nn.Parameter( - ((1 / w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32), - requires_grad=False, - ) - - layer.w2_input_scale_quant = torch.nn.Parameter( - (w2_input_global_scale), requires_grad=False - ) - - # TensorRT-LLM specific processing - if self.use_flashinfer_trtllm: - # Prepare static weights for TRT-LLM kernel - ( - gemm1_weights_fp4_shuffled, - gemm1_scales_fp4_shuffled, - gemm2_weights_fp4_shuffled, - gemm2_scales_fp4_shuffled, - ) = prepare_static_weights_for_trtllm_fp4_moe( - layer.w13_weight, - layer.w2_weight, - layer.w13_weight_scale, - layer.w2_weight_scale, - layer.w2_weight.size(-2), # hidden_size - layer.w13_weight.size(-2) // 2, # intermediate_size - layer.w13_weight.size(0), # num_experts - ) - logger.debug("Finished shuffling weights for TRT-LLM MOE") - - layer.gemm1_weights_fp4_shuffled = torch.nn.Parameter( - gemm1_weights_fp4_shuffled, requires_grad=False - ) - layer.gemm2_weights_fp4_shuffled = torch.nn.Parameter( - gemm2_weights_fp4_shuffled, requires_grad=False - ) - layer.gemm1_scales_fp4_shuffled = torch.nn.Parameter( - gemm1_scales_fp4_shuffled, requires_grad=False - ) - layer.gemm2_scales_fp4_shuffled = torch.nn.Parameter( - gemm2_scales_fp4_shuffled, requires_grad=False - ) - - # Additional parameter needed for TRT-LLM - layer.g1_scale_c = torch.nn.Parameter( - (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), - requires_grad=False, - ) - - # Clean up weights that won't be used by TRT-LLM - del layer.w2_weight - del layer.w2_weight_scale - del layer.w13_weight - del layer.w13_weight_scale - else: - # swizzle weight scales - layer.w13_weight_scale = torch.nn.Parameter( - swizzle_blockscale(layer.w13_weight_scale), requires_grad=False - ) - - layer.w2_weight_scale = torch.nn.Parameter( - swizzle_blockscale(layer.w2_weight_scale), requires_grad=False - ) - - layer.cutlass_moe_params = CutlassMoEParams( - CutlassMoEType.BlockscaledFP4, - layer.w13_weight.device, - num_experts=layer.num_experts, - intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, - hidden_size=layer.w13_weight.shape[2] * 2, - ) - - def create_moe_runner( - self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig - ): - self.moe_runner_config = moe_runner_config - self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) - - def apply( - self, - layer: torch.nn.Module, - dispatch_output: StandardDispatchOutput, - ) -> CombineInput: - - from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput - - x = dispatch_output.hidden_states - topk_output = dispatch_output.topk_output - - if self.use_flashinfer_trtllm: - from flashinfer import fp4_quantize, trtllm_fp4_block_scale_moe - - router_logits = topk_output.router_logits - topk_config = topk_output.topk_config - - # Quantize input hidden states using fp4_quantize - hs_fp4_bytes, hs_sf_bytes = fp4_quantize( - x, - layer.w13_input_scale_quant, - self.group_size, # sf_vec_size - False, # use_ue8m0 - False, # is_sf_swizzled_layout - ) - hs_fp4 = hs_fp4_bytes.reshape(x.shape[0], x.shape[1] // 2) - hs_scale = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1) - - correction_bias = ( - None - if topk_config.correction_bias is None - else topk_config.correction_bias.to(x.dtype) - ) - - assert layer.routing_method_type is not None - - # DeepSeekV3 style routing requires float32 router logits - if layer.routing_method_type == RoutingMethodType.DeepSeekV3: - router_logits = router_logits.to(torch.float32) - - routed_scaling_factor = self.moe_runner_config.routed_scaling_factor - routed_scaling_factor = ( - routed_scaling_factor if routed_scaling_factor is not None else 1.0 - ) - - with use_symmetric_memory( - get_tp_group(), disabled=not is_allocation_symmetric() - ): - num_tokens = hs_fp4.shape[0] - hidden_size = ( - hs_fp4.shape[-1] * 2 - if hs_fp4.dtype == torch.uint8 - else hs_fp4.shape[-1] - ) - symm_output = torch.empty( - num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device - ) - - output = trtllm_fp4_block_scale_moe( - routing_logits=router_logits, - routing_bias=correction_bias, - hidden_states=hs_fp4, - hidden_states_scale=hs_scale, - gemm1_weights=layer.gemm1_weights_fp4_shuffled, - gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), - gemm1_bias=None, - gemm1_alpha=None, - gemm1_beta=None, - gemm1_clamp_limit=None, - gemm2_weights=layer.gemm2_weights_fp4_shuffled, - gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), - gemm2_bias=None, - output1_scale_scalar=layer.g1_scale_c, - output1_scale_gate_scalar=layer.g1_alphas, - output2_scale_scalar=layer.g2_alphas, - num_experts=layer.num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=layer.intermediate_size_per_partition, - local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, - local_num_experts=layer.num_local_experts, - routed_scaling_factor=routed_scaling_factor, - routing_method_type=layer.routing_method_type, - do_finalize=True, - tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]), - output=symm_output, - )[0] - else: - from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 - - topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids - - output = cutlass_moe_fp4( - a=x, - a1_gscale=layer.w13_input_scale_quant, - w1_fp4=layer.w13_weight, - w1_blockscale=layer.w13_weight_scale, - w1_alphas=layer.g1_alphas, - a2_gscale=layer.w2_input_scale_quant, - w2_fp4=layer.w2_weight, - w2_blockscale=layer.w2_weight_scale, - w2_alphas=layer.g2_alphas, - topk_weights=topk_weights, - topk_ids=topk_ids, - params=layer.cutlass_moe_params, - apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input, - ).to(x.dtype) - - return StandardCombineInput(hidden_states=output) - - -class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): - - def __init__(self, quant_config: CompressedTensorsConfig): - self.quant_config = quant_config - self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") - self.input_quant = self.quant_config.target_scheme_map["Linear"].get( - "input_activations" - ) - - per_tensor = ( - self.weight_quant.strategy == QuantizationStrategy.TENSOR - and self.input_quant.strategy == QuantizationStrategy.TENSOR - ) - per_channel = ( - self.weight_quant.strategy == QuantizationStrategy.CHANNEL - and self.input_quant.strategy == QuantizationStrategy.TOKEN - ) - if not (per_tensor or per_channel): - assert self.weight_quant.strategy == QuantizationStrategy.BLOCK - self.weight_block_size = self.weight_quant.block_structure - assert self.weight_quant.dynamic is not None - else: - self.weight_block_size = None - self.block_quant = self.weight_block_size is not None - - self.static_input_scales = not self.input_quant.dynamic - if self.static_input_scales and per_channel: - raise ValueError( - "For FP8 Fused MoE layer, we require either per tensor or " - "channelwise, dynamic per token quantization." - ) - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported - - params_dtype = torch.float8_e4m3fn - - if self.block_quant: - assert self.weight_block_size is not None - layer.weight_block_size = self.weight_block_size - tp_size = get_tensor_model_parallel_world_size() - block_n, block_k = ( - self.weight_block_size[0], - self.weight_block_size[1], - ) - # NOTE: To ensure proper alignment of the block-wise quantization - # scales, the output_size of the weights for both the gate and up - # layers must be divisible by block_n. - # Required by column parallel or enabling merged weights - if intermediate_size_per_partition % block_n != 0: - raise ValueError( - f"The output_size of gate's and up's weight = " - f"{intermediate_size_per_partition} is not divisible by " - f"weight quantization block_n = {block_n}." - ) - if tp_size > 1 and intermediate_size_per_partition % block_k != 0: - # Required by row parallel - raise ValueError( - f"The input_size of down's weight = " - f"{intermediate_size_per_partition} is not divisible by " - f"weight quantization block_k = {block_k}." - ) - - # WEIGHTS - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # WEIGHT_SCALES - # per-tensor quantization - if self.weight_quant.strategy == QuantizationStrategy.TENSOR: - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False - ) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - weight_quant_method = FusedMoeWeightScaleSupported.TENSOR.value - elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: - w13_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=torch.float32, - ), - requires_grad=False, - ) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), - requires_grad=False, - ) - weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value - elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: - w13_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, - 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), - (hidden_size + block_k - 1) // block_k, - dtype=torch.float32, - ), - requires_grad=False, - ) - w2_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, - (hidden_size + block_n - 1) // block_n, - (intermediate_size_per_partition + block_k - 1) // block_k, - dtype=torch.float32, - ), - requires_grad=False, - ) - weight_quant_method = FusedMoeWeightScaleSupported.BLOCK.value - else: - raise ValueError( - f"Unsupported weight quantization strategy: {self.weight_quant.strategy}" - ) - - layer.register_parameter("w13_weight_scale", w13_weight_scale) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add the quantization method used (per tensor/grouped/channel) - # to ensure the weight scales are loaded in properly - extra_weight_attrs.update({"quant_method": weight_quant_method}) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - # INPUT_SCALES - if self.static_input_scales: - assert ( - self.input_quant.strategy == QuantizationStrategy.TENSOR - ), "Only per-tensor quantization is supported for static input scales" - w13_input_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w13_input_scale", w13_input_scale) - set_weight_attrs(w13_input_scale, extra_weight_attrs) - - w2_input_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w2_input_scale", w2_input_scale) - set_weight_attrs(w2_input_scale, extra_weight_attrs) - else: - layer.w13_input_scale = None - layer.w2_input_scale = None - - def process_weights_after_loading(self, layer: torch.nn.Module | FusedMoE) -> None: - # Fp8 moe kernels require a single activation scale. - # We take the max of all the scales in case they differ. - if self.static_input_scales: - if layer.w13_input_scale is None or layer.w2_input_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - if not all_close_1d(layer.w13_input_scale) or not all_close_1d( - layer.w2_input_scale - ): - logger.warning( - "Found input_scales that are not equal for " - "fp8 MoE layer. Using the maximum across experts " - "for each layer." - ) - layer.w13_input_scale = torch.nn.Parameter( - layer.w13_input_scale.max(), requires_grad=False - ) - layer.w2_input_scale = torch.nn.Parameter( - layer.w2_input_scale.max(), requires_grad=False - ) - - if is_fp8_fnuz(): - # Normalize the weights and scales - w13_weight, w13_weight_scale, w13_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale - ) - ) - w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale - ) - # Reset the parameter - layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter( - w13_weight_scale, requires_grad=False - ) - if w13_input_scale is not None: - layer.w13_input_scale = torch.nn.Parameter( - w13_input_scale, requires_grad=False - ) - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter( - w2_weight_scale, requires_grad=False - ) - if w2_input_scale is not None: - layer.w2_input_scale = torch.nn.Parameter( - w2_input_scale, requires_grad=False - ) - if self.weight_quant.strategy == QuantizationStrategy.TENSOR: - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_weight_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.num_local_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start : start + shard_size, :], - layer.w13_weight_scale[expert_id][shard_id], - ) - ( - layer.w13_weight[expert_id][start : start + shard_size, :], - _, - ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) - - start += shard_size - - layer.w13_weight_scale = torch.nn.Parameter( - max_w13_scales, requires_grad=False - ) - - if self.weight_quant.strategy == QuantizationStrategy.CHANNEL and _use_aiter: - with torch.no_grad(): - # Pre-shuffle weights - layer.w13_weight = torch.nn.Parameter( - shuffle_weight(layer.w13_weight.data, (16, 16)), - requires_grad=False, - ) - torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - shuffle_weight(layer.w2_weight.data, (16, 16)), - requires_grad=False, - ) - torch.cuda.empty_cache() - - def create_moe_runner( - self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig - ): - self.moe_runner_config = moe_runner_config - self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) - - def apply( - self, - layer: torch.nn.Module, - dispatch_output: StandardDispatchOutput, - ) -> CombineInput: - - from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput - - x = dispatch_output.hidden_states - topk_output = dispatch_output.topk_output - - moe_runner_config = self.moe_runner_config - - if _use_aiter and self.weight_quant.strategy == QuantizationStrategy.CHANNEL: - assert not moe_runner_config.no_combine, "unsupported" - topk_weights, topk_ids, _ = topk_output - if moe_runner_config.apply_router_weight_on_input: - assert ( - topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" - _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" - x = x * topk_weights.to(x.dtype) - topk_weights = torch.ones_like( - topk_weights, dtype=torch.float32 - ) # topk_weights must be FP32 (float32) - output = fused_moe( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=( - ActivationType.Silu - if moe_runner_config.activation == "silu" - else ActivationType.Gelu - ), - quant_type=QuantType.per_Token, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) - return StandardCombineInput(hidden_states=output) - elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: - quant_info = TritonMoeQuantInfo( - w13_weight=layer.w13_weight, - w2_weight=layer.w2_weight, - use_fp8_w8a8=True, - w13_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a13_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - block_shape=self.weight_block_size, - ) - return self.runner.run(dispatch_output, quant_info) - else: - quant_info = TritonMoeQuantInfo( - w13_weight=layer.w13_weight, - w2_weight=layer.w2_weight, - use_fp8_w8a8=True, - per_channel_quant=self.weight_quant.strategy - == QuantizationStrategy.CHANNEL, - w13_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a13_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) - return self.runner.run(dispatch_output, quant_info) - - -class NPUCompressedTensorsW8A8Int8DynamicMoEMethod(CompressedTensorsMoEMethod): - - def __init__(self, quant_config: CompressedTensorsConfig): - self.quant_config = quant_config - self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") - self.input_quant = self.quant_config.target_scheme_map["Linear"].get( - "input_activations" - ) - self.kernel = NPUW8A8Int8DynamicMoEMethod() - - self.static_input_scales = not self.input_quant.dynamic - per_channel = ( - self.weight_quant.strategy == QuantizationStrategy.CHANNEL - and self.input_quant.strategy == QuantizationStrategy.TOKEN - ) - if not per_channel: - raise ValueError( - "For INT8 Fused MoE layers, we require channelwise, " - "dynamic per token quantization. Found " - f"{self.weight_quant}, {self.input_quant}" - ) - - self.static_input_scales = not self.input_quant.dynamic - if self.static_input_scales: - raise ValueError( - "For INT8 Fused MoE layers, we require channelwise, " - "dynamic per token quantization. Found static input scales." - ) - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - - from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported - - params_dtype = torch.int8 - - # WEIGHTS - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # WEIGHT_SCALES - assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL - w13_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add PER-CHANNEL quantization for FusedMoE.weight_loader. - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} - ) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - # INPUT_SCALES - assert not self.static_input_scales - layer.w13_input_scale = None - layer.w2_input_scale = None - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - self.kernel.process_weights_after_loading(layer) - - def create_moe_runner( - self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig - ): - self.moe_runner_config = moe_runner_config - - def apply( - self, - layer: torch.nn.Module, - dispatch_output: StandardDispatchOutput, - ) -> CombineInput: - - return self.kernel.apply(layer, dispatch_output) - - -class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): - - def __init__(self, quant_config: CompressedTensorsConfig, num_gpu_experts=-1): - self.quant_config = quant_config - # TODO: @dsikka: refactor this to use schemes as other kernels - # are supported + check if the layer is being ignored. - config = self.quant_config.target_scheme_map["Linear"].get("weights") - self.num_bits = config.num_bits - self.packed_factor = 32 // config.num_bits - self.strategy = config.strategy - self.group_size = config.group_size - self.actorder = config.actorder - assert config.symmetric, "Only symmetric quantization is supported for MoE" - - if not ( - self.quant_config.quant_format == CompressionFormat.pack_quantized.value - and self.num_bits in WNA16_SUPPORTED_BITS - ): - raise ValueError( - "For Fused MoE layers, only ", - f"{CompressionFormat.pack_quantized.value} ", - "is supported for the following bits: ", - f"{WNA16_SUPPORTED_BITS}", - ) - self.num_gpu_experts = num_gpu_experts - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - # Will transpose the loaded weight along the - # intermediate and hidden dim sizes. Will - # shard for TP along the transposed dims - extra_weight_attrs.update( - {"is_transposed": True, "quant_method": self.strategy} - ) - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size // self.packed_factor, - 2 * intermediate_size_per_partition, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_packed", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - intermediate_size_per_partition // self.packed_factor, - hidden_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_packed", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # In the case where we have actorder/g_idx, - # we do not partition the w2 scales - load_full_w2 = self.actorder and self.group_size != -1 - - if load_full_w2: - w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size - else: - w2_scales_size = intermediate_size_per_partition - - self.is_k_full = (not self.actorder) or layer.moe_tp_size == 1 - - if self.strategy == "channel": - num_groups_w2 = num_groups_w13 = 1 - self.group_size = -1 - else: - num_groups_w2 = w2_scales_size // self.group_size - num_groups_w13 = hidden_size // self.group_size - - w13_scale = torch.nn.Parameter( - torch.ones( - num_experts, - num_groups_w13, - 2 * intermediate_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_scale) - set_weight_attrs(w13_scale, extra_weight_attrs) - - w2_scale = torch.nn.Parameter( - torch.ones(num_experts, num_groups_w2, hidden_size, dtype=params_dtype), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_scale) - set_weight_attrs(w2_scale, extra_weight_attrs) - set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2}) - - w2_weight_shape = torch.nn.Parameter( - torch.empty(num_experts, 2), requires_grad=False - ) - layer.register_parameter("w2_weight_shape", w2_weight_shape) - set_weight_attrs(w2_weight_shape, extra_weight_attrs) - w13_weight_shape = torch.nn.Parameter( - torch.empty(num_experts, 2), requires_grad=False - ) - - layer.register_parameter("w13_weight_shape", w13_weight_shape) - set_weight_attrs(w13_weight_shape, extra_weight_attrs) - - w13_g_idx = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_g_idx", w13_g_idx) - set_weight_attrs(w13_g_idx, extra_weight_attrs) - - w2_g_idx = torch.nn.Parameter( - torch.empty( - num_experts, - intermediate_size_per_partition, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_g_idx", w2_g_idx) - set_weight_attrs(w2_g_idx, extra_weight_attrs) - - w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) - set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) - - w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty( - num_experts, - intermediate_size_per_partition, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) - set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) - - layer.a13_scale = None - layer.a2_scale = None - layer.marlin_state = GPTQMarlinState.REPACK - - if not hasattr(layer, "_original_shapes"): - layer._original_shapes = {} - - # Force record: these are the target GPTQ shapes for rollback. - layer._original_shapes["w13_weight_packed"] = tuple(w13_weight.shape) - layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape) - - # Also record the shapes of the scales. - layer._original_shapes["w2_weight_scale"] = tuple(w2_scale.shape) - layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - - # Skip if the layer is already converted to Marlin format to prevent double-packing. - if getattr(layer, "is_marlin_converted", False): - return - - if not hasattr(layer, "_original_shapes"): - layer._original_shapes = {} - - def replace_tensor(name, new_t): - target_attr = getattr(layer, name) - - # Only save if the key doesn't exist to prevent overwriting with Marlin shapes. - if name not in layer._original_shapes: - # This is a safety check; `create_weights` usually handles this already. - layer._original_shapes[name] = tuple(target_attr.shape) - - # It is important to use resize_() here since it ensures - # the same buffer is reused - target_attr.resize_(new_t.shape) - target_attr.copy_(new_t) - del new_t - - num_experts = layer.w13_weight_g_idx.shape[0] - device = layer.w13_weight_g_idx.device - - # when running models with grouped act order, - # resort to g_idx values provided in checkpoint - if self.actorder == "group": - w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx) - w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx) - w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx) - w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx) - - for e in range(num_experts): - w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_weight_g_idx[e]).to( - torch.int32 - ) - w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_weight_g_idx[e]).to( - torch.int32 - ) - w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][ - w13_g_idx_sort_indices[e] - ] - w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][w2_g_idx_sort_indices[e]] - - replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx) - replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx) - replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices) - replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices) - - else: - layer.w13_weight_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_weight_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - - marlin_w13_qweight = gptq_marlin_moe_repack( - layer.w13_weight_packed, - layer.w13_g_idx_sort_indices, - layer.w13_weight_packed.shape[1] * self.packed_factor, - layer.w13_weight_packed.shape[2], - self.num_bits, - ) - replace_tensor("w13_weight_packed", marlin_w13_qweight) - marlin_w2_qweight = gptq_marlin_moe_repack( - layer.w2_weight_packed, - layer.w2_g_idx_sort_indices, - layer.w2_weight_packed.shape[1] * self.packed_factor, - layer.w2_weight_packed.shape[2], - self.num_bits, - ) - replace_tensor("w2_weight_packed", marlin_w2_qweight) - # Repack scales - marlin_w13_scales = marlin_moe_permute_scales( - layer.w13_weight_scale, - layer.w13_weight_packed.shape[2], - layer.w13_weight_scale.shape[2], - self.group_size, - ) - replace_tensor("w13_weight_scale", marlin_w13_scales) - - marlin_w2_scales = marlin_moe_permute_scales( - layer.w2_weight_scale, - layer.w2_weight_scale.shape[1] - * (self.group_size if self.group_size != -1 else self.packed_factor), - layer.w2_weight_scale.shape[2], - self.group_size, - ) - replace_tensor("w2_weight_scale", marlin_w2_scales) - - layer.is_marlin_converted = True - - def restore_weights_before_loading(self, layer: torch.nn.Module): - """Forcibly resize parameters back to their original shapes (e.g., GPTQ format) before loading weights.""" - - if not hasattr(layer, "_original_shapes"): - return - - for name, orig_shape in layer._original_shapes.items(): - param = getattr(layer, name, None) - - if param is not None and param.shape != orig_shape: - param.resize_(orig_shape) - - layer.is_marlin_converted = False - - def create_moe_runner( - self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig - ): - self.moe_runner_config = moe_runner_config - - def apply( - self, - layer: torch.nn.Module, - dispatch_output: StandardDispatchOutput, - ) -> CombineInput: - from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import ( - fused_marlin_moe, - ) - from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput - - assert ( - self.moe_runner_config.activation == "silu" - ), "Only SiLU activation is supported." - - x = dispatch_output.hidden_states - topk_output = dispatch_output.topk_output - - topk_weights, topk_ids, router_logits = topk_output - - # Get expert_map for EP support - expert_map = None - global_num_experts = -1 - if hasattr(layer, "dispatcher") and hasattr( - layer.dispatcher, "local_expert_mapping" - ): - expert_map = layer.dispatcher.local_expert_mapping - if expert_map is not None: - global_num_experts = self.moe_runner_config.num_experts - - output = fused_marlin_moe( - x, - layer.w13_weight_packed, - layer.w2_weight_packed, - layer.w13_weight_scale, - layer.w2_weight_scale, - router_logits, - topk_weights, - topk_ids, - global_num_experts=global_num_experts, - expert_map=expert_map, - g_idx1=layer.w13_weight_g_idx, - g_idx2=layer.w2_weight_g_idx, - sort_indices1=layer.w13_g_idx_sort_indices, - sort_indices2=layer.w2_g_idx_sort_indices, - num_bits=self.num_bits, - is_k_full=self.is_k_full, - routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, - ) - return StandardCombineInput(hidden_states=output) - - -class CompressedTensorsWNA16TritonMoEMethod(CompressedTensorsWNA16MoEMethod): - """ROCm/HIP-compatible W4A16 MoE method using Triton kernels instead of Marlin. - - Inherits weight creation from CompressedTensorsWNA16MoEMethod but converts - weights to the uint8-packed format expected by the Triton fused MoE kernel - instead of the Marlin-specific format. - """ - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - if getattr(layer, "is_triton_converted", False): - return - - num_experts = layer.w13_weight_packed.shape[0] - - # Convert w13 weights: [E, K//8, N] int32 -> [E, N, K//2] uint8 - w13 = layer.w13_weight_packed.data - w13 = w13.transpose(1, 2).contiguous().view(torch.uint8) - layer.w13_weight_packed = torch.nn.Parameter(w13, requires_grad=False) - - # Convert w2 weights: [E, K//8, N] int32 -> [E, N, K//2] uint8 - w2 = layer.w2_weight_packed.data - w2 = w2.transpose(1, 2).contiguous().view(torch.uint8) - layer.w2_weight_packed = torch.nn.Parameter(w2, requires_grad=False) - - # Convert w13 scales: [E, K//group_size, N] -> [E, N, K//group_size] - w13_scale = layer.w13_weight_scale.data - w13_scale = w13_scale.transpose(1, 2).contiguous() - layer.w13_weight_scale = torch.nn.Parameter(w13_scale, requires_grad=False) - - # Convert w2 scales: [E, K//group_size, N] -> [E, N, K//group_size] - w2_scale = layer.w2_weight_scale.data - w2_scale = w2_scale.transpose(1, 2).contiguous() - layer.w2_weight_scale = torch.nn.Parameter(w2_scale, requires_grad=False) - - layer.is_triton_converted = True - - def create_moe_runner( - self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig - ): - self.moe_runner_config = moe_runner_config - self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) - - def apply( - self, - layer: torch.nn.Module, - dispatch_output: "StandardDispatchOutput", - ) -> "CombineInput": - from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo - - assert ( - self.moe_runner_config.activation == "silu" - ), "Only SiLU activation is supported." - - quant_info = TritonMoeQuantInfo( - w13_weight=layer.w13_weight_packed, - w2_weight=layer.w2_weight_packed, - use_int4_w4a16=True, - w13_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - block_shape=[0, self.group_size], - ) - return self.runner.run(dispatch_output, quant_info) - - -class NPUCompressedTensorsW4A8Int8DynamicMoEMethod(CompressedTensorsMoEMethod): - - ### TODO: Get rid of code duplication with python/sglang/srt/modelslim/modelslim_moe.py @OrangeRedeng @TamirBaydasov - def __init__(self, quantization_config) -> None: - self.group_size = 0 - self.is_per_channel_weight = self.group_size == 0 - self.tp_size = 1 - self.activation_use_clip = ( - self.quantization_config.get("config_groups", {}) - .get("group_1", {}) - .get("activation_use_clip", False) - ) - self.kernel = NPUW4A8Int8DynamicMoEMethod() - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ) -> None: - from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported - - self.num_experts = num_experts - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} - ) - - # >> weight - w13_output_size = intermediate_size_per_partition - w2_output_size = hidden_size // 2 - w13_weight = torch.nn.Parameter( - torch.empty(num_experts, w13_output_size, hidden_size, dtype=torch.int8), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - w2_output_size, - intermediate_size_per_partition, - dtype=torch.int8, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # >> scale - weight_scale_dtype = torch.int64 if self.activation_use_clip else torch.float32 - w13_weight_scale = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - 1, - dtype=weight_scale_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - - w2_weight_scale = torch.nn.Parameter( - torch.empty(num_experts, hidden_size, 1, dtype=weight_scale_dtype), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - # >> offset - w13_weight_offset = torch.nn.Parameter( - torch.empty( - num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_offset", w13_weight_offset) - set_weight_attrs(w13_weight_offset, extra_weight_attrs) - - w2_weight_offset = torch.nn.Parameter( - torch.empty(num_experts, hidden_size, 1, dtype=torch.float32), - requires_grad=False, - ) - layer.register_parameter("w2_weight_offset", w2_weight_offset) - set_weight_attrs(w2_weight_offset, extra_weight_attrs) - - # >>> special param for w4a8 - if self.activation_use_clip: - self._init_activation_clip_params( - layer, - num_experts, - hidden_size, - intermediate_size_per_partition, - extra_weight_attrs, - ) - else: - self._init_extra_scale_params( - layer, - num_experts, - hidden_size, - intermediate_size_per_partition, - extra_weight_attrs, - ) - - def _init_activation_clip_params( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - extra_weight_attrs: dict, - ) -> None: - """ - Initializes bias and alpha parameters for quantization schemes that use activation clipping. - - This helper registers `w13_bias`, `w2_bias`, and `w2_alpha`, which are required to - shift and scale the activations or outputs to compensate for the precision loss - introduced by clamping activations. - """ - w13_bias = torch.nn.Parameter( - torch.ones( - num_experts, 2 * intermediate_size_per_partition, dtype=torch.float - ), - requires_grad=False, - ) - layer.register_parameter("w13_bias", w13_bias) - set_weight_attrs(w13_bias, extra_weight_attrs) - - w2_bias = torch.nn.Parameter( - torch.ones(num_experts, hidden_size, dtype=torch.float), - requires_grad=False, - ) - layer.register_parameter("w2_bias", w2_bias) - set_weight_attrs(w2_bias, extra_weight_attrs) - - w2_alpha = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float), requires_grad=False - ) - layer.register_parameter("w2_alpha", w2_alpha) - set_weight_attrs(w2_alpha, extra_weight_attrs) - - def _init_extra_scale_params( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - extra_weight_attrs: dict, - ) -> None: - """ - Initializes additional scaling, offset, and bias parameters for quantization schemes without activation clipping. - - This method registers the following parameters: - 1. Scale Biases: `w13_scale_bias` and `w2_scale_bias`. - 2. Secondary Quantization Params (initialized only for grouped quantization): - `w13_weight_scale_second`, `w13_weight_offset_second`, - `w2_weight_scale_second`, and `w2_weight_offset_second`. - """ - if not self.is_per_channel_weight: - w13_weight_scale_second = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // self.group_size, - dtype=torch.float32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale_second", w13_weight_scale_second) - set_weight_attrs(w13_weight_scale_second, extra_weight_attrs) - - w13_weight_offset_second = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // self.group_size, - dtype=torch.float32, - ), - requires_grad=False, - ) - layer.register_parameter( - "w13_weight_offset_second", w13_weight_offset_second - ) - set_weight_attrs(w13_weight_offset_second, extra_weight_attrs) - - w2_weight_scale_second = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition // self.group_size, - dtype=torch.float32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale_second", w2_weight_scale_second) - set_weight_attrs(w2_weight_scale_second, extra_weight_attrs) - - w2_weight_offset_second = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition // self.group_size, - dtype=torch.float32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_offset_second", w2_weight_offset_second) - set_weight_attrs(w2_weight_offset_second, extra_weight_attrs) - - w13_scale_bias = torch.nn.Parameter( - torch.empty( - num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 - ), - requires_grad=False, - ) - layer.register_parameter("w13_scale_bias", w13_scale_bias) - set_weight_attrs(w13_scale_bias, extra_weight_attrs) - - w2_scale_bias = torch.nn.Parameter( - torch.empty( - num_experts, hidden_size, 16 // self.tp_size, dtype=torch.float32 - ), - requires_grad=False, - ) - layer.register_parameter("w2_scale_bias", w2_scale_bias) - set_weight_attrs(w2_scale_bias, extra_weight_attrs) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - self.kernel.process_weights_after_loading( - layer, self.is_per_channel_weight, self.activation_use_clip - ) - - def create_moe_runner( - self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig - ): - self.moe_runner_config = moe_runner_config - - def apply( - self, - layer: torch.nn.Module, - dispatch_output: StandardDispatchOutput, - ) -> CombineInput: - - return self.kernel.apply(layer, dispatch_output) - - def apply_without_routing_weights( - self, - layer, - hidden_states, - hidden_states_scale, - group_list_type, - group_list, - output_dtype, - ): - return self.kernel.apply_without_routing_weights( - layer, - hidden_states, - hidden_states_scale, - group_list_type, - group_list, - output_dtype, - ) - - -class NPUCompressedTensorsW4A16Int4DynamicMoEMethod(CompressedTensorsMoEMethod): - - def __init__(self, quantization_config) -> None: - self.pack_factor = 8 # weight dtype is int4, but use int32 to create - target = ( - "MoEGMM" if "MoEGMM" in quantization_config.target_scheme_map else "Linear" - ) - if target in quantization_config.target_scheme_map: - self.group_size = quantization_config.target_scheme_map[target][ - "weights" - ].group_size - else: - self.group_size = 128 - - self.kernel = NPUW4A16Int4DynamicMoEMethod() - - # TODO: See if we can merge this method's logic - # with CompressedTensorsWNA16MoEMethod. Need more models and tests. - # @OrangeRedeng @TamirBaydasov - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ) -> None: - from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported - - self.num_experts = num_experts - if ( - extra_weight_attrs.get( - "moe_intermediate_size", intermediate_size_per_partition - ) - // intermediate_size_per_partition - > 1 - ): - quant_method = FusedMoeWeightScaleSupported.GROUP.value - else: - quant_method = FusedMoeWeightScaleSupported.CHANNEL.value - extra_weight_attrs.update({"quant_method": quant_method}) - # weight - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // self.pack_factor, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition // self.pack_factor, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # scale - weight_scale_dtype = torch.bfloat16 - w13_weight_scale = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // self.group_size, - dtype=weight_scale_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - set_weight_attrs(w13_weight_scale, extra_weight_attrs) - w2_weight_scale = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition // self.group_size, - dtype=weight_scale_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) - set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - # offset - w13_weight_offset = torch.nn.Parameter( - torch.zeros( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // self.group_size, - dtype=weight_scale_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_offset", w13_weight_offset) - set_weight_attrs(w13_weight_offset, extra_weight_attrs) - - w2_weight_offset = torch.nn.Parameter( - torch.zeros( - num_experts, - hidden_size, - intermediate_size_per_partition // self.group_size, - dtype=weight_scale_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_offset", w2_weight_offset) - set_weight_attrs(w2_weight_offset, extra_weight_attrs) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - self.kernel.process_weights_after_loading(layer) - - def create_moe_runner( - self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig - ): - self.moe_runner_config = moe_runner_config - - def apply( - self, - layer: torch.nn.Module, - dispatch_output: StandardDispatchOutput, - ) -> CombineInput: - - return self.kernel.apply(layer, dispatch_output) - - def apply_without_routing_weights( - self, - layer, - hidden_states, - hidden_states_scale, - group_list_type, - group_list, - output_dtype, - ): - return self.kernel.apply_without_routing_weights( - layer, - hidden_states, - hidden_states_scale, - group_list_type, - group_list, - output_dtype, - ) - - -class CompressedTensorsMxInt4MoEMethod(CompressedTensorsMoEMethod): - def __init__(self, quant_config: CompressedTensorsConfig): - self.quant_config = quant_config - config = self.quant_config.target_scheme_map["Linear"].get("weights") - self.num_bits = config.num_bits - self.packed_factor = 32 // config.num_bits - self.strategy = config.strategy - self.group_size = config.group_size - self.actorder = config.actorder - assert ( - config.strategy == "group" - and config.group_size == 32 - and config.num_bits == 4 - ), "MxInt4 only supports group strategy with group size 32" - assert config.symmetric, "Only symmetric quantization is supported for MoE" - assert ( - get_moe_runner_backend().is_flashinfer_trtllm() - ), "MxInt4 only supports flashinfer_trtllm backend" - assert ( - not config.actorder - ), "Actorder is not supported by flashinfer_trtllm backend" - self.moe_ep_rank = get_moe_expert_parallel_rank() - - if self.quant_config.quant_format != CompressionFormat.pack_quantized.value: - raise ValueError( - f"For Fused MoE layers, only {CompressionFormat.pack_quantized.value} " - "is supported for the mxint4" - ) - self._cache_permute_indices = {} - - def create_weights( - self, - layer: torch.nn.Module, - num_experts: int, - hidden_size: int, - intermediate_size_per_partition: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - extra_weight_attrs.update({"quant_method": self.strategy}) - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size // self.packed_factor, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_packed", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition // self.packed_factor, - dtype=torch.int32, - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight_packed", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - w2_scales_size = intermediate_size_per_partition - num_groups_w2 = w2_scales_size // self.group_size - num_groups_w13 = hidden_size // self.group_size - - assert params_dtype == torch.bfloat16 - w13_scale = torch.nn.Parameter( - torch.ones( - num_experts, - 2 * intermediate_size_per_partition, - num_groups_w13, - dtype=params_dtype, - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight_scale", w13_scale) - set_weight_attrs(w13_scale, extra_weight_attrs) - - w2_scale = torch.nn.Parameter( - torch.ones(num_experts, hidden_size, num_groups_w2, dtype=params_dtype), - requires_grad=False, - ) - layer.register_parameter("w2_weight_scale", w2_scale) - set_weight_attrs(w2_scale, extra_weight_attrs) - - w13_weight_shape = torch.nn.Parameter( - torch.empty(num_experts, 2), requires_grad=False - ) - - layer.register_parameter("w13_weight_shape", w13_weight_shape) - set_weight_attrs(w13_weight_shape, extra_weight_attrs) - - w2_weight_shape = torch.nn.Parameter( - torch.empty(num_experts, 2), requires_grad=False - ) - layer.register_parameter("w2_weight_shape", w2_weight_shape) - set_weight_attrs(w2_weight_shape, extra_weight_attrs) - - layer.a13_scale = None - layer.a2_scale = None - - # Adapted from https://github.com/flashinfer-ai/flashinfer/blob/main/tests/moe/test_trtllm_gen_fused_moe.py - def prepare_static_weights_for_kernel( - self, - gemm1_weights, - gemm2_weights, - gemm1_scales, - gemm2_scales, - num_experts, - ): - """Prepare quantized weights for kernel (done offline with weights).""" - - epilogue_tile_m = 128 - gemm1_weights_mxint4_shuffled = [] - gemm1_scales_shuffled = [] - gemm2_weights_mxint4_shuffled = [] - gemm2_scales_shuffled = [] - - def repack(w): - assert w.dim() == 2 and w.dtype == torch.int32 - shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=w.device) - w = (w.unsqueeze(2) >> shifts) & 0x0F - w = (w - 8).to(torch.int8).reshape(w.shape[0], -1, 2) - w = (w[..., 0] & 0x0F) | ((w[..., 1] & 0x0F) << 4) - w = w.to(torch.uint8) - return w - - for i in range(num_experts): - # NOTE(HandH1998): - # the huggingface weight format follows (w/s + 8) to pack, - # however, trtllm requires (w/s) to pack - # we need to convert the weight to trtllm's format first - cur_expert_gemm1_weight = repack(gemm1_weights[i]) - cur_expert_gemm2_weight = repack(gemm2_weights[i]) - - # Calculate the permute indices for the following: - # 1. Reorder rows of W1 and scales for fused gated activation - # 2. Shuffle weights and scaling factors for transposed mma output - # for both w3_w1 and w2 weights and scale factors - permute_indices = _maybe_get_cached_w3_w1_permute_indices( - self._cache_permute_indices, - cur_expert_gemm1_weight, - epilogue_tile_m, - ) - gemm1_weights_shuffled = cur_expert_gemm1_weight[ - permute_indices.to(gemm1_weights.device) - ].contiguous() - permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices( - self._cache_permute_indices, - gemm1_scales[i].to(torch.bfloat16), - epilogue_tile_m, - num_elts_per_sf=32, - ) - gemm1_scales_shuffled.append( - block_scale_interleave( - gemm1_scales[i] - .to(torch.bfloat16)[permute_sf_indices.to(gemm1_scales.device)] - .contiguous() - ) - ) - - permute_indices = get_w2_permute_indices_with_cache( - self._cache_permute_indices, - cur_expert_gemm2_weight, - epilogue_tile_m, - ) - gemm2_weights_shuffled = cur_expert_gemm2_weight[ - permute_indices.to(gemm2_weights.device) - ].contiguous() - - permute_sf_indices = get_w2_permute_indices_with_cache( - self._cache_permute_indices, - gemm2_scales[i].to(torch.bfloat16), - epilogue_tile_m, - num_elts_per_sf=16, - ) - gemm2_scales_shuffled.append( - block_scale_interleave( - gemm2_scales[i] - .to(torch.bfloat16)[permute_sf_indices.to(gemm2_scales.device)] - .contiguous() - ) - ) - - block_k = 128 - gemm1_weights_shuffled = convert_to_block_layout( - gemm1_weights_shuffled.view(torch.uint8), block_k - ) - gemm2_weights_shuffled = convert_to_block_layout( - gemm2_weights_shuffled.view(torch.uint8), block_k - ) - - gemm1_weights_mxint4_shuffled.append(gemm1_weights_shuffled) - gemm2_weights_mxint4_shuffled.append(gemm2_weights_shuffled) - - gemm1_weights_mxint4_shuffled = torch.stack(gemm1_weights_mxint4_shuffled) - gemm2_weights_mxint4_shuffled = torch.stack(gemm2_weights_mxint4_shuffled) - gemm1_scales_shuffled = torch.stack(gemm1_scales_shuffled).view(torch.bfloat16) - gemm2_scales_shuffled = torch.stack(gemm2_scales_shuffled).view(torch.bfloat16) - - return ( - gemm1_weights_mxint4_shuffled, - gemm1_scales_shuffled, - gemm2_weights_mxint4_shuffled, - gemm2_scales_shuffled, - ) - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - - num_experts = layer.w13_weight_packed.shape[0] - ( - gemm1_weights_mxint4_shuffled, - gemm1_scales_shuffled, - gemm2_weights_mxint4_shuffled, - gemm2_scales_shuffled, - ) = self.prepare_static_weights_for_kernel( - layer.w13_weight_packed, - layer.w2_weight_packed, - layer.w13_weight_scale, - layer.w2_weight_scale, - num_experts=num_experts, - ) - replace_parameter(layer, "w13_weight_packed", gemm1_weights_mxint4_shuffled) - replace_parameter(layer, "w2_weight_packed", gemm2_weights_mxint4_shuffled) - replace_parameter(layer, "w13_weight_scale", gemm1_scales_shuffled) - replace_parameter(layer, "w2_weight_scale", gemm2_scales_shuffled) - - def create_moe_runner( - self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig - ): - self.moe_runner_config = moe_runner_config - - def apply( - self, - layer: torch.nn.Module, - dispatch_output: StandardDispatchOutput, - ) -> CombineInput: - from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput - - assert ( - self.moe_runner_config.is_gated - ), "Only gated MoEs are supported for flashinfer mxint4" - - x = dispatch_output.hidden_states - topk_output = dispatch_output.topk_output - - router_logits = topk_output.router_logits - topk_config = topk_output.topk_config - correction_bias = ( - None - if topk_config.correction_bias is None - else topk_config.correction_bias.to(x.dtype) - ) - - local_num_experts = self.moe_runner_config.num_local_experts - routing_method_type = layer.routing_method_type - assert routing_method_type is not None - # DeepSeekV3 style routing requires float32 router logits, - # see this PR for details: https://github.com/flashinfer-ai/flashinfer/commit/d84e1d560da0a27961c19ca788d96c19cb9dcfb6 - if routing_method_type == RoutingMethodType.DeepSeekV3: - router_logits = router_logits.to(torch.float32) - routed_scaling_factor = self.moe_runner_config.routed_scaling_factor - routed_scaling_factor = ( - routed_scaling_factor if routed_scaling_factor is not None else 1.0 - ) - - with use_symmetric_memory( - get_tp_group(), disabled=not is_allocation_symmetric() - ): - num_tokens = x.shape[0] - hidden_size = x.shape[-1] - symm_output = torch.empty( - num_tokens, hidden_size, dtype=torch.bfloat16, device=x.device - ) - - output = trtllm_mxint4_block_scale_moe( - routing_logits=router_logits, # float - routing_bias=correction_bias, - hidden_states=x, - gemm1_weights=layer.w13_weight_packed, - gemm1_weights_scale=layer.w13_weight_scale, - gemm1_alpha=self.moe_runner_config.gemm1_alpha, - gemm1_beta=None, - gemm1_clamp_limit=self.moe_runner_config.gemm1_clamp_limit, - gemm2_weights=layer.w2_weight_packed, - gemm2_weights_scale=layer.w2_weight_scale, - num_experts=self.moe_runner_config.num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=self.moe_runner_config.intermediate_size_per_partition, - local_expert_offset=self.moe_ep_rank * local_num_experts, - local_num_experts=local_num_experts, - routed_scaling_factor=routed_scaling_factor, - routing_method_type=routing_method_type, - tune_max_num_tokens=next_power_of_2(x.shape[0]), - output=symm_output, - ) - - return StandardCombineInput(hidden_states=output) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py index 70ca328c8a91..8f67e5ba5338 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py @@ -1,22 +1,44 @@ # SPDX-License-Identifier: Apache-2.0 -from .compressed_tensors_scheme import CompressedTensorsScheme +from .compressed_tensors_scheme import ( + CompressedTensorsLinearScheme, + CompressedTensorsMoEScheme, +) +from .compressed_tensors_w4a4_mxint4_moe import CompressedTensorsMxInt4MoE from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4 +from .compressed_tensors_w4a4_nvfp4_moe import CompressedTensorsW4A4Nvfp4MoE +from .compressed_tensors_w4a8_int8_moe import NPUCompressedTensorsW4A8Int8DynamicMoE from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 +from .compressed_tensors_w8a8_fp8_moe import CompressedTensorsW8A8Fp8MoE from .compressed_tensors_w8a8_int8 import ( CompressedTensorsW8A8Int8, NPUCompressedTensorsW8A8Int8, ) +from .compressed_tensors_w8a8_int8_moe import NPUCompressedTensorsW8A8Int8DynamicMoE from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16 +from .compressed_tensors_wNa16_moe import ( + CompressedTensorsWNA16MoE, + CompressedTensorsWNA16TritonMoE, + NPUCompressedTensorsW4A16Int4DynamicMoE, +) __all__ = [ - "CompressedTensorsScheme", + "CompressedTensorsLinearScheme", + "CompressedTensorsMoEScheme", "CompressedTensorsW8A8Fp8", + "CompressedTensorsW8A8Fp8MoE", "CompressedTensorsW8A16Fp8", "CompressedTensorsW8A8Int8", "NPUCompressedTensorsW8A8Int8", + "NPUCompressedTensorsW8A8Int8DynamicMoE", "CompressedTensorsWNA16", + "CompressedTensorsWNA16MoE", + "CompressedTensorsWNA16TritonMoE", + "NPUCompressedTensorsW4A16Int4DynamicMoE", "WNA16_SUPPORTED_BITS", "CompressedTensorsW4A4Fp4", + "CompressedTensorsW4A4Nvfp4MoE", + "NPUCompressedTensorsW4A8Int8DynamicMoE", + "CompressedTensorsMxInt4MoE", ] diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py index 3795d0a54ebe..917d417e76b6 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -1,22 +1,27 @@ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors # SPDX-License-Identifier: Apache-2.0 -from abc import ABC, abstractmethod -from typing import Optional +from abc import abstractmethod +from typing import TYPE_CHECKING, Optional import torch -__all__ = ["CompressedTensorsScheme"] +from sglang.srt.layers.moe import MoeRunnerConfig +from sglang.srt.layers.quantization.base_scheme import BaseLinearScheme, BaseMoEScheme +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import StandardDispatchOutput -class CompressedTensorsScheme(ABC): +__all__ = ["CompressedTensorsLinearScheme", "CompressedTensorsMoEScheme"] + + +class CompressedTensorsLinearScheme(BaseLinearScheme): """ Abstract class used to describe the weight creation and forward pass of different quantization schemes supported by CompressedTensors. """ @classmethod - @abstractmethod def get_min_capability(cls) -> int: """ Get minimum device capability. @@ -54,3 +59,57 @@ def process_weights_after_loading(self, layer: torch.nn.Module): needs to occur. """ raise NotImplementedError + + +class CompressedTensorsMoEScheme(BaseMoEScheme): + """ + Abstract class used to describe the weight creation and forward pass + of different quantization schemes supported by CompressedTensors. + """ + + @classmethod + def get_min_capability(cls) -> int: + """ + Get minimum device capability. + """ + raise NotImplementedError + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + raise NotImplementedError + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module): + """ + Called after weight loading is complete for any cleanup that + needs to occur. + """ + raise NotImplementedError + + @abstractmethod + def apply_weights( + self, + layer: torch.nn.Module, + dispatch_output: "StandardDispatchOutput", + ): + """ + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. + + :param layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + :param x: input to the layer + :param bias: bias parameter + + """ + raise NotImplementedError diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxint4_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxint4_moe.py new file mode 100644 index 000000000000..47db282016b6 --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_mxint4_moe.py @@ -0,0 +1,358 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch +from compressed_tensors import CompressionFormat + +from sglang.srt.distributed import get_moe_expert_parallel_rank, get_tp_group +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) +from sglang.srt.layers.dp_attention import is_allocation_symmetric +from sglang.srt.layers.moe import MoeRunnerConfig +from sglang.srt.layers.moe.utils import RoutingMethodType, get_moe_runner_backend +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsMoEScheme, +) +from sglang.srt.layers.quantization.utils import replace_parameter +from sglang.srt.utils import is_flashinfer_available, next_power_of_2, set_weight_attrs + +logger = logging.getLogger(__name__) + +__all__ = ["CompressedTensorsMxInt4MoE"] + +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) + from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( + CompressedTensorsConfig, + ) + +if is_flashinfer_available(): + from flashinfer.fp4_quantization import block_scale_interleave + from flashinfer.fused_moe import ( + convert_to_block_layout, + trtllm_mxint4_block_scale_moe, + ) + from flashinfer.fused_moe.core import ( + _maybe_get_cached_w3_w1_permute_indices, + get_w2_permute_indices_with_cache, + ) + + +class CompressedTensorsMxInt4MoE(CompressedTensorsMoEScheme): + def __init__(self, quant_config: CompressedTensorsConfig): + self.quant_config = quant_config + config = self.quant_config.target_scheme_map["Linear"].get("weights") + self.num_bits = config.num_bits + self.packed_factor = 32 // config.num_bits + self.strategy = config.strategy + self.group_size = config.group_size + self.actorder = config.actorder + assert ( + config.strategy == "group" + and config.group_size == 32 + and config.num_bits == 4 + ), "MxInt4 only supports group strategy with group size 32" + assert config.symmetric, "Only symmetric quantization is supported for MoE" + assert ( + get_moe_runner_backend().is_flashinfer_trtllm() + ), "MxInt4 only supports flashinfer_trtllm backend" + assert ( + not config.actorder + ), "Actorder is not supported by flashinfer_trtllm backend" + self.moe_ep_rank = get_moe_expert_parallel_rank() + + if self.quant_config.quant_format != CompressionFormat.pack_quantized.value: + raise ValueError( + f"For Fused MoE layers, only {CompressionFormat.pack_quantized.value} " + "is supported for the mxint4" + ) + self._cache_permute_indices = {} + + @classmethod + def get_min_capability(cls) -> int: + # Requires sm100(blackwell) architecture + return 100 + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + assert ( + params_dtype == torch.bfloat16, + f"Params dtype should be torch.bfloat16, but got: {params_dtype}", + ) + + extra_weight_attrs.update({"quant_method": self.strategy}) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.packed_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.packed_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w2_scales_size = intermediate_size_per_partition + num_groups_w2 = w2_scales_size // self.group_size + num_groups_w13 = hidden_size // self.group_size + + w13_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + num_groups_w13, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_scale) + set_weight_attrs(w13_scale, extra_weight_attrs) + + w2_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, num_groups_w2, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_scale) + set_weight_attrs(w2_scale, extra_weight_attrs) + + w13_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + w2_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + + layer.a13_scale = None + layer.a2_scale = None + + # Adapted from https://github.com/flashinfer-ai/flashinfer/blob/main/tests/moe/test_trtllm_gen_fused_moe.py + def prepare_static_weights_for_kernel( + self, + gemm1_weights, + gemm2_weights, + gemm1_scales, + gemm2_scales, + num_experts, + ): + """Prepare quantized weights for kernel (done offline with weights).""" + + epilogue_tile_m = 128 + gemm1_weights_mxint4_shuffled = [] + gemm1_scales_shuffled = [] + gemm2_weights_mxint4_shuffled = [] + gemm2_scales_shuffled = [] + + def repack(w): + assert w.dim() == 2 and w.dtype == torch.int32 + shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=w.device) + w = (w.unsqueeze(2) >> shifts) & 0x0F + w = (w - 8).to(torch.int8).reshape(w.shape[0], -1, 2) + w = (w[..., 0] & 0x0F) | ((w[..., 1] & 0x0F) << 4) + w = w.to(torch.uint8) + return w + + for i in range(num_experts): + # NOTE(HandH1998): + # the huggingface weight format follows (w/s + 8) to pack, + # however, trtllm requires (w/s) to pack + # we need to convert the weight to trtllm's format first + cur_expert_gemm1_weight = repack(gemm1_weights[i]) + cur_expert_gemm2_weight = repack(gemm2_weights[i]) + + # Calculate the permute indices for the following: + # 1. Reorder rows of W1 and scales for fused gated activation + # 2. Shuffle weights and scaling factors for transposed mma output + # for both w3_w1 and w2 weights and scale factors + permute_indices = _maybe_get_cached_w3_w1_permute_indices( + self._cache_permute_indices, + cur_expert_gemm1_weight, + epilogue_tile_m, + ) + gemm1_weights_shuffled = cur_expert_gemm1_weight[ + permute_indices.to(gemm1_weights.device) + ].contiguous() + permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices( + self._cache_permute_indices, + gemm1_scales[i].to(torch.bfloat16), + epilogue_tile_m, + num_elts_per_sf=32, + ) + gemm1_scales_shuffled.append( + block_scale_interleave( + gemm1_scales[i] + .to(torch.bfloat16)[permute_sf_indices.to(gemm1_scales.device)] + .contiguous() + ) + ) + + permute_indices = get_w2_permute_indices_with_cache( + self._cache_permute_indices, + cur_expert_gemm2_weight, + epilogue_tile_m, + ) + gemm2_weights_shuffled = cur_expert_gemm2_weight[ + permute_indices.to(gemm2_weights.device) + ].contiguous() + + permute_sf_indices = get_w2_permute_indices_with_cache( + self._cache_permute_indices, + gemm2_scales[i].to(torch.bfloat16), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm2_scales_shuffled.append( + block_scale_interleave( + gemm2_scales[i] + .to(torch.bfloat16)[permute_sf_indices.to(gemm2_scales.device)] + .contiguous() + ) + ) + + block_k = 128 + gemm1_weights_shuffled = convert_to_block_layout( + gemm1_weights_shuffled.view(torch.uint8), block_k + ) + gemm2_weights_shuffled = convert_to_block_layout( + gemm2_weights_shuffled.view(torch.uint8), block_k + ) + + gemm1_weights_mxint4_shuffled.append(gemm1_weights_shuffled) + gemm2_weights_mxint4_shuffled.append(gemm2_weights_shuffled) + + gemm1_weights_mxint4_shuffled = torch.stack(gemm1_weights_mxint4_shuffled) + gemm2_weights_mxint4_shuffled = torch.stack(gemm2_weights_mxint4_shuffled) + gemm1_scales_shuffled = torch.stack(gemm1_scales_shuffled).view(torch.bfloat16) + gemm2_scales_shuffled = torch.stack(gemm2_scales_shuffled).view(torch.bfloat16) + + return ( + gemm1_weights_mxint4_shuffled, + gemm1_scales_shuffled, + gemm2_weights_mxint4_shuffled, + gemm2_scales_shuffled, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + num_experts = layer.w13_weight_packed.shape[0] + ( + gemm1_weights_mxint4_shuffled, + gemm1_scales_shuffled, + gemm2_weights_mxint4_shuffled, + gemm2_scales_shuffled, + ) = self.prepare_static_weights_for_kernel( + layer.w13_weight_packed, + layer.w2_weight_packed, + layer.w13_weight_scale, + layer.w2_weight_scale, + num_experts=num_experts, + ) + replace_parameter(layer, "w13_weight_packed", gemm1_weights_mxint4_shuffled) + replace_parameter(layer, "w2_weight_packed", gemm2_weights_mxint4_shuffled) + replace_parameter(layer, "w13_weight_scale", gemm1_scales_shuffled) + replace_parameter(layer, "w2_weight_scale", gemm2_scales_shuffled) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + assert ( + self.moe_runner_config.is_gated + ), "Only gated MoEs are supported for flashinfer mxint4" + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + router_logits = topk_output.router_logits + topk_config = topk_output.topk_config + correction_bias = ( + None + if topk_config.correction_bias is None + else topk_config.correction_bias.to(x.dtype) + ) + + local_num_experts = self.moe_runner_config.num_local_experts + routing_method_type = layer.routing_method_type + assert routing_method_type is not None + # DeepSeekV3 style routing requires float32 router logits, + # see this PR for details: https://github.com/flashinfer-ai/flashinfer/commit/d84e1d560da0a27961c19ca788d96c19cb9dcfb6 + if routing_method_type == RoutingMethodType.DeepSeekV3: + router_logits = router_logits.to(torch.float32) + routed_scaling_factor = self.moe_runner_config.routed_scaling_factor + routed_scaling_factor = ( + routed_scaling_factor if routed_scaling_factor is not None else 1.0 + ) + + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + num_tokens = x.shape[0] + hidden_size = x.shape[-1] + symm_output = torch.empty( + num_tokens, hidden_size, dtype=torch.bfloat16, device=x.device + ) + + output = trtllm_mxint4_block_scale_moe( + routing_logits=router_logits, # float + routing_bias=correction_bias, + hidden_states=x, + gemm1_weights=layer.w13_weight_packed, + gemm1_weights_scale=layer.w13_weight_scale, + gemm1_alpha=self.moe_runner_config.gemm1_alpha, + gemm1_beta=None, + gemm1_clamp_limit=self.moe_runner_config.gemm1_clamp_limit, + gemm2_weights=layer.w2_weight_packed, + gemm2_weights_scale=layer.w2_weight_scale, + num_experts=self.moe_runner_config.num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=self.moe_runner_config.intermediate_size_per_partition, + local_expert_offset=self.moe_ep_rank * local_num_experts, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=routing_method_type, + tune_max_num_tokens=next_power_of_2(x.shape[0]), + output=symm_output, + ) + + return StandardCombineInput(hidden_states=output) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py index 5fe7b1acd001..1072dac7b3ca 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py @@ -13,7 +13,7 @@ PerTensorScaleParameter, ) from sglang.srt.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, + CompressedTensorsLinearScheme, ) from sglang.srt.layers.quantization.fp4_utils import get_fp4_gemm_runner_backend from sglang.srt.layers.quantization.modelopt_quant import ( @@ -28,7 +28,7 @@ __all__ = ["CompressedTensorsW4A4Fp4"] -class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): +class CompressedTensorsW4A4Fp4(CompressedTensorsLinearScheme): def __init__(self): self.group_size = 16 diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py new file mode 100644 index 000000000000..01f909a3093f --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py @@ -0,0 +1,421 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch + +from sglang.srt.distributed import get_tp_group +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) +from sglang.srt.layers.dp_attention import is_allocation_symmetric +from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig +from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType +from sglang.srt.layers.moe.utils import RoutingMethodType, get_moe_runner_backend +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsMoEScheme, +) +from sglang.srt.layers.quantization.fp8_utils import is_blackwell_supported +from sglang.srt.layers.quantization.utils import ( + prepare_static_weights_for_trtllm_fp4_moe, + reorder_w1w3_to_w3w1, + swizzle_blockscale, +) +from sglang.srt.utils import next_power_of_2, set_weight_attrs + +logger = logging.getLogger(__name__) + +__all__ = ["CompressedTensorsW4A4Nvfp4MoE"] + +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) + + +class CompressedTensorsW4A4Nvfp4MoE(CompressedTensorsMoEScheme): + + def __init__(self): + if not is_blackwell_supported(): + raise ValueError( + "Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above." + ) + self.group_size = 16 + self.use_flashinfer_trtllm = get_moe_runner_backend().is_flashinfer_trtllm() + + @classmethod + def get_min_capability(cls) -> int: + # Requires sm100(blackwell) architecture + return 100 + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + layer.params_dtype = params_dtype + + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + requires_grad=False, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Weight Scales + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // self.group_size, + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # Weight Global Scales + w13_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w13_weight_scale_2, extra_weight_attrs) + + w2_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) + + # Input Global Scales + w13_input_scale = torch.nn.Parameter( + torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_input_global_scale", w13_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_input_global_scale", w2_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # From packed to weight + layer.w13_weight = torch.nn.Parameter( + layer.w13_weight_packed.data, requires_grad=False + ) + delattr(layer, "w13_weight_packed") + + layer.w2_weight = torch.nn.Parameter( + layer.w2_weight_packed.data, requires_grad=False + ) + delattr(layer, "w2_weight_packed") + + if self.use_flashinfer_trtllm: + w, s = reorder_w1w3_to_w3w1( + layer.w13_weight.data, layer.w13_weight_scale.data, dim=-2 + ) + layer.w13_weight = torch.nn.Parameter(w, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(s, requires_grad=False) + + if not torch.allclose( + layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1] + ): + logger.warning_once( + "w1_weight_global_scale must match w3_weight_global_scale. " + "Accuracy may be affected." + ) + + # Take inverse of global scale saved to disk + layer.w13_weight_scale_2 = torch.nn.Parameter( + 1 / layer.w13_weight_global_scale[:, 0], requires_grad=False + ) + + layer.w2_weight_scale_2 = torch.nn.Parameter( + 1 / layer.w2_weight_global_scale.data, requires_grad=False + ) + + # w13 + if self.use_flashinfer_trtllm: + w13_input_global_scale = ( + layer.w13_input_global_scale.min() + .to(torch.float32) + .expand(layer.num_local_experts) + ) + else: + w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to( + torch.float32 + ) + layer.g1_alphas = torch.nn.Parameter( + ((1 / w13_input_global_scale) * layer.w13_weight_scale_2), + requires_grad=False, + ) + + layer.w13_input_scale_quant = torch.nn.Parameter( + (w13_input_global_scale), requires_grad=False + ) + + # w2 + if self.use_flashinfer_trtllm: + w2_input_global_scale = ( + layer.w2_input_global_scale.min() + .to(torch.float32) + .expand(layer.num_local_experts) + ) + else: + w2_input_global_scale = layer.w2_input_global_scale + + layer.g2_alphas = torch.nn.Parameter( + ((1 / w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32), + requires_grad=False, + ) + + layer.w2_input_scale_quant = torch.nn.Parameter( + (w2_input_global_scale), requires_grad=False + ) + + # TensorRT-LLM specific processing + if self.use_flashinfer_trtllm: + # Prepare static weights for TRT-LLM kernel + ( + gemm1_weights_fp4_shuffled, + gemm1_scales_fp4_shuffled, + gemm2_weights_fp4_shuffled, + gemm2_scales_fp4_shuffled, + ) = prepare_static_weights_for_trtllm_fp4_moe( + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + layer.w2_weight.size(-2), # hidden_size + layer.w13_weight.size(-2) // 2, # intermediate_size + layer.w13_weight.size(0), # num_experts + ) + logger.debug("Finished shuffling weights for TRT-LLM MOE") + + layer.gemm1_weights_fp4_shuffled = torch.nn.Parameter( + gemm1_weights_fp4_shuffled, requires_grad=False + ) + layer.gemm2_weights_fp4_shuffled = torch.nn.Parameter( + gemm2_weights_fp4_shuffled, requires_grad=False + ) + layer.gemm1_scales_fp4_shuffled = torch.nn.Parameter( + gemm1_scales_fp4_shuffled, requires_grad=False + ) + layer.gemm2_scales_fp4_shuffled = torch.nn.Parameter( + gemm2_scales_fp4_shuffled, requires_grad=False + ) + + # Additional parameter needed for TRT-LLM + layer.g1_scale_c = torch.nn.Parameter( + (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), + requires_grad=False, + ) + + # Clean up weights that won't be used by TRT-LLM + del layer.w2_weight + del layer.w2_weight_scale + del layer.w13_weight + del layer.w13_weight_scale + else: + # swizzle weight scales + layer.w13_weight_scale = torch.nn.Parameter( + swizzle_blockscale(layer.w13_weight_scale), requires_grad=False + ) + + layer.w2_weight_scale = torch.nn.Parameter( + swizzle_blockscale(layer.w2_weight_scale), requires_grad=False + ) + + layer.cutlass_moe_params = CutlassMoEParams( + CutlassMoEType.BlockscaledFP4, + layer.w13_weight.device, + num_experts=layer.num_experts, + intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, + hidden_size=layer.w13_weight.shape[2] * 2, + ) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + + def apply_weights( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + if self.use_flashinfer_trtllm: + from flashinfer import fp4_quantize, trtllm_fp4_block_scale_moe + + router_logits = topk_output.router_logits + topk_config = topk_output.topk_config + + # Quantize input hidden states using fp4_quantize + hs_fp4_bytes, hs_sf_bytes = fp4_quantize( + x, + layer.w13_input_scale_quant, + self.group_size, # sf_vec_size + False, # use_ue8m0 + False, # is_sf_swizzled_layout + ) + hs_fp4 = hs_fp4_bytes.reshape(x.shape[0], x.shape[1] // 2) + hs_scale = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1) + + correction_bias = ( + None + if topk_config.correction_bias is None + else topk_config.correction_bias.to(x.dtype) + ) + + assert layer.routing_method_type is not None + + # DeepSeekV3 style routing requires float32 router logits + if layer.routing_method_type == RoutingMethodType.DeepSeekV3: + router_logits = router_logits.to(torch.float32) + + routed_scaling_factor = self.moe_runner_config.routed_scaling_factor + routed_scaling_factor = ( + routed_scaling_factor if routed_scaling_factor is not None else 1.0 + ) + + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + num_tokens = hs_fp4.shape[0] + hidden_size = ( + hs_fp4.shape[-1] * 2 + if hs_fp4.dtype == torch.uint8 + else hs_fp4.shape[-1] + ) + symm_output = torch.empty( + num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device + ) + + output = trtllm_fp4_block_scale_moe( + routing_logits=router_logits, + routing_bias=correction_bias, + hidden_states=hs_fp4, + hidden_states_scale=hs_scale, + gemm1_weights=layer.gemm1_weights_fp4_shuffled, + gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.view( + torch.float8_e4m3fn + ), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=layer.gemm2_weights_fp4_shuffled, + gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.view( + torch.float8_e4m3fn + ), + gemm2_bias=None, + output1_scale_scalar=layer.g1_scale_c, + output1_scale_gate_scalar=layer.g1_alphas, + output2_scale_scalar=layer.g2_alphas, + num_experts=layer.num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=layer.intermediate_size_per_partition, + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=layer.num_local_experts, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=layer.routing_method_type, + do_finalize=True, + tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]), + output=symm_output, + )[0] + else: + from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 + + topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids + + output = cutlass_moe_fp4( + a=x, + a1_gscale=layer.w13_input_scale_quant, + w1_fp4=layer.w13_weight, + w1_blockscale=layer.w13_weight_scale, + w1_alphas=layer.g1_alphas, + a2_gscale=layer.w2_input_scale_quant, + w2_fp4=layer.w2_weight, + w2_blockscale=layer.w2_weight_scale, + w2_alphas=layer.g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + params=layer.cutlass_moe_params, + apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input, + ).to(x.dtype) + + return StandardCombineInput(hidden_states=output) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int8_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int8_moe.py new file mode 100644 index 000000000000..b45b63fc4e44 --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int8_moe.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch + +from sglang.srt.hardware_backend.npu.quantization.fused_moe_method_npu import ( + NPUW4A8Int8DynamicMoEMethod, +) +from sglang.srt.layers.moe import MoeRunnerConfig +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsMoEScheme, +) +from sglang.srt.utils import set_weight_attrs + +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) + +__all__ = ["NPUCompressedTensorsW4A8Int8DynamicMoE"] + + +logger = logging.getLogger(__name__) + + +class NPUCompressedTensorsW4A8Int8DynamicMoE(CompressedTensorsMoEScheme): + + ### TODO: Get rid of code duplication with python/sglang/srt/modelslim/modelslim_moe.py @OrangeRedeng @TamirBaydasov + def __init__(self, quantization_config) -> None: + self.group_size = 0 + self.is_per_channel_weight = self.group_size == 0 + self.tp_size = 1 + self.activation_use_clip = ( + quantization_config.get("config_groups", {}) + .get("group_1", {}) + .get("activation_use_clip", False) + ) + self.kernel = NPUW4A8Int8DynamicMoEMethod() + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + self.num_experts = num_experts + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + + # >> weight + w13_output_size = intermediate_size_per_partition + w2_output_size = hidden_size // 2 + w13_weight = torch.nn.Parameter( + torch.empty(num_experts, w13_output_size, hidden_size, dtype=torch.int8), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + w2_output_size, + intermediate_size_per_partition, + dtype=torch.int8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # >> scale + weight_scale_dtype = torch.int64 if self.activation_use_clip else torch.float32 + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, 1, dtype=weight_scale_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # >> offset + w13_weight_offset = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_offset", w13_weight_offset) + set_weight_attrs(w13_weight_offset, extra_weight_attrs) + + w2_weight_offset = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_offset", w2_weight_offset) + set_weight_attrs(w2_weight_offset, extra_weight_attrs) + + # >>> special param for w4a8 + if self.activation_use_clip: + self._init_activation_clip_params( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition, + extra_weight_attrs, + ) + else: + self._init_extra_scale_params( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition, + extra_weight_attrs, + ) + + def _init_activation_clip_params( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + extra_weight_attrs: dict, + ) -> None: + """ + Initializes bias and alpha parameters for quantization schemes that use activation clipping. + + This helper registers `w13_bias`, `w2_bias`, and `w2_alpha`, which are required to + shift and scale the activations or outputs to compensate for the precision loss + introduced by clamping activations. + """ + w13_bias = torch.nn.Parameter( + torch.ones( + num_experts, 2 * intermediate_size_per_partition, dtype=torch.float + ), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + w2_bias = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, dtype=torch.float), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + + w2_alpha = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float), requires_grad=False + ) + layer.register_parameter("w2_alpha", w2_alpha) + set_weight_attrs(w2_alpha, extra_weight_attrs) + + def _init_extra_scale_params( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + extra_weight_attrs: dict, + ) -> None: + """ + Initializes additional scaling, offset, and bias parameters for quantization schemes without activation clipping. + + This method registers the following parameters: + 1. Scale Biases: `w13_scale_bias` and `w2_scale_bias`. + 2. Secondary Quantization Params (initialized only for grouped quantization): + `w13_weight_scale_second`, `w13_weight_offset_second`, + `w2_weight_scale_second`, and `w2_weight_offset_second`. + """ + if not self.is_per_channel_weight: + w13_weight_scale_second = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_second", w13_weight_scale_second) + set_weight_attrs(w13_weight_scale_second, extra_weight_attrs) + + w13_weight_offset_second = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter( + "w13_weight_offset_second", w13_weight_offset_second + ) + set_weight_attrs(w13_weight_offset_second, extra_weight_attrs) + + w2_weight_scale_second = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale_second", w2_weight_scale_second) + set_weight_attrs(w2_weight_scale_second, extra_weight_attrs) + + w2_weight_offset_second = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_offset_second", w2_weight_offset_second) + set_weight_attrs(w2_weight_offset_second, extra_weight_attrs) + + w13_scale_bias = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w13_scale_bias", w13_scale_bias) + set_weight_attrs(w13_scale_bias, extra_weight_attrs) + + w2_scale_bias = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, 16 // self.tp_size, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w2_scale_bias", w2_scale_bias) + set_weight_attrs(w2_scale_bias, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading( + layer, self.is_per_channel_weight, self.activation_use_clip + ) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply_weights( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + return self.kernel.apply(layer, dispatch_output) + + def apply_weights_with_router_logits( + self, + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ): + return self.kernel.apply_without_routing_weights( + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py index 35d579de47d9..353f049b9953 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py @@ -12,7 +12,7 @@ PerTensorScaleParameter, ) from sglang.srt.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, + CompressedTensorsLinearScheme, ) from sglang.srt.layers.quantization.marlin_utils_fp8 import ( apply_fp8_marlin_linear, @@ -25,7 +25,7 @@ SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR] -class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): +class CompressedTensorsW8A16Fp8(CompressedTensorsLinearScheme): def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 423e74132610..b4334481a87c 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -14,7 +14,7 @@ PerTensorScaleParameter, ) from sglang.srt.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, + CompressedTensorsLinearScheme, ) from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.fp8_utils import ( @@ -42,7 +42,7 @@ } -class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): +class CompressedTensorsW8A8Fp8(CompressedTensorsLinearScheme): def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool): self.weight_quant = weight_quant self.strategy = self.weight_quant.strategy diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8_moe.py new file mode 100644 index 000000000000..71651efb6b0f --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8_moe.py @@ -0,0 +1,384 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch +from compressed_tensors.quantization import QuantizationStrategy + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig +from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsMoEScheme, +) +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant +from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz +from sglang.srt.layers.quantization.utils import all_close_1d, per_tensor_dequantize +from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs + +if TYPE_CHECKING: + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) + +__all__ = ["CompressedTensorsW8A8Fp8MoE"] + +_is_hip = is_hip() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip + +if _use_aiter: + from aiter import ActivationType, QuantType + from aiter.fused_moe import fused_moe + from aiter.ops.shuffle import shuffle_weight + + +logger = logging.getLogger(__name__) + + +class CompressedTensorsW8A8Fp8MoE(CompressedTensorsMoEScheme): + + def __init__(self, weight_quant, input_quant): + self.weight_quant = weight_quant + self.input_quant = input_quant + + per_tensor = ( + self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy == QuantizationStrategy.TENSOR + ) + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN + ) + if not (per_tensor or per_channel): + assert self.weight_quant.strategy == QuantizationStrategy.BLOCK + self.weight_block_size = self.weight_quant.block_structure + assert self.weight_quant.dynamic is not None + else: + self.weight_block_size = None + self.block_quant = self.weight_block_size is not None + + self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales and per_channel: + raise ValueError( + "For FP8 Fused MoE layer, we require either per tensor or " + "channelwise, dynamic per token quantization." + ) + + @classmethod + def get_min_capability(cls) -> int: + # ampere and up + return 80 + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + params_dtype = torch.float8_e4m3fn + + if self.block_quant: + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + tp_size = get_tensor_model_parallel_world_size() + block_n, block_k = ( + self.weight_block_size[0], + self.weight_block_size[1], + ) + # NOTE: To ensure proper alignment of the block-wise quantization + # scales, the output_size of the weights for both the gate and up + # layers must be divisible by block_n. + # Required by column parallel or enabling merged weights + if intermediate_size_per_partition % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1 and intermediate_size_per_partition % block_k != 0: + # Required by row parallel + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # per-tensor quantization + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + weight_quant_method = FusedMoeWeightScaleSupported.TENSOR.value + elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value + elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + weight_quant_method = FusedMoeWeightScaleSupported.BLOCK.value + else: + raise ValueError( + f"Unsupported weight quantization strategy: {self.weight_quant.strategy}" + ) + + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update({"quant_method": weight_quant_method}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + assert ( + self.input_quant.strategy == QuantizationStrategy.TENSOR + ), "Only per-tensor quantization is supported for static input scales" + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module | FusedMoE) -> None: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.static_input_scales: + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): + logger.warning( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer." + ) + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False + ) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False + ) + + if is_fp8_fnuz(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale + ) + ) + w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale + ) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False + ) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False + ) + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.num_local_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], + ) + ( + layer.w13_weight[expert_id][start : start + shard_size, :], + _, + ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) + + if self.weight_quant.strategy == QuantizationStrategy.CHANNEL and _use_aiter: + with torch.no_grad(): + # Pre-shuffle weights + layer.w13_weight = torch.nn.Parameter( + shuffle_weight(layer.w13_weight.data, (16, 16)), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + shuffle_weight(layer.w2_weight.data, (16, 16)), + requires_grad=False, + ) + torch.cuda.empty_cache() + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + + def apply_weights( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + moe_runner_config = self.moe_runner_config + + if _use_aiter and self.weight_quant.strategy == QuantizationStrategy.CHANNEL: + assert not moe_runner_config.no_combine, "unsupported" + topk_weights, topk_ids, _ = topk_output + if moe_runner_config.apply_router_weight_on_input: + assert ( + topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + x = x * topk_weights.to(x.dtype) + topk_weights = torch.ones_like( + topk_weights, dtype=torch.float32 + ) # topk_weights must be FP32 (float32) + output = fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=( + ActivationType.Silu + if moe_runner_config.activation == "silu" + else ActivationType.Gelu + ), + quant_type=QuantType.per_Token, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + return StandardCombineInput(hidden_states=output) + elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, + use_fp8_w8a8=True, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a13_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.weight_block_size, + ) + return self.runner.run(dispatch_output, quant_info) + else: + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, + use_fp8_w8a8=True, + per_channel_quant=self.weight_quant.strategy + == QuantizationStrategy.CHANNEL, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a13_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + return self.runner.run(dispatch_output, quant_info) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index efcd4b611fa9..05c5410b5751 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -16,18 +16,20 @@ PerTensorScaleParameter, ) from sglang.srt.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, + CompressedTensorsLinearScheme, ) from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.layers.quantization.utils import requantize_with_max_scale from sglang.srt.utils import is_cuda +__all__ = ["CompressedTensorsW8A8Int8", "NPUCompressedTensorsW8A8Int8"] + _is_cuda = is_cuda() if _is_cuda: from sgl_kernel import int8_scaled_mm -class CompressedTensorsW8A8Int8(CompressedTensorsScheme): +class CompressedTensorsW8A8Int8(CompressedTensorsLinearScheme): def __init__( self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8_moe.py new file mode 100644 index 000000000000..b391a1c6999a --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8_moe.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch +from compressed_tensors.quantization import QuantizationStrategy + +from sglang.srt.hardware_backend.npu.quantization.fused_moe_method_npu import ( + NPUW8A8Int8DynamicMoEMethod, +) +from sglang.srt.layers.moe import MoeRunnerConfig +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsMoEScheme, +) +from sglang.srt.utils import set_weight_attrs + +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) + +__all__ = ["NPUCompressedTensorsW8A8Int8DynamicMoE"] + +logger = logging.getLogger(__name__) + + +class NPUCompressedTensorsW8A8Int8DynamicMoE(CompressedTensorsMoEScheme): + + def __init__(self, weight_quant, input_quant): + self.weight_quant = weight_quant + self.input_quant = input_quant + self.kernel = NPUW8A8Int8DynamicMoEMethod() + + self.static_input_scales = not self.input_quant.dynamic + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN + ) + if not per_channel: + raise ValueError( + "For INT8 Fused MoE layers, we require channelwise, " + "dynamic per token quantization. Found " + f"{self.weight_quant}, {self.input_quant}" + ) + + self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales: + raise ValueError( + "For INT8 Fused MoE layers, we require channelwise, " + "dynamic per token quantization. Found static input scales." + ) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + params_dtype = torch.int8 + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + assert not self.static_input_scales + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply_weights( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + return self.kernel.apply(layer, dispatch_output) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 415bb79602b3..15375212c30b 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -19,7 +19,7 @@ permute_param_layout_, ) from sglang.srt.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, + CompressedTensorsLinearScheme, ) from sglang.srt.layers.quantization.marlin_utils import ( MarlinLinearLayerConfig, @@ -59,7 +59,7 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) -class CompressedTensorsWNA16(CompressedTensorsScheme): +class CompressedTensorsWNA16(CompressedTensorsLinearScheme): _kernel_backends_being_used: set[str] = set() def __init__(self, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py new file mode 100644 index 000000000000..6264f36d0459 --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16_moe.py @@ -0,0 +1,621 @@ +from __future__ import annotations + +import enum +import logging +from enum import Enum +from typing import TYPE_CHECKING + +import torch +from compressed_tensors import CompressionFormat + +from sglang.srt.hardware_backend.npu.quantization.fused_moe_method_npu import ( + NPUW4A16Int4DynamicMoEMethod, +) +from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + WNA16_SUPPORTED_BITS, + CompressedTensorsMoEScheme, +) +from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack +from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales +from sglang.srt.layers.quantization.utils import replace_parameter +from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, set_weight_attrs + +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) + from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( + CompressedTensorsConfig, + ) + + +__all__ = [ + "CompressedTensorsWNA16MoE", + "CompressedTensorsWNA16TritonMoE", + "NPUCompressedTensorsW4A16Int4DynamicMoE", +] + +_is_hip = is_hip() +_is_cuda = is_cuda() + +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip + +if _use_aiter: + pass + + +logger = logging.getLogger(__name__) + + +class GPTQMarlinState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + +class CompressedTensorsWNA16MoE(CompressedTensorsMoEScheme): + + def __init__(self, quant_config: CompressedTensorsConfig, num_gpu_experts=-1): + self.quant_config = quant_config + config = self.quant_config.target_scheme_map["Linear"].get("weights") + self.num_bits = config.num_bits + self.packed_factor = 32 // config.num_bits + self.strategy = config.strategy + self.group_size = config.group_size + self.actorder = config.actorder + assert config.symmetric, "Only symmetric quantization is supported for MoE" + + if not ( + self.quant_config.quant_format == CompressionFormat.pack_quantized.value + and self.num_bits in WNA16_SUPPORTED_BITS + ): + raise ValueError( + "For Fused MoE layers, only ", + f"{CompressionFormat.pack_quantized.value} ", + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}", + ) + self.num_gpu_experts = num_gpu_experts + + @classmethod + def get_min_capability(cls) -> int: + # ampere and up + return 80 + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Will transpose the loaded weight along the + # intermediate and hidden dim sizes. Will + # shard for TP along the transposed dims + extra_weight_attrs.update( + {"is_transposed": True, "quant_method": self.strategy} + ) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.packed_factor, + 2 * intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition // self.packed_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # In the case where we have actorder/g_idx, + # we do not partition the w2 scales + load_full_w2 = self.actorder and self.group_size != -1 + + if load_full_w2: + w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size + else: + w2_scales_size = intermediate_size_per_partition + + self.is_k_full = (not self.actorder) or layer.moe_tp_size == 1 + + if self.strategy == "channel": + num_groups_w2 = num_groups_w13 = 1 + self.group_size = -1 + else: + num_groups_w2 = w2_scales_size // self.group_size + num_groups_w13 = hidden_size // self.group_size + + w13_scale = torch.nn.Parameter( + torch.ones( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_scale) + set_weight_attrs(w13_scale, extra_weight_attrs) + + w2_scale = torch.nn.Parameter( + torch.ones(num_experts, num_groups_w2, hidden_size, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_scale) + set_weight_attrs(w2_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2}) + + w2_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + w13_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + layer.a13_scale = None + layer.a2_scale = None + layer.marlin_state = GPTQMarlinState.REPACK + + if not hasattr(layer, "_original_shapes"): + layer._original_shapes = {} + + # Force record: these are the target GPTQ shapes for rollback. + layer._original_shapes["w13_weight_packed"] = tuple(w13_weight.shape) + layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape) + + # Also record the shapes of the scales. + layer._original_shapes["w2_weight_scale"] = tuple(w2_scale.shape) + layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # Skip if the layer is already converted to Marlin format to prevent double-packing. + if getattr(layer, "is_marlin_converted", False): + return + + if not hasattr(layer, "_original_shapes"): + layer._original_shapes = {} + + def replace_tensor(name, new_t): + target_attr = getattr(layer, name) + + # Only save if the key doesn't exist to prevent overwriting with Marlin shapes. + if name not in layer._original_shapes: + # This is a safety check; `create_weights` usually handles this already. + layer._original_shapes[name] = tuple(target_attr.shape) + + # It is important to use resize_() here since it ensures + # the same buffer is reused + target_attr.resize_(new_t.shape) + target_attr.copy_(new_t) + del new_t + + num_experts = layer.w13_weight_g_idx.shape[0] + device = layer.w13_weight_g_idx.device + + # when running models with grouped act order, + # resort to g_idx values provided in checkpoint + if self.actorder == "group": + w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx) + + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_weight_g_idx[e]).to( + torch.int32 + ) + w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_weight_g_idx[e]).to( + torch.int32 + ) + w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][ + w13_g_idx_sort_indices[e] + ] + w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][w2_g_idx_sort_indices[e]] + + replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx) + replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx) + replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices) + + else: + layer.w13_weight_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_weight_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = gptq_marlin_moe_repack( + layer.w13_weight_packed, + layer.w13_g_idx_sort_indices, + layer.w13_weight_packed.shape[1] * self.packed_factor, + layer.w13_weight_packed.shape[2], + self.num_bits, + ) + replace_tensor("w13_weight_packed", marlin_w13_qweight) + marlin_w2_qweight = gptq_marlin_moe_repack( + layer.w2_weight_packed, + layer.w2_g_idx_sort_indices, + layer.w2_weight_packed.shape[1] * self.packed_factor, + layer.w2_weight_packed.shape[2], + self.num_bits, + ) + replace_tensor("w2_weight_packed", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + layer.w13_weight_scale, + layer.w13_weight_packed.shape[2], + layer.w13_weight_scale.shape[2], + self.group_size, + ) + replace_tensor("w13_weight_scale", marlin_w13_scales) + + marlin_w2_scales = marlin_moe_permute_scales( + layer.w2_weight_scale, + layer.w2_weight_scale.shape[1] + * (self.group_size if self.group_size != -1 else self.packed_factor), + layer.w2_weight_scale.shape[2], + self.group_size, + ) + replace_tensor("w2_weight_scale", marlin_w2_scales) + + layer.is_marlin_converted = True + + def restore_weights_before_loading(self, layer: torch.nn.Module): + """Forcibly resize parameters back to their original shapes (e.g., GPTQ format) before loading weights.""" + + if not hasattr(layer, "_original_shapes"): + return + + for name, orig_shape in layer._original_shapes.items(): + param = getattr(layer, name, None) + + if param is not None and param.shape != orig_shape: + param.resize_(orig_shape) + + layer.is_marlin_converted = False + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply_weights( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import ( + fused_marlin_moe, + ) + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + assert ( + self.moe_runner_config.activation == "silu" + ), "Only SiLU activation is supported." + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + topk_weights, topk_ids, router_logits = topk_output + + # Get expert_map for EP support + expert_map = None + global_num_experts = -1 + if hasattr(layer, "dispatcher") and hasattr( + layer.dispatcher, "local_expert_mapping" + ): + expert_map = layer.dispatcher.local_expert_mapping + if expert_map is not None: + global_num_experts = self.moe_runner_config.num_experts + + output = fused_marlin_moe( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, + g_idx1=layer.w13_weight_g_idx, + g_idx2=layer.w2_weight_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + num_bits=self.num_bits, + is_k_full=self.is_k_full, + routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, + ) + return StandardCombineInput(hidden_states=output) + + +class CompressedTensorsWNA16TritonMoE(CompressedTensorsWNA16MoE): + """ROCm/HIP-compatible W4A16 MoE method using Triton kernels instead of Marlin. + + Inherits weight creation from CompressedTensorsWNA16MoE but converts + weights to the uint8-packed format expected by the Triton fused MoE kernel + instead of the Marlin-specific format. + """ + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if getattr(layer, "is_triton_converted", False): + return + + num_experts = layer.w13_weight_packed.shape[0] + + # Convert w13 weights: [E, K//8, N] int32 -> [E, N, K//2] uint8 + w13 = layer.w13_weight_packed.data + w13 = w13.transpose(1, 2).contiguous().view(torch.uint8) + layer.w13_weight_packed = torch.nn.Parameter(w13, requires_grad=False) + + # Convert w2 weights: [E, K//8, N] int32 -> [E, N, K//2] uint8 + w2 = layer.w2_weight_packed.data + w2 = w2.transpose(1, 2).contiguous().view(torch.uint8) + layer.w2_weight_packed = torch.nn.Parameter(w2, requires_grad=False) + + # Convert w13 scales: [E, K//group_size, N] -> [E, N, K//group_size] + w13_scale = layer.w13_weight_scale.data + w13_scale = w13_scale.transpose(1, 2).contiguous() + layer.w13_weight_scale = torch.nn.Parameter(w13_scale, requires_grad=False) + + # Convert w2 scales: [E, K//group_size, N] -> [E, N, K//group_size] + w2_scale = layer.w2_weight_scale.data + w2_scale = w2_scale.transpose(1, 2).contiguous() + layer.w2_weight_scale = torch.nn.Parameter(w2_scale, requires_grad=False) + + layer.is_triton_converted = True + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + + def apply_weights( + self, + layer: torch.nn.Module, + dispatch_output: "StandardDispatchOutput", + ) -> "CombineInput": + from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo + + assert ( + self.moe_runner_config.activation == "silu" + ), "Only SiLU activation is supported." + + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight_packed, + w2_weight=layer.w2_weight_packed, + use_int4_w4a16=True, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + block_shape=[0, self.group_size], + ) + return self.runner.run(dispatch_output, quant_info) + + +class NPUCompressedTensorsW4A16Int4DynamicMoE(CompressedTensorsMoEScheme): + + def __init__(self, quantization_config) -> None: + self.pack_factor = 8 # weight dtype is int4, but use int32 to create + target = ( + "MoEGMM" if "MoEGMM" in quantization_config.target_scheme_map else "Linear" + ) + if target in quantization_config.target_scheme_map: + self.group_size = quantization_config.target_scheme_map[target][ + "weights" + ].group_size + else: + self.group_size = 128 + + self.kernel = NPUW4A16Int4DynamicMoEMethod() + + # TODO: See if we can merge this method's logic + # with CompressedTensorsWNA16MoE. Need more models and tests. + # @OrangeRedeng @TamirBaydasov + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + self.num_experts = num_experts + if ( + extra_weight_attrs.get( + "intermediate_size_full", intermediate_size_per_partition + ) + // intermediate_size_per_partition + > 1 + ): + quant_method = FusedMoeWeightScaleSupported.GROUP.value + else: + quant_method = FusedMoeWeightScaleSupported.CHANNEL.value + extra_weight_attrs.update({"quant_method": quant_method}) + # weight + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # scale + weight_scale_dtype = torch.bfloat16 + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + w2_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # offset + w13_weight_offset = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_offset", w13_weight_offset) + set_weight_attrs(w13_weight_offset, extra_weight_attrs) + + w2_weight_offset = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_offset", w2_weight_offset) + set_weight_attrs(w2_weight_offset, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply_weights( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + return self.kernel.apply(layer, dispatch_output) + + def apply_without_routing_weights( + self, + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ): + return self.kernel.apply_without_routing_weights( + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + )