Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
6c29a1d
Refactor AWQ into schemes and backend kernels
Alisehen Mar 21, 2026
0618e35
bug fix
Alisehen Mar 21, 2026
90b4283
reduce extra method
Alisehen Mar 22, 2026
cdde14e
Merge branch 'sgl-project:main' into hyc/awq-scheme-refactor
Alisehen Mar 22, 2026
cf8bfeb
Merge remote-tracking branch 'origin/main' into hyc/awq-scheme-refactor
Alisehen Mar 22, 2026
571b834
bug fix
Alisehen Mar 22, 2026
9d89014
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen Mar 22, 2026
0abe523
Merge branch 'main' of https://github.com/sgl-project/sglang into hyc…
Alisehen Apr 9, 2026
2fd568d
follow the reviews
Alisehen Apr 9, 2026
7fd5b63
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen Apr 9, 2026
ffa6ed5
fix ci
Alisehen Apr 15, 2026
c6bde30
Merge branch 'hyc/awq-scheme-refactor' of https://github.com/Alisehen…
Alisehen Apr 15, 2026
1cfb42f
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen Apr 15, 2026
4ee2dda
fix
Alisehen Apr 15, 2026
ab79808
fix
Alisehen Apr 15, 2026
f37729c
fix
Alisehen Apr 15, 2026
0ed8dfa
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen Apr 15, 2026
336a3cf
Merge branch 'main' of https://github.com/sgl-project/sglang into hyc…
Alisehen Apr 26, 2026
b1bf125
Fix AWQ scheme import ordering
Alisehen Apr 26, 2026
27b39bc
Refactor AWQ CPU path into schemes
Alisehen Apr 27, 2026
c98b0f6
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen Apr 27, 2026
37b08bb
Merge remote-tracking branch 'origin/hyc/awq-scheme-refactor' into hy…
Alisehen Apr 27, 2026
ff7c24f
Format AWQ marlin fallback
Alisehen Apr 27, 2026
9751bab
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen Apr 27, 2026
178c91a
Address AWQ scheme review feedback
Alisehen Apr 28, 2026
b42cde5
Merge remote-tracking branch 'origin/hyc/awq-scheme-refactor' into hy…
Alisehen Apr 28, 2026
4350cb6
Fix AWQ dequant test import
Alisehen Apr 28, 2026
e04e8f6
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen Apr 28, 2026
825b704
bug fix
Alisehen Apr 28, 2026
aee44f8
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen Apr 28, 2026
c81189d
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen Apr 28, 2026
0fdcebb
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen Apr 29, 2026
e93958a
Fix AWQ scheme setup and MoE native swiglu import
Alisehen Apr 29, 2026
c6255cb
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen Apr 29, 2026
e5088b5
Merge branch 'main' into hyc/awq-scheme-refactor
Alisehen Apr 29, 2026
513bbea
Merge branch 'main' of https://github.com/sgl-project/sglang into hyc…
Alisehen Apr 30, 2026
1cc3126
Merge remote-tracking branch 'origin/hyc/awq-scheme-refactor' into hy…
Alisehen Apr 30, 2026
1b0145c
Use upstream GPT-OSS SwiGLU helper
Alisehen Apr 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 255 additions & 0 deletions python/sglang/srt/hardware_backend/gpu/quantization/awq_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional

import torch

from sglang.srt.layers.moe import MoeRunner
from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo
from sglang.srt.layers.quantization.marlin_utils import (
apply_awq_marlin_linear,
awq_to_marlin_zero_points,
marlin_make_empty_g_idx,
marlin_make_workspace,
marlin_moe_permute_scales,
marlin_permute_scales,
moe_awq_to_marlin_zero_points,
)
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
from sglang.srt.utils import is_hip, is_xpu

if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
CombineInput,
StandardDispatchOutput,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig

awq_marlin_moe_repack = None
awq_marlin_repack = None


def _unsupported_awq_dequantize(*args, **kwargs):
raise RuntimeError("AWQ GPU kernels are unavailable on the current platform.")


awq_dequantize = _unsupported_awq_dequantize

if is_xpu():
try:
from sgl_kernel import awq_dequantize
except ImportError:
pass
elif is_hip():
try:
from sglang.srt.layers.quantization.awq.awq_triton import (
awq_dequantize_triton as awq_dequantize,
)
except ImportError:
pass
else:
try:
from sglang.jit_kernel.awq_dequantize import awq_dequantize
from sglang.jit_kernel.awq_marlin_repack import (
awq_marlin_moe_repack,
awq_marlin_repack,
)
from sglang.srt.utils.custom_op import register_custom_op_from_extern

awq_dequantize = register_custom_op_from_extern(
awq_dequantize,
fake_impl=lambda qweight, scales, qzeros: qweight.new_empty(
qweight.shape[:-1] + (qweight.shape[-1] * 8,), dtype=scales.dtype
),
)
except ImportError:
try:
from sglang.srt.layers.quantization.awq.awq_triton import (
awq_dequantize_triton as awq_dequantize,
)
except ImportError:
try:
from sgl_kernel import awq_dequantize
except ImportError:
pass

_, scalar_types = get_scalar_types()


class AWQLinearKernel:
def __init__(self, quant_config: Optional["QuantizationConfig"] = None):
self.quant_config = quant_config

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales
qzeros = layer.qzeros
pack_factor = self.quant_config.pack_factor
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
reshaped_x = x.reshape(-1, x.shape[-1])
out = awq_dequantize(qweight, scales, qzeros)
out = torch.matmul(reshaped_x, out)

if bias is not None:
out.add_(bias)
return out.reshape(out_shape)


class AWQMarlinLinearKernel:
def __init__(self, quant_config: Optional["QuantizationConfig"] = None):
self.quant_config = quant_config

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)

layer.workspace = marlin_make_workspace(device)

marlin_qweight = awq_marlin_repack(
layer.qweight,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
)
replace_parameter(layer, "qweight", marlin_qweight)

marlin_scales = marlin_permute_scales(
layer.scales,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "scales", marlin_scales)

marlin_zp = awq_to_marlin_zero_points(
layer.qzeros,
size_k=layer.num_groups,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
)
replace_parameter(layer, "qzeros", marlin_zp)

layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return apply_awq_marlin_linear(
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
weight_zp=layer.qzeros,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
bias=bias,
)


class AWQMoEKernel:
def __init__(self, quant_config: Optional["QuantizationConfig"] = None):
self.quant_config = quant_config
self.runner: Optional[MoeRunner] = None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
device = layer.w13_qweight.device

layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)

marlin_w13_qweight = awq_marlin_moe_repack(
layer.w13_qweight,
layer.w13_g_idx_sort_indices,
size_k=layer.w13_qweight.shape[1],
size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)

marlin_w2_qweight = awq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
size_k=layer.w2_qweight.shape[1],
size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)

marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
size_k=layer.intermediate_size_per_partition,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The size_k parameter for marlin_moe_permute_scales appears to be incorrect for w13_scales. The k dimension of the w13 weight matrix is hidden_size, and the scales are grouped along this dimension. Therefore, size_k should be layer.hidden_size instead of layer.intermediate_size_per_partition.

Suggested change
size_k=layer.intermediate_size_per_partition,
size_k=layer.hidden_size,

size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)

marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.intermediate_size_per_partition,
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)

marlin_w13_zp = moe_awq_to_marlin_zero_points(
layer.w13_qzeros,
size_k=layer.w13_qzeros.shape[1],
size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w13_qzeros", marlin_w13_zp)

marlin_w2_zp = moe_awq_to_marlin_zero_points(
layer.w2_qzeros,
size_k=layer.w2_qzeros.shape[1],
size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor,
num_bits=self.quant_config.weight_bits,
)
replace_parameter(layer, "w2_qzeros", marlin_w2_zp)

def apply(
self,
layer: torch.nn.Module,
dispatch_output: "StandardDispatchOutput",
) -> "CombineInput":
if self.runner is None:
raise RuntimeError("moe runner is not initialized")

quant_info = MarlinMoeQuantInfo(
w13_qweight=layer.w13_qweight,
w2_qweight=layer.w2_qweight,
w13_scales=layer.w13_scales,
w2_scales=layer.w2_scales,
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
w13_qzeros=layer.w13_qzeros,
w2_qzeros=layer.w2_qzeros,
weight_bits=self.quant_config.weight_bits,
)
return self.runner.run(dispatch_output, quant_info)
Loading
Loading