Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions python/sglang/srt/layers/moe/moe_runner/marlin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional

import torch

from sglang.srt.layers.moe.moe_runner.base import (
MoeQuantInfo,
MoeRunnerConfig,
RunnerInput,
RunnerOutput,
register_fused_func,
)
from sglang.srt.layers.moe.utils import MoeRunnerBackend

if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
StandardCombineInput,
StandardDispatchOutput,
)

MARLIN_MOE_WORKSPACE: Optional[torch.Tensor] = None


@dataclass
class MarlinRunnerInput(RunnerInput):
"""Input bundle passed to the Marlin runner core."""

hidden_states: torch.Tensor
topk_weights: torch.Tensor
topk_ids: torch.Tensor
router_logits: torch.Tensor

@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.MARLIN


@dataclass
class MarlinRunnerOutput(RunnerOutput):
"""Output bundle returned from the Marlin runner core."""

hidden_states: torch.Tensor

@property
def runner_backend(self) -> MoeRunnerBackend:
return MoeRunnerBackend.MARLIN


@dataclass
class MarlinMoeQuantInfo(MoeQuantInfo):
"""Quantization payload consumed by the Marlin backend."""

w13_qweight: torch.Tensor
w2_qweight: torch.Tensor
w13_scales: torch.Tensor
w2_scales: torch.Tensor
w13_g_idx_sort_indices: Optional[torch.Tensor]
w2_g_idx_sort_indices: Optional[torch.Tensor]
weight_bits: int

# GPTQ specific (Optional)
w13_g_idx: Optional[torch.Tensor] = None
w2_g_idx: Optional[torch.Tensor] = None
is_k_full: bool = True

# AWQ specific (Optional)
w13_qzeros: Optional[torch.Tensor] = None
w2_qzeros: Optional[torch.Tensor] = None

# Optional
expert_map: Optional[torch.Tensor] = None


@register_fused_func("none", "marlin")
def fused_experts_none_to_marlin(
dispatch_output: StandardDispatchOutput,
quant_info: MarlinMoeQuantInfo,
runner_config: MoeRunnerConfig,
) -> StandardCombineInput:
global MARLIN_MOE_WORKSPACE
from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import fused_marlin_moe
from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
from sglang.srt.layers.quantization.marlin_utils import marlin_make_workspace

hidden_states = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output

assert runner_config.activation == "silu", "Only SiLU activation is supported."

if (
MARLIN_MOE_WORKSPACE is None
or MARLIN_MOE_WORKSPACE.device != hidden_states.device
):
MARLIN_MOE_WORKSPACE = marlin_make_workspace(
hidden_states.device, max_blocks_per_sm=4
)

output = fused_marlin_moe(
hidden_states=hidden_states,
w1=quant_info.w13_qweight,
w2=quant_info.w2_qweight,
w1_scale=quant_info.w13_scales,
w2_scale=quant_info.w2_scales,
gating_output=topk_output.router_logits,
topk_weights=topk_output.topk_weights,
topk_ids=topk_output.topk_ids,
expert_map=quant_info.expert_map,
g_idx1=quant_info.w13_g_idx,
g_idx2=quant_info.w2_g_idx,
sort_indices1=quant_info.w13_g_idx_sort_indices,
sort_indices2=quant_info.w2_g_idx_sort_indices,
w1_zeros=quant_info.w13_qzeros,
w2_zeros=quant_info.w2_qzeros,
workspace=MARLIN_MOE_WORKSPACE,
num_bits=quant_info.weight_bits,
is_k_full=quant_info.is_k_full,
inplace=runner_config.inplace,
routed_scaling_factor=runner_config.routed_scaling_factor,
).to(hidden_states.dtype)

return StandardCombineInput(
hidden_states=output,
)
2 changes: 2 additions & 0 deletions python/sglang/srt/layers/moe/moe_runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig):
self.runner_core = TritonKernelsRunnerCore(config)
elif runner_backend.is_deep_gemm():
self.runner_core = DeepGemmRunnerCore(config)
elif runner_backend.is_marlin():
self.runner_core = None # Marlin only supports fused path
else:
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")

Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/layers/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class MoeRunnerBackend(Enum):
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
CUTLASS = "cutlass"
MARLIN = "marlin"

def is_auto(self):
return self == MoeRunnerBackend.AUTO
Expand Down Expand Up @@ -87,6 +88,9 @@ def is_flashinfer_mxfp4(self):
def is_cutlass(self):
return self == MoeRunnerBackend.CUTLASS

def is_marlin(self):
return self == MoeRunnerBackend.MARLIN


class DeepEPMode(Enum):

Expand Down
57 changes: 22 additions & 35 deletions python/sglang/srt/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
npu_fused_experts,
)
from sglang.srt.layers.linear import LinearBase, set_weight_attrs
from sglang.srt.layers.moe import (
MoeRunner,
MoeRunnerBackend,
MoeRunnerConfig,
get_moe_runner_backend,
)
from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
Expand All @@ -37,10 +44,9 @@
from sglang.srt.utils.patch_torch import register_fake_if_exists

if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.token_dispatcher import (
StandardDispatchOutput,
CombineInput,
StandardDispatchOutput,
)

from sglang.srt.utils import is_cuda, is_hip, is_npu, is_xpu
Expand Down Expand Up @@ -753,10 +759,6 @@ def create_weights(
layer.register_parameter("w2_qzeros", w2_qzeros)
set_weight_attrs(w2_qzeros, extra_weight_attrs)

device = layer.w13_qweight.device
if not _is_npu:
layer.workspace = marlin_make_workspace(device, 4)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
device = layer.w13_qweight.device
Expand Down Expand Up @@ -825,44 +827,29 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
assert get_moe_runner_backend().is_auto()
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.MARLIN, moe_runner_config)

def apply(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import (
fused_marlin_moe,
)
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput

assert (
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."

x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
orig_dtype = x.dtype

topk_weights, topk_ids, router_logits = topk_output
quant_info = MarlinMoeQuantInfo(
w13_qweight=layer.w13_qweight,
w2_qweight=layer.w2_qweight,
w13_scales=layer.w13_scales,
w2_scales=layer.w2_scales,
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
w13_qzeros=layer.w13_qzeros,
w2_qzeros=layer.w2_qzeros,
weight_bits=self.quant_config.weight_bits,
)

output = fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
layer.w13_scales,
layer.w2_scales,
router_logits,
topk_weights,
topk_ids,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros,
num_bits=self.quant_config.weight_bits,
).to(orig_dtype)
return StandardCombineInput(hidden_states=output)
return self.runner.run(dispatch_output, quant_info)


class AWQMoEAscendMethod(AWQMoEMethod):
Expand Down
55 changes: 21 additions & 34 deletions python/sglang/srt/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@

import torch

from sglang.srt.layers.moe import (
MoeRunner,
MoeRunnerBackend,
MoeRunnerConfig,
get_moe_runner_backend,
)
from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo
from sglang.srt.layers.parameter import (
BasevLLMParameter,
ChannelQuantScaleParameter,
Expand Down Expand Up @@ -46,7 +53,6 @@
from sglang.srt.utils.patch_torch import register_fake_if_exists

if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
Expand Down Expand Up @@ -1052,48 +1058,29 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
assert get_moe_runner_backend().is_auto()
self.moe_runner_config = moe_runner_config
self.runner = MoeRunner(MoeRunnerBackend.MARLIN, moe_runner_config)

def apply(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> CombineInput:
from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import (
fused_marlin_moe,
quant_info = MarlinMoeQuantInfo(
w13_qweight=layer.w13_qweight,
w2_qweight=layer.w2_qweight,
w13_scales=layer.w13_scales,
w2_scales=layer.w2_scales,
w13_g_idx=layer.w13_g_idx,
w2_g_idx=layer.w2_g_idx,
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
weight_bits=self.quant_config.weight_bits,
is_k_full=self.is_k_full,
)
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput

x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output

assert (
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."

# The input must currently be float16
orig_dtype = x.dtype
x = x.half()

topk_weights, topk_ids, router_logits = topk_output

output = fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,
layer.w13_scales,
layer.w2_scales,
router_logits,
topk_weights,
topk_ids,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.quant_config.weight_bits,
is_k_full=self.is_k_full,
).to(orig_dtype)
return StandardCombineInput(hidden_states=output)
return self.runner.run(dispatch_output, quant_info)


# Register fake implementations for torch.compile support
Expand Down
33 changes: 33 additions & 0 deletions test/srt/quant/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,38 @@ def test_mmlu(self):
self.assertGreater(metrics["score"], 0.85)


class TestAWQMarlinFloat16(CustomTestCase):
"""
Verify that the model can be loaded with float16 dtype and awq_marlin quantization
"""

@classmethod
def setUpClass(cls):
cls.model = "QuantTrio/Qwen3-VL-30B-A3B-Instruct-AWQ"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--dtype", "float16", "--quantization", "awq_marlin"],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)

metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.85)


if __name__ == "__main__":
unittest.main()
Loading