diff --git a/python/sglang/multimodal_gen/runtime/layers/vocab_parallel_embedding.py b/python/sglang/multimodal_gen/runtime/layers/vocab_parallel_embedding.py index fbddaab40632..9572c6a3821b 100644 --- a/python/sglang/multimodal_gen/runtime/layers/vocab_parallel_embedding.py +++ b/python/sglang/multimodal_gen/runtime/layers/vocab_parallel_embedding.py @@ -438,8 +438,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert loaded_weight.shape[output_dim] == ( self.org_vocab_size // param.packed_factor ) - start_idx = start_idx // packed_factor - shard_size = shard_size // packed_factor + start_idx = round(start_idx // packed_factor) + shard_size = round(shard_size // packed_factor) else: assert loaded_weight.shape[output_dim] == self.org_vocab_size diff --git a/python/sglang/multimodal_gen/runtime/models/parameter.py b/python/sglang/multimodal_gen/runtime/models/parameter.py index ba9b42c664a8..5a30e3b809e7 100644 --- a/python/sglang/multimodal_gen/runtime/models/parameter.py +++ b/python/sglang/multimodal_gen/runtime/models/parameter.py @@ -418,6 +418,6 @@ def permute_param_layout_( def _adjust_shard_indexes_for_packing( shard_size, shard_offset, packed_factor ) -> tuple[Any, Any]: - shard_size = shard_size // packed_factor - shard_offset = shard_offset // packed_factor + shard_size = round(shard_size // packed_factor) + shard_offset = round(shard_offset // packed_factor) return shard_size, shard_offset diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 5a3997deeac9..442c2f0c7919 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -962,6 +962,7 @@ def _verify_quantization(self) -> None: "petit_nvfp4", "quark", "modelslim", + "humming", ] compatible_quantization_methods = { "modelopt_fp8": ["modelopt"], diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 949af5c320ac..7468a03a8739 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -1,3 +1,4 @@ +import json import os import subprocess import warnings @@ -110,6 +111,16 @@ def parse(self, value: str) -> str: return value +class EnvJSON(EnvField): + def parse(self, value: str | None) -> list | dict | None: + if not value: + return None + if os.path.exists(value): + with open(value) as f: + return json.load(f) + return json.loads(value) + + class EnvBool(EnvField): def parse(self, value: str) -> bool: value = value.lower() @@ -307,6 +318,12 @@ class Envs: SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2 = EnvBool(False) SGLANG_NVFP4_CKPT_FP8_NEXTN_MOE = EnvBool(False) + # Quantization (Humming) + SGLANG_HUMMING_ONLINE_QUANT_CONFIG = EnvJSON(None) + SGLANG_HUMMING_INPUT_QUANT_CONFIG = EnvJSON(None) + SGLANG_HUMMING_USE_F16_ACCUM = EnvBool(False) + SGLANG_HUMMING_MOE_GEMM_TYPE = EnvStr("") + # Flashinfer SGLANG_IS_FLASHINFER_AVAILABLE = EnvBool(True) SGLANG_ENABLE_FLASHINFER_FP8_GEMM = EnvBool(False) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 1e48d7f0b799..153220e312bb 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -66,6 +66,7 @@ "IPEXAWQLinearMethod", "PetitNvFp4LinearMethod", "QuarkInt4Fp8LinearMethod", + "HummingLinearMethod", ] _is_cpu = is_cpu() @@ -209,6 +210,7 @@ def __init__( # All the linear layer supports quant method. assert self.quant_method is not None + self.with_bias = bias self.quant_method.create_weights( self, self.input_size, @@ -315,6 +317,7 @@ def __init__( input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix ) + self.with_bias = bias self.gather_output = gather_output self.use_presharded_weights = use_presharded_weights @@ -502,6 +505,7 @@ def __init__( tp_size: Optional[int] = None, use_presharded_weights: bool = False, ): + self.with_bias = bias self.output_sizes = output_sizes if tp_rank is None: tp_rank = get_tensor_model_parallel_rank() @@ -589,8 +593,8 @@ def weight_loader( # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = round(shard_size // param.pack_factor) + shard_offset = round(shard_offset // param.pack_factor) # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset @@ -622,8 +626,8 @@ def weight_loader( # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = round(shard_size // param.pack_factor) + shard_offset = round(shard_offset // param.pack_factor) # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset @@ -825,6 +829,7 @@ def __init__( v_head_size: Optional[int] = None, skip_block_quant_check: bool = False, ): + self.with_bias = bias self.hidden_size = hidden_size self.head_size = head_size self.v_head_size = v_head_size if v_head_size is not None else head_size @@ -1086,8 +1091,8 @@ def weight_loader( # If quantized, we need to adjust the offset and size to account # for the packing. if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = round(shard_size // param.pack_factor) + shard_offset = round(shard_offset // param.pack_factor) # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( @@ -1143,8 +1148,8 @@ def weight_loader( # for the packing. packed_dim = getattr(param, "packed_dim", None) if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor + shard_size = round(shard_size // param.pack_factor) + shard_offset = round(shard_offset // param.pack_factor) # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( @@ -1272,6 +1277,7 @@ def __init__( input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix ) + self.with_bias = bias self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 044c590f2200..98611df63ff5 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -1381,3 +1381,72 @@ def silu_and_mul_masked_post_per_tensor_quant_fwd( NUM_STAGE=NUM_STAGES, ) return output + + +def moe_permute( + inputs: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + use_int64_offset: bool = False, + is_ep: bool = False, + outputs: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from sgl_kernel import moe_permute_prepare + + expert_offsets, src2dst = moe_permute_prepare( + topk_ids=topk_ids, + num_experts=num_experts, + use_int64_offset=use_int64_offset, + is_ep=is_ep, + ) + output_shape = (topk_ids.nelement(), inputs.size(-1)) + if outputs is None: + outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) + + assert outputs.shape == output_shape + assert outputs.dtype == inputs.dtype + assert outputs.device == inputs.device + + deepep_permute_triton_kernel[(inputs.shape[0],)]( + inputs, + outputs, + src2dst, + topk_ids, + None, + topk_ids.size(1), + inputs.size(1), + BLOCK_SIZE=512, + ) + + return outputs, src2dst, expert_offsets + + +def moe_unpermute( + inputs: torch.Tensor, + src2dst: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + outputs: torch.Tensor | None = None, +) -> torch.Tensor: + num_tokens = topk_ids.size(0) + output_shape = (num_tokens, inputs.size(1)) + if outputs is None: + outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) + + assert outputs.shape == output_shape + assert outputs.dtype == inputs.dtype + assert outputs.device == inputs.device + + deepep_post_reorder_triton_kernel[(num_tokens,)]( + inputs, + outputs, + src2dst, + topk_ids, + topk_weights, + topk_ids.size(1), + inputs.size(1), + BLOCK_SIZE=512, + ) + + assert outputs is not None + return outputs diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 552346cb0f6b..656867d25bfa 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -95,7 +95,16 @@ def __init__( routed_scaling_factor=routed_scaling_factor, **kwargs, ) - if _use_aiter or _is_npu: + is_humming = ( + get_moe_runner_backend().is_humming() + or get_moe_runner_backend().is_auto() + and quant_config is not None + and quant_config.get_name() == "humming" + ) + if is_humming: + envs.SGLANG_DEEPEP_BF16_DISPATCH.set(True) + self.deprecate_flag = True + elif _use_aiter or _is_npu: self.deprecate_flag = False elif deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and isinstance( quant_config, Fp8Config diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 8154ca9ff65f..2b2084e4aee3 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -191,10 +191,13 @@ def __init__( if params_dtype is None: params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + self.layer_name = prefix self.layer_id = layer_id self.top_k = top_k self.hidden_size = hidden_size self.num_experts = num_experts + self.with_bias = with_bias self.num_fused_shared_experts = num_fused_shared_experts self.enable_flashinfer_cutlass_moe = ( diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py b/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py index 2840eb8fc66e..78080904b6e7 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py @@ -15,7 +15,10 @@ def moe_align_block_size( - topk_ids: torch.Tensor, block_size: int, num_experts: int + topk_ids: torch.Tensor, + block_size: int, + num_experts: int, + ignore_invalid_expert: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block @@ -81,5 +84,6 @@ def moe_align_block_size( num_tokens_post_pad, cumsum_buffer, True, + ignore_invalid_expert, ) return sorted_ids, expert_ids, num_tokens_post_pad diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/moe_fused_mul_sum.py b/python/sglang/srt/layers/moe/fused_moe_triton/moe_fused_mul_sum.py new file mode 100644 index 000000000000..c3efa0d59642 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/moe_fused_mul_sum.py @@ -0,0 +1,210 @@ +import torch +import triton +import triton.language as tl +from torch._subclasses.fake_tensor import FakeTensor + + +@triton.jit +def moe_fused_mul_sum_kernel( + inputs_ptr, + topk_weights_ptr, + outputs_ptr, + top_ids_ptr, + expert_map_ptr, + num_tokens, + stride_m, + has_expert_map: tl.constexpr, + is_ep: tl.constexpr, + top_k: tl.constexpr, + size: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_k = tl.program_id(0) + pid_m = tl.program_id(1) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = pid_k * BLOCK_K + tl.arange(0, BLOCK_K) + + m_mask = offs_m < num_tokens + k_mask = offs_k < size + mask = m_mask[:, None] & k_mask[None, :] + + a_base = inputs_ptr + (offs_m * stride_m)[:, None] + offs_k[None, :] + b_base = topk_weights_ptr + offs_m * top_k + + acc = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) + + for n in tl.static_range(top_k): + b_val = tl.load(b_base + n, mask=m_mask, other=0.0).to(tl.float32) + if has_expert_map: + id_val = tl.load(top_ids_ptr + offs_m * top_k + n, mask=m_mask, other=0) + expert_mask = (id_val >= 0) & (tl.load(expert_map_ptr + id_val, mask=id_val >= 0, other=-1) >= 0) + a_vec = tl.load( + a_base + n * size, + mask=mask & expert_mask[:, None], + other=0.0, + ).to(tl.float32) + elif is_ep: + id_val = tl.load(top_ids_ptr + offs_m * top_k + n, mask=m_mask, other=0) + expert_mask = id_val >= 0 + a_vec = tl.load( + a_base + n * size, + mask=mask & expert_mask[:, None], + other=0.0, + ).to(tl.float32) + else: + a_vec = tl.load( + a_base + n * size, + mask=mask, + other=0.0, + ).to(tl.float32) + acc += a_vec * b_val[:, None] + + out_ptrs = outputs_ptr + (offs_m * size)[:, None] + offs_k[None, :] + tl.store( + out_ptrs, + acc.to(outputs_ptr.dtype.element_ty), + mask=mask, + ) + + +def _heuristic_config( + num_tokens: int, + top_k: int, + size: int, + element_size: int, +): + is_fp32 = element_size > 2 + is_sm90_plus = torch.cuda.get_device_capability() >= (9, 0) + is_sm80_before = torch.cuda.get_device_capability() < (8, 0) + + if is_sm90_plus: + # SM90/SM100+: prefer small tiles + many CTAs. + if is_fp32: + BLOCK_M = 1 if num_tokens <= 4 else 2 + else: + if num_tokens <= 4: + BLOCK_M = 1 + elif num_tokens <= 128: + BLOCK_M = 2 + else: + BLOCK_M = 4 + elif is_fp32: + if num_tokens <= 4: + BLOCK_M = 1 + elif num_tokens <= 32: + BLOCK_M = 2 + elif num_tokens <= 128: + BLOCK_M = 4 + else: + BLOCK_M = 4 + else: + if num_tokens <= 4: + BLOCK_M = 1 + elif num_tokens <= 32: + BLOCK_M = 2 + elif num_tokens <= 128: + BLOCK_M = 4 + elif num_tokens <= 1024: + BLOCK_M = 16 + else: + BLOCK_M = 8 + + if is_fp32: + max_block_k = 256 + elif is_sm80_before or is_sm90_plus: + max_block_k = 512 + else: + max_block_k = 1024 + BLOCK_K = min(triton.next_power_of_2(size), max_block_k) + BLOCK_K = max(BLOCK_K, 256) + + total = BLOCK_M * BLOCK_K + if is_fp32: + num_warps = max(8, min(16, total // 64)) + else: + num_warps = max(4, min(16, total // 256)) + + if is_sm80_before: + num_warps = min(num_warps, 8) + num_stages = 2 + elif is_sm90_plus: + num_warps = min(num_warps, 8) + num_stages = 4 if total <= 2048 else 2 + else: + num_stages = 4 if total <= 2048 else 2 + + return BLOCK_M, BLOCK_K, num_warps, num_stages + + +def moe_fused_mul_sum( + inputs: torch.Tensor, + topk_weights: torch.Tensor, + outputs: torch.Tensor | None = None, + topk_ids: torch.Tensor | None = None, + expert_map: torch.Tensor | None = None, + is_ep: bool = False, +) -> torch.Tensor: + """ + Fused kernel for MoE (Mixture of Experts) to perform weighted summation + of expert outputs. + + Args: + inputs: The output from experts. + Shape: (num_tokens, top_k, hidden_size). + topk_weights: The weights assigned to each expert for each token. + Shape: (num_tokens, top_k). + outputs: Optional pre-allocated output tensor. + Shape: (num_tokens, hidden_size). + topk_ids: Optional indices of the top-k experts. Used when + `expert_map` is provided. Shape: (num_tokens, top_k). + expert_map: Optional mapping for Expert Parallelism. A value < 0 + indicates an invalid token/expert pair that will be skipped. + + Returns: + The fused weighted sum of expert outputs. + Shape: (num_tokens, hidden_size). + """ + assert inputs.ndim == 3 + assert topk_weights.ndim == 2 + assert inputs.is_contiguous() + assert topk_weights.is_contiguous() + assert inputs.dtype in (torch.float32, torch.float16, torch.bfloat16) + assert topk_weights.dtype in (torch.float32, torch.float16, torch.bfloat16) + + num_tokens, top_k, size = inputs.shape + output_shape = (num_tokens, size) + if outputs is None: + outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device) + + assert outputs.shape == output_shape + assert topk_weights.shape == (num_tokens, top_k) + + if not isinstance(inputs, FakeTensor): + BLOCK_M, BLOCK_K, num_warps, num_stages = _heuristic_config( + num_tokens, + top_k, + size, + inputs.element_size(), + ) + grid = (triton.cdiv(size, BLOCK_K), triton.cdiv(num_tokens, BLOCK_M)) + moe_fused_mul_sum_kernel[grid]( + inputs, + topk_weights, + outputs, + topk_ids, + expert_map, + num_tokens, + top_k * size, + expert_map is not None, + is_ep, + top_k, + size, + BLOCK_M, + BLOCK_K, + num_warps=num_warps, + num_stages=num_stages, + ) + + return outputs diff --git a/python/sglang/srt/layers/moe/moe_runner/base.py b/python/sglang/srt/layers/moe/moe_runner/base.py index 088fbfbef33d..971d6a2d1eb4 100644 --- a/python/sglang/srt/layers/moe/moe_runner/base.py +++ b/python/sglang/srt/layers/moe/moe_runner/base.py @@ -50,6 +50,9 @@ class MoeRunnerConfig: gemm1_clamp_limit: Optional[float] = None swiglu_limit: Optional[float] = None + # MoE Layer + layer: Optional[torch.nn.Module] = None + @dataclass class RunnerInput(ABC): diff --git a/python/sglang/srt/layers/moe/moe_runner/humming.py b/python/sglang/srt/layers/moe/moe_runner/humming.py new file mode 100644 index 000000000000..bc8588f30384 --- /dev/null +++ b/python/sglang/srt/layers/moe/moe_runner/humming.py @@ -0,0 +1,789 @@ +from __future__ import annotations + +import json +import logging +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional +from weakref import WeakValueDictionary + +import torch +from sglang.srt.debug_utils.deepseek_v4_debug_utils import ( + deepseek_v4_moe_code_path_checker, +) +from sglang.srt.environ import envs +from sglang.srt.layers.moe.ep_moe.kernels import moe_permute, moe_unpermute +from sglang.srt.layers.moe.fused_moe_triton.moe_fused_mul_sum import moe_fused_mul_sum +from sglang.srt.layers.moe.moe_runner.base import ( + MoeQuantInfo, + MoeRunnerConfig, + MoeRunnerCore, + RunnerInput, + RunnerOutput, + register_fused_func, + register_post_permute, + register_pre_permute, +) +from sglang.srt.layers.moe.utils import MoeRunnerBackend +from sglang.srt.utils.custom_op import register_custom_op + +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher.deepep import ( + DeepEPLLCombineInput, + DeepEPLLDispatchOutput, + DeepEPNormalCombineInput, + DeepEPNormalDispatchOutput, + ) + from sglang.srt.layers.moe.token_dispatcher.standard import ( + StandardCombineInput, + StandardDispatchOutput, + ) + + +logger = logging.getLogger(__name__) + + +try: + from humming import dtypes + from humming.config import GemmType as HummingGemmType + from humming.layer import HummingMethod + + _humming_available = True +except ModuleNotFoundError: + _humming_available = False + + +def get_standard_humming_moe_gemm_type() -> HummingGemmType: + env_gemm_type_str = envs.SGLANG_HUMMING_MOE_GEMM_TYPE.get().lower() + if env_gemm_type_str == "grouped": + gemm_type = HummingGemmType.GROUPED_CONTIGUOUS + elif env_gemm_type_str == "indexed": + gemm_type = HummingGemmType.INDEXED + else: + gemm_type = HummingGemmType.INDEXED + + logger.info_once(f"Using {gemm_type.value} gemm for humming moe") + + return gemm_type + + +@dataclass +class HummingRunnerInput(RunnerInput): + hidden_states: torch.Tensor + topk_weights: torch.Tensor + topk_ids: torch.Tensor + gemm_type: HummingGemmType + expert_num_tokens: torch.Tensor | None = None + expected_m: int | None = None + + @property + def runner_backend(self) -> MoeRunnerBackend: + return MoeRunnerBackend.HUMMING + + +@dataclass +class HummingRunnerOutput(RunnerOutput): + hidden_states: torch.Tensor + + @property + def runner_backend(self) -> MoeRunnerBackend: + return MoeRunnerBackend.HUMMING + + +@dataclass +class HummingMoeQuantInfo(MoeQuantInfo): + pass + + +@register_custom_op() +def humming_moe_runner_core_run( + moe_runner_id: int, + gemm_type: str, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_num_tokens: torch.Tensor | None = None, + expected_m: int | None = None, +) -> torch.Tensor: + runner = HummingRunnerCore.runner_cores[moe_runner_id] + if gemm_type == "indexed": + return runner._run_indexed_gemm( + hidden_states=hidden_states, + topk_ids=topk_ids, + topk_weights=topk_weights, + ) + elif gemm_type == "grouped_contiguous": + return runner._run_grouped_contiguous_gemm( + hidden_states=hidden_states, + topk_ids=topk_ids, + topk_weights=topk_weights, + ) + elif gemm_type == "grouped_masked": + assert expected_m is not None and expert_num_tokens is not None + return runner._run_grouped_masked_gemm( + hidden_states=hidden_states, + topk_ids=topk_ids, + topk_weights=topk_weights, + expected_m=expected_m, + expert_num_tokens=expert_num_tokens, + ) + else: + raise ValueError(f"Unknown gemm type: {gemm_type}") + + +class HummingRunnerCore(MoeRunnerCore): + runner_cores: WeakValueDictionary = WeakValueDictionary() + + def __init__(self, config: MoeRunnerConfig): + super().__init__(config) + assert config.layer is not None + self.layer = config.layer + assert config.num_local_experts is not None + assert config.num_experts is not None + self.num_experts = config.num_local_experts + self.global_num_experts = config.num_experts + self.activation = config.activation + self.swiglu_limit = config.swiglu_limit + self.humming_gemm_configs = {} + HummingRunnerCore.runner_cores[id(self)] = self + + @property + def runner_backend(self) -> MoeRunnerBackend: + return MoeRunnerBackend.HUMMING + + def get_humming_gemm_configs(self, humming_gemm_type: HummingGemmType): + if humming_gemm_type.value in self.humming_gemm_configs: + return self.humming_gemm_configs[humming_gemm_type.value] + + compute_config = { + "use_f16_accum": envs.SGLANG_HUMMING_USE_F16_ACCUM.get(), + "gemm_type": humming_gemm_type.value, + } + w13_tuning_config = HummingMethod.get_default_tuning_configs( + layer=self.layer, + use_f16_accum=envs.SGLANG_HUMMING_USE_F16_ACCUM.get(), + gemm_type=humming_gemm_type, + sublayer_name="w13", + ) + w2_tuning_config = HummingMethod.get_default_tuning_configs( + layer=self.layer, + use_f16_accum=envs.SGLANG_HUMMING_USE_F16_ACCUM.get(), + gemm_type=humming_gemm_type, + sublayer_name="w2", + ) + self.humming_gemm_configs[humming_gemm_type.value] = { + "compute_config": compute_config, + "w13_tuning_config": w13_tuning_config, + "w2_tuning_config": w2_tuning_config, + "compute_config_str": json.dumps(compute_config), + "w13_tuning_config_str": json.dumps(w13_tuning_config), + "w2_tuning_config_str": json.dumps(w2_tuning_config), + } + + return self.humming_gemm_configs[humming_gemm_type.value] + + def estimate_local_valid_shape_m( + self, + topk_ids: torch.Tensor, + expected_m: int | None = None, + ): + # estimate shape_m for kernel tuning + if expected_m is not None: + return expected_m * self.num_experts + + # TODO: update for EP and DP + return topk_ids.nelement() + + def get_buffer_metas( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + gemm_type: HummingGemmType, + ): + num_experts = self.num_experts + N = self.layer.intermediate_size_per_partition + K = self.layer.hidden_size + assert isinstance(num_experts, int) + assert isinstance(N, int) + assert isinstance(K, int) + + # hidden_states + # (-> quanted_gate_up_input) (if not BF16/FP16 activation) + # -> gate_up_output + # -> activation_output + # (-> quanted_down_input) (if not BF16/FP16 activation) + # -> down_output + # (-> output) (if not is_grouped_masked) + # Neighboring nodes are required to utilize distinct workspaces. + # The output must be derived from workspace1. + + is_grouped_masked = gemm_type == HummingGemmType.GROUPED_MASKED + output_shape: tuple[int, ...] + if gemm_type == HummingGemmType.GROUPED_MASKED: + if hidden_states.ndim == 3: + max_num_tokens = hidden_states.size(1) + else: + max_num_tokens = hidden_states.size(0) // num_experts + input_shape_m = num_experts * max_num_tokens + real_shape_m = num_experts * max_num_tokens + output_shape = (num_experts, max_num_tokens, K) + else: + input_shape_m = hidden_states.size(0) + real_shape_m = hidden_states.size(0) * topk_ids.size(1) + if gemm_type == HummingGemmType.GROUPED_CONTIGUOUS: + input_shape_m = real_shape_m + output_shape = (hidden_states.size(0), K) + + down_input_size = N + a_dtype = self.layer.humming_metas["w13"].a_dtype + c_dtype = self.layer.humming_metas["w13"].c_dtype + num_bits = a_dtype.num_bits + torch_dtype_map = { + dtypes.float16: torch.float16, + dtypes.bfloat16: torch.bfloat16, + dtypes.float8e4m3: torch.float8_e4m3fn, + dtypes.int8: torch.int8, + dtypes.int4: torch.uint8, + } + + buffer_metas = { + "quanted_gate_up_input": { + "shape": (input_shape_m, K), + "dtype": torch_dtype_map[a_dtype], + }, + "gate_up_output": { + "shape": (real_shape_m, N * 2), + "dtype": torch_dtype_map[c_dtype], + }, + "activation_output": { + "shape": (real_shape_m, down_input_size), + "dtype": torch_dtype_map[c_dtype], + }, + "quanted_down_input": { + "shape": (real_shape_m, down_input_size), + "dtype": torch_dtype_map[a_dtype], + }, + "down_output": { + "shape": output_shape if is_grouped_masked else (real_shape_m, K), + "dtype": torch_dtype_map[c_dtype], + }, + "output": { + "shape": output_shape, + "dtype": torch_dtype_map[c_dtype], + }, + } + + for key in buffer_metas: + meta = buffer_metas[key] + if "quanted" in key and a_dtype.num_bits == 4: + meta["shape"] = meta["shape"][:-1] + (meta["shape"][-1] // 2,) + + if num_bits == 16: + required_buffers = ["gate_up_output", "activation_output", "down_output"] + else: + required_buffers = [ + "quanted_gate_up_input", + "gate_up_output", + "activation_output", + "quanted_down_input", + "down_output", + ] + + # grouped masked moe use down_output as output + if gemm_type != HummingGemmType.GROUPED_MASKED: + required_buffers.append("output") + + return buffer_metas, required_buffers + + def _workspace_shapes( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + gemm_type: HummingGemmType, + ): + buffer_metas, required_buffers = self.get_buffer_metas( + hidden_states=hidden_states, + topk_ids=topk_ids, + gemm_type=gemm_type, + ) + + workspace1_nbytes = 0 + workspace2_nbytes = 0 + + for index, name in enumerate(required_buffers[::-1]): + buffer_meta = buffer_metas[name] + nelement = math.prod(buffer_meta["shape"]) + nbytes = nelement * buffer_meta["dtype"].itemsize + if index % 2 == 0: + workspace1_nbytes = max(workspace1_nbytes, nbytes) + else: + workspace2_nbytes = max(workspace2_nbytes, nbytes) + + output_key = ( + "down_output" if gemm_type == HummingGemmType.GROUPED_MASKED else "output" + ) + output_shape = buffer_metas[output_key]["shape"] + + return (workspace1_nbytes // 2,), (workspace2_nbytes // 2,), output_shape + + def make_workspaces( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + gemm_type: HummingGemmType, + ): + shapes = self._workspace_shapes(hidden_states, topk_ids, gemm_type) + workspace1_shape, workspace2_shape, output_shape = shapes + torch_dtype = self.layer.params_dtype + device = hidden_states.device + workspace1 = torch.empty(workspace1_shape, dtype=torch_dtype, device=device) + workspace2 = torch.empty(workspace2_shape, dtype=torch_dtype, device=device) + output = workspace1[: math.prod(output_shape)].view(*output_shape) + return workspace1, workspace2, output + + def prepare_buffers( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + gemm_type: HummingGemmType, + ) -> dict[str, torch.Tensor]: + workspace1, workspace2, output = self.make_workspaces( + hidden_states=hidden_states, + topk_ids=topk_ids, + gemm_type=gemm_type, + ) + buffer_metas, required_buffers = self.get_buffer_metas( + hidden_states=hidden_states, + topk_ids=topk_ids, + gemm_type=gemm_type, + ) + buffers = {"output": output} + for index, name in enumerate(required_buffers[::-1]): + buffer_meta = buffer_metas[name] + workspace = workspace1 if index % 2 == 0 else workspace2 + workspace = workspace.view(buffer_meta["dtype"]) + shape = buffer_meta["shape"] + tensor = workspace[: math.prod(shape)].view(*shape) + buffers[name] = tensor + + return buffers + + def apply_activation(self, inputs: torch.Tensor, outputs: torch.Tensor): + if self.activation == "silu": + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B": + from sglang.srt.layers.moe.moe_runner.deep_gemm import ( + _apply_swiglu_limit, + ) + + inputs = _apply_swiglu_limit(inputs, swiglu_limit=self.swiglu_limit) + deepseek_v4_moe_code_path_checker.observed += 1 + from sgl_kernel import silu_and_mul + + silu_and_mul(inputs, outputs) + elif self.activation == "gelu": + from sgl_kernel import gelu_and_mul + + gelu_and_mul(inputs, outputs) + else: + raise ValueError(f"Unsupported activation: {self.activation}") + + def run( + self, + runner_input: HummingRunnerInput, + quant_info: HummingMoeQuantInfo, + running_state: dict, + hooks: Optional[Any] = None, + ) -> HummingRunnerOutput: + if runner_input.hidden_states.size(0) == 0: + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B": + deepseek_v4_moe_code_path_checker.observed += 1 + return HummingRunnerOutput( + hidden_states=torch.empty_like(runner_input.hidden_states) + ) + + # To make it compatible with dynamic shapes in torch.compile, + # we wrap the main logic inside a torch op. + # (the moe_block_size selection in indexed gemm would break dynamic shapes). + output = humming_moe_runner_core_run( + moe_runner_id=id(self), + gemm_type=runner_input.gemm_type.value, + hidden_states=runner_input.hidden_states, + topk_weights=runner_input.topk_weights, + topk_ids=runner_input.topk_ids, + expected_m=runner_input.expected_m, + expert_num_tokens=runner_input.expert_num_tokens, + ) + + return HummingRunnerOutput(hidden_states=output) + + def _prepare_indexed_gemm_kwargs( + self, topk_ids: torch.Tensor + ) -> tuple[dict[str, Any], dict[str, Any]]: + from sglang.srt.layers.moe.fused_moe_triton import moe_align_block_size + + configs = self.get_humming_gemm_configs(HummingGemmType.INDEXED) + valid_shape_m = self.estimate_local_valid_shape_m(topk_ids) + + for min_shape_m, max_shape_m, config in configs["w13_tuning_config"]: + if valid_shape_m > min_shape_m and valid_shape_m <= max_shape_m: + moe_block_size = config["block_shape"][0] + break + else: + raise ValueError(f"cannot found moe_block_size for shape {valid_shape_m}") + + sorted_ids, expert_ids, num_tokens_padded = moe_align_block_size( + topk_ids=topk_ids, + block_size=moe_block_size, + num_experts=self.num_experts, + ignore_invalid_expert=True, + ) + + moe_common_kwargs = { + "sorted_ids": sorted_ids, + "expert_ids": expert_ids, + "num_tokens_padded": num_tokens_padded, + "compute_config": configs["compute_config_str"], + "valid_shape_m": valid_shape_m, + } + + top_k = topk_ids.size(1) + moe_kwargs1 = {"top_k": top_k, "tuning_config": configs["w13_tuning_config_str"]} + moe_kwargs2 = {"top_k": 1, "tuning_config": configs["w2_tuning_config_str"]} + moe_kwargs1.update(moe_common_kwargs) + moe_kwargs2.update(moe_common_kwargs) + + return moe_kwargs1, moe_kwargs2 + + def _run_indexed_gemm( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ): + hidden_states = hidden_states.view(-1, hidden_states.size(-1)) + buffers = self.prepare_buffers( + hidden_states=hidden_states, + topk_ids=topk_ids, + gemm_type=HummingGemmType.INDEXED, + ) + + moe_kwargs1, moe_kwargs2 = self._prepare_indexed_gemm_kwargs(topk_ids) + + inputs, input_scale = HummingMethod.may_quant_input( + layer=self.layer, + inputs=hidden_states, + quanted_input=buffers.get("quanted_gate_up_input", None), + sublayer_name="w13", + ) + + HummingMethod.forward_layer( + layer=self.layer, + inputs=inputs, + input_scale=input_scale, + outputs=buffers["gate_up_output"], + sublayer_name="w13", + **moe_kwargs1, + ) + + self.apply_activation( + inputs=buffers["gate_up_output"], + outputs=buffers["activation_output"], + ) + + inputs, input_scale = HummingMethod.may_quant_input( + layer=self.layer, + inputs=buffers["activation_output"], + quanted_input=buffers.get("quanted_down_input", None), + sublayer_name="w2", + ) + + HummingMethod.forward_layer( + layer=self.layer, + inputs=inputs, + input_scale=input_scale, + outputs=buffers["down_output"].view(-1, hidden_states.size(-1)), + sublayer_name="w2", + **moe_kwargs2, + ) + + moe_fused_mul_sum( + inputs=buffers["down_output"].view(*topk_ids.shape, -1), + topk_weights=topk_weights, + topk_ids=topk_ids, + is_ep=self.num_experts != self.global_num_experts, + outputs=buffers["output"], + ) + + return buffers["output"] + + def _run_grouped_contiguous_gemm( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ): + configs = self.get_humming_gemm_configs(HummingGemmType.GROUPED_CONTIGUOUS) + valid_shape_m = self.estimate_local_valid_shape_m(topk_ids) + + buffers = self.prepare_buffers( + hidden_states=hidden_states, + topk_ids=topk_ids, + gemm_type=HummingGemmType.GROUPED_CONTIGUOUS, + ) + + hidden_states, src2dst, expert_first_token_offset = moe_permute( + inputs=hidden_states, + topk_ids=topk_ids, + num_experts=self.num_experts, + is_ep=self.num_experts != self.global_num_experts, + ) + + inputs, input_scale = HummingMethod.may_quant_input( + layer=self.layer, + inputs=hidden_states, + quanted_input=buffers.get("quanted_gate_up_input", None), + sublayer_name="w13", + ) + + HummingMethod.forward_layer( + layer=self.layer, + inputs=inputs, + input_scale=input_scale, + outputs=buffers["gate_up_output"], + valid_shape_m=valid_shape_m, + expert_layout=expert_first_token_offset, + compute_config=configs["compute_config_str"], + tuning_config=configs["w13_tuning_config_str"], + sublayer_name="w13", + ) + + self.apply_activation( + inputs=buffers["gate_up_output"], + outputs=buffers["activation_output"], + ) + + inputs, input_scale = HummingMethod.may_quant_input( + layer=self.layer, + inputs=buffers["activation_output"], + quanted_input=buffers.get("quanted_down_input", None), + sublayer_name="w2", + ) + + HummingMethod.forward_layer( + layer=self.layer, + inputs=inputs, + input_scale=input_scale, + outputs=buffers["down_output"], + valid_shape_m=valid_shape_m, + expert_layout=expert_first_token_offset, + compute_config=configs["compute_config_str"], + tuning_config=configs["w2_tuning_config_str"], + sublayer_name="w2", + ) + + moe_unpermute( + outputs=buffers["output"], + inputs=buffers["down_output"], + topk_weights=topk_weights, + topk_ids=topk_ids, + src2dst=src2dst, + ) + + return buffers["output"] + + def _run_grouped_masked_gemm( + self, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + expert_num_tokens: torch.Tensor, + expected_m: int, + ): + configs = self.get_humming_gemm_configs(HummingGemmType.GROUPED_MASKED) + valid_shape_m = self.estimate_local_valid_shape_m(topk_ids, expected_m) + hidden_states = hidden_states.view(-1, hidden_states.size(-1)) + + buffers = self.prepare_buffers( + hidden_states=hidden_states, + topk_ids=topk_ids, + gemm_type=HummingGemmType.GROUPED_MASKED, + ) + + inputs, input_scale = HummingMethod.may_quant_input( + layer=self.layer, + inputs=hidden_states, + quanted_input=buffers.get("quanted_gate_up_input", None), + sublayer_name="w13", + ) + + HummingMethod.forward_layer( + layer=self.layer, + inputs=inputs, + input_scale=input_scale, + outputs=buffers["gate_up_output"], + valid_shape_m=valid_shape_m, + expert_layout=expert_num_tokens, + compute_config=configs["compute_config_str"], + tuning_config=configs["w13_tuning_config_str"], + sublayer_name="w13", + ) + + self.apply_activation( + inputs=buffers["gate_up_output"], + outputs=buffers["activation_output"], + ) + + inputs, input_scale = HummingMethod.may_quant_input( + layer=self.layer, + inputs=buffers["activation_output"], + quanted_input=buffers.get("quanted_down_input", None), + sublayer_name="w2", + ) + + HummingMethod.forward_layer( + layer=self.layer, + inputs=inputs, + input_scale=input_scale, + outputs=buffers["down_output"].view(-1, hidden_states.size(-1)), + valid_shape_m=valid_shape_m, + expert_layout=expert_num_tokens, + compute_config=configs["compute_config_str"], + tuning_config=configs["w2_tuning_config_str"], + sublayer_name="w2", + ) + + return buffers["down_output"] + + +@register_fused_func("none", "humming") +def fused_experts_none_to_humming( + dispatch_output: StandardDispatchOutput, + quant_info: HummingMoeQuantInfo, + runner_config: MoeRunnerConfig, +) -> StandardCombineInput: + from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput + + hidden_states = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + topk_ids = topk_output.topk_ids + topk_weights = topk_output.topk_weights + + runner_input = HummingRunnerInput( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + gemm_type=get_standard_humming_moe_gemm_type(), + ) + + runner_core = HummingRunnerCore(runner_config) + runner_output = runner_core.run(runner_input, quant_info, {}) + + return StandardCombineInput(hidden_states=runner_output.hidden_states) + + +@register_pre_permute("deepep_ll", "humming") +def pre_permute_deepep_ll_to_humming( + dispatch_output: DeepEPLLDispatchOutput, + quant_info: HummingMoeQuantInfo, + runner_config: MoeRunnerConfig, + running_state: dict, +) -> HummingRunnerInput: + hidden_states = dispatch_output.hidden_states + topk_ids = dispatch_output.topk_ids + topk_weights = dispatch_output.topk_weights + running_state["topk_ids"] = topk_ids + running_state["topk_weights"] = topk_weights + + return HummingRunnerInput( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids.int(), + expert_num_tokens=dispatch_output.masked_m, + expected_m=dispatch_output.expected_m, + gemm_type=HummingGemmType.GROUPED_MASKED, + ) + + +@register_post_permute("humming", "deepep_ll") +def post_permute_humming_to_deepep_ll( + runner_output: HummingRunnerOutput, + quant_info: HummingMoeQuantInfo, + runner_config: MoeRunnerConfig, + running_state: dict, +) -> DeepEPLLCombineInput: + from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPLLCombineInput + + return DeepEPLLCombineInput( + hidden_states=runner_output.hidden_states, + topk_ids=running_state["topk_ids"], + topk_weights=running_state["topk_weights"], + ) + + +@register_pre_permute("deepep_normal", "humming") +def pre_permute_deepep_normal_to_humming( + dispatch_output: DeepEPNormalDispatchOutput, + quant_info: HummingMoeQuantInfo, + runner_config: MoeRunnerConfig, + running_state: dict, +) -> HummingRunnerInput: + hidden_states = dispatch_output.hidden_states + topk_ids = dispatch_output.topk_ids + topk_weights = dispatch_output.topk_weights + running_state["topk_ids"] = topk_ids + running_state["topk_weights"] = topk_weights + + return HummingRunnerInput( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids.int(), + gemm_type=get_standard_humming_moe_gemm_type(), + ) + + +@register_post_permute("humming", "deepep_normal") +def post_permute_humming_to_deepep_normal( + runner_output: HummingRunnerOutput, + quant_info: HummingMoeQuantInfo, + runner_config: MoeRunnerConfig, + running_state: dict, +) -> DeepEPNormalCombineInput: + from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPNormalCombineInput + + return DeepEPNormalCombineInput( + hidden_states=runner_output.hidden_states, + topk_ids=running_state["topk_ids"], + topk_weights=running_state["topk_weights"], + ) + + +@register_pre_permute("standard", "humming") +def pre_permute_standard_to_humming( + dispatch_output: StandardDispatchOutput, + quant_info: HummingMoeQuantInfo, + runner_config: MoeRunnerConfig, + running_state: dict, +) -> HummingRunnerInput: + hidden_states = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + topk_ids = topk_output.topk_ids + topk_weights = topk_output.topk_weights + + return HummingRunnerInput( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids.int(), + gemm_type=get_standard_humming_moe_gemm_type(), + ) + + +@register_post_permute("humming", "standard") +def post_permute_humming_to_standard( + runner_output: HummingRunnerOutput, + quant_info: HummingMoeQuantInfo, + runner_config: MoeRunnerConfig, + running_state: dict, +) -> StandardCombineInput: + from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput + + return StandardCombineInput(hidden_states=runner_output.hidden_states) diff --git a/python/sglang/srt/layers/moe/moe_runner/runner.py b/python/sglang/srt/layers/moe/moe_runner/runner.py index 8b58cd3115bd..a4b85f165461 100644 --- a/python/sglang/srt/layers/moe/moe_runner/runner.py +++ b/python/sglang/srt/layers/moe/moe_runner/runner.py @@ -37,6 +37,10 @@ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig): self.runner_core = TritonKernelsRunnerCore(config) elif runner_backend.is_deep_gemm(): self.runner_core = DeepGemmRunnerCore(config) + elif runner_backend.is_humming(): + from sglang.srt.layers.moe.moe_runner.humming import HummingRunnerCore + + self.runner_core = HummingRunnerCore(config) elif runner_backend.is_marlin(): self.runner_core = None # Marlin only supports fused path elif runner_backend.is_flashinfer_trtllm(): diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index ba6ca01ff140..6b54bd7e8fde 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -63,6 +63,7 @@ class MoeRunnerBackend(Enum): FLASHINFER_CUTEDSL = "flashinfer_cutedsl" CUTLASS = "cutlass" MARLIN = "marlin" + HUMMING = "humming" def is_auto(self): return self == MoeRunnerBackend.AUTO @@ -94,6 +95,9 @@ def is_cutlass(self): def is_marlin(self): return self == MoeRunnerBackend.MARLIN + def is_humming(self): + return self == MoeRunnerBackend.HUMMING + class DeepEPMode(Enum): diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py index 9e4ca8ebba4a..5b8b3e7b327a 100644 --- a/python/sglang/srt/layers/parameter.py +++ b/python/sglang/srt/layers/parameter.py @@ -587,8 +587,8 @@ def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size) def _adjust_shard_indexes_for_packing( shard_size, shard_offset, packed_factor, marlin_tile_size ): - shard_size = shard_size // packed_factor - shard_offset = shard_offset // packed_factor + shard_size = round(shard_size // packed_factor) + shard_offset = round(shard_offset // packed_factor) if marlin_tile_size is not None: return _adjust_shard_indexes_for_marlin( shard_size=shard_size, diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 734b7f037040..933a1fdcfbe0 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -28,6 +28,7 @@ def override_quantization_method(self, *args, **kwargs): from sglang.srt.layers.quantization.fpgemm_fp8 import FBGEMMFp8Config from sglang.srt.layers.quantization.gguf import GGUFConfig from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig +from sglang.srt.layers.quantization.humming import HummingConfig from sglang.srt.layers.quantization.modelopt_quant import ( ModelOptFp4Config, ModelOptFp8Config, @@ -74,6 +75,7 @@ def override_quantization_method(self, *args, **kwargs): "auto-round": AutoRoundConfig, "modelslim": ModelSlimConfig, "quark_int4fp8_moe": QuarkInt4Fp8Config, + "humming": HummingConfig, } diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index e11f208455b1..bd9223a9bf04 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -192,6 +192,7 @@ def get_quant_method( and ( get_moe_runner_backend().is_flashinfer_mxfp4() or get_moe_runner_backend().is_marlin() + or get_moe_runner_backend().is_humming() ) ): from sglang.srt.layers.quantization.mxfp4_deepseek import ( diff --git a/python/sglang/srt/layers/quantization/humming.py b/python/sglang/srt/layers/quantization/humming.py new file mode 100644 index 000000000000..c3e3073ffd85 --- /dev/null +++ b/python/sglang/srt/layers/quantization/humming.py @@ -0,0 +1,856 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import dataclasses +import json +import math +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, List + +import regex as re +import torch +from sglang.srt.environ import envs +from sglang.srt.layers.linear import LinearBase, set_weight_attrs +from sglang.srt.layers.moe import ( + MoeRunner, + MoeRunnerBackend, + MoeRunnerConfig, + get_moe_runner_backend, +) +from sglang.srt.layers.parameter import ( + BasevLLMParameter, + BlockQuantScaleParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + ModelWeightParameter, + PackedvLLMParameter, + PerTensorScaleParameter, + RowvLLMParameter, +) +from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + LinearMethodBase, + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.unquant import ( + UnquantizedFusedMoEMethod, + UnquantizedLinearMethod, +) + +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput + from sglang.srt.models.utils import WeightsMapper + + +try: + from humming.dtypes import DataType + from humming.layer import HummingMethod + from humming.schema import ( + BaseInputSchema, + BaseWeightSchema, + HummingInputSchema, + HummingWeightSchema, + ) + from humming.utils.weight import quantize_weight +except ModuleNotFoundError: + HummingMethod = None + + +def assert_humming_available(): + assert HummingMethod is not None, ( + "humming is not available, please run " + "'pip install git+https://github.com/inclusionAI/humming' to install it." + ) + + +def prepare_padded_shape(shape, x): + padded_shape = math.ceil(shape / x) * x + return padded_shape, padded_shape - shape + + +def prepare_param(tensor, name, extra_attrs): + extra_attrs = extra_attrs.copy() + scale_type = extra_attrs.pop("scale_type", None) + param_cls_name_map = { + "block": BlockQuantScaleParameter, + "tensor": PerTensorScaleParameter, + "group": GroupQuantScaleParameter, + "channel": ChannelQuantScaleParameter, + "input_scale": PerTensorScaleParameter, + } + + param_cls: type[BasevLLMParameter] + if "packed_dim" in extra_attrs: + param_cls = PackedvLLMParameter + elif scale_type in param_cls_name_map: + param_cls = param_cls_name_map[scale_type] + elif "output_dim" in extra_attrs and "input_dim" in extra_attrs: + param_cls = ModelWeightParameter + elif "input_dim" in extra_attrs: + param_cls = RowvLLMParameter + elif "output_dim" in extra_attrs: + param_cls = ChannelQuantScaleParameter + else: + param_cls = BasevLLMParameter + + kwargs_keys = [ + "input_dim", + "output_dim", + "packed_dim", + "packed_factor", + "weight_loader", + ] + cls_kwargs = {} + for key in extra_attrs.copy(): + if key in kwargs_keys: + cls_kwargs[key] = extra_attrs.pop(key) + + param = param_cls(data=tensor, **cls_kwargs) + set_weight_attrs(param, extra_attrs) + + param.param_name = name + param.ignore_warning = True + if scale_type in ["tensor", "input_scale"]: + param.needs_scalar_to_array = True + + return param + + +def prepare_moe_param(tensor, name, extra_attrs): + param = torch.nn.Parameter(tensor, requires_grad=False) + if "scale_type" in extra_attrs: + extra_attrs["quant_method"] = extra_attrs["scale_type"] + + if "input_dim" in extra_attrs and "output_dim" in extra_attrs: + input_dim = extra_attrs["input_dim"] + output_dim = extra_attrs["output_dim"] + extra_attrs["is_transposed"] = input_dim < output_dim + + set_weight_attrs(param, extra_attrs) + param.param_name = name + return param + + +def may_pad_loaded_weight(param, loaded_weight): + pad_shape = getattr(param, "pad_shape", None) + if pad_shape is None: + return loaded_weight + value = 1 if loaded_weight.dtype == torch.float8_e8m0fnu else 0 + padding = [] + for x in pad_shape[::-1][: loaded_weight.ndim]: + padding += [0, x] + loaded_weight = torch.nn.functional.pad( + input=loaded_weight, + pad=padding, + value=value, + ) + return loaded_weight + + +def compressed_tensors_get_config(config: dict[str, Any], key: str): + assert key in ["weights", "input_activations"] + target_group_config = None + for group_config in config["config_groups"].values(): + if "Linear" in group_config["targets"]: + if "weights" not in group_config: + return None + if key not in group_config or group_config[key] is None: + return None + target_group_config = group_config[key].copy() + break + + if target_group_config is None: + return None + target_group_config["quant_method"] = config["quant_method"] + if config["quant_method"] == "compressed-tensors": + target_group_config["format"] = config["format"] + elif config["quant_method"] == "modelopt": + target_group_config["quant_algo"] = config["quant_algo"] + return target_group_config + + +class HummingConfig(QuantizationConfig): + packed_modules_mapping = {} + + def __init__(self, full_config: dict[str, Any] | None = None): + assert_humming_available() + self.full_config: dict[str, Any] = full_config or {} + + @classmethod + def get_name(cls) -> str: + return "humming" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "HummingConfig": + return cls(full_config=config) + + def get_scaled_act_names(self) -> List[str]: + raise NotImplementedError + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, user_quant) -> str | None: + if hf_quant_cfg["quant_method"] == "mxfp4": + # NOTE: gpt-oss has a special weight loading logic, so we don't support it now. + # TODO: integrate humming kernels to mxfp4.py + return None + return "humming" if user_quant == "humming" else None + + def apply_weight_name_mapper(self, hf_to_sglang_mapper: "WeightsMapper"): + self.hf_to_sglang_mapper = hf_to_sglang_mapper + + def is_layer_skipped(self, config: dict[str, Any], prefix: str): + keys = ["ignored_layers", "ignore", "modules_to_not_convert"] + ignored_layers = self.get_from_keys_or(config, keys, []) or [] + if hasattr(self, "hf_to_sglang_mapper"): + ignored_layers = self.hf_to_sglang_mapper.apply_list(ignored_layers) + + if any(module_name in prefix for module_name in ignored_layers): + return True + if "lm_head" in prefix: + return True + + for regex in config.get("dynamic", {}): + if regex[:1] != "-": + continue + if re.match(regex[2:], prefix): + return True + + return False + + def get_layer_weight_schema(self, config: dict[str, Any], prefix: str): + if self.is_layer_skipped(config, prefix): + return None + + if config["quant_method"] in ["compressed-tensors", "modelopt"]: + group_config = compressed_tensors_get_config(config, "weights") + if group_config is None: + return None + config = group_config + + layer_config = config + layer_dynamic = config.get("dynamic", {}) + if not isinstance(layer_dynamic, dict): + layer_dynamic = {} + for regex, override_config in layer_dynamic.items(): + if regex[:1] != "+": + continue + if re.match(regex[2:], prefix): + layer_config = config.copy() + layer_config.update(override_config) + break + + if "quant_method" in layer_config: + return BaseWeightSchema.from_config(layer_config) + return None + + def get_layer_input_schema(self, config: dict[str, Any], prefix: str): + if self.is_layer_skipped(config, prefix): + return None + if config["quant_method"] in ["compressed-tensors", "modelopt"]: + group_config = compressed_tensors_get_config(config, "input_activations") + if group_config is None: + return None + config = group_config + + if config.get("quant_method", None) in BaseInputSchema.INPUT_SCHEMA_MAP: + return BaseInputSchema.from_config(config) + return None + + def get_quant_config_for_layer( + self, prefix: str, layer_type: str + ) -> "HummingLayerQuantizationConfig | None": + weight_schema: BaseWeightSchema | None = None + force_weight_schema: HummingWeightSchema | None = None + + if self.full_config: + weight_schema = self.get_layer_weight_schema(self.full_config, prefix) + + is_online_quant = False + online_quant_config = envs.SGLANG_HUMMING_ONLINE_QUANT_CONFIG.get() or {} + if not self.full_config or online_quant_config.get("force_requant", False): + online_quant_config["quant_method"] = "humming" + schema = self.get_layer_weight_schema(online_quant_config, prefix) + if not self.full_config: + weight_schema = schema + is_online_quant = True + else: + force_weight_schema = schema + + if weight_schema is not None: + input_schema = None + force_input_schema = None + + if self.full_config: + input_schema = self.get_layer_input_schema(self.full_config, prefix) + + if envs.SGLANG_HUMMING_INPUT_QUANT_CONFIG.get(): + quant_config = envs.SGLANG_HUMMING_INPUT_QUANT_CONFIG.get().copy() + quant_config["quant_method"] = "humming" + force_input_schema = self.get_layer_input_schema(quant_config, prefix) + if input_schema is None: + input_schema = force_input_schema + + if force_weight_schema is not None and force_input_schema is None: + force_input_schema = HummingInputSchema() + + return HummingLayerQuantizationConfig( + weight_schema=weight_schema, + input_schema=input_schema, + force_weight_schema=force_weight_schema, + force_input_schema=force_input_schema, + is_online_quant=is_online_quant, + ) + return None + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> "QuantizeMethodBase | None": + layer_type = "other" + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE + + if isinstance(layer, FusedMoE): + layer_type = "moe" + elif isinstance(layer, LinearBase): + layer_type = "linear" + quant_config = self.get_quant_config_for_layer(prefix, layer_type) + if quant_config is None: + if isinstance(layer, FusedMoE): + return UnquantizedFusedMoEMethod() + elif isinstance(layer, LinearBase): + return UnquantizedLinearMethod() + elif isinstance(layer, LinearBase): + return HummingLinearMethod(quant_config) + elif isinstance(layer, FusedMoE): + return HummingMoEMethod(quant_config) + return None + + +class HummingLayerQuantizationConfig(HummingConfig): + def __init__( + self, + weight_schema: "BaseWeightSchema", + input_schema: "BaseInputSchema | None" = None, + force_weight_schema: "HummingWeightSchema | None" = None, + force_input_schema: "HummingInputSchema | None" = None, + is_online_quant: bool = False, + ): + self.weight_schema = weight_schema + if input_schema is None: + input_schema = HummingInputSchema() + self.input_schema = input_schema + self.force_weight_schema = force_weight_schema + self.force_input_schema = force_input_schema + self.is_online_quant = is_online_quant + + @classmethod + def from_config(cls, config): + weight_schema = BaseWeightSchema.from_config(config) + return cls(weight_schema) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> QuantizeMethodBase | None: + raise NotImplementedError + + +class HummingLinearMethod(LinearMethodBase): + def __init__(self, quant_config: HummingLayerQuantizationConfig): + self.quant_config = quant_config + self.weight_schema = quant_config.weight_schema + self.input_schema = quant_config.input_schema + self.force_weight_schema = quant_config.force_weight_schema + self.force_input_schema = quant_config.force_input_schema + self.is_online_quant = self.quant_config.is_online_quant + + def prepare_weight_loader(self, layer: torch.nn.Module, weight_loader: Callable): + def new_weight_loader( + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + shard_id: str | int | None = None, + ): + name = param.param_name + float_dtypes = [torch.float16, torch.bfloat16, torch.float32] + is_unquantized = name == "weight" and loaded_weight.dtype in float_dtypes + if is_unquantized and self.is_online_quant: + # online quant (fp16/bf16 -> quant_type) + assert isinstance(self.weight_schema, HummingWeightSchema) + f16_dtype = DataType.from_torch_dtype(layer.param_dtype) + has_global_scale = "TENSOR" in str(self.weight_schema.weight_scale_type) + tensor_list = quantize_weight( + weight=loaded_weight, + dtype=self.weight_schema.b_dtype, + scale_dtype=self.weight_schema.bs_dtype or f16_dtype, + group_size=self.weight_schema.weight_scale_group_size, + has_zero_point=self.weight_schema.has_zero_point, + has_global_scale=has_global_scale, + is_fp_zero_point=self.weight_schema.is_fp_zero_point, + pack=True, + ) + + key_list = ["weight", "weight_scale", "zero_point", "global_scale"] + for key, tensor in zip(key_list, tensor_list): + if tensor is None or tensor.nelement() == 0: + continue + param = getattr(layer, key) + param.weight_loader(param, tensor, shard_id) + + return None + elif is_unquantized and not self.is_online_quant: + # fallback to unquantized linear + # some model skip some layer when quantizing model, but + # don't mark the layer as unquantized. + if not layer.is_fallback: + layer.is_fallback = True + for name, _ in list(layer.named_parameters()): + if name != "bias": + delattr(layer, name) + delattr(layer, "locks") + self.__class__ = UnquantizedLinearMethod # type: ignore + tensor = torch.empty( + (layer.output_partition_sizes_sum, layer.input_size_per_partition), + dtype=layer.param_dtype, + device=param.device, + ) + extra_weight_attrs = layer.extra_weight_attrs.copy() + orig_weight_loader = extra_weight_attrs.pop("weight_loader") + layer.weight = ModelWeightParameter( + data=tensor, + input_dim=1, + output_dim=0, + weight_loader=orig_weight_loader, + ) + layer.weight.tp_size = layer.tp_size + layer.weight.tp_rank = layer.tp_rank + set_weight_attrs(layer.weight, extra_weight_attrs) + + param = layer.weight + if shard_id is not None: + return layer.weight.weight_loader(param, loaded_weight, shard_id) + return layer.weight.weight_loader(param, loaded_weight) + + # weight processing logic for specific quantization schema + loaded_weight = self.weight_schema.process_loaded_weight( + tensor=loaded_weight, + name=name, + ) + if shard_id is not None: + return weight_loader(param, loaded_weight, shard_id) + return weight_loader(param, loaded_weight) + + return new_weight_loader + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.model_loader.weight_utils import default_weight_loader + + layer.is_fallback = False + layer.param_dtype = params_dtype + layer.input_size = input_size + layer.output_size = output_size + layer.input_size_per_partition = input_size_per_partition + layer.output_partition_sizes_sum = sum(output_partition_sizes) + layer.output_partition_sizes = output_partition_sizes + layer.extra_weight_attrs = extra_weight_attrs.copy() + + weight_loader = extra_weight_attrs.get("weight_loader", default_weight_loader) + new_weight_loader = self.prepare_weight_loader(layer, weight_loader) + extra_weight_attrs["weight_loader"] = new_weight_loader + + for key in ["weight_block_size", "block_structure"]: + block_size = getattr(self.weight_schema, key, None) + if block_size is not None: + layer.weight_block_size = block_size + + weight_tensor_attrs = self.weight_schema.get_tensors_attrs( + shape_n=layer.output_partition_sizes_sum, + shape_k=layer.input_size_per_partition, + param_dtype=params_dtype, + stack_size=len(layer.output_partition_sizes), + ) + + input_tensor_attrs = self.input_schema.get_tensors_attrs( + shape_k=layer.input_size_per_partition, + param_dtype=params_dtype, + stack_size=len(layer.output_partition_sizes), + ) + + tensors_attrs = weight_tensor_attrs | input_tensor_attrs + + for name, attrs in tensors_attrs.items(): + tensor = torch.empty(attrs["shape"], dtype=attrs["dtype"]) + extra_attrs = attrs.get("extra_attrs", {}).copy() + extra_attrs.update(extra_weight_attrs) + param = prepare_param(tensor, name, extra_attrs) + setattr(layer, name, param) + + locks = torch.zeros(1024, dtype=torch.int32) + layer.register_buffer("locks", locks) + + if self.force_input_schema is not None: + self.input_schema = self.force_input_schema + + if not hasattr(layer, "weight"): + param = prepare_param(torch.tensor(0), "weight", extra_weight_attrs) + layer.weight = param + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if layer.is_fallback: + return None + + # convert from checkpoint format to humming format + if not isinstance(self.weight_schema, HummingWeightSchema): + self.weight_schema, tensors = self.weight_schema.convert_humming( + tensors=layer.state_dict(), + shape_n_stacks=layer.output_partition_sizes, + shape_k_stacks=[layer.input_size_per_partition], + param_dtype=layer.param_dtype, + ) + + self.input_schema, _ = self.input_schema.convert_humming( + tensors=layer.state_dict(), + shape_n_stacks=layer.output_partition_sizes, + shape_k_stacks=[layer.input_size_per_partition], + param_dtype=layer.param_dtype, + ) + + for name, _ in list(layer.named_parameters()): + delattr(layer, name) + + for name, tensor in tensors.items(): + param = torch.nn.Parameter(tensor, requires_grad=False) + setattr(layer, name, param) + + del tensors + + # force requant (origin quant setting -> fp16/bf16 -> new_quant setting) + assert isinstance(self.weight_schema, HummingWeightSchema) + force_requant = self.force_weight_schema is not None + if force_requant and self.weight_schema != self.force_weight_schema: + tensors = self.weight_schema.requant_tensors( + tensors=layer.state_dict(), + target_weight_schema=self.force_weight_schema, + param_dtype=layer.param_dtype, + ) + + self.weight_schema = self.force_weight_schema + + for name, _ in list(layer.named_parameters()): + if name != "bias": + delattr(layer, name) + + for name, tensor in tensors.items(): + param = torch.nn.Parameter(tensor, requires_grad=False) + setattr(layer, name, param) + + del tensors + + # prepare layer config from humming kernel + HummingMethod.prepare_layer_meta( + layer=layer, + shape_n=layer.output_partition_sizes_sum, + shape_k=layer.input_size_per_partition, + weight_schema=self.weight_schema, + input_schema=self.input_schema, + pad_n_to_multiple=256, + pad_k_to_multiple=128, + has_bias=layer.with_bias, + torch_dtype=layer.param_dtype, + ) + + # preprocess weight for inference + HummingMethod.transform_humming_layer(layer) + + # compute_config: kernel configs that do not directly affect weights + # but significantly impact kernel behavior or computation precision. + # see https://github.com/inclusionAI/humming/blob/main/docs/config.md + compute_config = { + "use_f16_accum": envs.SGLANG_HUMMING_USE_F16_ACCUM.get(), + "gemm_type": "dense", + } + self.compute_config = json.dumps(compute_config) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + flatten_inputs = x.view(-1, x.size(-1)) + output = HummingMethod.forward_layer( + layer=layer, + inputs=flatten_inputs, + compute_config=self.compute_config, + ) + output = output.view(*x.shape[:-1], output.size(-1)) + return output + + +class HummingMoEMethod(FusedMoEMethodBase): + def __init__(self, quant_config: HummingLayerQuantizationConfig) -> None: + self.quant_config = quant_config + self.weight_schema = quant_config.weight_schema + self.input_schema = quant_config.input_schema + self.force_weight_schema = quant_config.force_weight_schema + self.force_input_schema = quant_config.force_input_schema + + def prepare_weight_loader(self, layer, weight_loader): + def new_weight_loader( + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: str, + expert_id: int | None = None, + ): + name = param.param_name + float_dtypes = [torch.float16, torch.bfloat16, torch.float32] + is_unquantized = name == "weight" and loaded_weight.dtype in float_dtypes + # online quant (fp16/bf16 -> quant_type) + if is_unquantized: + assert isinstance(self.weight_schema, HummingWeightSchema) + f16_dtype = DataType.from_torch_dtype(layer.param_dtype) + has_global_scale = "TENSOR" in str(self.weight_schema.weight_scale_type) + tensor_list = quantize_weight( + weight=loaded_weight, + dtype=self.weight_schema.b_dtype, + scale_dtype=self.weight_schema.bs_dtype or f16_dtype, + group_size=self.weight_schema.weight_scale_group_size, + has_zero_point=self.weight_schema.has_zero_point, + has_global_scale=has_global_scale, + is_fp_zero_point=self.weight_schema.is_fp_zero_point, + pack=True, + ) + + key_list = ["weight", "weight_scale", "zero_point", "global_scale"] + for key, tensor in zip(key_list, tensor_list): + if tensor is None or tensor.nelement() == 0: + continue + sublayer_name = "w2" if shard_id == "w2" else "w13" + + param = getattr(layer, sublayer_name + "_" + key) + param.weight_loader( + param=param, + loaded_weight=tensor.cpu(), + weight_name=shard_id + "_" + key, + shard_id=shard_id, + expert_id=expert_id, + ) + + return None + + # weight processing logic for specific quantization schema + loaded_weight = self.weight_schema.process_loaded_weight( + tensor=loaded_weight, + name=name, + ) + return weight_loader( + param, + loaded_weight, + weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) + + return new_weight_loader + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + with_bias: bool = False, + **extra_weight_attrs, + ): + from sglang.srt.model_loader.weight_utils import default_weight_loader + + layer.num_experts = num_experts + layer.param_dtype = params_dtype + layer.intermediate_size = intermediate_size_per_partition + layer.with_bias = with_bias + weight_loader = extra_weight_attrs.get("weight_loader", default_weight_loader) + weight_loader = self.prepare_weight_loader(layer, weight_loader) + extra_weight_attrs["weight_loader"] = weight_loader + + # sublayer: a layer contains multiple sets of weights for quantized GEMM + # (e.g., weight, weight_scale, etc.). + # The weight names of sublayer start with the prefix "{sublayer_name}_" + layer.sublayer_configs = { + "w13": { + "shape_n": intermediate_size_per_partition * 2, + "shape_k": hidden_size, + "tensors_attrs": self.weight_schema.get_padded_tensors_attrs( + shape_n=intermediate_size_per_partition * 2, + shape_k=hidden_size, + num_experts=num_experts, + param_dtype=params_dtype, + has_bias=with_bias, + ), + }, + "w2": { + "shape_n": hidden_size, + "shape_k": intermediate_size_per_partition, + "tensors_attrs": self.weight_schema.get_padded_tensors_attrs( + shape_n=hidden_size, + shape_k=intermediate_size_per_partition, + num_experts=num_experts, + param_dtype=params_dtype, + has_bias=with_bias, + ), + }, + } + + for sublayer_name, configs in layer.sublayer_configs.items(): + for name, attrs in configs["tensors_attrs"].items(): + tensor = torch.empty(attrs["shape"], dtype=attrs["dtype"]) + param = torch.nn.Parameter(tensor, requires_grad=False) + extra_attrs = attrs.get("extra_attrs", {}).copy() + extra_attrs.update(extra_weight_attrs) + param = prepare_moe_param(tensor, name, extra_attrs) + setattr(layer, f"{sublayer_name}_{name}", param) + + if self.force_input_schema is not None: + self.input_schema = self.force_input_schema + + locks = torch.zeros(1024, dtype=torch.int32) + layer.register_buffer("locks", locks) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if getattr(self, "processed", False): + return + self.processed = True + self.weight_schemas = {} + self.input_schemas = {} + for sublayer_name, configs in layer.sublayer_configs.items(): + input_schema = self.input_schema + weight_schema = self.weight_schema + # convert from checkpoint format to humming format + if not isinstance(weight_schema, HummingWeightSchema): + tensors: dict[str, torch.Tensor] = dict( + (key.removeprefix(sublayer_name + "_"), value) + for key, value in layer.state_dict().items() + if key.startswith(sublayer_name + "_") + ) + + shape_k_stacks = [configs["shape_k"]] + shape_n_stacks = [configs["shape_n"]] + if sublayer_name == "w13": + shape_n_stacks = [configs["shape_n"] // 2] * 2 + + weight_schema, tensors = weight_schema.convert_humming( + tensors=tensors, + shape_n_stacks=shape_n_stacks, + shape_k_stacks=shape_k_stacks, + param_dtype=layer.param_dtype, + num_experts=layer.num_experts, + ) + + input_schema, _ = input_schema.convert_humming( + tensors=tensors, + shape_n_stacks=shape_n_stacks, + shape_k_stacks=shape_k_stacks, + param_dtype=layer.param_dtype, + num_experts=layer.num_experts, + ) + + for name, _ in list(layer.named_parameters()): + if not name.startswith(sublayer_name + "_"): + continue + delattr(layer, name) + + for name, tensor in tensors.items(): + name = f"{sublayer_name}_{name}" + param = torch.nn.Parameter(tensor, requires_grad=False) + setattr(layer, name, param) + + self.weight_schemas[sublayer_name] = weight_schema + self.input_schemas[sublayer_name] = input_schema + + # force requant (origin quant setting -> fp16/bf16 -> new_quant setting) + assert isinstance(weight_schema, HummingWeightSchema) + force_requant = self.force_weight_schema is not None + if force_requant and weight_schema != self.force_weight_schema: + tensors = dict( + (key.removeprefix(sublayer_name + "_"), value) + for key, value in layer.state_dict().items() + if key.startswith(sublayer_name + "_") + ) + + tensors = weight_schema.requant_tensors( + tensors=tensors, + target_weight_schema=self.force_weight_schema, + param_dtype=layer.param_dtype, + ) + + weight_schema = self.force_weight_schema + + for name, _ in list(layer.named_parameters()): + if not name.startswith(sublayer_name + "_"): + continue + if name == sublayer_name + "_bias": + continue + delattr(layer, name) + + for name, tensor in tensors.items(): + name = f"{sublayer_name}_{name}" + param = torch.nn.Parameter(tensor, requires_grad=False) + setattr(layer, name, param) + + del tensors + + # prepare layer config from humming kernel + HummingMethod.prepare_layer_meta( + layer=layer, + shape_n=configs["shape_n"], + shape_k=configs["shape_k"], + pad_n_to_multiple=256, + pad_k_to_multiple=128, + input_schema=input_schema, + weight_schema=weight_schema, + has_bias=layer.with_bias, + num_experts=layer.num_experts, + torch_dtype=layer.param_dtype, + sublayer_name=sublayer_name, + ) + + # preprocess weight for inference + HummingMethod.transform_humming_layer(layer, sublayer_name=sublayer_name) + + def create_moe_runner( + self, + layer: torch.nn.Module, + moe_runner_config: MoeRunnerConfig, + ): + assert get_moe_runner_backend().is_auto() + moe_runner_config = dataclasses.replace(moe_runner_config, layer=layer) + self.runner = MoeRunner(MoeRunnerBackend.HUMMING, moe_runner_config) + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: "DispatchOutput", + ) -> "CombineInput": + from sglang.srt.layers.moe.moe_runner.humming import HummingMoeQuantInfo + + quant_info = HummingMoeQuantInfo() + return self.runner.run(dispatch_output, quant_info) diff --git a/python/sglang/srt/layers/quantization/humming_utils.py b/python/sglang/srt/layers/quantization/humming_utils.py new file mode 100644 index 000000000000..10af83e728c4 --- /dev/null +++ b/python/sglang/srt/layers/quantization/humming_utils.py @@ -0,0 +1,155 @@ +from typing import Any + +import regex as re +import torch +from humming.layer import HummingInputSchema, HummingMethod +from humming.schema import BaseWeightSchema +from sglang.srt.environ import envs +from sglang.srt.layers.linear import LinearBase +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE + + +def humming_is_layer_skipped(config: dict[str, Any], prefix: str): + if not config: + return True + + keys = ["ignored_layers", "ignore", "modules_to_not_convert"] + ignored_layers: list[str] = [] + for key in keys: + ignored_layers = config.get(key, []) or [] + if not ignored_layers: + break + + if any(module_name in prefix for module_name in ignored_layers): + return True + if "lm_head" in prefix: + return True + + for regex in config.get("dynamic", {}): + if regex[:1] != "-": + continue + if re.match(regex[2:], prefix): + return True + + return False + + +def prepare_humming_layer(layer: LinearBase, quant_config: dict): + weight_schema = BaseWeightSchema.from_config(quant_config) + input_schema = HummingInputSchema() + + shape_k_stacks = [layer.input_size_per_partition] + shape_n_stacks = layer.output_partition_sizes + + # Step 1: convert weight to humming standard format + weight_schema, tensors = weight_schema.convert_humming( + tensors=layer.named_parameters(), + shape_n_stacks=shape_n_stacks, + shape_k_stacks=shape_k_stacks, + param_dtype=layer.params_dtype, + ) + + layer.weight_schema = weight_schema + + for name, _ in list(layer.named_parameters()): + delattr(layer, name) + + for name, tensor in tensors.items(): + param = torch.nn.Parameter(tensor, requires_grad=False) + setattr(layer, name, param) + + # Step 2: transform weight (humming standard format) for forwarding + HummingMethod.prepare_layer_meta( + layer=layer, + shape_n=layer.output_partition_sizes_sum, + shape_k=layer.input_size_per_partition, + weight_schema=weight_schema, + input_schema=input_schema, + pad_n_to_multiple=256, + pad_k_to_multiple=128, + has_bias=layer.has_bias, + torch_dtype=layer.param_dtype, + ) + + HummingMethod.transform_humming_layer(layer) + + +def prepare_humming_moe_layer(layer: FusedMoE, quant_config: dict): + weight_schema = BaseWeightSchema.from_config(quant_config) + input_quant_config = envs.SGLANG_HUMMING_INPUT_QUANT_CONFIG.get() or {} + if humming_is_layer_skipped(input_quant_config, layer.layer_name): + input_schema = HummingInputSchema() + else: + # TODO: read input_quant_config from quant_config + input_schema = HummingInputSchema.from_config(input_quant_config) + + shape_config = { + "w13": ( + layer.intermediate_size_per_partition * 2, + layer.hidden_size, + ), + "w2": ( + layer.hidden_size, + layer.intermediate_size_per_partition, + ), + } + + layer.weight_schemas = {} + layer.input_schemas = {} + + for sublayer_name in shape_config: + # Step 1: convert weight to humming standard format + tensors: dict[str, torch.Tensor] = dict( + (key.removeprefix(sublayer_name + "_"), value) + for key, value in layer.state_dict().items() + if key.startswith(sublayer_name + "_") + ) + + shape_n, shape_k = shape_config[sublayer_name] + shape_n_stacks = [shape_n] + shape_k_stacks = [shape_k] + if sublayer_name == "w13": + shape_n_stacks = [shape_n // 2] * 2 + + weight_schema_new, tensors = weight_schema.convert_humming( + tensors=tensors, + shape_n_stacks=shape_n_stacks, + shape_k_stacks=shape_k_stacks, + num_experts=layer.num_local_experts, + param_dtype=layer.params_dtype, + ) + + layer.weight_schemas[sublayer_name] = weight_schema_new + layer.input_schemas[sublayer_name] = input_schema + + for name, _ in list(layer.named_parameters()): + if not name.startswith(sublayer_name + "_"): + continue + delattr(layer, name) + + for name, tensor in tensors.items(): + name = f"{sublayer_name}_{name}" + param = torch.nn.Parameter(tensor, requires_grad=False) + setattr(layer, name, param) + + # Step 2: transform weight (humming standard format) for forwarding + HummingMethod.prepare_layer_meta( + layer=layer, + shape_n=shape_n, + shape_k=shape_k, + pad_n_to_multiple=256, + pad_k_to_multiple=128, + input_schema=input_schema, + weight_schema=weight_schema_new, + has_bias=layer.with_bias, + num_experts=layer.num_local_experts, + torch_dtype=layer.params_dtype, + sublayer_name=sublayer_name, + ) + + HummingMethod.transform_humming_layer(layer, sublayer_name=sublayer_name) + + if not hasattr(layer, "locks"): + device = layer.w13_weight.device + locks = torch.zeros(1024, dtype=torch.int32, device=device) + layer.register_buffer("locks", locks) diff --git a/python/sglang/srt/layers/quantization/mxfp4_deepseek.py b/python/sglang/srt/layers/quantization/mxfp4_deepseek.py index 8705f395eb87..fa4030ee5a86 100644 --- a/python/sglang/srt/layers/quantization/mxfp4_deepseek.py +++ b/python/sglang/srt/layers/quantization/mxfp4_deepseek.py @@ -1,6 +1,7 @@ from __future__ import annotations +import dataclasses import logging from typing import TYPE_CHECKING @@ -143,6 +144,11 @@ def create_moe_runner(self, layer, moe_runner_config): from sglang.srt.layers.moe.moe_runner import MoeRunner self.runner = MoeRunner(MoeRunnerBackend.MARLIN, moe_runner_config) + elif self.moe_runner_backend.is_humming(): + from sglang.srt.layers.moe.moe_runner import MoeRunner + + moe_runner_config = dataclasses.replace(moe_runner_config, layer=layer) + self.runner = MoeRunner(MoeRunnerBackend.HUMMING, moe_runner_config) swiglu_limit = moe_runner_config.swiglu_limit is_2604b = envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B" @@ -233,7 +239,35 @@ def process_weights_after_loading(self, layer: Module) -> None: if getattr(layer, "_mega_moe_weights_built", False): return - if self.moe_runner_backend.is_marlin(): + if self.moe_runner_backend.is_humming(): + from sglang.srt.layers.quantization.humming_utils import ( + prepare_humming_moe_layer, + ) + + log_info_on_rank0( + logger, + f"Preparing DeepSeekV4 MXFP4 experts for Humming backend (layer: {self.prefix})...", + ) + + layer.register_parameter( + "w13_weight_scale", + Parameter( + layer.w13_weight_scale_inv.to(torch.float8_e8m0fnu), + requires_grad=False, + ), + ) + layer.register_parameter( + "w2_weight_scale", + Parameter( + layer.w2_weight_scale_inv.to(torch.float8_e8m0fnu), + requires_grad=False, + ), + ) + + del layer.w13_weight_scale_inv, layer.w2_weight_scale_inv + prepare_humming_moe_layer(layer, {"quant_method": "mxfp4"}) + return + elif self.moe_runner_backend.is_marlin(): from sglang.srt.layers.quantization.marlin_utils import ( check_moe_marlin_supports_layer, ) @@ -380,7 +414,12 @@ def apply( layer: Module, dispatch_output: DispatchOutput, ) -> CombineInput: - if self.moe_runner_backend.is_marlin(): + if self.moe_runner_backend.is_humming(): + from sglang.srt.layers.moe.moe_runner.humming import HummingMoeQuantInfo + + quant_info = HummingMoeQuantInfo() + return self.runner.run(dispatch_output, quant_info) + elif self.moe_runner_backend.is_marlin(): from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput from sglang.srt.layers.moe.topk import TopKOutputChecker @@ -536,4 +575,3 @@ def apply( output.mul_(rsf) return StandardCombineInput(hidden_states=output) - diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d48dab8757ab..b6ce51a1b425 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -112,6 +112,8 @@ "compressed-tensors", # for Ktransformers "modelslim", # for NPU "quark_int4fp8_moe", + "unquant", + "humming", ] SPECULATIVE_DRAFT_MODEL_QUANTIZATION_CHOICES = [*QUANTIZATION_CHOICES, "unquant"] @@ -182,6 +184,7 @@ "flashinfer_cutedsl", "cutlass", "marlin", + "humming", ] MOE_A2A_BACKEND_CHOICES = ["none", "deepep", "mooncake", "ascend_fuseep", "flashinfer"] diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 984ea74377a5..60a9e9d23900 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -296,6 +296,7 @@ set(SOURCES "csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu" "csrc/moe/marlin_moe_wna16/ops.cu" "csrc/moe/moe_align_kernel.cu" + "csrc/moe/moe_permute_prepare.cu" "csrc/moe/moe_fused_gate.cu" "csrc/moe/fused_qknorm_rope_kernel.cu" "csrc/moe/kimi_k2_moe_fused_gate.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 8274896795a5..c1804f8a5eb9 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -224,9 +224,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def( "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " "experts_ids, Tensor! num_tokens_post_pad, Tensor! cumsum_buffer, bool " - "pad_sorted_token_ids) -> ()"); + "pad_sorted_token_ids, bool ignore_invalid_expert) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + m.def("moe_permute_prepare(Tensor topk_ids, int num_experts, bool use_int64, bool is_ep) -> Tensor[]"); + m.impl("moe_permute_prepare", torch::kCUDA, &moe_permute_prepare); + m.def( "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize, float " "moe_softcapping, Tensor? correction_bias) -> ()"); diff --git a/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/csrc/moe/moe_align_kernel.cu index b7bad43da6c0..8e7e7b61c55e 100644 --- a/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -29,11 +29,13 @@ __global__ void count_and_sort_expert_tokens_kernel( const scalar_t* __restrict__ topk_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, - size_t numel) { + size_t numel, + bool ignore_invalid_expert) { const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (size_t i = tid; i < numel; i += stride) { + if (ignore_invalid_expert && topk_ids[i] < 0) continue; int32_t expert_id = topk_ids[i] + 1; int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); sorted_token_ids[rank_post_pad] = i; @@ -63,6 +65,7 @@ __global__ void moe_align_block_size_kernel( size_t numel, int32_t* __restrict__ cumsum, bool pad_sorted_token_ids, + bool ignore_invalid_expert, const int32_t scan_size, int32_t max_num_tokens_padded) { // Use a separate thread block to populate sorted_token_ids @@ -95,6 +98,7 @@ __global__ void moe_align_block_size_kernel( __syncthreads(); for (size_t i = tid; i < numel; i += stride) { + if (ignore_invalid_expert && topk_ids[i] < 0) continue; int expert_id = topk_ids[i] + 1; atomicAdd(&shared_counts[expert_id], 1); } @@ -211,6 +215,7 @@ __global__ void moe_align_block_size_kernel( // Write prefix[0..num_experts - 1] and cumsum if (tid < num_experts) prefix[tid] = scan_buf[tid]; #endif + __syncthreads(); if (tid <= num_experts) { cumsum[tid] = prefix[tid]; @@ -242,6 +247,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( int32_t block_size, size_t numel, bool pad_sorted_token_ids, + bool ignore_invalid_expert, int32_t max_num_tokens_padded) { // Adapted from // https://github.com/vllm-project/vllm/pull/29642/files#diff-5647b1413f4ae9aacba904eca8f8a8aee9079321eadff4c10101a2c6962dcc53R226 @@ -275,6 +281,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( } for (size_t i = tid; i < numel; i += stride) { + if (ignore_invalid_expert && topk_ids[i] < 0) continue; int32_t expert_id = topk_ids[i] + 1; ++tokens_cnts[(tid + 1) * num_experts + expert_id]; } @@ -307,6 +314,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( } for (size_t i = tid; i < numel; i += stride) { + if (ignore_invalid_expert && topk_ids[i] < 0) continue; int32_t expert_id = topk_ids[i] + 1; int32_t rank_post_pad = tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; sorted_token_ids[rank_post_pad] = i; @@ -322,7 +330,8 @@ void moe_align_block_size( torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, torch::Tensor cumsum_buffer, - bool pad_sorted_token_ids) { + bool pad_sorted_token_ids, + bool ignore_invalid_expert) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); int threads = 1024; @@ -348,6 +357,7 @@ void moe_align_block_size( block_size, topk_ids.numel(), pad_sorted_token_ids, + ignore_invalid_expert, max_num_tokens_padded); } else { auto align_kernel = moe_align_block_size_kernel; @@ -364,6 +374,7 @@ void moe_align_block_size( topk_ids.numel(), cumsum_buffer.data_ptr(), pad_sorted_token_ids, + ignore_invalid_expert, scan_size, max_num_tokens_padded); @@ -377,7 +388,8 @@ void moe_align_block_size( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), cumsum_buffer.data_ptr(), - topk_ids.numel()); + topk_ids.numel(), + ignore_invalid_expert); } }); } diff --git a/sgl-kernel/csrc/moe/moe_permute_prepare.cu b/sgl-kernel/csrc/moe/moe_permute_prepare.cu new file mode 100644 index 000000000000..ae9f2088f612 --- /dev/null +++ b/sgl-kernel/csrc/moe/moe_permute_prepare.cu @@ -0,0 +1,110 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "utils.h" + +// Binary search: find first index where sorted_topk_ids[index] >= target +__device__ __forceinline__ int32_t lower_bound( + const int32_t* __restrict__ data, int32_t n, int32_t target) { + int32_t lo = 0, hi = n; + while (lo < hi) { + int32_t mid = lo + (hi - lo) / 2; + if (data[mid] < target) { + lo = mid + 1; + } else { + hi = mid; + } + } + return lo; +} + +// All blocks cooperate on both expert_offsets (via binary search) and src2dst. +// No shared memory, no atomics. +__global__ void moe_permute_prepare_kernel( + const int32_t* __restrict__ sorted_topk_ids, + const int64_t* __restrict__ reorder_ids, + void* __restrict__ expert_offsets, + int32_t* __restrict__ src2dst, + int32_t num_experts, + int32_t numel, + bool use_int64_offset, + bool is_ep) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = gridDim.x * blockDim.x; + int32_t neg_count = 0; + if (is_ep) neg_count = lower_bound(sorted_topk_ids, numel, 0); + + // Compute expert_offsets via binary search + for (int e = tid; e <= num_experts; e += stride) { + int32_t offset; + if (e < num_experts) { + offset = lower_bound(sorted_topk_ids, numel, e) - neg_count; + } else { + offset = numel - neg_count; + } + + if (use_int64_offset) { + reinterpret_cast(expert_offsets)[e] = static_cast(offset); + } else { + reinterpret_cast(expert_offsets)[e] = offset; + } + } + + // Compute src2dst, skipping negative entries when is_ep + for (int i = tid; i < numel; i += stride) { + src2dst[reorder_ids[i]] = i - neg_count; + } +} + +std::vector moe_permute_prepare( + torch::Tensor topk_ids, + int64_t num_experts, + bool use_int64_offset, + bool is_ep) { + TORCH_CHECK(topk_ids.is_cuda(), "topk_ids must be a CUDA tensor"); + TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, "topk_ids must be int32"); + + const auto device = topk_ids.device(); + const at::cuda::CUDAGuard guard(device); + auto stream = at::cuda::getCurrentCUDAStream(); + + int32_t numel = topk_ids.numel(); + + // Sort using torch::sort + auto [sorted_topk_ids, reorder_ids] = torch::sort(topk_ids.flatten()); + + auto out_dtype = use_int64_offset ? at::ScalarType::Long : at::ScalarType::Int; + auto expert_offsets = torch::empty({num_experts + 1}, topk_ids.options().dtype(out_dtype)); + auto src2dst = torch::empty({numel}, topk_ids.options()); + + const int threads = 256; + int num_blocks = std::max(1, (std::max((int)numel, (int)num_experts + 1) + threads - 1) / threads); + + moe_permute_prepare_kernel<<>>( + sorted_topk_ids.data_ptr(), + reorder_ids.data_ptr(), + expert_offsets.data_ptr(), + src2dst.data_ptr(), + (int32_t)num_experts, + numel, + use_int64_offset, + is_ep); + + return {expert_offsets, src2dst}; +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 5e3cf24f9036..32a8378194bb 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -307,7 +307,14 @@ void moe_align_block_size( torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, torch::Tensor cumsum_buffer, - bool pad_sorted_token_ids); + bool pad_sorted_token_ids, + bool ignore_invalid_expert); + +std::vector moe_permute_prepare( + torch::Tensor topk_ids, + int64_t num_experts, + bool use_int64, + bool is_ep); void topk_softmax( torch::Tensor& topk_weights, diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 1b97ef94f02b..0c9c06b6dc2f 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -93,6 +93,7 @@ fused_qk_norm_rope, kimi_k2_moe_fused_gate, moe_align_block_size, + moe_permute_prepare, moe_fused_gate, moe_sum, moe_sum_reduce, diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index d85e4b602751..b1c67c59a3c7 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -12,6 +12,7 @@ def moe_align_block_size( num_tokens_post_pad, cumsum_buffer, pad_sorted_token_ids=False, + ignore_invalid_expert=False, ): torch.ops.sgl_kernel.moe_align_block_size.default( topk_ids, @@ -22,6 +23,21 @@ def moe_align_block_size( num_tokens_post_pad, cumsum_buffer, pad_sorted_token_ids, + ignore_invalid_expert, + ) + + +def moe_permute_prepare( + topk_ids: torch.Tensor, + num_experts: int, + use_int64_offset: bool = False, + is_ep: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops.sgl_kernel.moe_permute_prepare.default( + topk_ids, + num_experts, + use_int64_offset, + is_ep, )