Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
711d217
Qwen3 MOE quick fix
eshoguli Nov 20, 2025
a5bae60
Merge branch 'ping1jing2:main' into memory_and_nz_fix
OrangeRedeng Dec 26, 2025
2b24ec3
Add nz support for MOE
OrangeRedeng Dec 26, 2025
5eda2b9
Merge branch 'ping1jing2:main' into memory_and_nz_fix
OrangeRedeng Dec 26, 2025
0195e52
Merge branch 'main' into memory_and_nz_fix
OrangeRedeng Dec 26, 2025
2ec286a
Update python/sglang/srt/layers/quantization/unquant.py
OrangeRedeng Dec 29, 2025
cecaea0
Merge branch 'main' into memory_and_nz_fix
OrangeRedeng Dec 29, 2025
d9a3818
Update unquant.py
OrangeRedeng Dec 29, 2025
1ad1ca1
Update unquant.py
OrangeRedeng Dec 29, 2025
e5484c9
Fix lint issue
OrangeRedeng Dec 29, 2025
153b9b7
Merge branch 'main' into memory_and_nz_fix
OrangeRedeng Jan 12, 2026
d898855
Merge branch 'main' into memory_and_nz_fix
OrangeRedeng Jan 13, 2026
25a0e56
Remove a non-used env ENABLE_ASCEND_MOE_NZ variable from ascend_npu_q…
OrangeRedeng Jan 13, 2026
61830a2
Remove a non-used env ENABLE_MOE_NZ variable from ascend_npu_qwen3_ex…
OrangeRedeng Jan 13, 2026
f586b40
Update NZ converison
OrangeRedeng Jan 13, 2026
ecfafa1
Merge branch 'main' into memory_and_nz_fix
OrangeRedeng Jan 13, 2026
4d38ade
Remove unnecessary function
OrangeRedeng Jan 13, 2026
2f4608d
Update layer.py
OrangeRedeng Jan 13, 2026
3699288
Update unquant.py
OrangeRedeng Jan 13, 2026
3092b31
Update layer.py
OrangeRedeng Jan 13, 2026
fe2aed7
Update layer.py
OrangeRedeng Jan 13, 2026
1054c9d
Update fused_moe_method_npu.py
OrangeRedeng Jan 13, 2026
da5158b
Update fused_moe_method_npu.py
OrangeRedeng Jan 14, 2026
0162b74
Update fused_moe_method_npu.py
OrangeRedeng Jan 14, 2026
019e2d6
Update fused_moe_method_npu.py
OrangeRedeng Jan 14, 2026
8018ee9
Merge branch 'main' into memory_and_nz_fix
OrangeRedeng Jan 14, 2026
d02b451
Create test_ascend_memory_consumption.py‎
OrangeRedeng Jan 14, 2026
312ad28
Merge branch 'main' into memory_and_nz_fix
OrangeRedeng Jan 15, 2026
edbba1b
Fix lint issue
OrangeRedeng Jan 15, 2026
fa13828
Fix lint issue
OrangeRedeng Jan 15, 2026
c78449b
Fix lint issue
OrangeRedeng Jan 15, 2026
5b0b787
Merge branch 'main' into memory_and_nz_fix
OrangeRedeng Jan 16, 2026
5888bd6
Update fused_moe_method_npu.py
OrangeRedeng Jan 16, 2026
a2b332d
Merge branch 'main' into memory_and_nz_fix
OrangeRedeng Jan 16, 2026
ab233ad
Update test_ascend_memory_consumption.py‎
OrangeRedeng Jan 16, 2026
5ceeab1
Update run_suite.py
OrangeRedeng Jan 16, 2026
5cfa59c
Merge branch 'main' into memory_and_nz_fix
OrangeRedeng Jan 16, 2026
0d1cfd1
Merge branch 'main' into memory_and_nz_fix
OrangeRedeng Jan 19, 2026
a3d2798
Merge branch 'main' into memory_and_nz_fix
OrangeRedeng Jan 20, 2026
929e03b
Merge branch 'main' into memory_and_nz_fix
OrangeRedeng Jan 22, 2026
f5462d1
Merge branch 'main' into memory_and_nz_fix
OrangeRedeng Jan 26, 2026
b8d8285
Move transpose(1,2) from forward_npu() to process_weights
OrangeRedeng Jan 26, 2026
a80de0b
Quickfix
OrangeRedeng Jan 26, 2026
495ee00
Merge branch 'main' into memory_and_nz_fix
iforgetmyname Jan 27, 2026
47a5d8a
Delete test/srt/ascend/test_ascend_memory_consumption.py‎
OrangeRedeng Jan 27, 2026
87d6963
Rename test_ascend_memory_consumption.py‎
OrangeRedeng Jan 27, 2026
105e063
Delete test/srt/ascend/test_ascend_memory_consumption.py‎
OrangeRedeng Jan 27, 2026
c501e4e
Add test_ascend_memory_consumption.py
OrangeRedeng Jan 27, 2026
4824eb4
Update run_suite.py
OrangeRedeng Jan 27, 2026
f398506
Move test to test/registered
OrangeRedeng Jan 27, 2026
bacb1ee
Move test to test/registered
OrangeRedeng Jan 27, 2026
5929c9b
Delete test/srt/ascend/test_ascend_memory_consumption.py
OrangeRedeng Jan 27, 2026
8fbe3a1
Fix lint issue
OrangeRedeng Jan 27, 2026
f15c406
Merge branch 'main' into memory_and_nz_fix
OrangeRedeng Jan 27, 2026
b08d02c
Merge branch 'main' into memory_and_nz_fix
OrangeRedeng Jan 27, 2026
d33144c
Merge branch 'main' into memory_and_nz_fix
iforgetmyname Jan 28, 2026
57fb4a1
Merge branch 'main' into memory_and_nz_fix
iforgetmyname Jan 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/platforms/ascend_npu_deepseek_example.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ export SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1
#npu acceleration operator
export SGLANG_NPU_USE_MLAPO=1
export SGLANG_USE_FIA_NZ=1
export ENABLE_MOE_NZ=1

python3 -m sglang.launch_server \
--model-path ${MODEL_PATH} \
Expand Down Expand Up @@ -71,7 +70,6 @@ export HCCL_BUFFSIZE=1536
#npu acceleration operator
export SGLANG_NPU_USE_MLAPO=1
export SGLANG_USE_FIA_NZ=1
export ENABLE_MOE_NZ=1
export TASK_QUEUE_ENABLE=2

python -m sglang.launch_server \
Expand Down
2 changes: 0 additions & 2 deletions docs/platforms/ascend_npu_qwen3_examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ export HCCL_BUFFSIZE=1536
export HCCL_OP_EXPANSION_MODE=AIV
export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=32
export SGLANG_DEEPEP_BF16_DISPATCH=1
export ENABLE_ASCEND_MOE_NZ=1

python -m sglang.launch_server \
--device npu \
Expand All @@ -84,7 +83,6 @@ export STREAMS_PER_DEVICE=32
export HCCL_BUFFSIZE=1536
export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=32
export SGLANG_DEEPEP_BF16_DISPATCH=1
export ENABLE_ASCEND_MOE_NZ=1

python -m sglang.launch_server \
--model-path Qwen/Qwen3-235B-A22B-Instruct-2507 \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,42 +150,27 @@ def __init__(

class NPUW8A8Int8DynamicMoEMethod(_NPUFusedMoEMethodBase):

def _release_weight_cache(self, weight: torch.Tensor):
# .contiguous() introduces additional memory overhead and needs to be released using resize_(0)
origin_weight = weight.data.transpose(1, 2)
new_weight = origin_weight.contiguous()
origin_weight.untyped_storage().resize_(0)
return new_weight

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight_data = self._release_weight_cache(layer.w13_weight.data)
layer.w13_weight = torch.nn.Parameter(weight_data, requires_grad=False)

weight_data = self._release_weight_cache(layer.w2_weight.data)
layer.w2_weight = torch.nn.Parameter(weight_data, requires_grad=False)

layer.w13_weight.data = npu_format_cast(layer.w13_weight.data.transpose(1, 2))
layer.w2_weight.data = npu_format_cast(layer.w2_weight.data.transpose(1, 2))
layer.w13_weight_scale = torch.nn.Parameter(
layer.w13_weight_scale.data.squeeze(-1).contiguous().to(torch.float32),
requires_grad=False,
layer.w13_weight_scale.data.squeeze(-1), requires_grad=False
)
layer.w2_weight_scale = torch.nn.Parameter(
layer.w2_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False
layer.w2_weight_scale.data.squeeze(-1), requires_grad=False
)
# Compressed-tensors format doesn't have this field
if hasattr(layer, "w13_weight_offset"):
layer.w13_weight_offset = torch.nn.Parameter(
layer.w13_weight_offset.data.squeeze(-1).contiguous(),
layer.w13_weight_offset.data.squeeze(-1),
requires_grad=False,
)
if hasattr(layer, "w2_weight_offset"):
layer.w2_weight_offset = torch.nn.Parameter(
layer.w2_weight_offset.data.squeeze(-1).contiguous(),
layer.w2_weight_offset.data.squeeze(-1),
requires_grad=False,
)

layer.w13_weight.data = npu_format_cast(layer.w13_weight.data)
layer.w2_weight.data = npu_format_cast(layer.w2_weight.data)

def apply(
self,
layer,
Expand Down
16 changes: 4 additions & 12 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph
from sglang.srt.environ import envs
from sglang.srt.hardware_backend.npu.utils import npu_format_cast
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe import (
get_deepep_mode,
Expand Down Expand Up @@ -472,13 +473,6 @@ def forward(
gmm2_weight_scale=self.w2_weight_scale,
).hidden_state

def release_weight_cache(self, weight: torch.Tensor):
# .contiguous() introduces additional memory overhead and needs to be released using resize_(0)
origin_weight = weight.data.transpose(1, 2)
new_weight = origin_weight.contiguous()
origin_weight.untyped_storage().resize_(0)
return new_weight

def permute_w13_weight_scale(self, w: torch.Tensor, tile_n: int):
if tile_n % 2 != 0:
raise ValueError(f"tile_n must be even, got {tile_n}")
Expand Down Expand Up @@ -520,14 +514,12 @@ def reshape_w13_weight(self, weight: torch.Tensor, dim: int, chunk_size: int = 6
return weight.view(*original_shape[:dim], -1, *original_shape[dim + 1 :])

def _process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w13 = self.release_weight_cache(layer.w13_weight)
torch_npu.npu_format_cast_(w13, 2)
cpu_w13 = w13.cpu()
cpu_w13 = layer.w13_weight.transpose(1, 2).cpu()
w13 = self.reshape_w13_weight(cpu_w13, -1).npu()
torch_npu.npu_format_cast_(w13, 29)
w13 = npu_format_cast(w13)
layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False)

w2 = torch_npu.npu_format_cast(layer.w2_weight.data, 29)
w2 = npu_format_cast(layer.w2_weight)
layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False)

w13_scale = layer.w13_weight_scale.data.squeeze(-1).contiguous()
Expand Down
20 changes: 15 additions & 5 deletions python/sglang/srt/layers/quantization/unquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_bool_env_var,
is_cpu,
is_hip,
is_npu,
next_power_of_2,
set_weight_attrs,
use_intel_amx_backend,
Expand All @@ -40,13 +41,17 @@
_is_cpu_amx_available = cpu_has_amx_support()
_is_hip = is_hip()
_is_cpu = is_cpu()
_is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip

if _use_aiter:
from aiter import ActivationType
from aiter.fused_moe import fused_moe
from aiter.ops.shuffle import shuffle_weight

if _is_npu:
from sglang.srt.hardware_backend.npu.utils import npu_format_cast

try:
from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe
except ImportError:
Expand Down Expand Up @@ -296,6 +301,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.num_local_experts, *new_shape_w2
)

if _is_npu:
for weight_name in ["w13_weight", "w2_weight"]:
weight = getattr(layer, weight_name)
weight.data = weight.data.transpose(1, 2)
weight.data = npu_format_cast(
weight.data,
)
Comment thread
OrangeRedeng marked this conversation as resolved.

return

def create_moe_runner(
Expand Down Expand Up @@ -494,14 +507,11 @@ def forward_npu(
expert_tokens = expert_tokens.to(torch.int64)
w13_bias = [layer.w13_weight_bias] if self.with_bias else None
w2_bias = [layer.w2_weight_bias] if self.with_bias else None
if layer.w13_weight.shape[-1] == layer.hidden_size:
w13 = layer.w13_weight.transpose(1, 2)
w2 = layer.w2_weight.transpose(1, 2)

# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w13],
weight=[layer.w13_weight],
bias=w13_bias,
split_item=2,
group_list_type=0,
Expand All @@ -525,7 +535,7 @@ def forward_npu(
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
weight=[layer.w2_weight],
bias=w2_bias,
split_item=2,
group_list_type=0,
Expand Down
17 changes: 10 additions & 7 deletions python/sglang/srt/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
)
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import (
LazyValue,
add_prefix,
is_cuda,
is_flashinfer_available,
Expand Down Expand Up @@ -1119,14 +1120,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
else:
logger.warning(f"Parameter {name} not found in params_dict")

# TODO mimic deepseek
# Lazy initialization of expert weights cache to avoid slowing down load_weights
if not hasattr(self, "routed_experts_weights_of_layer"):
self.routed_experts_weights_of_layer = {
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
for layer_id in range(self.start_layer, self.end_layer)
if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
}
self.routed_experts_weights_of_layer = LazyValue(
lambda: {
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
for layer_id in range(self.start_layer, self.end_layer)
if isinstance(
self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock
)
}
)

@classmethod
def get_model_config_for_expert_location(cls, config):
Expand Down
76 changes: 76 additions & 0 deletions test/registered/ascend/test_ascend_memory_consumption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
Usage:
python3 -m unittest test_ascend_memory_consumption.TestMemoryConsumptionAscend.test_memory_consumption
"""

import os
import unittest

import torch

from sglang.srt.utils import kill_process_tree
from sglang.test.ci.ci_register import register_npu_ci
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)

register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True)

if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ:
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
8000 + int(os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "0")[0]) * 100
)
DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}"


class TestMemoryConsumptionAscend(CustomTestCase):

def test_memory_consumption(self):

model = "nytopop/Qwen3-30B-A3B.w8a8"
base_url = DEFAULT_URL_FOR_TEST

### Calculate initial used memory
free_npu_memory, total_npu_memory = torch.npu.mem_get_info()
initial_used_memory = total_npu_memory - free_npu_memory

process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--device",
"npu",
"--attention-backend",
"ascend",
"--tp-size",
"2",
"--mem-fraction-static",
"0.8",
"--cuda-graph-bs",
"1",
"--max-total-tokens",
"1024",
"--disable-radix-cache",
"--disable-cuda-graph",
],
)

### Calculate initial used memory
free_npu_memory, total_npu_memory = torch.npu.mem_get_info()
used_memory_after_server_starting = (
total_npu_memory - free_npu_memory - initial_used_memory
) / (1 << 30)
self.assertLessEqual(float(used_memory_after_server_starting), 16.00)

# Clean up everything
kill_process_tree(process.pid)


if __name__ == "__main__":
unittest.main()
Loading