diff --git a/python/sglang/srt/layers/moe/moe_runner/marlin.py b/python/sglang/srt/layers/moe/moe_runner/marlin.py new file mode 100644 index 000000000000..45104dd27805 --- /dev/null +++ b/python/sglang/srt/layers/moe/moe_runner/marlin.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +import torch + +from sglang.srt.layers.moe.moe_runner.base import ( + MoeQuantInfo, + MoeRunnerConfig, + RunnerInput, + RunnerOutput, + register_fused_func, +) +from sglang.srt.layers.moe.utils import MoeRunnerBackend + +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import ( + StandardCombineInput, + StandardDispatchOutput, + ) + +MARLIN_MOE_WORKSPACE: Optional[torch.Tensor] = None + + +@dataclass +class MarlinRunnerInput(RunnerInput): + """Input bundle passed to the Marlin runner core.""" + + hidden_states: torch.Tensor + topk_weights: torch.Tensor + topk_ids: torch.Tensor + router_logits: torch.Tensor + + @property + def runner_backend(self) -> MoeRunnerBackend: + return MoeRunnerBackend.MARLIN + + +@dataclass +class MarlinRunnerOutput(RunnerOutput): + """Output bundle returned from the Marlin runner core.""" + + hidden_states: torch.Tensor + + @property + def runner_backend(self) -> MoeRunnerBackend: + return MoeRunnerBackend.MARLIN + + +@dataclass +class MarlinMoeQuantInfo(MoeQuantInfo): + """Quantization payload consumed by the Marlin backend.""" + + w13_qweight: torch.Tensor + w2_qweight: torch.Tensor + w13_scales: torch.Tensor + w2_scales: torch.Tensor + w13_g_idx_sort_indices: Optional[torch.Tensor] + w2_g_idx_sort_indices: Optional[torch.Tensor] + weight_bits: int + + # GPTQ specific (Optional) + w13_g_idx: Optional[torch.Tensor] = None + w2_g_idx: Optional[torch.Tensor] = None + is_k_full: bool = True + + # AWQ specific (Optional) + w13_qzeros: Optional[torch.Tensor] = None + w2_qzeros: Optional[torch.Tensor] = None + + # Optional + expert_map: Optional[torch.Tensor] = None + + +@register_fused_func("none", "marlin") +def fused_experts_none_to_marlin( + dispatch_output: StandardDispatchOutput, + quant_info: MarlinMoeQuantInfo, + runner_config: MoeRunnerConfig, +) -> StandardCombineInput: + global MARLIN_MOE_WORKSPACE + from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import fused_marlin_moe + from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput + from sglang.srt.layers.quantization.marlin_utils import marlin_make_workspace + + hidden_states = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + assert runner_config.activation == "silu", "Only SiLU activation is supported." + + if ( + MARLIN_MOE_WORKSPACE is None + or MARLIN_MOE_WORKSPACE.device != hidden_states.device + ): + MARLIN_MOE_WORKSPACE = marlin_make_workspace( + hidden_states.device, max_blocks_per_sm=4 + ) + + output = fused_marlin_moe( + hidden_states=hidden_states, + w1=quant_info.w13_qweight, + w2=quant_info.w2_qweight, + w1_scale=quant_info.w13_scales, + w2_scale=quant_info.w2_scales, + gating_output=topk_output.router_logits, + topk_weights=topk_output.topk_weights, + topk_ids=topk_output.topk_ids, + expert_map=quant_info.expert_map, + g_idx1=quant_info.w13_g_idx, + g_idx2=quant_info.w2_g_idx, + sort_indices1=quant_info.w13_g_idx_sort_indices, + sort_indices2=quant_info.w2_g_idx_sort_indices, + w1_zeros=quant_info.w13_qzeros, + w2_zeros=quant_info.w2_qzeros, + workspace=MARLIN_MOE_WORKSPACE, + num_bits=quant_info.weight_bits, + is_k_full=quant_info.is_k_full, + inplace=runner_config.inplace, + routed_scaling_factor=runner_config.routed_scaling_factor, + ).to(hidden_states.dtype) + + return StandardCombineInput( + hidden_states=output, + ) diff --git a/python/sglang/srt/layers/moe/moe_runner/runner.py b/python/sglang/srt/layers/moe/moe_runner/runner.py index fa0fd2559ed6..fde68df941ae 100644 --- a/python/sglang/srt/layers/moe/moe_runner/runner.py +++ b/python/sglang/srt/layers/moe/moe_runner/runner.py @@ -37,6 +37,8 @@ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig): self.runner_core = TritonKernelsRunnerCore(config) elif runner_backend.is_deep_gemm(): self.runner_core = DeepGemmRunnerCore(config) + elif runner_backend.is_marlin(): + self.runner_core = None # Marlin only supports fused path else: raise NotImplementedError(f"Unsupported runner backend: {runner_backend}") diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index cd85fc2f2656..a220318a9258 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -59,6 +59,7 @@ class MoeRunnerBackend(Enum): FLASHINFER_MXFP4 = "flashinfer_mxfp4" FLASHINFER_CUTEDSL = "flashinfer_cutedsl" CUTLASS = "cutlass" + MARLIN = "marlin" def is_auto(self): return self == MoeRunnerBackend.AUTO @@ -87,6 +88,9 @@ def is_flashinfer_mxfp4(self): def is_cutlass(self): return self == MoeRunnerBackend.CUTLASS + def is_marlin(self): + return self == MoeRunnerBackend.MARLIN + class DeepEPMode(Enum): diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index ab3105ec85e8..5497900a0ce3 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -11,6 +11,13 @@ npu_fused_experts, ) from sglang.srt.layers.linear import LinearBase, set_weight_attrs +from sglang.srt.layers.moe import ( + MoeRunner, + MoeRunnerBackend, + MoeRunnerConfig, + get_moe_runner_backend, +) +from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, @@ -37,10 +44,9 @@ from sglang.srt.utils.patch_torch import register_fake_if_exists if TYPE_CHECKING: - from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.token_dispatcher import ( - StandardDispatchOutput, CombineInput, + StandardDispatchOutput, ) from sglang.srt.utils import is_cuda, is_hip, is_npu, is_xpu @@ -753,10 +759,6 @@ def create_weights( layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) - device = layer.w13_qweight.device - if not _is_npu: - layer.workspace = marlin_make_workspace(device, 4) - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts = layer.w13_qweight.shape[0] device = layer.w13_qweight.device @@ -825,44 +827,29 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): + assert get_moe_runner_backend().is_auto() self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.MARLIN, moe_runner_config) def apply( self, layer: torch.nn.Module, dispatch_output: StandardDispatchOutput, ) -> CombineInput: - from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import ( - fused_marlin_moe, - ) - from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput - assert ( - self.moe_runner_config.activation == "silu" - ), "Only SiLU activation is supported." - - x = dispatch_output.hidden_states - topk_output = dispatch_output.topk_output - orig_dtype = x.dtype - - topk_weights, topk_ids, router_logits = topk_output + quant_info = MarlinMoeQuantInfo( + w13_qweight=layer.w13_qweight, + w2_qweight=layer.w2_qweight, + w13_scales=layer.w13_scales, + w2_scales=layer.w2_scales, + w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices, + w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices, + w13_qzeros=layer.w13_qzeros, + w2_qzeros=layer.w2_qzeros, + weight_bits=self.quant_config.weight_bits, + ) - output = fused_marlin_moe( - x, - layer.w13_qweight, - layer.w2_qweight, - layer.w13_scales, - layer.w2_scales, - router_logits, - topk_weights, - topk_ids, - sort_indices1=layer.w13_g_idx_sort_indices, - sort_indices2=layer.w2_g_idx_sort_indices, - w1_zeros=layer.w13_qzeros, - w2_zeros=layer.w2_qzeros, - num_bits=self.quant_config.weight_bits, - ).to(orig_dtype) - return StandardCombineInput(hidden_states=output) + return self.runner.run(dispatch_output, quant_info) class AWQMoEAscendMethod(AWQMoEMethod): diff --git a/python/sglang/srt/layers/quantization/gptq.py b/python/sglang/srt/layers/quantization/gptq.py index 9d52bf30cc5d..74f64174ce2d 100644 --- a/python/sglang/srt/layers/quantization/gptq.py +++ b/python/sglang/srt/layers/quantization/gptq.py @@ -7,6 +7,13 @@ import torch +from sglang.srt.layers.moe import ( + MoeRunner, + MoeRunnerBackend, + MoeRunnerConfig, + get_moe_runner_backend, +) +from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo from sglang.srt.layers.parameter import ( BasevLLMParameter, ChannelQuantScaleParameter, @@ -46,7 +53,6 @@ from sglang.srt.utils.patch_torch import register_fake_if_exists if TYPE_CHECKING: - from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.token_dispatcher import ( CombineInput, StandardDispatchOutput, @@ -1052,48 +1058,29 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): + assert get_moe_runner_backend().is_auto() self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.MARLIN, moe_runner_config) def apply( self, layer: torch.nn.Module, dispatch_output: StandardDispatchOutput, ) -> CombineInput: - from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import ( - fused_marlin_moe, + quant_info = MarlinMoeQuantInfo( + w13_qweight=layer.w13_qweight, + w2_qweight=layer.w2_qweight, + w13_scales=layer.w13_scales, + w2_scales=layer.w2_scales, + w13_g_idx=layer.w13_g_idx, + w2_g_idx=layer.w2_g_idx, + w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices, + w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices, + weight_bits=self.quant_config.weight_bits, + is_k_full=self.is_k_full, ) - from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput - - x = dispatch_output.hidden_states - topk_output = dispatch_output.topk_output - - assert ( - self.moe_runner_config.activation == "silu" - ), "Only SiLU activation is supported." - - # The input must currently be float16 - orig_dtype = x.dtype - x = x.half() - topk_weights, topk_ids, router_logits = topk_output - - output = fused_marlin_moe( - x, - layer.w13_qweight, - layer.w2_qweight, - layer.w13_scales, - layer.w2_scales, - router_logits, - topk_weights, - topk_ids, - g_idx1=layer.w13_g_idx, - g_idx2=layer.w2_g_idx, - sort_indices1=layer.w13_g_idx_sort_indices, - sort_indices2=layer.w2_g_idx_sort_indices, - num_bits=self.quant_config.weight_bits, - is_k_full=self.is_k_full, - ).to(orig_dtype) - return StandardCombineInput(hidden_states=output) + return self.runner.run(dispatch_output, quant_info) # Register fake implementations for torch.compile support diff --git a/test/srt/quant/test_awq.py b/test/srt/quant/test_awq.py index 63254935b363..7eb936284660 100644 --- a/test/srt/quant/test_awq.py +++ b/test/srt/quant/test_awq.py @@ -74,5 +74,38 @@ def test_mmlu(self): self.assertGreater(metrics["score"], 0.85) +class TestAWQMarlinFloat16(CustomTestCase): + """ + Verify that the model can be loaded with float16 dtype and awq_marlin quantization + """ + + @classmethod + def setUpClass(cls): + cls.model = "QuantTrio/Qwen3-VL-30B-A3B-Instruct-AWQ" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--dtype", "float16", "--quantization", "awq_marlin"], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.85) + + if __name__ == "__main__": unittest.main()