diff --git a/docs/platforms/ascend_npu_deepseek_example.md b/docs/platforms/ascend_npu_deepseek_example.md index d0b207f18586..45b608b26a8c 100644 --- a/docs/platforms/ascend_npu_deepseek_example.md +++ b/docs/platforms/ascend_npu_deepseek_example.md @@ -22,7 +22,6 @@ export SGLANG_ENABLE_OVERLAP_PLAN_STREAM=1 #npu acceleration operator export SGLANG_NPU_USE_MLAPO=1 export SGLANG_USE_FIA_NZ=1 -export ENABLE_MOE_NZ=1 python3 -m sglang.launch_server \ --model-path ${MODEL_PATH} \ @@ -71,7 +70,6 @@ export HCCL_BUFFSIZE=1536 #npu acceleration operator export SGLANG_NPU_USE_MLAPO=1 export SGLANG_USE_FIA_NZ=1 -export ENABLE_MOE_NZ=1 export TASK_QUEUE_ENABLE=2 python -m sglang.launch_server \ diff --git a/docs/platforms/ascend_npu_qwen3_examples.md b/docs/platforms/ascend_npu_qwen3_examples.md index 958ad8c97398..5278a22a1001 100644 --- a/docs/platforms/ascend_npu_qwen3_examples.md +++ b/docs/platforms/ascend_npu_qwen3_examples.md @@ -62,7 +62,6 @@ export HCCL_BUFFSIZE=1536 export HCCL_OP_EXPANSION_MODE=AIV export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=32 export SGLANG_DEEPEP_BF16_DISPATCH=1 -export ENABLE_ASCEND_MOE_NZ=1 python -m sglang.launch_server \ --device npu \ @@ -84,7 +83,6 @@ export STREAMS_PER_DEVICE=32 export HCCL_BUFFSIZE=1536 export SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK=32 export SGLANG_DEEPEP_BF16_DISPATCH=1 -export ENABLE_ASCEND_MOE_NZ=1 python -m sglang.launch_server \ --model-path Qwen/Qwen3-235B-A22B-Instruct-2507 \ diff --git a/python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py index 91a5da075807..b3bd7c2155e6 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/fused_moe_method_npu.py @@ -150,42 +150,27 @@ def __init__( class NPUW8A8Int8DynamicMoEMethod(_NPUFusedMoEMethodBase): - def _release_weight_cache(self, weight: torch.Tensor): - # .contiguous() introduces additional memory overhead and needs to be released using resize_(0) - origin_weight = weight.data.transpose(1, 2) - new_weight = origin_weight.contiguous() - origin_weight.untyped_storage().resize_(0) - return new_weight - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - weight_data = self._release_weight_cache(layer.w13_weight.data) - layer.w13_weight = torch.nn.Parameter(weight_data, requires_grad=False) - - weight_data = self._release_weight_cache(layer.w2_weight.data) - layer.w2_weight = torch.nn.Parameter(weight_data, requires_grad=False) - + layer.w13_weight.data = npu_format_cast(layer.w13_weight.data.transpose(1, 2)) + layer.w2_weight.data = npu_format_cast(layer.w2_weight.data.transpose(1, 2)) layer.w13_weight_scale = torch.nn.Parameter( - layer.w13_weight_scale.data.squeeze(-1).contiguous().to(torch.float32), - requires_grad=False, + layer.w13_weight_scale.data.squeeze(-1), requires_grad=False ) layer.w2_weight_scale = torch.nn.Parameter( - layer.w2_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False + layer.w2_weight_scale.data.squeeze(-1), requires_grad=False ) # Compressed-tensors format doesn't have this field if hasattr(layer, "w13_weight_offset"): layer.w13_weight_offset = torch.nn.Parameter( - layer.w13_weight_offset.data.squeeze(-1).contiguous(), + layer.w13_weight_offset.data.squeeze(-1), requires_grad=False, ) if hasattr(layer, "w2_weight_offset"): layer.w2_weight_offset = torch.nn.Parameter( - layer.w2_weight_offset.data.squeeze(-1).contiguous(), + layer.w2_weight_offset.data.squeeze(-1), requires_grad=False, ) - layer.w13_weight.data = npu_format_cast(layer.w13_weight.data) - layer.w2_weight.data = npu_format_cast(layer.w2_weight.data) - def apply( self, layer, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 552346cb0f6b..7e8bb33ca70a 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -7,6 +7,7 @@ from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph from sglang.srt.environ import envs +from sglang.srt.hardware_backend.npu.utils import npu_format_cast from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe import ( get_deepep_mode, @@ -472,13 +473,6 @@ def forward( gmm2_weight_scale=self.w2_weight_scale, ).hidden_state - def release_weight_cache(self, weight: torch.Tensor): - # .contiguous() introduces additional memory overhead and needs to be released using resize_(0) - origin_weight = weight.data.transpose(1, 2) - new_weight = origin_weight.contiguous() - origin_weight.untyped_storage().resize_(0) - return new_weight - def permute_w13_weight_scale(self, w: torch.Tensor, tile_n: int): if tile_n % 2 != 0: raise ValueError(f"tile_n must be even, got {tile_n}") @@ -520,14 +514,12 @@ def reshape_w13_weight(self, weight: torch.Tensor, dim: int, chunk_size: int = 6 return weight.view(*original_shape[:dim], -1, *original_shape[dim + 1 :]) def _process_weights_after_loading(self, layer: torch.nn.Module) -> None: - w13 = self.release_weight_cache(layer.w13_weight) - torch_npu.npu_format_cast_(w13, 2) - cpu_w13 = w13.cpu() + cpu_w13 = layer.w13_weight.transpose(1, 2).cpu() w13 = self.reshape_w13_weight(cpu_w13, -1).npu() - torch_npu.npu_format_cast_(w13, 29) + w13 = npu_format_cast(w13) layer.w13_weight = torch.nn.Parameter(w13, requires_grad=False) - w2 = torch_npu.npu_format_cast(layer.w2_weight.data, 29) + w2 = npu_format_cast(layer.w2_weight) layer.w2_weight = torch.nn.Parameter(w2, requires_grad=False) w13_scale = layer.w13_weight_scale.data.squeeze(-1).contiguous() diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index b7c052c016e2..628fadbd166e 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -25,6 +25,7 @@ get_bool_env_var, is_cpu, is_hip, + is_npu, next_power_of_2, set_weight_attrs, use_intel_amx_backend, @@ -40,6 +41,7 @@ _is_cpu_amx_available = cpu_has_amx_support() _is_hip = is_hip() _is_cpu = is_cpu() +_is_npu = is_npu() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _use_aiter: @@ -47,6 +49,9 @@ from aiter.fused_moe import fused_moe from aiter.ops.shuffle import shuffle_weight +if _is_npu: + from sglang.srt.hardware_backend.npu.utils import npu_format_cast + try: from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe except ImportError: @@ -296,6 +301,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.num_local_experts, *new_shape_w2 ) + if _is_npu: + for weight_name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, weight_name) + weight.data = weight.data.transpose(1, 2) + weight.data = npu_format_cast( + weight.data, + ) + return def create_moe_runner( @@ -494,14 +507,11 @@ def forward_npu( expert_tokens = expert_tokens.to(torch.int64) w13_bias = [layer.w13_weight_bias] if self.with_bias else None w2_bias = [layer.w2_weight_bias] if self.with_bias else None - if layer.w13_weight.shape[-1] == layer.hidden_size: - w13 = layer.w13_weight.transpose(1, 2) - w2 = layer.w2_weight.transpose(1, 2) # gmm1: gate_up_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], - weight=[w13], + weight=[layer.w13_weight], bias=w13_bias, split_item=2, group_list_type=0, @@ -525,7 +535,7 @@ def forward_npu( # gmm2: down_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], - weight=[w2], + weight=[layer.w2_weight], bias=w2_bias, split_item=2, group_list_type=0, diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 7085ff68e513..437c4fa3e09f 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -71,6 +71,7 @@ ) from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( + LazyValue, add_prefix, is_cuda, is_flashinfer_available, @@ -1119,14 +1120,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: logger.warning(f"Parameter {name} not found in params_dict") - # TODO mimic deepseek - # Lazy initialization of expert weights cache to avoid slowing down load_weights if not hasattr(self, "routed_experts_weights_of_layer"): - self.routed_experts_weights_of_layer = { - layer_id: self.model.layers[layer_id].mlp.get_moe_weights() - for layer_id in range(self.start_layer, self.end_layer) - if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock) - } + self.routed_experts_weights_of_layer = LazyValue( + lambda: { + layer_id: self.model.layers[layer_id].mlp.get_moe_weights() + for layer_id in range(self.start_layer, self.end_layer) + if isinstance( + self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock + ) + } + ) @classmethod def get_model_config_for_expert_location(cls, config): diff --git a/test/registered/ascend/test_ascend_memory_consumption.py b/test/registered/ascend/test_ascend_memory_consumption.py new file mode 100644 index 000000000000..2e6b09524476 --- /dev/null +++ b/test/registered/ascend/test_ascend_memory_consumption.py @@ -0,0 +1,76 @@ +""" +Usage: +python3 -m unittest test_ascend_memory_consumption.TestMemoryConsumptionAscend.test_memory_consumption +""" + +import os +import unittest + +import torch + +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_npu_ci +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +register_npu_ci(est_time=400, suite="nightly-1-npu-a3", nightly=True) + +if "ASCEND_RT_VISIBLE_DEVICES" not in os.environ: + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1" +DEFAULT_PORT_FOR_SRT_TEST_RUNNER = ( + 8000 + int(os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "0")[0]) * 100 +) +DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 1000}" + + +class TestMemoryConsumptionAscend(CustomTestCase): + + def test_memory_consumption(self): + + model = "nytopop/Qwen3-30B-A3B.w8a8" + base_url = DEFAULT_URL_FOR_TEST + + ### Calculate initial used memory + free_npu_memory, total_npu_memory = torch.npu.mem_get_info() + initial_used_memory = total_npu_memory - free_npu_memory + + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--device", + "npu", + "--attention-backend", + "ascend", + "--tp-size", + "2", + "--mem-fraction-static", + "0.8", + "--cuda-graph-bs", + "1", + "--max-total-tokens", + "1024", + "--disable-radix-cache", + "--disable-cuda-graph", + ], + ) + + ### Calculate initial used memory + free_npu_memory, total_npu_memory = torch.npu.mem_get_info() + used_memory_after_server_starting = ( + total_npu_memory - free_npu_memory - initial_used_memory + ) / (1 << 30) + self.assertLessEqual(float(used_memory_after_server_starting), 16.00) + + # Clean up everything + kill_process_tree(process.pid) + + +if __name__ == "__main__": + unittest.main()