Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
2ca9cb7
add awq quantization to ascend backend
Alisehen Sep 8, 2025
c74b35b
format
ErvinXie Sep 8, 2025
b009a54
code fix
Alisehen Sep 8, 2025
b865bce
Merge commit 'HEAD@{1}' into awq
Alisehen Sep 8, 2025
6857687
Merge branch 'main' into awq
Alisehen Sep 10, 2025
4515d25
format
ErvinXie Sep 16, 2025
e8cb970
Merge branch 'main' into awq
ErvinXie Sep 16, 2025
281d084
refact npu_fused_experts
ErvinXie Sep 17, 2025
a2c4e02
Merge branch 'awq' of https://github.com/kvcache-ai/sglang_awq into awq
ErvinXie Sep 17, 2025
7c30f09
minor
ErvinXie Sep 17, 2025
0744852
minor
ErvinXie Sep 17, 2025
a32ed84
minor
ErvinXie Sep 17, 2025
a3d9077
minor
ErvinXie Sep 17, 2025
98f63d0
Merge branch 'main' into awq
ErvinXie Sep 28, 2025
7edc183
Merge remote-tracking branch 'upstream/main' into awq
ZhengdQin Sep 28, 2025
44732a9
ci bug fix
ZhengdQin Sep 28, 2025
a2770f2
bug fix
ZhengdQin Sep 28, 2025
436e141
Merge branch 'main' into awq
Alisehen Sep 28, 2025
eda1ed7
Merge branch 'main' into awq
Alisehen Sep 29, 2025
e9530b2
Merge branch 'main' into awq
Alisehen Sep 29, 2025
1e38fe5
Merge branch 'main' into awq
Alisehen Oct 1, 2025
90ae315
merge main
ZhengdQin Oct 11, 2025
21ce362
Merge branch 'main' into awq
Alisehen Oct 11, 2025
709f011
Merge branch 'sgl-project:main' into awq
Alisehen Oct 11, 2025
b289972
Merge branch 'main' into awq
Alisehen Oct 11, 2025
af799de
ci fix
Alisehen Oct 11, 2025
2c00cf3
Merge branch 'awq' of github.com:kvcache-ai/sglang_awq into awq
Alisehen Oct 11, 2025
31c6ec7
chore: apply pre-commit autofix (trailing whitespace)
Alisehen Oct 11, 2025
6d62983
Merge branch 'main' into awq
Alisehen Oct 12, 2025
599d004
Merge branch 'main' into awq
Alisehen Oct 17, 2025
385e194
Merge branch 'main' into awq
ErvinXie Oct 18, 2025
7570fd2
format
ErvinXie Oct 18, 2025
168faac
format
ErvinXie Oct 18, 2025
71a8ec2
Merge branch 'main' into awq
ErvinXie Oct 18, 2025
8fc5424
Merge branch 'main' into awq
ErvinXie Oct 20, 2025
de5e2d2
Merge branch 'main' into awq
ErvinXie Oct 21, 2025
0b24a48
Merge remote-tracking branch 'upstream/main' into awq
Alisehen Oct 21, 2025
0bdb55b
format fix
Alisehen Oct 21, 2025
630f76d
Merge branch 'main' into awq
ErvinXie Oct 21, 2025
ed55e68
Merge branch 'main' into awq
ErvinXie Oct 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"CompressedTensorsLinearMethod",
"AWQMarlinLinearMethod",
"AWQLinearMethod",
"AWQLinearAscendMethod",
"GPTQMarlinLinearMethod",
"Fp8LinearMethod",
"BlockInt8LinearMethod",
Expand Down
180 changes: 176 additions & 4 deletions python/sglang/srt/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
from sglang.srt.layers.quantization.w8a8_int8 import npu_fused_experts

if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
Expand All @@ -39,11 +40,16 @@
CombineInput,
)

from sglang.srt.utils import is_cuda, is_hip, is_xpu
from sglang.srt.utils import is_cuda, is_hip, is_npu, is_xpu

_is_cuda = is_cuda()
_is_hip = is_hip()
_is_xpu = is_xpu()
_is_npu = is_npu()

if _is_npu:
import torch_npu

if _is_cuda:
from sgl_kernel import (
awq_dequantize,
Expand Down Expand Up @@ -117,12 +123,17 @@ def get_name(self) -> str:
return "awq"

def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
return [torch.float16] if not _is_npu else [torch.float16, torch.bfloat16]

@classmethod
def get_min_capability(cls) -> int:
# The AWQ kernel only supports Turing or newer GPUs.
return 75
if _is_npu:
raise NotImplementedError(
'NPU hardware does not support "get_min_capability" feature.'
)
else:
return 75

@staticmethod
def get_config_filenames() -> List[str]:
Expand All @@ -146,6 +157,16 @@ def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[LinearMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE

if _is_npu:
if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return AWQLinearAscendMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEAscendMethod(self)
return None

if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
Expand Down Expand Up @@ -575,6 +596,64 @@ def apply(
)


class AWQLinearAscendMethod(AWQLinearMethod):
"""Linear method for AWQ on Ascend.

Args:
quant_config: The AWQ quantization config.
"""

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
qweight_tmp = torch.zeros_like(layer.qweight.data)
qzeros_tmp = layer.qzeros.data
qzeros_list = []
shifts = [0, 4, 1, 5, 2, 6, 3, 7]

for i in range(0, self.quant_config.pack_factor):
shift_num = shifts[i] * 4
qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF)
qweight_tmp.bitwise_or_(
((layer.qweight.data >> shift_num) * (2 ** (4 * i))) & (0xF << (4 * i))
)
Comment on lines +616 to +618
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The bitwise operation used for repacking weights is functionally correct but unnecessarily complex and hard to read. Using (2 ** (4 * i)) for left-shifting and then masking can be simplified. A more direct and readable approach is to first mask the desired nibble with & 0xF and then shift it to its new position. This improves code clarity and maintainability.

Suggested change
qweight_tmp.bitwise_or_(
((layer.qweight.data >> shift_num) * (2 ** (4 * i))) & (0xF << (4 * i))
)
qweight_tmp.bitwise_or_(
(((layer.qweight.data >> shift_num) & 0xF) << (4 * i))
)


qweight_tmp.bitwise_xor_(0x88888888)

qzeros_tmp = torch.cat(qzeros_list, dim=-1).reshape(qzeros_tmp.shape[0], -1)
qzeros_tmp = -(qzeros_tmp - 8)
qzeros_tmp = qzeros_tmp.to(layer.scales.data.dtype)

layer.qzeros = torch.nn.Parameter(qzeros_tmp, requires_grad=False)
layer.qweight = torch.nn.Parameter(qweight_tmp, requires_grad=False)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales
qzeros = layer.qzeros
pack_factor = self.quant_config.pack_factor
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
reshaped_x = x.reshape(-1, x.shape[-1])

if bias is not None and bias.dtype == torch.bfloat16:
bias = bias.float()

out = torch_npu.npu_weight_quant_batchmatmul(
reshaped_x,
qweight,
antiquant_scale=scales,
antiquant_offset=qzeros,
antiquant_group_size=self.quant_config.group_size,
bias=bias,
)

return out.reshape(out_shape)


class AWQMoEMethod(FusedMoEMethodBase):

def __init__(self, quant_config: AWQMarlinConfig):
Expand Down Expand Up @@ -677,7 +756,8 @@ def create_weights(
set_weight_attrs(w2_qzeros, extra_weight_attrs)

device = layer.w13_qweight.device
layer.workspace = marlin_make_workspace(device, 4)
if not _is_npu:
layer.workspace = marlin_make_workspace(device, 4)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
Expand Down Expand Up @@ -785,3 +865,95 @@ def apply(
num_bits=self.quant_config.weight_bits,
).to(orig_dtype)
return StandardCombineInput(hidden_states=output)


class AWQMoEAscendMethod(AWQMoEMethod):
def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config
Comment on lines +871 to +872
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The __init__ method of AWQMoEAscendMethod does not initialize its parent class AWQMoEMethod. This can lead to an improperly initialized object, as attributes set in the parent's __init__ (like self.quant_type) will be missing. While AWQMoEAscendMethod is specific to Ascend and AWQMoEMethod is for Marlin, inheriting methods like create_weights implies a need for proper parent initialization.

Given that AWQMoEAscendMethod is instantiated with an AWQConfig and not an AWQMarlinConfig, a direct super().__init__() call would cause a type error. A better approach would be to replicate the necessary initialization logic from the parent.

Suggested change
def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config
def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config
if self.quant_config.weight_bits != 4:
raise ValueError(f"{type(self).__name__} only supports 4bit now.")
self.quant_type = scalar_types.uint4


def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w13_qweight_tmp = torch.zeros_like(layer.w13_qweight.data)
w2_qweight_tmp = torch.zeros_like(layer.w2_qweight.data)
w13_qzeros_list = []
w2_qzeros_list = []
shifts = [0, 4, 1, 5, 2, 6, 3, 7]
for i in range(0, self.quant_config.pack_factor):
shift_num = shifts[i] * 4
w13_qzeros_list.append(
(layer.w13_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF
)
w2_qzeros_list.append(
(layer.w2_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF
)
w13_qweight_tmp.bitwise_or_(
((layer.w13_qweight.data >> shift_num) * (2 ** (4 * i)))
& (0xF << (4 * i))
)
w2_qweight_tmp.bitwise_or_(
((layer.w2_qweight.data >> shift_num) * (2 ** (4 * i)))
& (0xF << (4 * i))
)
Comment on lines +888 to +895
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to AWQLinearAscendMethod, the bitwise operation here for repacking weights is unnecessarily complex. Using * (2 ** (4 * i)) for left-shifting is less clear than using the left-shift operator << after masking the desired nibble. Simplifying this expression will improve code readability and maintainability.

            w13_qweight_tmp.bitwise_or_(
                (((layer.w13_qweight.data >> shift_num) & 0xF) << (4 * i))
            )
            w2_qweight_tmp.bitwise_or_(
                (((layer.w2_qweight.data >> shift_num) & 0xF) << (4 * i))
            )


w13_qweight_tmp.bitwise_xor_(0x88888888)
w2_qweight_tmp.bitwise_xor_(0x88888888)

w13_qzeros_tmp = torch.cat(w13_qzeros_list, dim=-1).reshape(
layer.w13_qzeros.shape[0], layer.w13_qzeros.shape[1], -1
)
w13_qzeros_tmp = -(w13_qzeros_tmp - 8)
w13_qzeros_tmp = w13_qzeros_tmp.to(layer.w13_scales.data.dtype)
w2_qzeros_tmp = torch.cat(w2_qzeros_list, dim=-1).reshape(
layer.w2_qzeros.shape[0], layer.w2_qzeros.shape[1], -1
)
w2_qzeros_tmp = -(w2_qzeros_tmp - 8)
w2_qzeros_tmp = w2_qzeros_tmp.to(layer.w2_scales.data.dtype)

layer.register_parameter(
"w13_qzeros", torch.nn.Parameter(w13_qzeros_tmp, requires_grad=False)
)
layer.register_parameter(
"w13_qweight", torch.nn.Parameter(w13_qweight_tmp, requires_grad=False)
)
layer.register_parameter(
"w2_qzeros", torch.nn.Parameter(w2_qzeros_tmp, requires_grad=False)
)
layer.register_parameter(
"w2_qweight", torch.nn.Parameter(w2_qweight_tmp, requires_grad=False)
)

def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config

def apply(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> torch.Tensor:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput

assert (
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."

x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output

topk_weights, topk_ids, _ = topk_output
topk_ids = topk_ids.to(torch.int32)
topk_weights = topk_weights.to(x.dtype)
output = npu_fused_experts(
hidden_states=x,
w13=layer.w13_qweight,
w13_scale=layer.w13_scales,
w13_offset=layer.w13_qzeros,
w2=layer.w2_qweight,
w2_scale=layer.w2_scales,
w2_offset=layer.w2_qzeros,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=topk_ids.shape[1],
use_wna16=True,
)
return StandardCombineInput(hidden_states=output)
29 changes: 29 additions & 0 deletions python/sglang/srt/layers/quantization/awq_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,32 @@ def awq_gemm_triton(
result = result.sum(0)

return result


def awq_dequantize_decomposition(
qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
) -> torch.Tensor:
qweight_tmp = qweight
qzeros_tmp = zeros
qweight_list = []
qzeros_list = []
shifts = [0, 4, 1, 5, 2, 6, 3, 7]
for i in range(0, 8):
shift_num = shifts[i] * 4
qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF)
qweight_list.append((qweight_tmp.reshape(-1, 1) >> shift_num) & 0xF)
qzeros_tmp = (
torch.cat(qzeros_list, dim=-1).reshape(qzeros_tmp.shape[0], -1).to(scales.dtype)
)
qweight_tmp = (
torch.cat(qweight_list, dim=-1)
.reshape(qweight_tmp.shape[0], -1)
.to(scales.dtype)
)
res = (
qweight_tmp.reshape(qzeros_tmp.shape[0], -1, qzeros_tmp.shape[1])
- qzeros_tmp.unsqueeze(1)
) * scales.unsqueeze(1)
return res.reshape(qweight_tmp.shape[0], -1)
34 changes: 28 additions & 6 deletions python/sglang/srt/layers/quantization/w8a8_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,12 @@ def npu_fused_experts(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
**kwargs,
):
w13_offset = kwargs.get("w13_offset", None)
w2_offset = kwargs.get("w2_offset", None)
use_wna16 = kwargs.get("use_wna16", False)

original_shape = hidden_states.shape
original_dtype = hidden_states.dtype
scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32
Expand All @@ -127,12 +132,22 @@ def npu_fused_experts(
)
expert_tokens = expert_tokens.to(torch.int64)
# gmm1: gate_up_proj
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
if not use_wna16:
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
scale_args13 = {
"scale": [w13_scale.to(scale_dtype)],
"per_token_scale": [pertoken_scale],
}
else:
scale_args13 = {
"antiquant_scale": [w13_scale],
"antiquant_offset": [w13_offset],
}

hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w13],
scale=[w13_scale.to(scale_dtype)],
per_token_scale=[pertoken_scale],
**scale_args13,
split_item=2,
group_list_type=0,
group_type=0,
Expand All @@ -141,13 +156,20 @@ def npu_fused_experts(
)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
if not use_wna16:
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)

scale_args2 = {
"scale": [w2_scale.to(scale_dtype)],
"per_token_scale": [pertoken_scale],
}
else:
scale_args2 = {"antiquant_scale": [w2_scale], "antiquant_offset": [w2_offset]}
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale.to(scale_dtype)],
per_token_scale=[pertoken_scale],
**scale_args2,
split_item=2,
group_list_type=0,
group_type=0,
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,8 @@ def load_weights_and_postprocess(model, weights, target_device):
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
if _is_npu:
torch.npu.empty_cache()


class LayeredModelLoader(DefaultModelLoader):
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@
import custom_ops # noqa: F401
import sgl_kernel_npu # noqa: F401
import torch_npu # noqa: F401

from sglang.srt.layers.quantization.awq_triton import (
awq_dequantize_decomposition as awq_dequantize,
)
else:
pass

Expand Down Expand Up @@ -2965,7 +2969,7 @@ def post_load_weights(self, is_nextn=False, weight_names=None):
)
if hasattr(self_attn.kv_b_proj, "qweight"):
# AWQ compatible
if _is_cuda or _is_hip:
if _is_cuda or _is_hip or _is_npu:
w = awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,8 @@ def get_available_gpu_memory(
f"WARNING: current device is not {gpu_id}, but {torch.npu.current_device()}, ",
"which may cause useless memory allocation for torch NPU context.",
)
if empty_cache:
torch.npu.empty_cache()
free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()

if distributed:
Expand Down
Loading