From 4ce70e6db9fc5aa448483314554c95f72169fe0d Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Wed, 17 Sep 2025 17:03:02 +0300 Subject: [PATCH 01/71] NPU Graph Compilation support and PassManager with AddRmsNorm & Quantize fuse --- python/sglang/srt/_custom_ops.py | 42 +++- python/sglang/srt/configs/model_config.py | 3 + .../srt/layers/attention/ascend_backend.py | 232 ++++++++++-------- python/sglang/srt/layers/communicator.py | 5 +- .../srt/layers/quantization/w8a8_int8.py | 41 +++- python/sglang/srt/layers/rotary_embedding.py | 36 ++- .../model_executor/compilation/custom_ops.py | 42 ++++ .../compilation/npu_compiler_backend.py | 42 ++++ .../compilation/npu_graph_compiler.py | 27 ++ .../compilation/pass_manager.py | 41 ++++ .../compilation/passes/w8a8_int8/div_fuse.py | 23 ++ .../passes/w8a8_int8/erase_copy.py | 70 ++++++ .../w8a8_int8/npu_add_rms_norm_quant_fuse.py | 36 +++ .../compilation/patch_dynamo.py | 54 ++++ .../srt/model_executor/cuda_graph_runner.py | 17 +- .../srt/model_executor/npu_graph_runner.py | 130 ++++++++-- .../sglang/srt/model_loader/weight_utils.py | 2 + python/sglang/srt/models/qwen3.py | 20 +- python/sglang/srt/utils/common.py | 4 +- 19 files changed, 700 insertions(+), 167 deletions(-) create mode 100644 python/sglang/srt/model_executor/compilation/custom_ops.py create mode 100644 python/sglang/srt/model_executor/compilation/npu_compiler_backend.py create mode 100644 python/sglang/srt/model_executor/compilation/npu_graph_compiler.py create mode 100644 python/sglang/srt/model_executor/compilation/pass_manager.py create mode 100644 python/sglang/srt/model_executor/compilation/passes/w8a8_int8/div_fuse.py create mode 100644 python/sglang/srt/model_executor/compilation/passes/w8a8_int8/erase_copy.py create mode 100644 python/sglang/srt/model_executor/compilation/passes/w8a8_int8/npu_add_rms_norm_quant_fuse.py create mode 100644 python/sglang/srt/model_executor/compilation/patch_dynamo.py diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index de47707c18a6..4201f5729835 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -4,13 +4,53 @@ import torch -from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu +from sglang.srt.utils import ( + direct_register_custom_op, + get_bool_env_var, + get_cmo_stream, + is_hip, + is_hpu, + is_npu, +) logger = logging.getLogger(__name__) use_vllm_custom_allreduce = get_bool_env_var( "USE_VLLM_CUSTOM_ALLREDUCE", default="false" ) + +import sglang.srt.utils + + +@torch.library.custom_op("sglang::wait_cmo_stream", mutates_args=()) +def wait_cmo_stream() -> None: + if is_npu() and get_cmo_stream(): + sglang.srt.utils.wait_cmo_stream() + + +@wait_cmo_stream.register_fake +def wait_cmo_stream_fake() -> None: + pass + + +def prepare_weight_cache(handle: torch.Tensor, cache: List[torch.Tensor]) -> None: + sglang.srt.utils.prepare_weight_cache(handle, cache) + + +def prepare_weight_cache_register_fake( + handle: torch.Tensor, cache: List[torch.Tensor] +) -> None: + pass + + +direct_register_custom_op( + op_name="prepare_weight_cache", + op_func=prepare_weight_cache, + mutates_args=["handle"], + fake_impl=prepare_weight_cache_register_fake, +) + + if not is_hpu(): # ROCm does not use vllm custom allreduce if use_vllm_custom_allreduce and not is_hip(): diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 285f7e554fed..1b1b375e846c 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -97,6 +97,7 @@ def __init__( model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, sampling_defaults: str = "openai", quantize_and_serve: bool = False, + enable_torch_compile: bool = False, ) -> None: # Parse args self.model_path = model_path @@ -106,6 +107,7 @@ def __init__( self.model_impl = model_impl self.sampling_defaults = sampling_defaults self.quantize_and_serve = quantize_and_serve + self.enable_torch_compile = enable_torch_compile # Validate quantize_and_serve configuration self._validate_quantize_and_serve_config() @@ -234,6 +236,7 @@ def from_server_args( model_impl=server_args.model_impl, sampling_defaults=server_args.sampling_defaults, quantize_and_serve=server_args.quantize_and_serve, + enable_torch_compile=server_args.enable_torch_compile, **kwargs, ) diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index 82526f0e875c..94de5a02702f 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -74,6 +74,7 @@ def update_verify_buffers_to_fill_after_draft( def __init__(self, model_runner: ModelRunner): super().__init__() + self.enable_torch_compile = False self.forward_metadata = None self.device = model_runner.device self.page_size = model_runner.page_size @@ -576,112 +577,151 @@ def forward_decode_graph( layer, forward_batch.out_cache_loc, k, v ) - if not self.use_mla: - k_cache = forward_batch.token_to_kv_pool.get_key_buffer( - layer.layer_id - ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim) - v_cache = forward_batch.token_to_kv_pool.get_value_buffer( - layer.layer_id - ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim) - query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim) - if self.forward_metadata.seq_lens_cpu_int is None: - actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list - else: - actual_seq_len_kv = ( - self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist() - ) + if not self.use_mla and self.enable_torch_compile: + k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) + v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) + query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim) num_tokens = query.shape[0] - workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( - query, - k_cache, - v_cache, - block_table=self.forward_metadata.block_tables, - block_size=self.page_size, - num_heads=layer.tp_q_head_num, - num_key_value_heads=layer.tp_k_head_num, - input_layout="BSH", - scale=layer.scaling, - actual_seq_lengths_kv=actual_seq_len_kv, - ) - output = torch.empty( - (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim), - dtype=q.dtype, - device=q.device, - ) - softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device) - torch_npu.npu_fused_infer_attention_score.out( - query, - k_cache, - v_cache, - block_table=self.forward_metadata.block_tables, - block_size=self.page_size, - num_heads=layer.tp_q_head_num, - num_key_value_heads=layer.tp_k_head_num, - input_layout="BSH", - scale=layer.scaling, - actual_seq_lengths_kv=actual_seq_len_kv, - workspace=workspace, - out=[output, softmax_lse], - ) - return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim) - else: - c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) - k_rope_cache = k_rope.view( - -1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim - ) - c_kv_cache = c_kv.view( - -1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank + attn_output = torch.empty( + (num_tokens, layer.tp_q_head_num, layer.v_head_dim), + dtype=query.dtype, + device=query.device, ) - q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank).contiguous() - q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim) if self.forward_metadata.seq_lens_cpu_int is None: - actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list - else: - actual_seq_len_kv = ( - self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist() + actual_seq_len_kv = torch.from_numpy( + np.array(self.forward_metadata.seq_lens_cpu_list).astype(np.int32) ) + else: + actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_int - workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( - q_nope, - c_kv_cache, - c_kv_cache, - query_rope=q_rope, - key_rope=k_rope_cache, + torch_npu._npu_paged_attention( + query=query, + key_cache=k_cache, + value_cache=v_cache, num_heads=layer.tp_q_head_num, - num_key_value_heads=layer.tp_k_head_num, + num_kv_heads=layer.tp_k_head_num, + scale_value=layer.scaling, block_table=self.forward_metadata.block_tables, - block_size=self.page_size, - input_layout="BNSD", - scale=layer.scaling, - actual_seq_lengths_kv=actual_seq_len_kv, - antiquant_mode=0, - antiquant_scale=None, - sparse_mode=0, + context_lens=actual_seq_len_kv, + out=attn_output, ) - output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device) - softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device) + return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim) + else: + if not self.use_mla: + k_cache = forward_batch.token_to_kv_pool.get_key_buffer( + layer.layer_id + ).view(-1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim) + v_cache = forward_batch.token_to_kv_pool.get_value_buffer( + layer.layer_id + ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim) + query = q.reshape(-1, 1, layer.tp_q_head_num * layer.qk_head_dim) + if self.forward_metadata.seq_lens_cpu_int is None: + actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list + else: + actual_seq_len_kv = ( + self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist() + ) + num_tokens = query.shape[0] + workspace = ( + torch_npu._npu_fused_infer_attention_score_get_max_workspace( + query, + k_cache, + v_cache, + block_table=self.forward_metadata.block_tables, + block_size=self.page_size, + num_heads=layer.tp_q_head_num, + num_key_value_heads=layer.tp_k_head_num, + input_layout="BSH", + scale=layer.scaling, + actual_seq_lengths_kv=actual_seq_len_kv, + ) + ) + output = torch.empty( + (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim), + dtype=q.dtype, + device=q.device, + ) + softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device) + torch_npu.npu_fused_infer_attention_score.out( + query, + k_cache, + v_cache, + block_table=self.forward_metadata.block_tables, + block_size=self.page_size, + num_heads=layer.tp_q_head_num, + num_key_value_heads=layer.tp_k_head_num, + input_layout="BSH", + scale=layer.scaling, + actual_seq_lengths_kv=actual_seq_len_kv, + workspace=workspace, + out=[output, softmax_lse], + ) + return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim) + else: + c_kv, k_rope = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + k_rope_cache = k_rope.view( + -1, layer.tp_k_head_num, self.page_size, self.qk_rope_head_dim + ) + c_kv_cache = c_kv.view( + -1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank + ) - torch_npu.npu_fused_infer_attention_score.out( - q_nope, - c_kv_cache, - c_kv_cache, - query_rope=q_rope, - key_rope=k_rope_cache, - num_heads=layer.tp_q_head_num, - num_key_value_heads=layer.tp_k_head_num, - block_table=self.forward_metadata.block_tables, - block_size=self.page_size, - input_layout="BNSD", - scale=layer.scaling, - actual_seq_lengths_kv=actual_seq_len_kv, - antiquant_mode=0, - antiquant_scale=None, - sparse_mode=0, - workspace=workspace, - out=[output, softmax_lse], - ) - return output.view(-1, layer.tp_q_head_num * self.kv_lora_rank) + q_nope = q.view( + -1, layer.tp_q_head_num, 1, self.kv_lora_rank + ).contiguous() + q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim) + if self.forward_metadata.seq_lens_cpu_int is None: + actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list + else: + actual_seq_len_kv = ( + self.forward_metadata.seq_lens_cpu_int.cpu().int().tolist() + ) + + workspace = ( + torch_npu._npu_fused_infer_attention_score_get_max_workspace( + q_nope, + c_kv_cache, + c_kv_cache, + query_rope=q_rope, + key_rope=k_rope_cache, + num_heads=layer.tp_q_head_num, + num_key_value_heads=layer.tp_k_head_num, + block_table=self.forward_metadata.block_tables, + block_size=self.page_size, + input_layout="BNSD", + scale=layer.scaling, + actual_seq_lengths_kv=actual_seq_len_kv, + antiquant_mode=0, + antiquant_scale=None, + sparse_mode=0, + ) + ) + output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device) + softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device) + + torch_npu.npu_fused_infer_attention_score.out( + q_nope, + c_kv_cache, + c_kv_cache, + query_rope=q_rope, + key_rope=k_rope_cache, + num_heads=layer.tp_q_head_num, + num_key_value_heads=layer.tp_k_head_num, + block_table=self.forward_metadata.block_tables, + block_size=self.page_size, + input_layout="BNSD", + scale=layer.scaling, + actual_seq_lengths_kv=actual_seq_len_kv, + antiquant_mode=0, + antiquant_scale=None, + sparse_mode=0, + workspace=workspace, + out=[output, softmax_lse], + ) + return output.view(-1, layer.tp_q_head_num * self.kv_lora_rank) def forward_decode( self, diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index eaeeac51def2..5799da2d0f85 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 SGLang Team +# Copyright 2023-2025 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -51,7 +51,6 @@ is_hip, is_sm90_supported, is_sm100_supported, - prepare_weight_cache, ) _is_flashinfer_available = is_flashinfer_available() @@ -567,7 +566,7 @@ def _gather_hidden_states_and_residual( else: hidden_states = tensor_model_parallel_all_reduce(hidden_states) if context.cache is not None: - _ = prepare_weight_cache(hidden_states, context.cache) + torch.ops.sglang.prepare_weight_cache(hidden_states, context.cache) hidden_states, residual = layernorm(hidden_states, residual) return hidden_states, residual diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 5ceba2f67b6c..59b5e7568137 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -8,6 +8,7 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading +from sglang.srt.layers.linear import MergedColumnParallelLinear, QKVParallelLinear from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo from sglang.srt.layers.parameter import ( @@ -200,6 +201,7 @@ class W8A8Int8Config(QuantizationConfig): def __init__(self, quant_config: Dict[str, Any] = {}): super().__init__() + self.enable_torch_compile = quant_config.get("enable_torch_compile", False) self.quant_description = quant_config self.is_dynamic = quant_config.get("is_dynamic", False) ignore = cast(List[str], quant_config.get("ignore", [])) @@ -568,7 +570,10 @@ def apply( class NPU_W8A8LinearMethodImpl: """Linear method for NPU W8A8.""" - def __init__(self) -> None: + quant_config = None + + def __init__(self, quant_config) -> None: + NPU_W8A8LinearMethodImpl.quant_config = quant_config # aclnn quant matmul requires to transpose matrix B, set to true by default. self.transpose_weight = True @@ -614,9 +619,16 @@ def apply( original_dtype = x.dtype if original_dtype != torch.int8: + aclnn_input_scale_reciprocal = layer.aclnn_input_scale_reciprocal + if NPU_W8A8LinearMethodImpl.quant_config.enable_torch_compile and ( + isinstance(layer, MergedColumnParallelLinear) + or isinstance(layer, QKVParallelLinear) + ): + aclnn_input_scale_reciprocal = 1.0 / aclnn_input_scale_reciprocal + x = torch_npu.npu_quantize( x, - layer.aclnn_input_scale_reciprocal, + aclnn_input_scale_reciprocal, layer.aclnn_input_offset, torch.qint8, -1, @@ -633,7 +645,7 @@ def apply( layer.weight, layer.deq_scale, bias=quant_bias, - output_dtype=original_dtype, + output_dtype=layer.params_dtype, ) def process_weights_after_loading(self, layer): @@ -642,10 +654,23 @@ def process_weights_after_loading(self, layer): layer.input_scale.data.repeat(expanding_factor).to(device="npu"), requires_grad=False, ) - layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter( - layer.input_scale.data.repeat(expanding_factor).to(device="npu"), - requires_grad=False, - ) + prev_layer_fuse_reciprocal = isinstance( + layer, MergedColumnParallelLinear + ) or isinstance(layer, QKVParallelLinear) + if ( + NPU_W8A8LinearMethodImpl.quant_config.enable_torch_compile + and prev_layer_fuse_reciprocal + ): + layer.aclnn_input_scale_reciprocal = torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor).to(device="npu"), + requires_grad=False, + ) + else: + layer.aclnn_input_scale_reciprocal = 1.0 / torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor).to(device="npu"), + requires_grad=False, + ) + layer.aclnn_input_offset = torch.nn.Parameter( layer.input_offset.data.repeat(expanding_factor).to(device="npu"), requires_grad=False, @@ -740,7 +765,7 @@ def __init__(self, quantization_config: W8A8Int8Config) -> None: self.quant_method = ( NPU_W8A8LinearMethodMTImpl() if useMindIETurbo - else NPU_W8A8LinearMethodImpl() + else NPU_W8A8LinearMethodImpl(quantization_config) ) def create_weights( diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 64ac00ff19e9..45208e473892 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -212,31 +212,27 @@ def forward_npu( fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """A PyTorch-npu implementation of forward().""" + assert ( fused_set_kv_buffer_arg is None ), "fused_set_kv_buffer_arg is not supported for npu implementation" - if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"): - return self.forward_native( - positions, query, key, offsets, fused_set_kv_buffer_arg - ) - else: + rotary_mode = "half" + if self.is_neox_style: rotary_mode = "half" - if self.is_neox_style: - rotary_mode = "half" - else: - rotary_mode = "interleave" - mrope_section = [0, 0, 0] - query_out, key_out = torch_npu.npu_mrope( - positions, - query, - key, - self.cos_sin_cache, - self.head_size, - mrope_section=mrope_section, - rotary_mode=rotary_mode, - ) - return query_out, key_out + else: + rotary_mode = "interleave" + mrope_section = [0, 0, 0] + query_out, key_out = torch_npu.npu_mrope( + positions, + query, + key, + self.cos_sin_cache, + self.head_size, + mrope_section=mrope_section, + rotary_mode=rotary_mode, + ) + return query_out, key_out def forward_cpu( self, diff --git a/python/sglang/srt/model_executor/compilation/custom_ops.py b/python/sglang/srt/model_executor/compilation/custom_ops.py new file mode 100644 index 000000000000..7715f85d3bcc --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/custom_ops.py @@ -0,0 +1,42 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Optional + +import torch + +from sglang.srt.layers.dp_attention import _DpGatheredBufferWrapper + + +@torch.library.custom_op("sglang::_set_dp_buffer_len", mutates_args=()) +def _set_dp_buffer_len( + global_dp_buffer_len: Optional[int], num_tokens: Optional[int] +) -> None: + _DpGatheredBufferWrapper._global_dp_buffer_len = global_dp_buffer_len + _DpGatheredBufferWrapper._local_dp_buffer_len = num_tokens + + +@_set_dp_buffer_len.register_fake +def _set_dp_buffer_len_register_fake(global_dp_buffer_len, num_tokens) -> None: + pass + + +@torch.library.custom_op("sglang::_set_is_extend_in_batch", mutates_args=()) +def _set_is_extend_in_batch(is_extend_in_batch: bool) -> None: + _DpGatheredBufferWrapper.set_is_extend_in_batch(is_extend_in_batch) + + +@_set_is_extend_in_batch.register_fake +def _set_is_extend_in_batch_fake(is_extend_in_batch: bool) -> None: + pass diff --git a/python/sglang/srt/model_executor/compilation/npu_compiler_backend.py b/python/sglang/srt/model_executor/compilation/npu_compiler_backend.py new file mode 100644 index 000000000000..cdbe8d5d3a81 --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/npu_compiler_backend.py @@ -0,0 +1,42 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Callable + +import torch +from torch._dynamo.eval_frame import DisableContext + +from sglang.srt.model_executor.compilation.pass_manager import PassManager +from sglang.srt.model_executor.compilation.passes.w8a8_int8.div_fuse import DivFuse +from sglang.srt.model_executor.compilation.passes.w8a8_int8.erase_copy import EraseCopy +from sglang.srt.model_executor.compilation.passes.w8a8_int8.npu_add_rms_norm_quant_fuse import ( + NpuAddRmsNormQuantFuse, +) + + +class NpuBackend: + def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: + DisableContext.compiled_function_args[DisableContext.batch_size] = ( + example_inputs + ) + NpuBackend.apply_passes(graph) + return graph + + def apply_passes(graph_module: torch.fx.GraphModule): + passManager = PassManager(graph_module) + passManager.add(NpuAddRmsNormQuantFuse) + passManager.add(DivFuse) + passManager.add(EraseCopy) + passManager.apply() + graph_module.recompile() diff --git a/python/sglang/srt/model_executor/compilation/npu_graph_compiler.py b/python/sglang/srt/model_executor/compilation/npu_graph_compiler.py new file mode 100644 index 000000000000..4f1816ca0e1e --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/npu_graph_compiler.py @@ -0,0 +1,27 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch + +from sglang.srt.model_executor.compilation.npu_compiler_backend import NpuBackend + + +class NpuGraphCompiler: + def __init__(self, model: torch.nn.Module): + torch._dynamo.reset() + + self.backend = NpuBackend() + self.compiled_callable = torch.compile( + model, fullgraph=True, dynamic=False, backend=self.backend + ) diff --git a/python/sglang/srt/model_executor/compilation/pass_manager.py b/python/sglang/srt/model_executor/compilation/pass_manager.py new file mode 100644 index 000000000000..2bc613768a80 --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/pass_manager.py @@ -0,0 +1,41 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch + + +class PassManager: + def __init__(self, graph_module: torch.fx.GraphModule): + self.graph_module = graph_module + self.passes = [] + + def add(self, pass_): + self.passes.append(pass_) + + def apply(self): + updated = False + for pass_ in self.passes: + pass_instance = pass_() + if callable(pass_instance): + results = pass_instance(self.graph_module) + else: + results = torch.fx.replace_pattern( + self.graph_module, pass_.pattern, pass_.replacement + ) + + if not updated: + updated = len(results) != 0 + + if updated: + self.graph_module.recompile() diff --git a/python/sglang/srt/model_executor/compilation/passes/w8a8_int8/div_fuse.py b/python/sglang/srt/model_executor/compilation/passes/w8a8_int8/div_fuse.py new file mode 100644 index 000000000000..7b431c6a65b5 --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/passes/w8a8_int8/div_fuse.py @@ -0,0 +1,23 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +class DivFuse: + def pattern(x): + y = 1.0 / x + z = 1.0 / y + return z + + def replacement(x): + return x diff --git a/python/sglang/srt/model_executor/compilation/passes/w8a8_int8/erase_copy.py b/python/sglang/srt/model_executor/compilation/passes/w8a8_int8/erase_copy.py new file mode 100644 index 000000000000..de34f61f3c11 --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/passes/w8a8_int8/erase_copy.py @@ -0,0 +1,70 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import torch + + +class EraseCopy: + def __call__(self, graph_module: torch.fx.GraphModule): + copy_node = None + prepare_weight_cache_default_node = None + + results = [] + for module in graph_module.modules(): + for node in list(module.graph.nodes): + if node.type == torch.nn.parameter.Parameter: + continue + if node.target == "copy_": + copy_node = node + prepare_weight_cache_default_node = None + continue + + if ( + copy_node + and node.target == torch.ops.sglang.prepare_weight_cache.default + ): + prepare_weight_cache_default_node = node + continue + + if copy_node and node.target == torch.ops.npu.npu_add_rms_norm_quant: + arg = copy_node.args[1] + + if prepare_weight_cache_default_node is not None: + prepare_weight_cache_default_node.args = ( + arg, + prepare_weight_cache_default_node.args[1], + ) + + node.args = ( + node.args[0], + arg, + node.args[2], + node.args[3], + node.args[4], + ) + + module.graph.erase_node(copy_node) + + result = ( + arg, + copy_node, + prepare_weight_cache_default_node, + ) + results.append(result) + + copy_node = None + prepare_weight_cache_default_node = None + + return results diff --git a/python/sglang/srt/model_executor/compilation/passes/w8a8_int8/npu_add_rms_norm_quant_fuse.py b/python/sglang/srt/model_executor/compilation/passes/w8a8_int8/npu_add_rms_norm_quant_fuse.py new file mode 100644 index 000000000000..ac97b70cf40a --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/passes/w8a8_int8/npu_add_rms_norm_quant_fuse.py @@ -0,0 +1,36 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch + + +class NpuAddRmsNormQuantFuse: + def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset, v1, v2, v3): + output = torch.ops.npu.npu_add_rms_norm( + rms_norm_input, residual, rms_norm_weight, 1e-6 + ) + out0 = output[0] + out2 = output[2] + quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, v1, v2, v3) + return quantized_output, out2 + + def replacement( + rms_norm_input, residual, rms_norm_weight, scale, offset, v1, v2, v3 + ): + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, residual, rms_norm_weight, 1.0 / scale, offset, epsilon=1e-6 + ) + quantized_output = output[0] + out2 = output[2] + return quantized_output, out2 diff --git a/python/sglang/srt/model_executor/compilation/patch_dynamo.py b/python/sglang/srt/model_executor/compilation/patch_dynamo.py new file mode 100644 index 000000000000..284582f86011 --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/patch_dynamo.py @@ -0,0 +1,54 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +import torch +from torch._dynamo.decorators import skip +from torch._dynamo.eval_frame import DisableContext, innermost_fn + + +def patch_dynamo_context(): + setattr(torch._dynamo.eval_frame.DisableContext, "compiled_function_args", {}) + setattr(torch._dynamo.eval_frame.DisableContext, "compiled_function", {}) + setattr(torch._dynamo.eval_frame.DisableContext, "batch_size", None) + + +original_disable_context_call = None +original_disable = None + + +def decorators_disable(fn=None, recursive=True): + if recursive: + if fn is not None: + fn = innermost_fn(fn) + assert callable(fn) + + DisableContext.compiled_function[DisableContext.batch_size] = fn + return DisableContext()(fn) + return DisableContext() + else: + return skip(fn) + + +def patch_dynamo_context_call(): + global original_disable + original_disable = torch._dynamo.decorators.disable + torch._dynamo.decorators.disable = decorators_disable + + +def restore_dynamo_context_call(): + global original_disable + torch._dynamo.decorators.disable = original_disable + original_disable = None diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 80f8c65648cd..559831d484c1 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 SGLang Team +# Copyright 2023-2025 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -39,8 +39,6 @@ DpPaddingMode, get_attention_tp_rank, get_attention_tp_size, - set_dp_buffer_len, - set_is_extend_in_batch, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer @@ -77,6 +75,11 @@ _is_hip = is_hip() +from sglang.srt.model_executor.compilation.custom_ops import ( + _set_dp_buffer_len, + _set_is_extend_in_batch, +) + logger = logging.getLogger(__name__) if TYPE_CHECKING: @@ -520,7 +523,7 @@ def capture(self) -> None: ) logger.info(log_message) - def _capture_graph(self, graph, pool, stream, run_once_fn): + def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): memory_saver_adapter = TorchMemorySaverAdapter.create( enable=self.model_runner.server_args.enable_memory_saver and get_bool_env_var("SGLANG_MEMORY_SAVER_CUDA_GRAPH") @@ -660,8 +663,8 @@ def capture_one_batch_size(self, bs: int, forward: Callable): def run_once(): # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None - set_dp_buffer_len(global_dp_buffer_len, num_tokens) - set_is_extend_in_batch(False) + _set_dp_buffer_len(global_dp_buffer_len, num_tokens) + _set_is_extend_in_batch(False) kwargs = {} if ( @@ -692,7 +695,7 @@ def run_once(): # Set graph pool id globally to be able to use symmetric memory set_graph_pool_id(get_global_graph_memory_pool()) out = self._capture_graph( - graph, get_global_graph_memory_pool(), stream, run_once + graph, get_global_graph_memory_pool(), stream, run_once, bs ) return graph, out diff --git a/python/sglang/srt/model_executor/npu_graph_runner.py b/python/sglang/srt/model_executor/npu_graph_runner.py index f9961a36f3e4..97d8fa7c7441 100644 --- a/python/sglang/srt/model_executor/npu_graph_runner.py +++ b/python/sglang/srt/model_executor/npu_graph_runner.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 SGLang Team +# Copyright 2023-2025 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,11 +17,15 @@ import logging import threading +from contextlib import contextmanager from typing import TYPE_CHECKING, Optional, Union +import numpy as np import torch -from sglang.srt.configs.model_config import is_deepseek_nsa +import sglang.srt.model_executor.cuda_graph_runner +from sglang.srt.configs.model_config import AttentionArch, is_deepseek_nsa +from sglang.srt.distributed.parallel_state import GroupCoordinator from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner logger = logging.getLogger(__name__) @@ -29,33 +33,100 @@ if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner +from torch._dynamo.eval_frame import DisableContext + from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.model_executor.compilation.patch_dynamo import ( + patch_dynamo_context, + patch_dynamo_context_call, + restore_dynamo_context_call, +) +from sglang.srt.model_executor.compilation.npu_graph_compiler import NpuGraphCompiler from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +@contextmanager +def patch_model_npu( + model: torch.nn.Module, + enable_compile: bool, + num_tokens: int, + tp_group: GroupCoordinator, +): + yield model + + class NPUGraphRunner(CudaGraphRunner): """A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile.""" def __init__(self, model_runner: ModelRunner): + if model_runner.server_args.enable_torch_compile: + patch_dynamo_context() + sglang.srt.model_executor.cuda_graph_runner.patch_model = patch_model_npu + model_runner.attn_backend.enable_torch_compile = ( + model_runner.server_args.enable_torch_compile + ) + super().__init__(model_runner) + self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA def _create_device_graph(self): return torch.npu.NPUGraph() - def _capture_graph(self, graph, pool, stream, run_once_fn): - with torch.npu.graph( - graph, - pool=pool, - stream=stream, - auto_dispatch_capture=True, - ): - out = run_once_fn() + def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): + if self.enable_torch_compile: + compiler = NpuGraphCompiler(run_once_fn) + + patch_dynamo_context_call() + DisableContext.batch_size = bs + try: + # compilation + out = compiler.compiled_callable() + + # capture function and args + out = compiler.compiled_callable() + finally: + DisableContext.batch_size = None + restore_dynamo_context_call() + + assert bs in DisableContext.compiled_function + assert DisableContext.compiled_function[bs] + assert bs in DisableContext.compiled_function_args + assert DisableContext.compiled_function_args[bs] + + compiled_function = DisableContext.compiled_function[bs] + args = DisableContext.compiled_function_args[bs] + with torch.npu.graph( + graph, + pool=pool, + stream=stream, + auto_dispatch_capture=True, + ): + compiled_function(*args) + + else: + with torch.npu.graph( + graph, + pool=pool, + stream=stream, + auto_dispatch_capture=True, + ): + out = run_once_fn() return out def _update_inputs(self, seq_lens): - self.graphs[self.bs].update( - cpu_update_input=[{"actual_seq_lengths_kv": seq_lens}] - ) + if self.enable_torch_compile: + if self.use_mla: + self.graphs[self.bs].update( + cpu_update_input=[{"actual_seq_lengths_kv": seq_lens}] + ) + else: + self.graphs[self.bs].update( + cpu_update_input=[{"context_lens": seq_lens}] + ) + else: + self.graphs[self.bs].update( + cpu_update_input=[{"actual_seq_lengths_kv": seq_lens}] + ) def _cache_loc_dtype(self): return torch.int32 @@ -74,20 +145,37 @@ def replay( self.positions[: self.raw_num_token].copy_(forward_batch.positions) # Replay - if not is_deepseek_nsa(self.model_runner.model_config.hf_config): - if forward_batch.forward_mode.is_target_verify(): - seq_lens_cpu = forward_batch.seq_lens.cpu() + self.num_tokens_per_bs - seq_lens = seq_lens_cpu.tolist() + [0] * (self.bs - self.raw_bs) + if self.enable_torch_compile: + seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * ( + self.bs - self.raw_bs + ) + if self.use_mla: + actual_seq_len_kv = seq_lens else: - seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * ( - self.bs - self.raw_bs + actual_seq_len_kv = torch.from_numpy( + np.array(seq_lens).astype(np.int32) ) - thread = threading.Thread(target=self._update_inputs, args=(seq_lens,)) + thread = threading.Thread( + target=self._update_inputs, args=(actual_seq_len_kv,) + ) thread.start() self.graphs[self.bs].replay() thread.join() else: - self.graphs[self.bs].replay() + if not is_deepseek_nsa(self.model_runner.model_config.hf_config): + if forward_batch.forward_mode.is_target_verify(): + seq_lens_cpu = forward_batch.seq_lens.cpu() + self.num_tokens_per_bs + seq_lens = seq_lens_cpu.tolist() + [0] * (self.bs - self.raw_bs) + else: + seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * ( + self.bs - self.raw_bs + ) + thread = threading.Thread(target=self._update_inputs, args=(seq_lens,)) + thread.start() + self.graphs[self.bs].replay() + thread.join() + else: + self.graphs[self.bs].replay() output = self.output_buffers[self.bs] if isinstance(output, LogitsProcessorOutput): diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index a7b987e110f4..5ae85f44ea17 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -253,6 +253,8 @@ def get_quant_config( return ModelOptFp8Config.from_config(config) elif "FP4" in quant_algo: return ModelOptFp4Config.from_config(config) + + config["enable_torch_compile"] = model_config.enable_torch_compile return quant_cls.from_config(config) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index b031d6e03e34..2f68a7d00c1c 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -5,6 +5,12 @@ import torch from torch import nn +# TODO: refactor +# if supports_custom_op(): +# from sglang.srt._custom_ops import wait_cmo_stream, wait_cmo_stream_fake +# else: +# from sglang.srt.utils import wait_cmo_stream +from sglang.srt._custom_ops import wait_cmo_stream from sglang.srt.distributed import ( get_pp_group, get_tensor_model_parallel_rank, @@ -30,13 +36,7 @@ from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.server_args import get_global_server_args -from sglang.srt.utils import ( - add_prefix, - get_cmo_stream, - is_cuda, - is_npu, - wait_cmo_stream, -) +from sglang.srt.utils import add_prefix, is_cuda, is_npu Qwen3Config = None @@ -281,8 +281,10 @@ def forward( ), ) hidden_states = self.mlp(hidden_states) - if _is_npu and get_cmo_stream(): - wait_cmo_stream() + # check if custom op is supported + # if _is_npu and get_cmo_stream(): + # wait_cmo_stream() + wait_cmo_stream() hidden_states, residual = self.layer_communicator.postprocess_layer( hidden_states, residual, forward_batch ) diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 0c1a13b5f86d..35cbd198b8ff 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -1888,7 +1888,7 @@ def direct_register_custom_op( op_func: Callable, mutates_args: List[str], fake_impl: Optional[Callable] = None, - target_lib: Optional[Library] = None, + target_lib: Optional[Library] = None ): """ `torch.library.custom_op` can have significant overhead because it @@ -1937,7 +1937,7 @@ def direct_register_custom_op( try: my_lib.define(op_name + schema_str) - my_lib.impl(op_name, op_func, "CUDA") + my_lib.impl(op_name, op_func, "CUDA" if not is_npu() else "PrivateUse1") if fake_impl is not None: my_lib._register_fake(op_name, fake_impl) except RuntimeError as error: From b9744608f0f57fe6c45a75b1eb97b69d367bd1fa Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Fri, 31 Oct 2025 08:40:52 +0300 Subject: [PATCH 02/71] pre-commit & refactoring --- python/sglang/srt/_custom_ops.py | 7 +++++-- .../npu}/custom_ops.py | 0 .../npu}/npu_compiler_backend.py | 8 ++++---- .../npu}/npu_graph_compiler.py | 2 +- .../npu}/pass_manager.py | 0 .../npu}/passes/w8a8_int8/div_fuse.py | 0 .../npu}/passes/w8a8_int8/erase_copy.py | 0 .../w8a8_int8/npu_add_rms_norm_quant_fuse.py | 0 .../npu}/patch_dynamo.py | 0 .../srt/layers/quantization/w8a8_int8.py | 17 ++++++----------- .../srt/model_executor/cuda_graph_runner.py | 2 +- .../srt/model_executor/npu_graph_runner.py | 4 ++-- python/sglang/srt/models/qwen3.py | 19 ++++++++----------- python/sglang/srt/utils/common.py | 2 +- 14 files changed, 28 insertions(+), 33 deletions(-) rename python/sglang/srt/{model_executor/compilation => compilation/npu}/custom_ops.py (100%) rename python/sglang/srt/{model_executor/compilation => compilation/npu}/npu_compiler_backend.py (79%) rename python/sglang/srt/{model_executor/compilation => compilation/npu}/npu_graph_compiler.py (91%) rename python/sglang/srt/{model_executor/compilation => compilation/npu}/pass_manager.py (100%) rename python/sglang/srt/{model_executor/compilation => compilation/npu}/passes/w8a8_int8/div_fuse.py (100%) rename python/sglang/srt/{model_executor/compilation => compilation/npu}/passes/w8a8_int8/erase_copy.py (100%) rename python/sglang/srt/{model_executor/compilation => compilation/npu}/passes/w8a8_int8/npu_add_rms_norm_quant_fuse.py (100%) rename python/sglang/srt/{model_executor/compilation => compilation/npu}/patch_dynamo.py (100%) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 4201f5729835..f6ebfc372b97 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -7,7 +7,6 @@ from sglang.srt.utils import ( direct_register_custom_op, get_bool_env_var, - get_cmo_stream, is_hip, is_hpu, is_npu, @@ -24,7 +23,7 @@ @torch.library.custom_op("sglang::wait_cmo_stream", mutates_args=()) def wait_cmo_stream() -> None: - if is_npu() and get_cmo_stream(): + if sglang.srt.utils.get_cmo_stream(): sglang.srt.utils.wait_cmo_stream() @@ -33,6 +32,10 @@ def wait_cmo_stream_fake() -> None: pass +def get_cmo_stream() -> bool: + return True + + def prepare_weight_cache(handle: torch.Tensor, cache: List[torch.Tensor]) -> None: sglang.srt.utils.prepare_weight_cache(handle, cache) diff --git a/python/sglang/srt/model_executor/compilation/custom_ops.py b/python/sglang/srt/compilation/npu/custom_ops.py similarity index 100% rename from python/sglang/srt/model_executor/compilation/custom_ops.py rename to python/sglang/srt/compilation/npu/custom_ops.py diff --git a/python/sglang/srt/model_executor/compilation/npu_compiler_backend.py b/python/sglang/srt/compilation/npu/npu_compiler_backend.py similarity index 79% rename from python/sglang/srt/model_executor/compilation/npu_compiler_backend.py rename to python/sglang/srt/compilation/npu/npu_compiler_backend.py index cdbe8d5d3a81..07754b636e95 100644 --- a/python/sglang/srt/model_executor/compilation/npu_compiler_backend.py +++ b/python/sglang/srt/compilation/npu/npu_compiler_backend.py @@ -17,10 +17,10 @@ import torch from torch._dynamo.eval_frame import DisableContext -from sglang.srt.model_executor.compilation.pass_manager import PassManager -from sglang.srt.model_executor.compilation.passes.w8a8_int8.div_fuse import DivFuse -from sglang.srt.model_executor.compilation.passes.w8a8_int8.erase_copy import EraseCopy -from sglang.srt.model_executor.compilation.passes.w8a8_int8.npu_add_rms_norm_quant_fuse import ( +from sglang.srt.compilation.npu.pass_manager import PassManager +from sglang.srt.compilation.npu.passes.w8a8_int8.div_fuse import DivFuse +from sglang.srt.compilation.npu.passes.w8a8_int8.erase_copy import EraseCopy +from sglang.srt.compilation.npu.passes.w8a8_int8.npu_add_rms_norm_quant_fuse import ( NpuAddRmsNormQuantFuse, ) diff --git a/python/sglang/srt/model_executor/compilation/npu_graph_compiler.py b/python/sglang/srt/compilation/npu/npu_graph_compiler.py similarity index 91% rename from python/sglang/srt/model_executor/compilation/npu_graph_compiler.py rename to python/sglang/srt/compilation/npu/npu_graph_compiler.py index 4f1816ca0e1e..b49d743167fb 100644 --- a/python/sglang/srt/model_executor/compilation/npu_graph_compiler.py +++ b/python/sglang/srt/compilation/npu/npu_graph_compiler.py @@ -14,7 +14,7 @@ import torch -from sglang.srt.model_executor.compilation.npu_compiler_backend import NpuBackend +from sglang.srt.compilation.npu.npu_compiler_backend import NpuBackend class NpuGraphCompiler: diff --git a/python/sglang/srt/model_executor/compilation/pass_manager.py b/python/sglang/srt/compilation/npu/pass_manager.py similarity index 100% rename from python/sglang/srt/model_executor/compilation/pass_manager.py rename to python/sglang/srt/compilation/npu/pass_manager.py diff --git a/python/sglang/srt/model_executor/compilation/passes/w8a8_int8/div_fuse.py b/python/sglang/srt/compilation/npu/passes/w8a8_int8/div_fuse.py similarity index 100% rename from python/sglang/srt/model_executor/compilation/passes/w8a8_int8/div_fuse.py rename to python/sglang/srt/compilation/npu/passes/w8a8_int8/div_fuse.py diff --git a/python/sglang/srt/model_executor/compilation/passes/w8a8_int8/erase_copy.py b/python/sglang/srt/compilation/npu/passes/w8a8_int8/erase_copy.py similarity index 100% rename from python/sglang/srt/model_executor/compilation/passes/w8a8_int8/erase_copy.py rename to python/sglang/srt/compilation/npu/passes/w8a8_int8/erase_copy.py diff --git a/python/sglang/srt/model_executor/compilation/passes/w8a8_int8/npu_add_rms_norm_quant_fuse.py b/python/sglang/srt/compilation/npu/passes/w8a8_int8/npu_add_rms_norm_quant_fuse.py similarity index 100% rename from python/sglang/srt/model_executor/compilation/passes/w8a8_int8/npu_add_rms_norm_quant_fuse.py rename to python/sglang/srt/compilation/npu/passes/w8a8_int8/npu_add_rms_norm_quant_fuse.py diff --git a/python/sglang/srt/model_executor/compilation/patch_dynamo.py b/python/sglang/srt/compilation/npu/patch_dynamo.py similarity index 100% rename from python/sglang/srt/model_executor/compilation/patch_dynamo.py rename to python/sglang/srt/compilation/npu/patch_dynamo.py diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 59b5e7568137..f4dadf2a59de 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -25,6 +25,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( apply_module_patch, cpu_has_amx_support, @@ -201,7 +202,6 @@ class W8A8Int8Config(QuantizationConfig): def __init__(self, quant_config: Dict[str, Any] = {}): super().__init__() - self.enable_torch_compile = quant_config.get("enable_torch_compile", False) self.quant_description = quant_config self.is_dynamic = quant_config.get("is_dynamic", False) ignore = cast(List[str], quant_config.get("ignore", [])) @@ -570,10 +570,8 @@ def apply( class NPU_W8A8LinearMethodImpl: """Linear method for NPU W8A8.""" - quant_config = None - - def __init__(self, quant_config) -> None: - NPU_W8A8LinearMethodImpl.quant_config = quant_config + def __init__(self) -> None: + self.enable_torch_compile = get_global_server_args().enable_torch_compile # aclnn quant matmul requires to transpose matrix B, set to true by default. self.transpose_weight = True @@ -620,7 +618,7 @@ def apply( original_dtype = x.dtype if original_dtype != torch.int8: aclnn_input_scale_reciprocal = layer.aclnn_input_scale_reciprocal - if NPU_W8A8LinearMethodImpl.quant_config.enable_torch_compile and ( + if get_global_server_args().enable_torch_compile and ( isinstance(layer, MergedColumnParallelLinear) or isinstance(layer, QKVParallelLinear) ): @@ -657,10 +655,7 @@ def process_weights_after_loading(self, layer): prev_layer_fuse_reciprocal = isinstance( layer, MergedColumnParallelLinear ) or isinstance(layer, QKVParallelLinear) - if ( - NPU_W8A8LinearMethodImpl.quant_config.enable_torch_compile - and prev_layer_fuse_reciprocal - ): + if self.enable_torch_compile and prev_layer_fuse_reciprocal: layer.aclnn_input_scale_reciprocal = torch.nn.Parameter( layer.input_scale.data.repeat(expanding_factor).to(device="npu"), requires_grad=False, @@ -765,7 +760,7 @@ def __init__(self, quantization_config: W8A8Int8Config) -> None: self.quant_method = ( NPU_W8A8LinearMethodMTImpl() if useMindIETurbo - else NPU_W8A8LinearMethodImpl(quantization_config) + else NPU_W8A8LinearMethodImpl() ) def create_weights( diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 559831d484c1..30f3baf9bc5f 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -75,7 +75,7 @@ _is_hip = is_hip() -from sglang.srt.model_executor.compilation.custom_ops import ( +from sglang.srt.compilation.npu.custom_ops import ( _set_dp_buffer_len, _set_is_extend_in_batch, ) diff --git a/python/sglang/srt/model_executor/npu_graph_runner.py b/python/sglang/srt/model_executor/npu_graph_runner.py index 97d8fa7c7441..9e831523fad6 100644 --- a/python/sglang/srt/model_executor/npu_graph_runner.py +++ b/python/sglang/srt/model_executor/npu_graph_runner.py @@ -36,12 +36,12 @@ from torch._dynamo.eval_frame import DisableContext from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.model_executor.compilation.patch_dynamo import ( +from sglang.srt.compilation.npu.npu_graph_compiler import NpuGraphCompiler +from sglang.srt.compilation.npu.patch_dynamo import ( patch_dynamo_context, patch_dynamo_context_call, restore_dynamo_context_call, ) -from sglang.srt.model_executor.compilation.npu_graph_compiler import NpuGraphCompiler from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 2f68a7d00c1c..8752c72a6472 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -5,12 +5,6 @@ import torch from torch import nn -# TODO: refactor -# if supports_custom_op(): -# from sglang.srt._custom_ops import wait_cmo_stream, wait_cmo_stream_fake -# else: -# from sglang.srt.utils import wait_cmo_stream -from sglang.srt._custom_ops import wait_cmo_stream from sglang.srt.distributed import ( get_pp_group, get_tensor_model_parallel_rank, @@ -36,7 +30,12 @@ from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.server_args import get_global_server_args -from sglang.srt.utils import add_prefix, is_cuda, is_npu +from sglang.srt.utils import add_prefix, is_cuda, is_npu, supports_custom_op + +if supports_custom_op() and get_global_server_args().enable_torch_compile: + from sglang.srt._custom_ops import get_cmo_stream, wait_cmo_stream +else: + from sglang.srt.utils import get_cmo_stream, wait_cmo_stream Qwen3Config = None @@ -281,10 +280,8 @@ def forward( ), ) hidden_states = self.mlp(hidden_states) - # check if custom op is supported - # if _is_npu and get_cmo_stream(): - # wait_cmo_stream() - wait_cmo_stream() + if _is_npu and get_cmo_stream(): + wait_cmo_stream() hidden_states, residual = self.layer_communicator.postprocess_layer( hidden_states, residual, forward_batch ) diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 35cbd198b8ff..03357fc703fc 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -1888,7 +1888,7 @@ def direct_register_custom_op( op_func: Callable, mutates_args: List[str], fake_impl: Optional[Callable] = None, - target_lib: Optional[Library] = None + target_lib: Optional[Library] = None, ): """ `torch.library.custom_op` can have significant overhead because it From 7a7bde7c22d7bdb44e2ed15acebbec2ac5936272 Mon Sep 17 00:00:00 2001 From: t00918722 <2201534206@qq.com> Date: Sat, 1 Nov 2025 17:56:26 +0800 Subject: [PATCH 03/71] pre-commit --- python/sglang/srt/model_executor/npu_graph_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/npu_graph_runner.py b/python/sglang/srt/model_executor/npu_graph_runner.py index 9e831523fad6..ea07cc4b24a3 100644 --- a/python/sglang/srt/model_executor/npu_graph_runner.py +++ b/python/sglang/srt/model_executor/npu_graph_runner.py @@ -35,13 +35,13 @@ from torch._dynamo.eval_frame import DisableContext -from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.compilation.npu.npu_graph_compiler import NpuGraphCompiler from sglang.srt.compilation.npu.patch_dynamo import ( patch_dynamo_context, patch_dynamo_context_call, restore_dynamo_context_call, ) +from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors From 9bb77510cca5039ad7a6dd26055b66d0d806d343 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Wed, 5 Nov 2025 14:37:06 +0300 Subject: [PATCH 04/71] Merge branch 'main' into eshogulin/pass_manager: fix - custom_ops.py --- python/sglang/srt/_custom_ops.py | 8 +------ .../sglang/srt/compilation/npu/custom_ops.py | 23 +++++++++++++------ 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 8777fd29d275..e3734ba087fe 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -4,17 +4,11 @@ import torch -from sglang.srt.utils import ( - direct_register_custom_op, - is_hip, - is_hpu, - is_npu, -) +from sglang.srt.utils import direct_register_custom_op, is_hip, is_hpu, is_npu logger = logging.getLogger(__name__) - import sglang.srt.utils diff --git a/python/sglang/srt/compilation/npu/custom_ops.py b/python/sglang/srt/compilation/npu/custom_ops.py index 7715f85d3bcc..7214c4845ba7 100644 --- a/python/sglang/srt/compilation/npu/custom_ops.py +++ b/python/sglang/srt/compilation/npu/custom_ops.py @@ -12,29 +12,38 @@ # limitations under the License. # ============================================================================== -from typing import Optional +from typing import List, Optional import torch -from sglang.srt.layers.dp_attention import _DpGatheredBufferWrapper +import sglang.srt.layers.dp_attention @torch.library.custom_op("sglang::_set_dp_buffer_len", mutates_args=()) def _set_dp_buffer_len( - global_dp_buffer_len: Optional[int], num_tokens: Optional[int] + global_dp_buffer_len: Optional[int], + num_tokens: Optional[int], + is_max_len: bool, + global_num_tokens: Optional[List[int]] = None, ) -> None: - _DpGatheredBufferWrapper._global_dp_buffer_len = global_dp_buffer_len - _DpGatheredBufferWrapper._local_dp_buffer_len = num_tokens + sglang.srt.layers.dp_attention.set_dp_buffer_len( + global_dp_buffer_len, num_tokens, is_max_len, global_num_tokens + ) @_set_dp_buffer_len.register_fake -def _set_dp_buffer_len_register_fake(global_dp_buffer_len, num_tokens) -> None: +def _set_dp_buffer_len_register_fake( + global_dp_buffer_len: Optional[int], + num_tokens: Optional[int], + is_max_len: bool, + global_num_tokens: Optional[List[int]] = None, +) -> None: pass @torch.library.custom_op("sglang::_set_is_extend_in_batch", mutates_args=()) def _set_is_extend_in_batch(is_extend_in_batch: bool) -> None: - _DpGatheredBufferWrapper.set_is_extend_in_batch(is_extend_in_batch) + sglang.srt.layers.dp_attention.set_is_extend_in_batch(is_extend_in_batch) @_set_is_extend_in_batch.register_fake From 704800507369ace42494b8c0d1486ce904fcea70 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Mon, 10 Nov 2025 16:43:07 +0800 Subject: [PATCH 05/71] cleanup & refactoring --- .../srt/compilation/{npu => }/custom_ops.py | 3 +- .../compilation/npu/npu_compiler_backend.py | 6 ++-- .../{w8a8_int8/erase_copy.py => w8a8_int8.py} | 32 ++++++++++++++++- .../npu/passes/w8a8_int8/div_fuse.py | 23 ------------ .../w8a8_int8/npu_add_rms_norm_quant_fuse.py | 36 ------------------- python/sglang/srt/configs/model_config.py | 3 -- .../srt/model_executor/cuda_graph_runner.py | 17 +++++---- .../srt/model_executor/npu_graph_runner.py | 19 ++++++++++ .../sglang/srt/model_loader/weight_utils.py | 2 -- python/sglang/srt/models/qwen3.py | 2 +- python/sglang/srt/server_args.py | 7 ++++ 11 files changed, 73 insertions(+), 77 deletions(-) rename python/sglang/srt/compilation/{npu => }/custom_ops.py (96%) rename python/sglang/srt/compilation/npu/passes/{w8a8_int8/erase_copy.py => w8a8_int8.py} (73%) delete mode 100644 python/sglang/srt/compilation/npu/passes/w8a8_int8/div_fuse.py delete mode 100644 python/sglang/srt/compilation/npu/passes/w8a8_int8/npu_add_rms_norm_quant_fuse.py diff --git a/python/sglang/srt/compilation/npu/custom_ops.py b/python/sglang/srt/compilation/custom_ops.py similarity index 96% rename from python/sglang/srt/compilation/npu/custom_ops.py rename to python/sglang/srt/compilation/custom_ops.py index 7214c4845ba7..21ac61f22f66 100644 --- a/python/sglang/srt/compilation/npu/custom_ops.py +++ b/python/sglang/srt/compilation/custom_ops.py @@ -26,13 +26,14 @@ def _set_dp_buffer_len( is_max_len: bool, global_num_tokens: Optional[List[int]] = None, ) -> None: + global set_dp_buffer_len_original sglang.srt.layers.dp_attention.set_dp_buffer_len( global_dp_buffer_len, num_tokens, is_max_len, global_num_tokens ) @_set_dp_buffer_len.register_fake -def _set_dp_buffer_len_register_fake( +def _set_dp_buffer_len_fake( global_dp_buffer_len: Optional[int], num_tokens: Optional[int], is_max_len: bool, diff --git a/python/sglang/srt/compilation/npu/npu_compiler_backend.py b/python/sglang/srt/compilation/npu/npu_compiler_backend.py index 07754b636e95..e4842f966f9a 100644 --- a/python/sglang/srt/compilation/npu/npu_compiler_backend.py +++ b/python/sglang/srt/compilation/npu/npu_compiler_backend.py @@ -18,9 +18,9 @@ from torch._dynamo.eval_frame import DisableContext from sglang.srt.compilation.npu.pass_manager import PassManager -from sglang.srt.compilation.npu.passes.w8a8_int8.div_fuse import DivFuse -from sglang.srt.compilation.npu.passes.w8a8_int8.erase_copy import EraseCopy -from sglang.srt.compilation.npu.passes.w8a8_int8.npu_add_rms_norm_quant_fuse import ( +from sglang.srt.compilation.npu.passes.w8a8_int8 import ( + DivFuse, + EraseCopy, NpuAddRmsNormQuantFuse, ) diff --git a/python/sglang/srt/compilation/npu/passes/w8a8_int8/erase_copy.py b/python/sglang/srt/compilation/npu/passes/w8a8_int8.py similarity index 73% rename from python/sglang/srt/compilation/npu/passes/w8a8_int8/erase_copy.py rename to python/sglang/srt/compilation/npu/passes/w8a8_int8.py index de34f61f3c11..ac2b86e7b171 100644 --- a/python/sglang/srt/compilation/npu/passes/w8a8_int8/erase_copy.py +++ b/python/sglang/srt/compilation/npu/passes/w8a8_int8.py @@ -12,10 +12,19 @@ # limitations under the License. # ============================================================================== - import torch +class DivFuse: + def pattern(x): + y = 1.0 / x + z = 1.0 / y + return z + + def replacement(x): + return x + + class EraseCopy: def __call__(self, graph_module: torch.fx.GraphModule): copy_node = None @@ -68,3 +77,24 @@ def __call__(self, graph_module: torch.fx.GraphModule): prepare_weight_cache_default_node = None return results + + +class NpuAddRmsNormQuantFuse: + def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset, v1, v2, v3): + output = torch.ops.npu.npu_add_rms_norm( + rms_norm_input, residual, rms_norm_weight, 1e-6 + ) + out0 = output[0] + out2 = output[2] + quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, v1, v2, v3) + return quantized_output, out2 + + def replacement( + rms_norm_input, residual, rms_norm_weight, scale, offset, v1, v2, v3 + ): + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, residual, rms_norm_weight, 1.0 / scale, offset, epsilon=1e-6 + ) + quantized_output = output[0] + out2 = output[2] + return quantized_output, out2 diff --git a/python/sglang/srt/compilation/npu/passes/w8a8_int8/div_fuse.py b/python/sglang/srt/compilation/npu/passes/w8a8_int8/div_fuse.py deleted file mode 100644 index 7b431c6a65b5..000000000000 --- a/python/sglang/srt/compilation/npu/passes/w8a8_int8/div_fuse.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright 2025 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -class DivFuse: - def pattern(x): - y = 1.0 / x - z = 1.0 / y - return z - - def replacement(x): - return x diff --git a/python/sglang/srt/compilation/npu/passes/w8a8_int8/npu_add_rms_norm_quant_fuse.py b/python/sglang/srt/compilation/npu/passes/w8a8_int8/npu_add_rms_norm_quant_fuse.py deleted file mode 100644 index ac97b70cf40a..000000000000 --- a/python/sglang/srt/compilation/npu/passes/w8a8_int8/npu_add_rms_norm_quant_fuse.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2025 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import torch - - -class NpuAddRmsNormQuantFuse: - def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset, v1, v2, v3): - output = torch.ops.npu.npu_add_rms_norm( - rms_norm_input, residual, rms_norm_weight, 1e-6 - ) - out0 = output[0] - out2 = output[2] - quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, v1, v2, v3) - return quantized_output, out2 - - def replacement( - rms_norm_input, residual, rms_norm_weight, scale, offset, v1, v2, v3 - ): - output = torch.ops.npu.npu_add_rms_norm_quant( - rms_norm_input, residual, rms_norm_weight, 1.0 / scale, offset, epsilon=1e-6 - ) - quantized_output = output[0] - out2 = output[2] - return quantized_output, out2 diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index cdc4f5a6d809..bdf7ca22acb3 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -97,7 +97,6 @@ def __init__( model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, sampling_defaults: str = "openai", quantize_and_serve: bool = False, - enable_torch_compile: bool = False, ) -> None: # Parse args self.model_path = model_path @@ -107,7 +106,6 @@ def __init__( self.model_impl = model_impl self.sampling_defaults = sampling_defaults self.quantize_and_serve = quantize_and_serve - self.enable_torch_compile = enable_torch_compile # Validate quantize_and_serve configuration self._validate_quantize_and_serve_config() @@ -236,7 +234,6 @@ def from_server_args( model_impl=server_args.model_impl, sampling_defaults=server_args.sampling_defaults, quantize_and_serve=server_args.quantize_and_serve, - enable_torch_compile=server_args.enable_torch_compile, override_config_file=server_args.decrypted_config_file, **kwargs, ) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 6d7f44e8cf9e..20eaca22aa78 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -39,6 +39,8 @@ DpPaddingMode, get_attention_tp_rank, get_attention_tp_size, + set_dp_buffer_len, + set_is_extend_in_batch, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer @@ -75,11 +77,6 @@ _is_hip = is_hip() -from sglang.srt.compilation.npu.custom_ops import ( - _set_dp_buffer_len, - _set_is_extend_in_batch, -) - logger = logging.getLogger(__name__) if TYPE_CHECKING: @@ -540,6 +537,12 @@ def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): def _create_device_graph(self): return torch.cuda.CUDAGraph() + def _init_dp_gathered_buffer( + self, global_dp_buffer_len: int, local_dp_buffer_len: int, dp_max_padding: bool + ): + set_dp_buffer_len(global_dp_buffer_len, local_dp_buffer_len, dp_max_padding) + set_is_extend_in_batch(False) + def capture_one_batch_size(self, bs: int, forward: Callable): graph = self._create_device_graph() stream = self.stream @@ -663,12 +666,12 @@ def capture_one_batch_size(self, bs: int, forward: Callable): def run_once(): # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None - _set_dp_buffer_len( + + self._init_dp_gathered_buffer( global_dp_buffer_len, num_tokens, forward_batch.dp_padding_mode.is_max_len(), ) - _set_is_extend_in_batch(False) kwargs = {} if ( diff --git a/python/sglang/srt/model_executor/npu_graph_runner.py b/python/sglang/srt/model_executor/npu_graph_runner.py index ea07cc4b24a3..7a0b6c9531fc 100644 --- a/python/sglang/srt/model_executor/npu_graph_runner.py +++ b/python/sglang/srt/model_executor/npu_graph_runner.py @@ -27,6 +27,8 @@ from sglang.srt.configs.model_config import AttentionArch, is_deepseek_nsa from sglang.srt.distributed.parallel_state import GroupCoordinator from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import supports_custom_op logger = logging.getLogger(__name__) @@ -35,6 +37,10 @@ from torch._dynamo.eval_frame import DisableContext +from sglang.srt.compilation.custom_ops import ( + _set_dp_buffer_len, + _set_is_extend_in_batch, +) from sglang.srt.compilation.npu.npu_graph_compiler import NpuGraphCompiler from sglang.srt.compilation.npu.patch_dynamo import ( patch_dynamo_context, @@ -72,6 +78,19 @@ def __init__(self, model_runner: ModelRunner): def _create_device_graph(self): return torch.npu.NPUGraph() + def _init_dp_gathered_buffer( + self, global_dp_buffer_len: int, local_dp_buffer_len: int, dp_max_padding: bool + ): + if supports_custom_op() and get_global_server_args().enable_torch_compile: + _set_dp_buffer_len( + global_dp_buffer_len, local_dp_buffer_len, dp_max_padding + ) + _set_is_extend_in_batch(False) + else: + super()._init_dp_gathered_buffer( + global_dp_buffer_len, local_dp_buffer_len, dp_max_padding + ) + def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): if self.enable_torch_compile: compiler = NpuGraphCompiler(run_once_fn) diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 5ae85f44ea17..a7b987e110f4 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -253,8 +253,6 @@ def get_quant_config( return ModelOptFp8Config.from_config(config) elif "FP4" in quant_algo: return ModelOptFp4Config.from_config(config) - - config["enable_torch_compile"] = model_config.enable_torch_compile return quant_cls.from_config(config) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 8752c72a6472..774df497cc49 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -32,7 +32,7 @@ from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, is_cuda, is_npu, supports_custom_op -if supports_custom_op() and get_global_server_args().enable_torch_compile: +if is_npu() and supports_custom_op() and get_global_server_args().enable_torch_compile: from sglang.srt._custom_ops import get_cmo_stream, wait_cmo_stream else: from sglang.srt.utils import get_cmo_stream, wait_cmo_stream diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 156be80acd4e..b5e4fd6563fa 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -31,6 +31,7 @@ from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.parser.reasoning_parser import ReasoningParser +from sglang.srt.utils import supports_custom_op from sglang.srt.utils.common import ( LORA_TARGET_ALL_MODULES, SUPPORTED_LORA_TARGET_MODULES, @@ -1825,6 +1826,12 @@ def _handle_other_validations(self): self.disable_cuda_graph = True self.skip_server_warmup = True + if is_npu() and not supports_custom_op(): + logger.warning( + "Torch compile is disabled because custom ops are not supported" + ) + self.enable_torch_compile = False + @staticmethod def add_cli_args(parser: argparse.ArgumentParser): From eb240d94509f22523f0f9666ea3a057f1a1519a4 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 11 Nov 2025 00:11:16 +0800 Subject: [PATCH 06/71] Pass Manager fix --- .../sglang/srt/compilation/npu/pass_manager.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/compilation/npu/pass_manager.py b/python/sglang/srt/compilation/npu/pass_manager.py index 2bc613768a80..e4da0bf8535d 100644 --- a/python/sglang/srt/compilation/npu/pass_manager.py +++ b/python/sglang/srt/compilation/npu/pass_manager.py @@ -27,12 +27,17 @@ def apply(self): updated = False for pass_ in self.passes: pass_instance = pass_() - if callable(pass_instance): - results = pass_instance(self.graph_module) - else: - results = torch.fx.replace_pattern( - self.graph_module, pass_.pattern, pass_.replacement - ) + results = [] + try: + if callable(pass_instance): + results = pass_instance(self.graph_module) + else: + results = torch.fx.replace_pattern( + self.graph_module, pass_.pattern, pass_.replacement + ) + except: + # pass was not applied + pass if not updated: updated = len(results) != 0 From 29c1d890bdb18645a8472806a9be5d5aaed45ba2 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 11 Nov 2025 19:16:40 +0800 Subject: [PATCH 07/71] Compilation: refactoring --- python/sglang/srt/compilation/npu/npu_graph_compiler.py | 6 ++++-- ...pu_compiler_backend.py => npu_graph_compiler_backend.py} | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) rename python/sglang/srt/compilation/npu/{npu_compiler_backend.py => npu_graph_compiler_backend.py} (94%) diff --git a/python/sglang/srt/compilation/npu/npu_graph_compiler.py b/python/sglang/srt/compilation/npu/npu_graph_compiler.py index b49d743167fb..46ab461e0c33 100644 --- a/python/sglang/srt/compilation/npu/npu_graph_compiler.py +++ b/python/sglang/srt/compilation/npu/npu_graph_compiler.py @@ -14,14 +14,16 @@ import torch -from sglang.srt.compilation.npu.npu_compiler_backend import NpuBackend +from sglang.srt.compilation.npu.npu_graph_compiler_backend import ( + NpuGraphCompilerBackend, +) class NpuGraphCompiler: def __init__(self, model: torch.nn.Module): torch._dynamo.reset() - self.backend = NpuBackend() + self.backend = NpuGraphCompilerBackend() self.compiled_callable = torch.compile( model, fullgraph=True, dynamic=False, backend=self.backend ) diff --git a/python/sglang/srt/compilation/npu/npu_compiler_backend.py b/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py similarity index 94% rename from python/sglang/srt/compilation/npu/npu_compiler_backend.py rename to python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py index e4842f966f9a..e40304adcb9a 100644 --- a/python/sglang/srt/compilation/npu/npu_compiler_backend.py +++ b/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py @@ -25,12 +25,12 @@ ) -class NpuBackend: +class NpuGraphCompilerBackend: def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: DisableContext.compiled_function_args[DisableContext.batch_size] = ( example_inputs ) - NpuBackend.apply_passes(graph) + NpuGraphCompilerBackend.apply_passes(graph) return graph def apply_passes(graph_module: torch.fx.GraphModule): From 3e98d1772982f3750194d04fc9d759794dca0f01 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Sun, 9 Nov 2025 03:22:04 +0800 Subject: [PATCH 08/71] NPU Piecewise Graph --- .../compilation/npu/compilation_context.py | 20 + python/sglang/srt/compilation/npu/config.py | 27 + .../srt/layers/attention/ascend_backend.py | 16 +- python/sglang/srt/layers/rotary_embedding.py | 2 +- .../compilation/npu_compilation_backend.py | 36 + .../compilation/npu_compiler_backend.py | 308 +++++++ .../compilation/npu_graph_backend.py | 62 ++ .../compilation/npu_graph_compiler.py | 45 + .../sglang/srt/model_executor/model_runner.py | 10 +- .../piecewise_npu_graph_runner_decode.py | 775 ++++++++++++++++++ python/sglang/srt/models/qwen3.py | 9 +- python/sglang/srt/server_args.py | 24 + 12 files changed, 1330 insertions(+), 4 deletions(-) create mode 100644 python/sglang/srt/compilation/npu/compilation_context.py create mode 100644 python/sglang/srt/compilation/npu/config.py create mode 100644 python/sglang/srt/model_executor/compilation/npu_compilation_backend.py create mode 100644 python/sglang/srt/model_executor/compilation/npu_compiler_backend.py create mode 100644 python/sglang/srt/model_executor/compilation/npu_graph_backend.py create mode 100644 python/sglang/srt/model_executor/compilation/npu_graph_compiler.py create mode 100644 python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py diff --git a/python/sglang/srt/compilation/npu/compilation_context.py b/python/sglang/srt/compilation/npu/compilation_context.py new file mode 100644 index 000000000000..11a01cb5c877 --- /dev/null +++ b/python/sglang/srt/compilation/npu/compilation_context.py @@ -0,0 +1,20 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch_npu + + +class CompilationContext: + graph_memory_pool = None + stream: torch_npu.npu.Stream = None diff --git a/python/sglang/srt/compilation/npu/config.py b/python/sglang/srt/compilation/npu/config.py new file mode 100644 index 000000000000..d6375c937413 --- /dev/null +++ b/python/sglang/srt/compilation/npu/config.py @@ -0,0 +1,27 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class CompilationConfig: + splitting_ops: Optional[list[str]] = None + replay_index: int = 1 + page_size: int = 0 + + @classmethod + def from_cli(cls, cli_value: str) -> "CompilationConfig": + return CompilationConfig() diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index 94de5a02702f..ea50f1643d24 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -75,6 +75,7 @@ def update_verify_buffers_to_fill_after_draft( def __init__(self, model_runner: ModelRunner): super().__init__() self.enable_torch_compile = False + self.enable_piecewise_npu_graph_decode = False self.forward_metadata = None self.device = model_runner.device self.page_size = model_runner.page_size @@ -577,7 +578,9 @@ def forward_decode_graph( layer, forward_batch.out_cache_loc, k, v ) - if not self.use_mla and self.enable_torch_compile: + if not self.use_mla and ( + self.enable_torch_compile or self.enable_piecewise_npu_graph_decode + ): k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim) @@ -595,6 +598,17 @@ def forward_decode_graph( else: actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_int + if ( + self.enable_piecewise_npu_graph_decode + and torch.compiler.is_dynamo_compiling() + ): + # input args for submodule forward + forward_batch.req_to_token_pool.req_to_token.add_( + forward_batch.req_to_token_pool.req_to_token + ) + forward_batch.req_pool_indices.add_(forward_batch.req_pool_indices) + forward_batch.seq_lens.add_(forward_batch.seq_lens) + torch_npu._npu_paged_attention( query=query, key_cache=k_cache, diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 0935b1750c93..bb69d164b3f5 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1385,7 +1385,7 @@ def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None: ): self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - @torch.compile(dynamic=True, backend=get_compiler_backend()) + # @torch.compile(dynamic=True, backend=get_compiler_backend()) def _forward_native( self, positions: torch.Tensor, diff --git a/python/sglang/srt/model_executor/compilation/npu_compilation_backend.py b/python/sglang/srt/model_executor/compilation/npu_compilation_backend.py new file mode 100644 index 000000000000..508933c36bd7 --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/npu_compilation_backend.py @@ -0,0 +1,36 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any + +import torch + +from sglang.srt.compilation.npu.compilation_context import CompilationContext + + +class NPUCompilationBackend: + def __init__( + self, graph: torch.fx.GraphModule, compilation_context: CompilationContext + ): + self.graph = graph + self.callable = None + + def __call__(self, *args) -> Any: + if not self.callable: + torch._dynamo.config.suppress_errors = True + self.callable = torch.compile( + self.graph.forward, fullgraph=False, dynamic=False, backend="eager" + ) + + return self.callable(*args) diff --git a/python/sglang/srt/model_executor/compilation/npu_compiler_backend.py b/python/sglang/srt/model_executor/compilation/npu_compiler_backend.py new file mode 100644 index 000000000000..032c4b291beb --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/npu_compiler_backend.py @@ -0,0 +1,308 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import dataclasses +import importlib +import logging +from typing import Any, Callable + +import torch +from torch._dynamo.eval_frame import DisableContext + +from sglang.srt.compilation.npu.compilation_context import CompilationContext +from sglang.srt.compilation.npu.config import CompilationConfig +from sglang.srt.compilation.npu.pass_manager import PassManager +from sglang.srt.compilation.npu.passes.w8a8_int8 import ( + DivFuse, + EraseCopy, + NpuAddRmsNormQuantFuse, +) +from sglang.srt.distributed import get_tensor_model_parallel_world_size + +logger = logging.getLogger(__name__) + + +class Submodule(torch.nn.Module): + block_tables = None + + def __init__(self, page_size, model_config): + self.page_size = page_size + self.config = model_config + + tp_size = get_tensor_model_parallel_world_size() + assert self.config.num_attention_heads % tp_size == 0 + self.num_heads = self.config.num_attention_heads // tp_size + + self.total_num_kv_heads = self.config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.hidden_size = self.config.hidden_size + self.head_dim = getattr( + self.config, "head_dim", self.hidden_size // self.config.num_attention_heads + ) + + self.scaling = self.head_dim**-0.5 + + def forward_with_calculation( + self, + l_args_2_req_to_token_pool_req_to_token, + l_args_2_req_pool_indices, + l_args_2_seq_lens, + query_2, + l_args_2_token_to_kv_pool_k_buffer_0_, + l_args_2_token_to_kv_pool_v_buffer_0_, + l_args_2_attn_backend_forward_metadata_block_tables, + l_args_2_attn_backend_forward_metadata_seq_lens_cpu_int, + output, + ): + Submodule.block_tables = ( + l_args_2_req_to_token_pool_req_to_token[ + l_args_2_req_pool_indices, : l_args_2_seq_lens.max() + ][:, :: self.page_size] + // self.page_size + ) + _npu_paged_attention = torch.ops.atb._npu_paged_attention( + query=query_2, + key_cache=l_args_2_token_to_kv_pool_k_buffer_0_, + value_cache=l_args_2_token_to_kv_pool_v_buffer_0_, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + scale_value=self.scaling, + block_table=Submodule.block_tables, + context_lens=l_args_2_attn_backend_forward_metadata_seq_lens_cpu_int, + out=output, + ) + + def forward( + self, + l_args_2_req_to_token_pool_req_to_token, + l_args_2_req_pool_indices, + l_args_2_seq_lens, + query_2, + l_args_2_token_to_kv_pool_k_buffer_0_, + l_args_2_token_to_kv_pool_v_buffer_0_, + l_args_2_attn_backend_forward_metadata_block_tables, + l_args_2_attn_backend_forward_metadata_seq_lens_cpu_int, + output, + ): + _npu_paged_attention = torch.ops.atb._npu_paged_attention( + query=query_2, + key_cache=l_args_2_token_to_kv_pool_k_buffer_0_, + value_cache=l_args_2_token_to_kv_pool_v_buffer_0_, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + scale_value=self.scaling, + block_table=Submodule.block_tables, + context_lens=l_args_2_attn_backend_forward_metadata_seq_lens_cpu_int, + out=output, + ) + + +def resolve_obj_by_qualname(qualname: str) -> Any: + module_name, obj_name = qualname.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, obj_name) + + +@dataclasses.dataclass +class SplitItem: + submod_name: str + graph_id: int + is_compiled_only: bool + graph: torch.fx.GraphModule + + +class NpuAddRmsNormFuse: + def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset, v1, v2, v3): + output = torch.ops.npu.npu_add_rms_norm( + rms_norm_input, residual, rms_norm_weight, 1e-6 + ) + out0 = output[0] + out2 = output[2] + quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, v1, v2, v3) + return quantized_output, out2 + + def replacement( + rms_norm_input, residual, rms_norm_weight, scale, offset, v1, v2, v3 + ): + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, residual, rms_norm_weight, 1.0 / scale, offset, epsilon=1e-6 + ) + quantized_output = output[0] + out2 = output[2] + return quantized_output, out2 + + +class NpuBackend: + graph: torch.fx.GraphModule + + def __init__( + self, + model_runner, + compilation_config: CompilationConfig, + compilation_context: CompilationContext, + page_size: int, + ): + self.model_runner = model_runner + self.model_config = model_runner.model.config + + self.compilation_config = compilation_config + self.page_size = page_size + self.compilation_context = compilation_context + + self.split_gm = None + + self.piecewise_graphs = None + self.submod_names_to_compile = None + + self.callables = {} + self.callables_by_branch = {} + + def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: + example_inputs_len = len(example_inputs) + if example_inputs_len in self.callables: + callable = self.callables[example_inputs_len] + return callable + + DisableContext.compiled_function_args[DisableContext.batch_size] = ( + example_inputs + ) + + self.graph = graph + NpuBackend.apply_passes(self.graph) + self.split_gm, self.piecewise_graphs = NpuBackend.split_graph( + self.graph, self.compilation_config.splitting_ops + ) + + npu_graph_backend = resolve_obj_by_qualname( + "sglang.srt.model_executor.compilation.npu_graph_backend.NPUGraphBackend" + ) + + self.submod_names_compiled_only = [ + item.submod_name for item in self.piecewise_graphs if item.is_compiled_only + ] + + named_modules = self.split_gm.named_modules() + submod = Submodule(self.page_size, self.model_config) + use_forward = False + for name, graph_module in named_modules: + if not name: + continue + + graph = getattr(self.split_gm, name) + if name in self.submod_names_compiled_only: + if use_forward: + self.split_gm.__dict__[name] = submod.forward + else: + self.split_gm.__dict__[name] = submod.forward_with_calculation + use_forward = True + else: + self.split_gm.__dict__[name] = npu_graph_backend( + self.model_runner, graph, self.compilation_context + ) + + self.split_gm(*example_inputs) + self.callables[example_inputs_len] = self.split_gm.forward + return self.split_gm.forward + + def apply_passes(graph_module: torch.fx.GraphModule): + passManager = PassManager(graph_module) + passManager.add(NpuAddRmsNormQuantFuse) + passManager.add(DivFuse) + passManager.add(EraseCopy) + passManager.apply() + graph_module.recompile() + + def split_graph( + graph: torch.fx.GraphModule, ops: list[str] + ) -> tuple[torch.fx.GraphModule, list[SplitItem]]: + subgraph_id = 0 + node_to_subgraph_id = {} + graphs_for_compilation = [] + + node_index = 0 + node_index_max = len(graph.graph.nodes) + + nodes = list(graph.graph.nodes) + + counter = 1 + ops_count = 3 + ops_step = ops_count + 1 + while node_index < node_index_max: + if ( + (node_index + ops_count) < node_index_max + and nodes[node_index + ops_count].op == "call_function" + and str(nodes[node_index + ops_count].target) in ops + ): + subgraph_id += 1 + graphs_for_compilation.append(subgraph_id) + + for submodule_node_index in range(node_index, node_index + ops_step): + submodule_node = nodes[submodule_node_index] + node_to_subgraph_id[submodule_node] = subgraph_id + counter = counter + 1 + node_index += ops_step + + subgraph_id += 1 + else: + node = nodes[node_index] + if node.op in ("output", "placeholder"): + node_index += 1 + elif node.op == "call_function" and str(node.target) in ops: + subgraph_id += 1 + graphs_for_compilation.append(subgraph_id) + + node_to_subgraph_id[node] = subgraph_id + node_index += 1 + + subgraph_id += 1 + else: + node_to_subgraph_id[node] = subgraph_id + node_index += 1 + counter += 1 + + split_gm = torch.fx.passes.split_module.split_module( + graph, + None, + lambda node: node_to_subgraph_id[node], + keep_original_order=True, + ) + + names = [name for (name, module) in split_gm.named_modules()] + + outputs = [] + for name in names: + if "." in name or name == "": + # recursive child module or the root module + continue + + module = getattr(split_gm, name) + + graph_id = int(name.replace("submod_", "")) + outputs.append( + SplitItem(name, graph_id, (graph_id in graphs_for_compilation), module) + ) + + # sort by intetger graph_id, rather than string name + outputs.sort(key=lambda x: x.graph_id) + + return split_gm, outputs diff --git a/python/sglang/srt/model_executor/compilation/npu_graph_backend.py b/python/sglang/srt/model_executor/compilation/npu_graph_backend.py new file mode 100644 index 000000000000..0481b0f6b163 --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/npu_graph_backend.py @@ -0,0 +1,62 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Any + +import torch +import torch_npu + +from sglang.srt.compilation.npu.compilation_context import CompilationContext + + +class NPUGraphBackend: + def __init__( + self, + model_runner, + graph: torch.fx.GraphModule, + compilation_context: CompilationContext, + ): + self.model_runner = model_runner + self.graph = graph + self.compilation_context = compilation_context + + self.captured = False + self.output = None + self.npu_graph = None + + def __call__(self, *args) -> Any: + if not self.captured: + if not self.compilation_context.stream: + self.compilation_context.stream = torch_npu.npu.Stream() + + torch.cuda.synchronize() + + self.npu_graph = torch_npu.npu.NPUGraph() + with torch.npu.graph( + self.npu_graph, + stream=self.compilation_context.stream, + pool=self.compilation_context.graph_memory_pool, + ): + + self.output = self.graph.forward(*args) + + if not self.compilation_context.graph_memory_pool: + self.compilation_context.graph_memory_pool = self.npu_graph.pool() + + self.npu_graph.replay() + self.captured = True + else: + self.npu_graph.replay() + + return self.output diff --git a/python/sglang/srt/model_executor/compilation/npu_graph_compiler.py b/python/sglang/srt/model_executor/compilation/npu_graph_compiler.py new file mode 100644 index 000000000000..477ec4fd3518 --- /dev/null +++ b/python/sglang/srt/model_executor/compilation/npu_graph_compiler.py @@ -0,0 +1,45 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import pathlib +import sys + +import torch + +from sglang.srt.compilation.npu.compilation_context import CompilationContext +from sglang.srt.compilation.npu.config import CompilationConfig +from sglang.srt.model_executor.compilation.npu_compiler_backend import NpuBackend + + +class NpuGraphCompiler: + def __init__( + self, + model_runner, + model: torch.nn.Module, + compilation_config: CompilationConfig, + compilation_context: CompilationContext, + page_size: int, + ): + self.backend = NpuBackend( + model_runner, compilation_config, compilation_context, page_size + ) + self.model = model + + torch._dynamo.reset() + torch.compiler.allow_in_graph(sys.intern) + torch.compiler.allow_in_graph(pathlib.Path) + + self.compiled_callable = torch.compile( + self.model, fullgraph=True, dynamic=False, backend=self.backend + ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 963e2cd814c6..9252abb7816b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -115,6 +115,9 @@ from sglang.srt.model_executor.piecewise_cuda_graph_runner import ( PiecewiseCudaGraphRunner, ) +from sglang.srt.model_executor.piecewise_npu_graph_runner_decode import ( + PiecewiseNPUGraphRunnerDecode, +) from sglang.srt.model_loader import get_model from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( @@ -477,6 +480,7 @@ def initialize(self, min_per_gpu_memory: float): self.init_cublas() self.init_attention_backend() self.init_device_graphs() + elif self.device in ["npu", "cpu"]: self.init_attention_backend() self.init_device_graphs() @@ -1989,7 +1993,11 @@ def init_device_graphs(self): lambda: CudaGraphRunner, { "cpu": CPUGraphRunner, - "npu": NPUGraphRunner, + "npu": ( + PiecewiseNPUGraphRunnerDecode + if self.server_args.enable_piecewise_npu_graph_decode + else NPUGraphRunner + ), }, ) self.graph_runner = graph_runners[self.device](self) diff --git a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py new file mode 100644 index 000000000000..6bb0d94604f0 --- /dev/null +++ b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py @@ -0,0 +1,775 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Run the model with npu graph and torch.compile.""" + +from __future__ import annotations + +import bisect +import gc +import os +from contextlib import contextmanager +from typing import TYPE_CHECKING, Callable, Optional, Union + +import torch +import torch._dynamo.config +import tqdm + +from sglang.srt.compilation.npu.compilation_context import CompilationContext +from sglang.srt.compilation.npu.config import CompilationConfig +from sglang.srt.compilation.npu.patch_dynamo import ( + patch_dynamo_context, + patch_dynamo_context_call, + restore_dynamo_context_call, +) +from sglang.srt.custom_op import CustomOp +from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture +from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.model_executor.compilation.npu_graph_compiler import NpuGraphCompiler +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, + PPProxyTensors, + enable_num_token_non_padded, +) +from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin +from sglang.srt.utils import ( + get_available_gpu_memory, + get_device_memory_capacity, + rank0_log, +) + +torch._dynamo.config.skip_nnmodule_hook_guards = True +torch._dynamo.config.automatic_dynamic_shapes = False +torch._dynamo.config.guard_nn_modules = False + +import logging + +from torch._dynamo.eval_frame import DisableContext + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + +from sglang.srt.model_executor.cuda_graph_runner import model_capture_mode + +torch.cuda.CUDAGraph = torch.npu.NPUGraph +torch.cuda.synchronize = torch.npu.synchronize +torch.cuda.graph = torch.npu.graph +torch.cuda.stream = torch.npu.stream +torch.cuda.Stream = torch.npu.Stream +torch.cuda.current_stream = torch.npu.current_stream +torch.cuda.graph_pool_handle = torch.npu.graph_pool_handle + + +class CompiledGraph: + def __init__( + self, + bs: int, + forward_batch: ForwardBatch, + attn_backend: AscendAttnBackend, + callable, + ): + self.bs = bs + self.forward_batch = forward_batch + # TODO: debug only + self.attn_backend = attn_backend + self.callable = callable + + +def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): + for sub in model._modules.values(): + if isinstance(sub, CustomOp): + if reverse: + sub.leave_torch_compile() + else: + sub.enter_torch_compile(num_tokens=num_tokens) + if isinstance(sub, torch.nn.Module): + _to_torch(sub, reverse, num_tokens) + + +@contextmanager +def patch_model( + model: torch.nn.Module, + enable_compile: bool, + num_tokens: int, + tp_group: GroupCoordinator, +): + """Patch the model to make it compatible with with torch.compile""" + backup_ca_comm = None + + try: + if enable_compile: + _to_torch(model, reverse=False, num_tokens=num_tokens) + backup_ca_comm = tp_group.ca_comm + # Use custom-allreduce here. + # We found the custom allreduce is much faster than the built-in allreduce in torch, + # even with ENABLE_INTRA_NODE_COMM=1. + # tp_group.ca_comm = None + yield torch.compile( + torch.no_grad()(model.forward), + mode=os.environ.get( + "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-npugraphs" + ), + dynamic=False, + ) + else: + yield model.forward + finally: + if enable_compile: + _to_torch(model, reverse=True, num_tokens=num_tokens) + tp_group.ca_comm = backup_ca_comm + + +def set_torch_compile_config(): + import torch._dynamo.config + import torch._inductor.config + + torch._inductor.config.fx_graph_cache = False + + from packaging import version + + if version.parse(torch.__version__) < version.parse("2.8.0"): + # These things are cacheable by torch.compile. torch.compile just doesn't know it. + # This was fixed in PyTorch 2.8, but until then, we monkey patch. + import torch._higher_order_ops.auto_functionalize as af + + af.auto_functionalized_v2._cacheable = False + af.auto_functionalized._cacheable = False + + torch._dynamo.config.accumulated_cache_size_limit = 1024 + if hasattr(torch._dynamo.config, "cache_size_limit"): + torch._dynamo.config.cache_size_limit = 1024 + + +def get_batch_sizes_to_capture(model_runner: ModelRunner): + server_args = model_runner.server_args + capture_bs = server_args.cuda_graph_bs + + if capture_bs is None: + if server_args.speculative_algorithm is None: + if server_args.disable_cuda_graph_padding: + capture_bs = list(range(1, 33)) + list(range(48, 161, 16)) + else: + capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8)) + else: + # Since speculative decoding requires more npu graph memory, we + # capture less. + capture_bs = ( + list(range(1, 9)) + + list(range(10, 33, 2)) + + list(range(40, 64, 8)) + + list(range(80, 161, 16)) + ) + + gpu_mem = get_device_memory_capacity() + if gpu_mem is not None and gpu_mem > 96 * 1024: + capture_bs += list(range(160, 257, 8)) + if gpu_mem is not None and gpu_mem > 180 * 1000: + capture_bs += list(range(256, 513, 16)) + + if max(capture_bs) > model_runner.req_to_token_pool.size: + # In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests + # is very small. We add more values here to make sure we capture the maximum bs. + capture_bs += [model_runner.req_to_token_pool.size] + + if server_args.enable_two_batch_overlap: + capture_bs = [bs for bs in capture_bs if bs >= 2] + + if server_args.cuda_graph_max_bs: + capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs] + if max(capture_bs) < server_args.cuda_graph_max_bs: + capture_bs += list( + range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16) + ) + capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size] + capture_bs = list(sorted(set(capture_bs))) + assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}" + compile_bs = ( + [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs] + if server_args.enable_torch_compile + else [] + ) + + return capture_bs, compile_bs + +class PiecewiseNPUGraphRunnerDecode: + """A PiecewiseNPUGraphRunnerDecode runs the forward pass of a model with npu graph and torch.compile.""" + + def __init__(self, model_runner: ModelRunner): + model_runner.attn_backend.enable_piecewise_npu_graph_decode = True + + patch_dynamo_context() + + self.inference_counter = 1 + self.init_forward_metadata_was_done = True + + # Parse args + self.model_runner = model_runner + self.compilation_config = CompilationConfig() + self.compilation_config.splitting_ops = ["atb._npu_paged_attention"] + self.compilation_context = CompilationContext() + + # self.compilation_context = model_runner.server_args.compilation_config + + self.graphs = {} + self.output_buffers = {} + self.enable_torch_compile = model_runner.server_args.enable_torch_compile + self.disable_padding = model_runner.server_args.disable_cuda_graph_padding + self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder + self.enable_dp_attention = model_runner.server_args.enable_dp_attention + # self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm + self.enable_two_batch_overlap = ( + model_runner.server_args.enable_two_batch_overlap + ) + self.speculative_algorithm = model_runner.server_args.speculative_algorithm + self.tp_size = model_runner.server_args.tp_size + self.dp_size = model_runner.server_args.dp_size + self.pp_size = model_runner.server_args.pp_size + + # Batch sizes to capture + self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) + rank0_log(f"Capture npu graph bs {self.capture_bs}") + self.capture_forward_mode: int = ForwardMode.DECODE + self.capture_hidden_mode: int = CaptureHiddenMode.NULL + self.num_tokens_per_bs = 1 + if model_runner.spec_algorithm.is_eagle(): + if self.model_runner.is_draft_worker: + raise RuntimeError("This should not happen") + else: + self.capture_forward_mode = ForwardMode.TARGET_VERIFY + self.num_tokens_per_bs = ( + self.model_runner.server_args.speculative_num_draft_tokens + ) + + # Attention backend + self.max_bs = max(self.capture_bs) + self.max_num_token = self.max_bs * self.num_tokens_per_bs + self.model_runner.attn_backend.init_cuda_graph_state( + self.max_bs, self.max_num_token + ) + self.seq_len_fill_value = ( + self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() + ) + # FIXME(lsyin): leave it here for now, I don't know whether it is necessary + self.encoder_len_fill_value = 0 + self.seq_lens_cpu = torch.full( + (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 + ) + + set_torch_compile_config() + + # if self.model_runner.server_args.lora_paths is not None: + # self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs) + + # Graph inputs + with torch.device(self.model_runner.device): + self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) + self.seq_lens = torch.full( + (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 + ) + self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32) + self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) + self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32) + self.tbo_plugin = TboCudaGraphRunnerPlugin() + + self.block_tables = torch.full((160, 160), 0, dtype=torch.int32) + + # pipeline parallelism + if self.pp_size > 1: + self.pp_proxy_tensors = { + "hidden_states": torch.zeros( + (self.max_bs, self.model_runner.model_config.hidden_size), + dtype=torch.bfloat16, + ), + "residual": torch.zeros( + (self.max_bs, self.model_runner.model_config.hidden_size), + dtype=torch.bfloat16, + ), + } + + # Speculative_inference + if model_runner.spec_algorithm.is_eagle3(): + self.model_runner.model.set_eagle3_layers_to_capture() + + if self.is_encoder_decoder: + # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch + self.encoder_lens = torch.full( + (self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32 + ) + else: + self.encoder_lens = None + + if self.enable_dp_attention: # or self.enable_sp_layernorm: + # TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer + self.gathered_buffer = torch.zeros( + ( + self.max_bs * self.dp_size * self.num_tokens_per_bs, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) + self.global_num_tokens_gpu = torch.zeros( + (self.dp_size,), dtype=torch.int32 + ) + + try: + with model_capture_mode(): + self.capture() + except RuntimeError as e: + raise Exception( + f"Graph compilation failed: {e}\n{NPU_GRAPH_CAPTURE_FAILED_MSG}" + ) + + def can_run(self, forward_batch: ForwardBatch): + if self.enable_dp_attention: # or self.enable_sp_layernorm: + total_global_tokens = sum(forward_batch.global_num_tokens_cpu) + + is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( + total_global_tokens in self.graphs + if self.disable_padding + else total_global_tokens <= self.max_bs + ) + else: + is_bs_supported = ( + forward_batch.batch_size in self.graphs + if self.disable_padding + else forward_batch.batch_size <= self.max_bs + ) + + # NOTE: npu graph cannot handle mixed batch (encoder_len = 0) + # If mixed batch cannot be supported, then encoder_lens can be removed in npu graph + # because the full_text_row_masked_out_mask tensor will always be ones + is_encoder_lens_supported = ( + torch.all(forward_batch.encoder_lens > 0) + if self.is_encoder_decoder + else True + ) + + is_tbo_supported = ( + forward_batch.can_run_tbo if self.enable_two_batch_overlap else True + ) + + can_run_value = ( + is_bs_supported and is_encoder_lens_supported and is_tbo_supported + ) + return can_run_value + + def capture(self, forward_batch_: ForwardBatch = None, bs_: int = None): + with graph_capture() as graph_capture_context: + self.stream = graph_capture_context.stream + + self.model_runner.tp_group.barrier() + + avail_mem = get_available_gpu_memory( + self.model_runner.device, self.model_runner.gpu_id, empty_cache=False + ) + + # Reverse the order to enable better memory sharing across cuda graphs. + capture_range = ( + tqdm.tqdm(list(reversed(self.capture_bs))) + if get_tensor_model_parallel_rank() == 0 + else reversed(self.capture_bs) + ) + + for bs in capture_range: + if get_tensor_model_parallel_rank() == 0: + avail_mem = get_available_gpu_memory( + self.model_runner.device, + self.model_runner.gpu_id, + empty_cache=False, + ) + capture_range.set_description( + f"Capturing batches ({avail_mem=:.2f} GB)" + ) + + (compiled_graph, output_buffers) = self.capture_one_batch_size( + bs, self.model_runner.model.forward, forward_batch_=forward_batch_ + ) + self.graphs[bs] = compiled_graph + self.output_buffers[bs] = output_buffers + + def init_forward_metadata_attn_backend( + self, bs: int, attn_backend: AscendAttnBackend, forward_batch: ForwardBatch + ): + attn_backend.forward_metadata.block_tables = self.block_tables + + seq_lens_cpu_int = forward_batch.seq_lens_cpu_int + seq_lens_cpu_int[ + : attn_backend.forward_metadata.seq_lens_cpu_int.shape[0] + ].copy_(attn_backend.forward_metadata.seq_lens_cpu_int) + attn_backend.forward_metadata.seq_lens_cpu_int = seq_lens_cpu_int + + def init_forward_batch( + self, bs: int, attn_backend: AscendAttnBackend, forward_batch_: ForwardBatch + ) -> ForwardBatch: + if forward_batch_: + return forward_batch_ + + num_tokens = bs * self.num_tokens_per_bs + + with torch.device(self.model_runner.device): + req_pool_indices = torch.zeros((bs,), dtype=torch.int32) + seq_lens = torch.full((bs,), self.seq_len_fill_value, dtype=torch.int32) + out_cache_loc = torch.zeros((bs,), dtype=torch.int32) + positions = torch.zeros((bs,), dtype=torch.int64) + input_ids = torch.zeros((bs,), dtype=torch.int64) + mrope_positions = torch.zeros((3, self.max_num_token), dtype=torch.int64) + + assert self.is_encoder_decoder == False + encoder_lens = None + num_token_non_padded = None + + # pipeline parallelism + assert self.pp_size <= 1 + + assert self.enable_dp_attention == False + # assert self.enable_sp_layernorm == False + global_num_tokens = None + gathered_buffer = None + + spec_info = self.get_spec_info(num_tokens) + if self.capture_hidden_mode != CaptureHiddenMode.FULL: + self.capture_hidden_mode = ( + spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL + ) + + # assert self.model_runner.server_args.lora_paths is None + # lora_paths = None + + forward_batch = ForwardBatch( + forward_mode=self.capture_forward_mode, + batch_size=bs, + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + attn_backend=attn_backend, + out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens.sum(), + encoder_lens=encoder_lens, + return_logprob=False, + positions=positions, + global_num_tokens_gpu=global_num_tokens, + # gathered_buffer=gathered_buffer, + mrope_positions=mrope_positions, + spec_algorithm=self.model_runner.spec_algorithm, + spec_info=spec_info, + capture_hidden_mode=self.capture_hidden_mode, + num_token_non_padded=self.num_token_non_padded, + global_forward_mode=self.capture_forward_mode, + # lora_paths=lora_paths, + ) + + seq_lens_cpu_int = torch.zeros((bs,), dtype=torch.int32, device="cpu") + forward_batch.seq_lens_cpu_int = seq_lens_cpu_int + + seq_lens_cpu = torch.full((bs,), 1, dtype=torch.int32, device="cpu") + forward_batch.seq_lens_cpu = seq_lens_cpu + + # TODO: don't use loop here + for i in range(bs): + forward_batch.global_forward_mode = None + forward_batch.input_ids[i] = 323 + forward_batch.num_token_non_padded = None + forward_batch.out_cache_loc[i] = 134 + forward_batch.positions[i] = 6 + forward_batch.seq_lens[i] = 7 + forward_batch.seq_lens_cpu[i] = 7 + forward_batch.seq_lens_cpu_int[i] = 7 + forward_batch.req_pool_indices[i] = 1 + forward_batch.seq_lens_sum = sum(forward_batch.seq_lens) + + if self.enable_dp_attention: # or self.enable_sp_layernorm: + assert False + assert self.pp_size <= 1 + assert self.enable_dp_attention == False + # assert self.enable_sp_layernorm == False + assert enable_num_token_non_padded(self.model_runner.server_args) == False + assert self.enable_two_batch_overlap == False + + attn_backend.init_forward_metadata(forward_batch) + + self.init_forward_metadata_attn_backend(bs, attn_backend, forward_batch) + + # Clean intermediate result cache for DP attention + forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None + return forward_batch + + def capture_one_batch_size( + self, + bs: int, + forward: Callable, + forward_batch_: ForwardBatch = None, + compile: bool = True, + ): + attn_backend = self.model_runner.attn_backend + # TODO: absent in CUDAGraphRunner + attn_backend.init_cuda_graph_state(bs, self.max_num_token) + + self.model_runner.attn_backend = attn_backend + + for _ in range(2): + forward_batch = self.init_forward_batch(bs, attn_backend, forward_batch_) + + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + + self.model_runner.attn_backend.graph_mode = True + self.model_runner.model( + forward_batch.input_ids, forward_batch.positions, forward_batch + ) + + forward_batch = self.init_forward_batch(bs, attn_backend, forward_batch_) + + self.compilation_context.stream = self.stream + self.model_runner.attn_backend.graph_mode = True + + compiler = NpuGraphCompiler( + self.model_runner, + self.model_runner.model, + self.compilation_config, + self.compilation_context, + self.model_runner.page_size, + ) + + patch_dynamo_context_call() + DisableContext.batch_size = bs + + logits_output_or_pp_proxy_tensors = compiler.compiled_callable( + forward_batch.input_ids, forward_batch.positions, forward_batch + ) + + try: + logits_output_or_pp_proxy_tensors = compiler.compiled_callable( + forward_batch.input_ids, forward_batch.positions, forward_batch + ) + finally: + DisableContext.batch_size = None + restore_dynamo_context_call() + + assert DisableContext.compiled_function + assert DisableContext.compiled_function_args + + compiled_graph = CompiledGraph( + bs, forward_batch, None, compiler.compiled_callable + ) + + torch._dynamo.reset() + gc.collect() + + return (compiled_graph, logits_output_or_pp_proxy_tensors) + + def recapture_if_needed(self, forward_batch: ForwardBatch): + assert False + + # If the capture_hidden_mode changes, we need to recapture the graph + hidden_mode_from_spec_info = getattr( + forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL + ) + if ( + forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL + and self.capture_hidden_mode != CaptureHiddenMode.FULL + ): + self.capture_hidden_mode = CaptureHiddenMode.FULL + self.capture() + elif ( + forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL + and self.capture_hidden_mode != hidden_mode_from_spec_info + ): + self.capture_hidden_mode = hidden_mode_from_spec_info + self.capture() + + def replay_prepare( + self, + forward_batch: ForwardBatch, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ): + raw_bs = forward_batch.batch_size + raw_num_token = raw_bs * self.num_tokens_per_bs + + # Pad + if self.enable_dp_attention: # or self.enable_sp_layernorm: + index = bisect.bisect_left( + self.capture_bs, sum(forward_batch.global_num_tokens_cpu) + ) + else: + index = bisect.bisect_left(self.capture_bs, raw_bs) + + bs = self.capture_bs[index] + compiled_graph = self.graphs[bs] + + compiled_graph.forward_batch.input_ids[ + : forward_batch.input_ids.shape[0] + ].copy_(forward_batch.input_ids) + forward_batch.input_ids = compiled_graph.forward_batch.input_ids + + compiled_graph.forward_batch.seq_lens[: forward_batch.seq_lens.shape[0]].copy_( + forward_batch.seq_lens + ) + forward_batch.seq_lens = compiled_graph.forward_batch.seq_lens + + compiled_graph.forward_batch.req_pool_indices[ + : forward_batch.req_pool_indices.shape[0] + ].copy_(forward_batch.req_pool_indices) + forward_batch.req_pool_indices = compiled_graph.forward_batch.req_pool_indices + + compiled_graph.forward_batch.out_cache_loc[ + : forward_batch.out_cache_loc.shape[0] + ].copy_(forward_batch.out_cache_loc) + forward_batch.out_cache_loc = compiled_graph.forward_batch.out_cache_loc + + compiled_graph.forward_batch.positions[ + : forward_batch.positions.shape[0] + ].copy_(forward_batch.positions) + forward_batch.positions = compiled_graph.forward_batch.positions + + if forward_batch.seq_lens_cpu is not None: + compiled_graph.forward_batch.seq_lens_cpu[ + : forward_batch.seq_lens_cpu.shape[0] + ].copy_(forward_batch.seq_lens_cpu) + forward_batch.seq_lens_cpu = compiled_graph.forward_batch.seq_lens_cpu + + if pp_proxy_tensors: + for key in self.pp_proxy_tensors.keys(): + dim = pp_proxy_tensors[key].shape[0] + self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key]) + + if self.is_encoder_decoder: + assert False + + if forward_batch.mrope_positions is not None: + compiled_graph.forward_batch.mrope_positions[:, :raw_num_token].copy_( + forward_batch.mrope_positions + ) + + if self.enable_dp_attention: # or self.enable_sp_layernorm: + assert False + + if enable_num_token_non_padded(self.model_runner.server_args): + assert False + + if self.enable_two_batch_overlap: + assert False + + # Store fields + self.raw_bs = raw_bs + self.raw_num_token = raw_num_token + self.bs = bs + + def replay( + self, + forward_batch: ForwardBatch, + skip_attn_backend_init: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + self.replay_prepare(forward_batch, pp_proxy_tensors) + compiled_graph = self.graphs[self.bs] + + def init(): + attn_backend = self.model_runner.attn_backend + forward_batch.attn_backend = attn_backend + + compiled_graph: CompiledGraph = self.graphs[self.bs] + + attn_backend = self.model_runner.attn_backend + if not self.init_forward_metadata_was_done: + attn_backend.init_forward_metadata(forward_batch) + self.init_forward_metadata_was_done = True + else: + if forward_batch.extend_seq_lens is not None: + attn_backend.forward_metadata.extend_seq_lens_cpu_int = ( + forward_batch.extend_seq_lens.cpu().int() + ) + attn_backend.forward_metadata.seq_lens_cpu_int = ( + forward_batch.seq_lens_cpu.int() + ) + + self.init_forward_metadata_attn_backend( + self.bs, attn_backend, compiled_graph.forward_batch + ) + + init() + + self.model_runner.attn_backend.graph_mode = True + + DisableContext.compiled_function[self.bs]( + *DisableContext.compiled_function_args[self.bs] + ) + + output = self.output_buffers[self.bs] + + if isinstance(output, LogitsProcessorOutput): + result = LogitsProcessorOutput( + next_token_logits=output.next_token_logits[: self.raw_num_token], + hidden_states=( + output.hidden_states[: self.raw_num_token] + if output.hidden_states is not None + else None + ), + ) + else: + assert isinstance(output, PPProxyTensors) + result = PPProxyTensors( + {k: v[: self.bs] for k, v in output.tensors.items()} + ) + + return result + + def get_spec_info(self, num_tokens: int): + spec_info = None + if self.model_runner.spec_algorithm.is_eagle(): + from sglang.srt.speculative.eagle_utils import EagleVerifyInput + + if self.model_runner.is_draft_worker: + raise RuntimeError("This should not happen.") + else: + spec_info = EagleVerifyInput( + draft_token=None, + custom_mask=torch.ones( + (num_tokens * self.model_runner.model_config.context_len), + dtype=torch.bool, + device=self.model_runner.device, + ), + positions=None, + retrive_index=None, + retrive_next_token=None, + retrive_next_sibling=None, + retrive_cum_len=None, + spec_steps=self.model_runner.server_args.speculative_num_steps, + topk=self.model_runner.server_args.speculative_eagle_topk, + draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, + capture_hidden_mode=CaptureHiddenMode.FULL, + seq_lens_sum=None, + seq_lens_cpu=None, + ) + + return spec_info + + +NPU_GRAPH_CAPTURE_FAILED_MSG = ( + "Possible solutions:\n" + "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" + "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n" + "3. disable torch compile by not using --enable-torch-compile\n" + "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n" + "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" +) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 774df497cc49..1fb01e069b03 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -32,7 +32,14 @@ from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, is_cuda, is_npu, supports_custom_op -if is_npu() and supports_custom_op() and get_global_server_args().enable_torch_compile: +if ( + is_npu() + and supports_custom_op() + and ( + get_global_server_args().enable_torch_compile + or get_global_server_args().enable_piecewise_npu_graph_decode + ) +): from sglang.srt._custom_ops import get_cmo_stream, wait_cmo_stream else: from sglang.srt.utils import get_cmo_stream, wait_cmo_stream diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b5e4fd6563fa..3c0e2074444e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -26,6 +26,7 @@ import orjson +from sglang.srt.compilation.npu.config import CompilationConfig from sglang.srt.connector import ConnectorType from sglang.srt.environ import ToolStrictLevel, envs from sglang.srt.function_call.function_call_parser import FunctionCallParser @@ -466,6 +467,7 @@ class ServerArgs: cuda_graph_bs: Optional[List[int]] = None disable_cuda_graph: bool = False disable_cuda_graph_padding: bool = False + enable_piecewise_npu_graph_decode: bool = False enable_profile_cuda_graph: bool = False enable_cudagraph_gc: bool = False enable_nccl_nvls: bool = False @@ -539,6 +541,8 @@ class ServerArgs: # FIXME: hack to reduce ITL when decode bs is small disaggregation_decode_polling_interval: int = 1 + compilation_config: Optional[CompilationConfig] = None + # For model weight update and weight loading custom_weight_loader: Optional[List[str]] = None weight_loader_disable_mmap: bool = False @@ -642,6 +646,10 @@ def __post_init__(self): # Handle elastic expert parallelism. self._handle_elastic_ep() + if not self.compilation_config: + self.compilation_config = CompilationConfig() + self.compilation_config.splitting_ops = ["atb._npu_paged_attention"] + def _handle_deprecated_args(self): # handle deprecated tool call parsers deprecated_tool_call_parsers = {"qwen25": "qwen", "glm45": "glm"} @@ -1173,6 +1181,7 @@ def _handle_attention_backend_compatibility(self): "Cuda graph is disabled because of using torch native attention backend" ) self.disable_cuda_graph = True + self.enable_piecewise_npu_graph_decode = False if self.attention_backend == "flex_attention": logger.warning( @@ -1400,6 +1409,7 @@ def _handle_a2a_moe(self): if self.deepep_mode == "normal": logger.warning("Cuda graph is disabled because deepep_mode=`normal`") self.disable_cuda_graph = True + self.enable_piecewise_npu_graph_decode = False self.ep_size = self.tp_size logger.warning( f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." @@ -1673,6 +1683,8 @@ def _handle_disaggregation(self): self.validate_disagg_tp_size(self.tp_size, self.disaggregation_decode_tp) self.disable_cuda_graph = True logger.warning("Cuda graph is disabled for prefill server") + self.enable_piecewise_npu_graph_decode = False + logger.warning("Piecewise graph is disabled for prefill server") def _handle_tokenizer_batching(self): if self.enable_tokenizer_batch_encode and self.enable_dynamic_batch_tokenizer: @@ -2671,6 +2683,13 @@ def add_cli_args(parser: argparse.ArgumentParser): choices=NSA_CHOICES, ) + parser.add_argument( + "--compilation-config", + type=CompilationConfig.from_cli, + default=None, + help="Compilation config.", + ) + # Speculative decoding parser.add_argument( "--speculative-algorithm", @@ -3149,6 +3168,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Disable cuda graph.", ) + parser.add_argument( + "--enable-piecewise-npu-graph-decode", + action="store_true", + help="Optimize the model with piecewise npu graph for decode.", + ) parser.add_argument( "--disable-cuda-graph-padding", action="store_true", From 3d9516a57b7a45679597f8680451c9765fa9c87a Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 11 Nov 2025 18:13:39 +0800 Subject: [PATCH 09/71] rollback --- python/sglang/srt/layers/rotary_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index bb69d164b3f5..0935b1750c93 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1385,7 +1385,7 @@ def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None: ): self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - # @torch.compile(dynamic=True, backend=get_compiler_backend()) + @torch.compile(dynamic=True, backend=get_compiler_backend()) def _forward_native( self, positions: torch.Tensor, From 2c1b6fe1c32fdda33b2f94dd932a31e3c15c1fd9 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 11 Nov 2025 18:18:32 +0800 Subject: [PATCH 10/71] linter --- .../srt/model_executor/piecewise_npu_graph_runner_decode.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py index 6bb0d94604f0..edb02787bffb 100644 --- a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py +++ b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py @@ -208,6 +208,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): return capture_bs, compile_bs + class PiecewiseNPUGraphRunnerDecode: """A PiecewiseNPUGraphRunnerDecode runs the forward pass of a model with npu graph and torch.compile.""" From 55016b0e8fa6afbfb08938b27306721309a61127 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 11 Nov 2025 18:53:13 +0800 Subject: [PATCH 11/71] refactoring --- .../compilation/npu_compilation_backend.py | 36 ------------------- ...ler.py => piecewise_npu_graph_compiler.py} | 6 ++-- ...> piecewise_npu_graph_compiler_backend.py} | 29 +++++++-------- .../piecewise_npu_graph_runner_decode.py | 6 ++-- 4 files changed, 19 insertions(+), 58 deletions(-) delete mode 100644 python/sglang/srt/model_executor/compilation/npu_compilation_backend.py rename python/sglang/srt/model_executor/compilation/{npu_graph_compiler.py => piecewise_npu_graph_compiler.py} (87%) rename python/sglang/srt/model_executor/compilation/{npu_compiler_backend.py => piecewise_npu_graph_compiler_backend.py} (94%) diff --git a/python/sglang/srt/model_executor/compilation/npu_compilation_backend.py b/python/sglang/srt/model_executor/compilation/npu_compilation_backend.py deleted file mode 100644 index 508933c36bd7..000000000000 --- a/python/sglang/srt/model_executor/compilation/npu_compilation_backend.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2025 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from typing import Any - -import torch - -from sglang.srt.compilation.npu.compilation_context import CompilationContext - - -class NPUCompilationBackend: - def __init__( - self, graph: torch.fx.GraphModule, compilation_context: CompilationContext - ): - self.graph = graph - self.callable = None - - def __call__(self, *args) -> Any: - if not self.callable: - torch._dynamo.config.suppress_errors = True - self.callable = torch.compile( - self.graph.forward, fullgraph=False, dynamic=False, backend="eager" - ) - - return self.callable(*args) diff --git a/python/sglang/srt/model_executor/compilation/npu_graph_compiler.py b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py similarity index 87% rename from python/sglang/srt/model_executor/compilation/npu_graph_compiler.py rename to python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py index 477ec4fd3518..a4c354f1d733 100644 --- a/python/sglang/srt/model_executor/compilation/npu_graph_compiler.py +++ b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py @@ -19,10 +19,10 @@ from sglang.srt.compilation.npu.compilation_context import CompilationContext from sglang.srt.compilation.npu.config import CompilationConfig -from sglang.srt.model_executor.compilation.npu_compiler_backend import NpuBackend +from sglang.srt.model_executor.compilation.piecewise_npu_graph_compiler_backend import PiecewiseNpuGraphCompilerBackend -class NpuGraphCompiler: +class PiecewiseNpuGraphCompiler: def __init__( self, model_runner, @@ -31,7 +31,7 @@ def __init__( compilation_context: CompilationContext, page_size: int, ): - self.backend = NpuBackend( + self.backend = PiecewiseNpuGraphCompilerBackend( model_runner, compilation_config, compilation_context, page_size ) self.model = model diff --git a/python/sglang/srt/model_executor/compilation/npu_compiler_backend.py b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py similarity index 94% rename from python/sglang/srt/model_executor/compilation/npu_compiler_backend.py rename to python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py index 032c4b291beb..512ac1065d2c 100644 --- a/python/sglang/srt/model_executor/compilation/npu_compiler_backend.py +++ b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py @@ -18,11 +18,11 @@ from typing import Any, Callable import torch -from torch._dynamo.eval_frame import DisableContext from sglang.srt.compilation.npu.compilation_context import CompilationContext from sglang.srt.compilation.npu.config import CompilationConfig from sglang.srt.compilation.npu.pass_manager import PassManager +from sglang.srt.compilation.npu.npu_compiler_backend import NpuBackend from sglang.srt.compilation.npu.passes.w8a8_int8 import ( DivFuse, EraseCopy, @@ -152,7 +152,7 @@ def replacement( return quantized_output, out2 -class NpuBackend: +class PiecewiseNpuGraphCompilerBackend(NpuBackend): graph: torch.fx.GraphModule def __init__( @@ -170,9 +170,7 @@ def __init__( self.compilation_context = compilation_context self.split_gm = None - self.piecewise_graphs = None - self.submod_names_to_compile = None self.callables = {} self.callables_by_branch = {} @@ -183,13 +181,10 @@ def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: callable = self.callables[example_inputs_len] return callable - DisableContext.compiled_function_args[DisableContext.batch_size] = ( - example_inputs - ) - + super().__call__(graph, example_inputs) + self.graph = graph - NpuBackend.apply_passes(self.graph) - self.split_gm, self.piecewise_graphs = NpuBackend.split_graph( + self.split_gm, self.piecewise_graphs = PiecewiseNpuGraphCompilerBackend.split_graph( self.graph, self.compilation_config.splitting_ops ) @@ -224,13 +219,13 @@ def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: self.callables[example_inputs_len] = self.split_gm.forward return self.split_gm.forward - def apply_passes(graph_module: torch.fx.GraphModule): - passManager = PassManager(graph_module) - passManager.add(NpuAddRmsNormQuantFuse) - passManager.add(DivFuse) - passManager.add(EraseCopy) - passManager.apply() - graph_module.recompile() + # def apply_passes(graph_module: torch.fx.GraphModule): + # passManager = PassManager(graph_module) + # passManager.add(NpuAddRmsNormQuantFuse) + # passManager.add(DivFuse) + # passManager.add(EraseCopy) + # passManager.apply() + # graph_module.recompile() def split_graph( graph: torch.fx.GraphModule, ops: list[str] diff --git a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py index edb02787bffb..5d56bb808484 100644 --- a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py +++ b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py @@ -37,7 +37,9 @@ from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.model_executor.compilation.npu_graph_compiler import NpuGraphCompiler +from sglang.srt.model_executor.compilation.piecewise_npu_graph_compiler import ( + PiecewiseNpuGraphCompiler, +) from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, @@ -544,7 +546,7 @@ def capture_one_batch_size( self.compilation_context.stream = self.stream self.model_runner.attn_backend.graph_mode = True - compiler = NpuGraphCompiler( + compiler = PiecewiseNpuGraphCompiler( self.model_runner, self.model_runner.model, self.compilation_config, From fbff08d8f32c5111aeabf7bc2846c809b8732eb3 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 11 Nov 2025 19:08:39 +0800 Subject: [PATCH 12/71] refactoring --- .../piecewise_npu_graph_compiler.py | 4 +- .../piecewise_npu_graph_compiler_backend.py | 22 +--- .../piecewise_npu_graph_runner_decode.py | 111 +----------------- 3 files changed, 11 insertions(+), 126 deletions(-) diff --git a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py index a4c354f1d733..a42917c7eafc 100644 --- a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py +++ b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py @@ -19,7 +19,9 @@ from sglang.srt.compilation.npu.compilation_context import CompilationContext from sglang.srt.compilation.npu.config import CompilationConfig -from sglang.srt.model_executor.compilation.piecewise_npu_graph_compiler_backend import PiecewiseNpuGraphCompilerBackend +from sglang.srt.model_executor.compilation.piecewise_npu_graph_compiler_backend import ( + PiecewiseNpuGraphCompilerBackend, +) class PiecewiseNpuGraphCompiler: diff --git a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py index 512ac1065d2c..8c2c479b73f4 100644 --- a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py +++ b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py @@ -21,13 +21,7 @@ from sglang.srt.compilation.npu.compilation_context import CompilationContext from sglang.srt.compilation.npu.config import CompilationConfig -from sglang.srt.compilation.npu.pass_manager import PassManager from sglang.srt.compilation.npu.npu_compiler_backend import NpuBackend -from sglang.srt.compilation.npu.passes.w8a8_int8 import ( - DivFuse, - EraseCopy, - NpuAddRmsNormQuantFuse, -) from sglang.srt.distributed import get_tensor_model_parallel_world_size logger = logging.getLogger(__name__) @@ -182,10 +176,12 @@ def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: return callable super().__call__(graph, example_inputs) - + self.graph = graph - self.split_gm, self.piecewise_graphs = PiecewiseNpuGraphCompilerBackend.split_graph( - self.graph, self.compilation_config.splitting_ops + self.split_gm, self.piecewise_graphs = ( + PiecewiseNpuGraphCompilerBackend.split_graph( + self.graph, self.compilation_config.splitting_ops + ) ) npu_graph_backend = resolve_obj_by_qualname( @@ -219,14 +215,6 @@ def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: self.callables[example_inputs_len] = self.split_gm.forward return self.split_gm.forward - # def apply_passes(graph_module: torch.fx.GraphModule): - # passManager = PassManager(graph_module) - # passManager.add(NpuAddRmsNormQuantFuse) - # passManager.add(DivFuse) - # passManager.add(EraseCopy) - # passManager.apply() - # graph_module.recompile() - def split_graph( graph: torch.fx.GraphModule, ops: list[str] ) -> tuple[torch.fx.GraphModule, list[SplitItem]]: diff --git a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py index 5d56bb808484..71f02cf9e7db 100644 --- a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py +++ b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py @@ -17,8 +17,6 @@ import bisect import gc -import os -from contextlib import contextmanager from typing import TYPE_CHECKING, Callable, Optional, Union import torch @@ -32,9 +30,8 @@ patch_dynamo_context_call, restore_dynamo_context_call, ) -from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import get_tensor_model_parallel_rank -from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture +from sglang.srt.distributed.parallel_state import graph_capture from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.model_executor.compilation.piecewise_npu_graph_compiler import ( @@ -89,76 +86,10 @@ def __init__( ): self.bs = bs self.forward_batch = forward_batch - # TODO: debug only self.attn_backend = attn_backend self.callable = callable -def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): - for sub in model._modules.values(): - if isinstance(sub, CustomOp): - if reverse: - sub.leave_torch_compile() - else: - sub.enter_torch_compile(num_tokens=num_tokens) - if isinstance(sub, torch.nn.Module): - _to_torch(sub, reverse, num_tokens) - - -@contextmanager -def patch_model( - model: torch.nn.Module, - enable_compile: bool, - num_tokens: int, - tp_group: GroupCoordinator, -): - """Patch the model to make it compatible with with torch.compile""" - backup_ca_comm = None - - try: - if enable_compile: - _to_torch(model, reverse=False, num_tokens=num_tokens) - backup_ca_comm = tp_group.ca_comm - # Use custom-allreduce here. - # We found the custom allreduce is much faster than the built-in allreduce in torch, - # even with ENABLE_INTRA_NODE_COMM=1. - # tp_group.ca_comm = None - yield torch.compile( - torch.no_grad()(model.forward), - mode=os.environ.get( - "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-npugraphs" - ), - dynamic=False, - ) - else: - yield model.forward - finally: - if enable_compile: - _to_torch(model, reverse=True, num_tokens=num_tokens) - tp_group.ca_comm = backup_ca_comm - - -def set_torch_compile_config(): - import torch._dynamo.config - import torch._inductor.config - - torch._inductor.config.fx_graph_cache = False - - from packaging import version - - if version.parse(torch.__version__) < version.parse("2.8.0"): - # These things are cacheable by torch.compile. torch.compile just doesn't know it. - # This was fixed in PyTorch 2.8, but until then, we monkey patch. - import torch._higher_order_ops.auto_functionalize as af - - af.auto_functionalized_v2._cacheable = False - af.auto_functionalized._cacheable = False - - torch._dynamo.config.accumulated_cache_size_limit = 1024 - if hasattr(torch._dynamo.config, "cache_size_limit"): - torch._dynamo.config.cache_size_limit = 1024 - - def get_batch_sizes_to_capture(model_runner: ModelRunner): server_args = model_runner.server_args capture_bs = server_args.cuda_graph_bs @@ -275,11 +206,6 @@ def __init__(self, model_runner: ModelRunner): (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) - set_torch_compile_config() - - # if self.model_runner.server_args.lora_paths is not None: - # self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs) - # Graph inputs with torch.device(self.model_runner.device): self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) @@ -440,11 +366,8 @@ def init_forward_batch( encoder_lens = None num_token_non_padded = None - # pipeline parallelism assert self.pp_size <= 1 - assert self.enable_dp_attention == False - # assert self.enable_sp_layernorm == False global_num_tokens = None gathered_buffer = None @@ -454,9 +377,6 @@ def init_forward_batch( spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL ) - # assert self.model_runner.server_args.lora_paths is None - # lora_paths = None - forward_batch = ForwardBatch( forward_mode=self.capture_forward_mode, batch_size=bs, @@ -472,14 +392,12 @@ def init_forward_batch( return_logprob=False, positions=positions, global_num_tokens_gpu=global_num_tokens, - # gathered_buffer=gathered_buffer, mrope_positions=mrope_positions, spec_algorithm=self.model_runner.spec_algorithm, spec_info=spec_info, capture_hidden_mode=self.capture_hidden_mode, num_token_non_padded=self.num_token_non_padded, global_forward_mode=self.capture_forward_mode, - # lora_paths=lora_paths, ) seq_lens_cpu_int = torch.zeros((bs,), dtype=torch.int32, device="cpu") @@ -488,7 +406,6 @@ def init_forward_batch( seq_lens_cpu = torch.full((bs,), 1, dtype=torch.int32, device="cpu") forward_batch.seq_lens_cpu = seq_lens_cpu - # TODO: don't use loop here for i in range(bs): forward_batch.global_forward_mode = None forward_batch.input_ids[i] = 323 @@ -505,7 +422,6 @@ def init_forward_batch( assert False assert self.pp_size <= 1 assert self.enable_dp_attention == False - # assert self.enable_sp_layernorm == False assert enable_num_token_non_padded(self.model_runner.server_args) == False assert self.enable_two_batch_overlap == False @@ -525,7 +441,6 @@ def capture_one_batch_size( compile: bool = True, ): attn_backend = self.model_runner.attn_backend - # TODO: absent in CUDAGraphRunner attn_backend.init_cuda_graph_state(bs, self.max_num_token) self.model_runner.attn_backend = attn_backend @@ -581,26 +496,6 @@ def capture_one_batch_size( return (compiled_graph, logits_output_or_pp_proxy_tensors) - def recapture_if_needed(self, forward_batch: ForwardBatch): - assert False - - # If the capture_hidden_mode changes, we need to recapture the graph - hidden_mode_from_spec_info = getattr( - forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL - ) - if ( - forward_batch.capture_hidden_mode == CaptureHiddenMode.FULL - and self.capture_hidden_mode != CaptureHiddenMode.FULL - ): - self.capture_hidden_mode = CaptureHiddenMode.FULL - self.capture() - elif ( - forward_batch.capture_hidden_mode != CaptureHiddenMode.FULL - and self.capture_hidden_mode != hidden_mode_from_spec_info - ): - self.capture_hidden_mode = hidden_mode_from_spec_info - self.capture() - def replay_prepare( self, forward_batch: ForwardBatch, @@ -610,7 +505,7 @@ def replay_prepare( raw_num_token = raw_bs * self.num_tokens_per_bs # Pad - if self.enable_dp_attention: # or self.enable_sp_layernorm: + if self.enable_dp_attention: index = bisect.bisect_left( self.capture_bs, sum(forward_batch.global_num_tokens_cpu) ) @@ -664,7 +559,7 @@ def replay_prepare( forward_batch.mrope_positions ) - if self.enable_dp_attention: # or self.enable_sp_layernorm: + if self.enable_dp_attention: assert False if enable_num_token_non_padded(self.model_runner.server_args): From 3e5db77d29b90986e96ceaa778c1322b98132e72 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 11 Nov 2025 19:18:22 +0800 Subject: [PATCH 13/71] Compilation: refactoring --- .../compilation/piecewise_npu_graph_compiler_backend.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py index 8c2c479b73f4..bfa12f1ef7f1 100644 --- a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py +++ b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py @@ -21,7 +21,9 @@ from sglang.srt.compilation.npu.compilation_context import CompilationContext from sglang.srt.compilation.npu.config import CompilationConfig -from sglang.srt.compilation.npu.npu_compiler_backend import NpuBackend +from sglang.srt.compilation.npu.npu_graph_compiler_backend import ( + NpuGraphCompilerBackend, +) from sglang.srt.distributed import get_tensor_model_parallel_world_size logger = logging.getLogger(__name__) @@ -146,7 +148,7 @@ def replacement( return quantized_output, out2 -class PiecewiseNpuGraphCompilerBackend(NpuBackend): +class PiecewiseNpuGraphCompilerBackend(NpuGraphCompilerBackend): graph: torch.fx.GraphModule def __init__( From 30da7fede35be2c8dcceb45ce38a485f738fe644 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Thu, 13 Nov 2025 19:50:58 +0800 Subject: [PATCH 14/71] model_type check --- python/sglang/srt/compilation/npu/npu_graph_compiler.py | 4 ++-- .../srt/compilation/npu/npu_graph_compiler_backend.py | 6 +++++- python/sglang/srt/model_executor/npu_graph_runner.py | 4 +++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/compilation/npu/npu_graph_compiler.py b/python/sglang/srt/compilation/npu/npu_graph_compiler.py index 46ab461e0c33..66e3e6a5ef4c 100644 --- a/python/sglang/srt/compilation/npu/npu_graph_compiler.py +++ b/python/sglang/srt/compilation/npu/npu_graph_compiler.py @@ -20,10 +20,10 @@ class NpuGraphCompiler: - def __init__(self, model: torch.nn.Module): + def __init__(self, model: torch.nn.Module, model_type: torch.dtype): torch._dynamo.reset() - self.backend = NpuGraphCompilerBackend() + self.backend = NpuGraphCompilerBackend(model_type) self.compiled_callable = torch.compile( model, fullgraph=True, dynamic=False, backend=self.backend ) diff --git a/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py b/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py index e40304adcb9a..e7d51e48eead 100644 --- a/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py +++ b/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py @@ -26,11 +26,15 @@ class NpuGraphCompilerBackend: + def __init__(self, model_type: torch.dtype): + self.model_type = model_type + def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: DisableContext.compiled_function_args[DisableContext.batch_size] = ( example_inputs ) - NpuGraphCompilerBackend.apply_passes(graph) + if self.model_type == torch.bfloat16: + NpuGraphCompilerBackend.apply_passes(graph) return graph def apply_passes(graph_module: torch.fx.GraphModule): diff --git a/python/sglang/srt/model_executor/npu_graph_runner.py b/python/sglang/srt/model_executor/npu_graph_runner.py index 4adb4c518be3..d7dc3892394d 100644 --- a/python/sglang/srt/model_executor/npu_graph_runner.py +++ b/python/sglang/srt/model_executor/npu_graph_runner.py @@ -101,7 +101,9 @@ def _init_dp_gathered_buffer( def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): if self.enable_torch_compile: - compiler = NpuGraphCompiler(run_once_fn) + compiler = NpuGraphCompiler( + run_once_fn, self.model_runner.model_config.dtype + ) patch_dynamo_context_call() DisableContext.batch_size = bs From 18084796bab85a73b9ec0c44efa6d5842a54f55d Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Fri, 14 Nov 2025 13:58:11 +0000 Subject: [PATCH 15/71] PiecewiseNpuGraphCompilerBackend quick fix --- .../compilation/piecewise_npu_graph_compiler_backend.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py index bfa12f1ef7f1..bb786d3e9999 100644 --- a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py +++ b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py @@ -158,6 +158,8 @@ def __init__( compilation_context: CompilationContext, page_size: int, ): + super().__init__(model_runner.model_config.dtype) + self.model_runner = model_runner self.model_config = model_runner.model.config From bcfc2c547dd08364905a043f7e6e06bd134611ef Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Mon, 17 Nov 2025 16:08:07 +0000 Subject: [PATCH 16/71] CompilationConfig reusage --- .../srt/compilation/compilation_config.py | 18 +++++++++++-- python/sglang/srt/compilation/npu/config.py | 27 ------------------- .../piecewise_npu_graph_compiler.py | 2 +- .../piecewise_npu_graph_compiler_backend.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 26 +++++++++--------- .../piecewise_npu_graph_runner_decode.py | 14 +++++----- python/sglang/srt/server_args.py | 6 +---- 7 files changed, 41 insertions(+), 54 deletions(-) delete mode 100644 python/sglang/srt/compilation/npu/config.py diff --git a/python/sglang/srt/compilation/compilation_config.py b/python/sglang/srt/compilation/compilation_config.py index fbf1493e1244..0acf483003ae 100644 --- a/python/sglang/srt/compilation/compilation_config.py +++ b/python/sglang/srt/compilation/compilation_config.py @@ -1,14 +1,23 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py -from typing import List +import json +from typing import List, Optional # TODO(Yuwei): support better compile config support class CompilationConfig: - def __init__(self, capture_sizes: List[int], compiler: str = "eager"): + splitting_ops: Optional[list[str]] = None + + def __init__( + self, + capture_sizes: List[int] = [], + compiler: str = "eager", + splitting_ops: list[str] = [], + ): self.traced_files = set() self.capture_sizes = capture_sizes self.compiler = compiler + self.splitting_ops = splitting_ops def add_traced_file(self, file_path: str): self.traced_files.add(file_path) @@ -18,3 +27,8 @@ def get_traced_files(self): def get_capture_sizes(self): return self.capture_sizes + + @classmethod + def from_cli(cls, args) -> "CompilationConfig": + args_dict = json.loads(args) + return CompilationConfig(**args_dict) diff --git a/python/sglang/srt/compilation/npu/config.py b/python/sglang/srt/compilation/npu/config.py deleted file mode 100644 index d6375c937413..000000000000 --- a/python/sglang/srt/compilation/npu/config.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2023-2025 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from dataclasses import dataclass -from typing import Optional - - -@dataclass -class CompilationConfig: - splitting_ops: Optional[list[str]] = None - replay_index: int = 1 - page_size: int = 0 - - @classmethod - def from_cli(cls, cli_value: str) -> "CompilationConfig": - return CompilationConfig() diff --git a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py index a42917c7eafc..4fdb42fa27bd 100644 --- a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py +++ b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py @@ -17,8 +17,8 @@ import torch +from sglang.srt.compilation.compilation_config import CompilationConfig from sglang.srt.compilation.npu.compilation_context import CompilationContext -from sglang.srt.compilation.npu.config import CompilationConfig from sglang.srt.model_executor.compilation.piecewise_npu_graph_compiler_backend import ( PiecewiseNpuGraphCompilerBackend, ) diff --git a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py index bb786d3e9999..4ddf3e85db74 100644 --- a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py +++ b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py @@ -19,8 +19,8 @@ import torch +from sglang.srt.compilation.compilation_config import CompilationConfig from sglang.srt.compilation.npu.compilation_context import CompilationContext -from sglang.srt.compilation.npu.config import CompilationConfig from sglang.srt.compilation.npu.npu_graph_compiler_backend import ( NpuGraphCompilerBackend, ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f479257620fa..16b849d4cf2b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2000,18 +2000,20 @@ def init_device_graphs(self): logger.info( f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" ) - graph_runners = defaultdict( - lambda: CudaGraphRunner, - { - "cpu": CPUGraphRunner, - "npu": ( - PiecewiseNPUGraphRunnerDecode - if self.server_args.enable_piecewise_npu_graph_decode - else NPUGraphRunner - ), - }, - ) - self.graph_runner = graph_runners[self.device](self) + + if self.server_args.enable_piecewise_npu_graph_decode: + self.graph_runner = PiecewiseNPUGraphRunnerDecode( + self, self.server_args.compilation_config + ) + else: + graph_runners = defaultdict( + lambda: CudaGraphRunner, + { + "cpu": CPUGraphRunner, + "npu": NPUGraphRunner, + }, + ) + self.graph_runner = graph_runners[self.device](self) after_mem = get_available_gpu_memory(self.device, self.gpu_id) self.graph_mem_usage = before_mem - after_mem diff --git a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py index 71f02cf9e7db..220803a46825 100644 --- a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py +++ b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py @@ -23,8 +23,8 @@ import torch._dynamo.config import tqdm +from sglang.srt.compilation.compilation_config import CompilationConfig from sglang.srt.compilation.npu.compilation_context import CompilationContext -from sglang.srt.compilation.npu.config import CompilationConfig from sglang.srt.compilation.npu.patch_dynamo import ( patch_dynamo_context, patch_dynamo_context_call, @@ -145,7 +145,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): class PiecewiseNPUGraphRunnerDecode: """A PiecewiseNPUGraphRunnerDecode runs the forward pass of a model with npu graph and torch.compile.""" - def __init__(self, model_runner: ModelRunner): + def __init__( + self, model_runner: ModelRunner, compilation_config: CompilationConfig + ): model_runner.attn_backend.enable_piecewise_npu_graph_decode = True patch_dynamo_context() @@ -155,12 +157,12 @@ def __init__(self, model_runner: ModelRunner): # Parse args self.model_runner = model_runner - self.compilation_config = CompilationConfig() - self.compilation_config.splitting_ops = ["atb._npu_paged_attention"] + if compilation_config is None: + compilation_config = CompilationConfig() + compilation_config.splitting_ops = ["atb._npu_paged_attention"] + self.compilation_config = compilation_config self.compilation_context = CompilationContext() - # self.compilation_context = model_runner.server_args.compilation_config - self.graphs = {} self.output_buffers = {} self.enable_torch_compile = model_runner.server_args.enable_torch_compile diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 2f09b95dd66f..0737a298e7fb 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -26,7 +26,7 @@ import orjson -from sglang.srt.compilation.npu.config import CompilationConfig +from sglang.srt.compilation.compilation_config import CompilationConfig from sglang.srt.connector import ConnectorType from sglang.srt.environ import ToolStrictLevel, envs from sglang.srt.function_call.function_call_parser import FunctionCallParser @@ -658,10 +658,6 @@ def __post_init__(self): # Handle elastic expert parallelism. self._handle_elastic_ep() - if not self.compilation_config: - self.compilation_config = CompilationConfig() - self.compilation_config.splitting_ops = ["atb._npu_paged_attention"] - def _handle_deprecated_args(self): # handle deprecated tool call parsers deprecated_tool_call_parsers = {"qwen25": "qwen", "glm45": "glm"} From a6a159dd7730c008723e8a5eb7a941362e37a954 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 18 Nov 2025 14:32:16 +0000 Subject: [PATCH 17/71] --torch-compile-max-bs support --- python/sglang/srt/model_executor/npu_graph_runner.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/model_executor/npu_graph_runner.py b/python/sglang/srt/model_executor/npu_graph_runner.py index d7dc3892394d..feac5dcfa982 100644 --- a/python/sglang/srt/model_executor/npu_graph_runner.py +++ b/python/sglang/srt/model_executor/npu_graph_runner.py @@ -100,7 +100,8 @@ def _init_dp_gathered_buffer( ) def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): - if self.enable_torch_compile: + if self.enable_torch_compile and (not self.compile_bs or bs in self.compile_bs): + self.model_runner.attn_backend.enable_torch_compile = True compiler = NpuGraphCompiler( run_once_fn, self.model_runner.model_config.dtype ) @@ -133,6 +134,7 @@ def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): compiled_function(*args) else: + self.model_runner.attn_backend.enable_torch_compile = False with torch.npu.graph( graph, pool=pool, @@ -143,7 +145,9 @@ def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): return out def _update_inputs(self, seq_lens): - if self.enable_torch_compile: + if self.enable_torch_compile and ( + not self.compile_bs or self.bs in self.compile_bs + ): if self.use_mla: self.graphs[self.bs].update( cpu_update_input=[{"actual_seq_lengths_kv": seq_lens}] @@ -202,7 +206,9 @@ def replay( self.positions[: self.raw_num_token].copy_(forward_batch.positions) # Replay - if self.enable_torch_compile: + if self.enable_torch_compile and ( + not self.compile_bs or self.bs in self.compile_bs + ): seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * ( self.bs - self.raw_bs ) From c08d0766679b8b270597dcf5dee7a829a4518196 Mon Sep 17 00:00:00 2001 From: XDaoHong Date: Fri, 14 Nov 2025 13:53:50 +0000 Subject: [PATCH 18/71] TorchAir compilation support --- .../srt/layers/attention/ascend_backend.py | 46 +++- python/sglang/srt/mem_cache/memory_pool.py | 117 ++++++++- .../sglang/srt/model_executor/model_runner.py | 9 +- .../npu_compile_model_runner.py | 247 ++++++++++++++++++ python/sglang/srt/server_args.py | 5 + python/sglang/srt/utils/common.py | 8 +- 6 files changed, 409 insertions(+), 23 deletions(-) create mode 100644 python/sglang/srt/model_executor/npu_compile_model_runner.py diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index ffda63e5a3f1..048a8be20848 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -74,8 +74,10 @@ def update_verify_buffers_to_fill_after_draft( def __init__(self, model_runner: ModelRunner): super().__init__() - self.enable_torch_compile = False - self.enable_piecewise_npu_graph_decode = False + self.enable_torch_compile = model_runner.server_args.enable_torch_compile + self.enable_piecewise_npu_graph_decode = ( + model_runner.server_args.enable_piecewise_npu_graph_decode + ) self.forward_metadata = None self.device = model_runner.device self.page_size = model_runner.page_size @@ -108,6 +110,14 @@ def __init__(self, model_runner: ModelRunner): self.mtp_mask = torch.tril(torch.ones(2048, 2048, dtype=torch.bool)).npu() self.mtp_mask = ~self.mtp_mask + self.enable_torch_air_compile = ( + model_runner.server_args.disable_cuda_graph + and model_runner.server_args.enable_torch_compile + ) + if self.enable_torch_air_compile: + max_total_tokens = model_runner.max_total_num_tokens + self.max_seqlen_pad = max_total_tokens // model_runner.server_args.page_size + def init_forward_metadata(self, forward_batch: ForwardBatch): """Init the metadata for a forward pass.""" tp_size = get_attention_tp_size() @@ -115,17 +125,41 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): seq_lens_max = forward_batch.seq_lens.max() if forward_batch.forward_mode.is_target_verify(): seq_lens_max += self.speculative_num_draft_tokens - self.forward_metadata.block_tables = ( + + block_tables = ( forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, :seq_lens_max ][:, :: self.page_size] // self.page_size ) + + if ( + self.enable_torch_air_compile + and forward_batch.forward_mode.is_decode_or_idle() + ): + bs = forward_batch.input_ids.size(0) + device = forward_batch.input_ids.device + self.forward_metadata.block_tables = torch.full( + (bs, self.max_seqlen_pad), -1, dtype=torch.int32, device=device + ) + self.forward_metadata.block_tables[:, : block_tables.size(1)].copy_( + block_tables + ) + else: + self.forward_metadata.block_tables = block_tables + if forward_batch.extend_seq_lens is not None: self.forward_metadata.extend_seq_lens_cpu_int = ( forward_batch.extend_seq_lens.cpu().int() ) self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() + if ( + self.enable_torch_air_compile + and forward_batch.forward_mode.is_decode_or_idle() + ): + self.forward_metadata.seq_lens_cpu_list = ( + self.forward_metadata.seq_lens_cpu_int.tolist() + ) if ( not forward_batch.forward_mode.is_draft_extend_v2() and not forward_batch.forward_mode.is_draft_extend() @@ -811,7 +845,11 @@ def forward_decode( atten_mask=None, block_size=self.page_size, block_table=self.forward_metadata.block_tables, - actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int, + actual_seq_lengths_kv=( + self.forward_metadata.seq_lens_cpu_list + if self.enable_torch_air_compile + else self.forward_metadata.seq_lens_cpu_int + ), scale=layer.scaling, ) else: diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 739e289439e4..e00554fff1ab 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -49,7 +49,15 @@ set_mla_kv_buffer_triton, set_mla_kv_scale_buffer_triton, ) -from sglang.srt.utils import is_cuda, is_float4_e2m1fn_x2, is_npu, next_power_of_2 +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import ( + get_bool_env_var, + is_cuda, + is_float4_e2m1fn_x2, + is_npu, + next_power_of_2, + supports_custom_op, +) if TYPE_CHECKING: from sglang.srt.managers.cache_controller import LayerDoneCounter @@ -1156,6 +1164,39 @@ def set_kv_buffer( class AscendTokenToKVPool(MHATokenToKVPool): + def __init__( + self, + size: int, + page_size: int, + dtype: torch.dtype, + head_num: int, + head_dim: int, + layer_num: int, + device: str, + enable_memory_saver: bool, + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, + ): + self.is_npu = is_npu() + self.supports_custom_op = supports_custom_op() + self.enable_torch_air_compile = ( + get_global_server_args().disable_cuda_graph + and get_global_server_args().enable_torch_compile + ) + + self.use_fia = get_bool_env_var("ASCEND_USE_FIA", "False") + super().__init__( + size, + page_size, + dtype, + head_num, + head_dim, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, + ) def _create_buffers(self): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): @@ -1175,9 +1216,33 @@ def _create_buffers(self): dtype=self.store_dtype, device=self.device, ) + self.k_buffer = self.kv_buffer[0] self.v_buffer = self.kv_buffer[1] + if ( + self.is_npu + and self.supports_custom_op + and self.enable_torch_air_compile + ): + self.k_buffer = [] + self.v_buffer = [] + for i in range(self.layer_num): + k_buffer_layer = self.kv_buffer[0][i] + v_buffer_layer = self.kv_buffer[1][i] + if self.use_fia: + k_buffer_layer = k_buffer_layer.view( + -1, 1, self.head_num, self.head_dim + ) + v_buffer_layer = v_buffer_layer.view( + -1, 1, self.head_num, self.head_dim + ) + self.k_buffer.append(k_buffer_layer) + self.v_buffer.append(v_buffer_layer) + else: + self.k_buffer = self.kv_buffer[0] + self.v_buffer = self.kv_buffer[1] + # for disagg def get_contiguous_buf_infos(self): # layer_num x [seq_len, head_num, head_dim] @@ -1231,17 +1296,45 @@ def set_kv_buffer( cache_k = cache_k.view(self.store_dtype) cache_v = cache_v.view(self.store_dtype) - torch_npu._npu_reshape_and_cache( - key=cache_k, - value=cache_v, - key_cache=self.k_buffer[layer_id - self.start_layer].view( - -1, self.page_size, self.head_num, self.head_dim - ), - value_cache=self.v_buffer[layer_id - self.start_layer].view( - -1, self.page_size, self.head_num, self.head_dim - ), - slot_indices=loc, - ) + if self.is_npu and self.supports_custom_op and self.enable_torch_air_compile: + if self.use_fia: + k_buffer_layer = self.k_buffer[layer_id - self.start_layer] + v_buffer_layer = self.v_buffer[layer_id - self.start_layer] + + torch_npu.npu_scatter_nd_update_( + k_buffer_layer, + loc.view(-1, 1), + cache_k.view(-1, 1, self.head_num, self.head_dim), + ) + torch_npu.npu_scatter_nd_update_( + v_buffer_layer, + loc.view(-1, 1), + cache_v.view(-1, 1, self.head_num, self.head_dim), + ) + else: + torch_npu._npu_reshape_and_cache( + key=cache_k, + value=cache_v, + key_cache=self.k_buffer[layer_id - self.start_layer].view( + -1, self.page_size, self.head_num, self.head_dim + ), + value_cache=self.v_buffer[layer_id - self.start_layer].view( + -1, self.page_size, self.head_num, self.head_dim + ), + slot_indices=loc, + ) + else: + torch_npu._npu_reshape_and_cache( + key=cache_k, + value=cache_v, + key_cache=self.k_buffer[layer_id - self.start_layer].view( + -1, self.page_size, self.head_num, self.head_dim + ), + value_cache=self.v_buffer[layer_id - self.start_layer].view( + -1, self.page_size, self.head_num, self.head_dim + ), + slot_indices=loc, + ) class MLATokenToKVPool(KVCache): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 16b849d4cf2b..1269895a7fc0 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -112,6 +112,7 @@ from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_executor.npu_compile_model_runner import NPUCompileModelRunner from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner from sglang.srt.model_executor.piecewise_cuda_graph_runner import ( PiecewiseCudaGraphRunner, @@ -1989,7 +1990,7 @@ def init_device_graphs(self): # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models return - if self.device != "cpu" and self.server_args.disable_cuda_graph: + if self.device not in ["cpu", "npu"] and self.server_args.disable_cuda_graph: return if self.device == "cpu" and not self.server_args.enable_torch_compile: @@ -2010,7 +2011,11 @@ def init_device_graphs(self): lambda: CudaGraphRunner, { "cpu": CPUGraphRunner, - "npu": NPUGraphRunner, + "npu": ( + NPUCompileModelRunner + if self.server_args.disable_cuda_graph + else NPUGraphRunner + ), }, ) self.graph_runner = graph_runners[self.device](self) diff --git a/python/sglang/srt/model_executor/npu_compile_model_runner.py b/python/sglang/srt/model_executor/npu_compile_model_runner.py new file mode 100644 index 000000000000..d69767e1fbc2 --- /dev/null +++ b/python/sglang/srt/model_executor/npu_compile_model_runner.py @@ -0,0 +1,247 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Run the model with torch air backend""" + +from __future__ import annotations + +import inspect +import logging +from typing import TYPE_CHECKING, Callable, Optional, Union + +import torch +import tqdm + +from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.utils import get_available_gpu_memory, get_compiler_backend + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.model_executor.cuda_graph_runner import ( + CudaGraphRunner +) +from sglang.srt.model_executor.forward_batch_info import ( + ForwardBatch, + PPProxyTensors, + CaptureHiddenMode, +) + + +class NPUCompileModelRunner(CudaGraphRunner): + def __init__(self, model_runner: ModelRunner): + super().__init__(model_runner) + + def capture(self) -> None: + # Reverse the order to enable better memory sharing across cuda graphs. + compile_range = ( + tqdm.tqdm(list(reversed(self.compile_bs))) + if get_tensor_model_parallel_rank() == 0 + else reversed(self.compile_bs) + ) + + backend = get_compiler_backend("reduce-overhead") + compile_forward = torch.compile( + torch.no_grad()(self.model_runner.model.forward), + fullgraph=True, + dynamic=True, + backend=backend, + ) + + self.model_runner.model.compile_forward = compile_forward + + @torch.compile(dynamic=True, backend=get_compiler_backend()) + def run_for_init(input): + return input + 1 + + run_for_init(torch.zeros([1]).to(self.model_runner.device)) + + for i, bs in enumerate(compile_range): + if get_tensor_model_parallel_rank() == 0: + avail_mem = get_available_gpu_memory( + self.model_runner.device, + self.model_runner.gpu_id, + empty_cache=False, + ) + compile_range.set_description( + f"Compiling batches ({bs=} {avail_mem=:.2f} GB)" + ) + + self.warm_up(bs, compile_forward) + + def replay( + self, + forward_batch: ForwardBatch, + skip_attn_backend_init: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + if not skip_attn_backend_init: + forward_batch.attn_backend.init_forward_metadata(forward_batch) + + kwargs = {} + if pp_proxy_tensors is not None: + kwargs["pp_proxy_tensors"] = pp_proxy_tensors + + with torch.no_grad(): + return self.model_runner.model.compile_forward( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + **kwargs, + ) + + def prepare_forward_batch(self, bs: int, num_tokens: int) -> ForwardBatch: + # Graph inputs + with torch.device(self.model_runner.device): + input_ids = torch.zeros((num_tokens,), dtype=torch.int64) + req_pool_indices = torch.zeros((bs,), dtype=torch.int64) + seq_lens = torch.full((bs,), self.seq_len_fill_value, dtype=torch.int64) + out_cache_loc = torch.zeros((num_tokens,), dtype=torch.int32) + positions = torch.zeros((num_tokens,), dtype=torch.int64) + num_token_non_padded = torch.tensor(num_tokens, dtype=torch.int32) + + if self.is_encoder_decoder: + encoder_lens = self.encoder_lens[:bs] + else: + encoder_lens = None + mrope_positions = None + + # pipeline parallelism + if self.pp_size > 1: + pp_proxy_tensors = PPProxyTensors( + {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()} + ) + + if self.require_mlp_tp_gather: + global_num_tokens = torch.tensor( + [ + num_tokens // self.dp_size + (i < (num_tokens % self.dp_size)) + for i in range(self.dp_size) + ], + dtype=torch.int64, + device=input_ids.device, + ) + elif self.require_attn_tp_gather: + global_num_tokens = torch.tensor( + [num_tokens], dtype=torch.int64, device=input_ids.device + ) + else: + global_num_tokens = None + gathered_buffer = None + + spec_info = self.get_spec_info(num_tokens) + if self.capture_hidden_mode != CaptureHiddenMode.FULL: + self.capture_hidden_mode = ( + spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL + ) + + forward_batch = ForwardBatch( + forward_mode=self.capture_forward_mode, + batch_size=bs, + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.cpu(), + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + attn_backend=self.model_runner.attn_backend, + out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens.sum().item(), + encoder_lens=encoder_lens, + return_logprob=False, + positions=positions, + global_num_tokens_gpu=global_num_tokens, + mrope_positions=mrope_positions, + spec_algorithm=self.model_runner.spec_algorithm, + spec_info=spec_info, + capture_hidden_mode=self.capture_hidden_mode, + num_token_non_padded=num_token_non_padded, + global_forward_mode=None, + mm_inputs=[None] * bs, + lora_ids=[None] * bs, + global_num_tokens_cpu=[num_tokens], + ) + return forward_batch + + def warm_up(self, bs: int, forward: Callable): + num_tokens = bs * self.num_tokens_per_bs + forward_batch = self.prepare_forward_batch(bs, num_tokens) + forward_batch.attn_backend.init_forward_metadata(forward_batch) + + # Run and compile + def run_once(): + # Clean intermediate result cache for DP attention + forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None + + kwargs = {} + if ( + self.pp_size > 1 + and "pp_proxy_tensors" in inspect.signature(forward).parameters + ): + kwargs["pp_proxy_tensors"] = forward_batch.pp_proxy_tensors + self.mark_static(forward_batch, kwargs.get("pp_proxy_tensors")) + + with torch.no_grad(): + logits_output_or_pp_proxy_tensors = forward( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + **kwargs, + ) + return logits_output_or_pp_proxy_tensors + + torch.npu.synchronize() + self.model_runner.tp_group.barrier() + run_once() + + def mark_static( + self, forward_batch: ForwardBatch, pp_proxy_tensors: PPProxyTensors = None + ): + def mark_tensor_static(model_input, is_cache=False): + if model_input is not None: + if isinstance(model_input, torch.Tensor): + torch._dynamo.mark_static(model_input) + elif is_cache: + for buffer_per_layer in model_input: + torch._dynamo.mark_static(buffer_per_layer) + elif isinstance(model_input, PPProxyTensors): + for pp_out in model_input.tensors.items(): + torch._dynamo.mark_static(pp_out) + elif isinstance(model_input, tuple): + for value in model_input: + torch._dynamo.mark_static(value) + else: + raise ValueError( + f"Unsupported type with mark static: {type(model_input)}" + ) + + mark_tensor_static(pp_proxy_tensors) + mark_tensor_static(forward_batch.input_ids) + mark_tensor_static(forward_batch.positions) + mark_tensor_static(forward_batch.input_embeds) + mark_tensor_static(forward_batch.out_cache_loc) + mark_tensor_static(forward_batch.attn_backend.forward_metadata.block_tables) + try: + mark_tensor_static(forward_batch.token_to_kv_pool.k_buffer, is_cache=True) + mark_tensor_static(forward_batch.token_to_kv_pool.v_buffer, is_cache=True) + except AttributeError as e: + mark_tensor_static(forward_batch.token_to_kv_pool.kv_buffer, is_cache=True) + + def can_run(self, forward_batch: ForwardBatch): + return ( + forward_batch.forward_mode.is_decode() + and forward_batch.batch_size in self.compile_bs + ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0737a298e7fb..d17d15933017 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -3345,6 +3345,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Optimize the model with torch.compile. Experimental feature.", ) + parser.add_argument( + "--enable-torch-air-compile", + action="store_true", + help="Optimize the model with Torch Ascend Intermediate Representation compilation. Experimental feature.", + ) parser.add_argument( "--enable-piecewise-cuda-graph", action="store_true", diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 3e9cbd5bcbce..48b886b130f9 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -1887,7 +1887,7 @@ def get_npu_compiler_config(): return config -def get_compiler_backend() -> str: +def get_compiler_backend(mode=None) -> str: if hasattr(torch, "hpu") and torch.hpu.is_available(): return "hpu_backend" @@ -1902,10 +1902,8 @@ def get_compiler_backend() -> str: "Please install torchair for torch.compile support on NPU." ) compiler_config = CompilerConfig() - predefined_config = get_npu_compiler_config() - for k, v in predefined_config.items(): - setattr(compiler_config.experimental_config, k, v) - + # TODO(iforgetmyname): Change this default value once torch_npu version 7.2.0 + compiler_config.mode = "max-autotune" if mode is None else mode npu_backend = torchair.get_npu_backend(compiler_config=compiler_config) return npu_backend From 73f2ee94104e6e90865d01231e5ef3b3b0b24b76 Mon Sep 17 00:00:00 2001 From: Eduard Shogulin Date: Wed, 19 Nov 2025 09:27:30 +0000 Subject: [PATCH 19/71] runner selection fix: model forward usage --- python/sglang/srt/model_executor/model_runner.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1269895a7fc0..c053ec05626f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1993,6 +1993,13 @@ def init_device_graphs(self): if self.device not in ["cpu", "npu"] and self.server_args.disable_cuda_graph: return + if ( + self.device == "npu" + and self.server_args.disable_cuda_graph + and not self.server_args.enable_torch_compile + ): + return + if self.device == "cpu" and not self.server_args.enable_torch_compile: return From 2f976413c2641d77a317803c52327ff8c6b6d8f2 Mon Sep 17 00:00:00 2001 From: XDaoHong Date: Wed, 19 Nov 2025 15:19:53 +0800 Subject: [PATCH 20/71] add test for torchair Co-authored-by: ZhengdQin --- .../test_ascend_compile_graph_tp1_bf16.py | 103 ++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py diff --git a/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py b/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py new file mode 100644 index 000000000000..2ecc97a95edf --- /dev/null +++ b/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py @@ -0,0 +1,103 @@ +import os +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, + run_bench_offline_throughput, +) + +TEST_MODEL_MATRIX = { + "Qwen/Qwen2.5-7B-Instruct": { + "accuracy": 0.84, + "latency": 150, + "output_throughput": 30, + }, +} + +os.environ["ASCEND_USE_FIA"] = "true" + + +class TestAscendTp1Bf16(CustomTestCase): + + @classmethod + def setUpClass(cls): + cls.models = TEST_MODEL_MATRIX.keys() + cls.base_url = DEFAULT_URL_FOR_TEST + cls.url = urlparse(DEFAULT_URL_FOR_TEST) + cls.common_args = [ + "--trust-remote-code", + "--mem-fraction-static", + 0.6, + "--attention-backend", + "ascend", + "--disable-radix-cache", + "--enable-torch-compile", + "--watchdog-timeout", + 30000, + "--disable-cuda-graph", + ] + + def test_a_gsm8k(self): + for model in self.models: + with self.subTest(model=model): + print(f"##=== Testing accuracy: {model} ===##") + + process = popen_launch_server( + model, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + *self.common_args, + ], + ) + + try: + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=1319, + max_new_tokens=512, + parallel=32, + host=f"http://{self.url.hostname}", + port=int(self.url.port), + ) + + metrics = run_eval_few_shot_gsm8k(args) + self.assertGreaterEqual( + metrics["accuracy"], + TEST_MODEL_MATRIX[model]["accuracy"], + ) + finally: + kill_process_tree(process.pid) + + def test_b_throughput(self): + for model in self.models: + with self.subTest(model=model): + print(f"##=== Testing throughput: {model} ===##") + + output_throughput = run_bench_offline_throughput( + model, + [ + *self.common_args, + ], + ) + + print(f"##=== {model} throughput: {output_throughput} ===##") + + if is_in_ci(): + self.assertGreater( + output_throughput, + TEST_MODEL_MATRIX[model]["output_throughput"], + ) + + +if __name__ == "__main__": + unittest.main() From 7154cf4ae32b97e378d7bd68bc9b53d03e4c3b97 Mon Sep 17 00:00:00 2001 From: Eduard Shogulin Date: Wed, 19 Nov 2025 09:51:14 +0000 Subject: [PATCH 21/71] TorchAir compilation support: refactoring --- .../srt/model_executor/npu_compile_model_runner.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/model_executor/npu_compile_model_runner.py b/python/sglang/srt/model_executor/npu_compile_model_runner.py index d69767e1fbc2..46db823df63d 100644 --- a/python/sglang/srt/model_executor/npu_compile_model_runner.py +++ b/python/sglang/srt/model_executor/npu_compile_model_runner.py @@ -31,19 +31,18 @@ from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.model_executor.cuda_graph_runner import ( - CudaGraphRunner -) +from sglang.srt.model_executor.cuda_graph_runner import get_batch_sizes_to_capture from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, ForwardBatch, PPProxyTensors, - CaptureHiddenMode, ) -class NPUCompileModelRunner(CudaGraphRunner): +class NPUCompileModelRunner: def __init__(self, model_runner: ModelRunner): - super().__init__(model_runner) + self.model_runner = model_runner + _, self.compile_bs = get_batch_sizes_to_capture(model_runner) def capture(self) -> None: # Reverse the order to enable better memory sharing across cuda graphs. From dfaee00c534de8745aec3429966cb838730cabc3 Mon Sep 17 00:00:00 2001 From: Eduard Shogulin Date: Wed, 19 Nov 2025 09:52:36 +0000 Subject: [PATCH 22/71] NPU Piecewise Graph: refactoring --- .../piecewise_npu_graph_runner_decode.py | 56 +------------------ 1 file changed, 2 insertions(+), 54 deletions(-) diff --git a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py index 220803a46825..4e89cd5dae6c 100644 --- a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py +++ b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py @@ -37,6 +37,7 @@ from sglang.srt.model_executor.compilation.piecewise_npu_graph_compiler import ( PiecewiseNpuGraphCompiler, ) +from sglang.srt.model_executor.cuda_graph_runner import get_batch_sizes_to_capture from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, @@ -47,7 +48,6 @@ from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin from sglang.srt.utils import ( get_available_gpu_memory, - get_device_memory_capacity, rank0_log, ) @@ -90,58 +90,6 @@ def __init__( self.callable = callable -def get_batch_sizes_to_capture(model_runner: ModelRunner): - server_args = model_runner.server_args - capture_bs = server_args.cuda_graph_bs - - if capture_bs is None: - if server_args.speculative_algorithm is None: - if server_args.disable_cuda_graph_padding: - capture_bs = list(range(1, 33)) + list(range(48, 161, 16)) - else: - capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8)) - else: - # Since speculative decoding requires more npu graph memory, we - # capture less. - capture_bs = ( - list(range(1, 9)) - + list(range(10, 33, 2)) - + list(range(40, 64, 8)) - + list(range(80, 161, 16)) - ) - - gpu_mem = get_device_memory_capacity() - if gpu_mem is not None and gpu_mem > 96 * 1024: - capture_bs += list(range(160, 257, 8)) - if gpu_mem is not None and gpu_mem > 180 * 1000: - capture_bs += list(range(256, 513, 16)) - - if max(capture_bs) > model_runner.req_to_token_pool.size: - # In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests - # is very small. We add more values here to make sure we capture the maximum bs. - capture_bs += [model_runner.req_to_token_pool.size] - - if server_args.enable_two_batch_overlap: - capture_bs = [bs for bs in capture_bs if bs >= 2] - - if server_args.cuda_graph_max_bs: - capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs] - if max(capture_bs) < server_args.cuda_graph_max_bs: - capture_bs += list( - range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16) - ) - capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size] - capture_bs = list(sorted(set(capture_bs))) - assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}" - compile_bs = ( - [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs] - if server_args.enable_torch_compile - else [] - ) - - return capture_bs, compile_bs - - class PiecewiseNPUGraphRunnerDecode: """A PiecewiseNPUGraphRunnerDecode runs the forward pass of a model with npu graph and torch.compile.""" @@ -179,7 +127,7 @@ def __init__( self.pp_size = model_runner.server_args.pp_size # Batch sizes to capture - self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) + self.capture_bs, _ = get_batch_sizes_to_capture(model_runner) rank0_log(f"Capture npu graph bs {self.capture_bs}") self.capture_forward_mode: int = ForwardMode.DECODE self.capture_hidden_mode: int = CaptureHiddenMode.NULL From 253c14da9e5be101b8a15aa9de02e26532cfcb70 Mon Sep 17 00:00:00 2001 From: Eduard Shogulin Date: Wed, 19 Nov 2025 10:58:27 +0000 Subject: [PATCH 23/71] linter fix after merge commit --- python/sglang/srt/model_executor/model_runner.py | 2 +- .../srt/model_executor/piecewise_npu_graph_runner_decode.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 258438d5fda2..8424250ab872 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -114,8 +114,8 @@ from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors -from sglang.srt.model_executor.npu_compile_model_runner import NPUCompileModelRunner from sglang.srt.model_executor.hook_manager import register_hooks +from sglang.srt.model_executor.npu_compile_model_runner import NPUCompileModelRunner from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner from sglang.srt.model_executor.piecewise_cuda_graph_runner import ( PiecewiseCudaGraphRunner, diff --git a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py index 4e89cd5dae6c..76ec726e92bb 100644 --- a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py +++ b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py @@ -46,10 +46,7 @@ enable_num_token_non_padded, ) from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin -from sglang.srt.utils import ( - get_available_gpu_memory, - rank0_log, -) +from sglang.srt.utils import get_available_gpu_memory, rank0_log torch._dynamo.config.skip_nnmodule_hook_guards = True torch._dynamo.config.automatic_dynamic_shapes = False From 85d808eec06ddb6df66027c85b7728fb71d165aa Mon Sep 17 00:00:00 2001 From: Eduard Shogulin Date: Wed, 19 Nov 2025 16:11:42 +0000 Subject: [PATCH 24/71] NPUGraph compilation (fp16) & NPU Piecewise Graph tests --- .../test_ascend_npu_graph_compile_tp1_bf16.py | 61 +++++++++++++++++++ ...est_ascend_npu_piecewise_graph_tp1_bf16.py | 61 +++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py create mode 100644 test/srt/ascend/test_ascend_npu_piecewise_graph_tp1_bf16.py diff --git a/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py b/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py new file mode 100644 index 000000000000..85d1c2e83fbf --- /dev/null +++ b/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py @@ -0,0 +1,61 @@ +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +DEFAULT_MODEL_NAME_FOR_TEST = "Qwen/Qwen2.5-7B-Instruct" + + +class TestAscendNpuGraphCompile(CustomTestCase): + def test_gsm8k(self): + model = DEFAULT_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + url = urlparse(base_url) + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--attention-backend", + "ascend", + "--mem-fraction-static", + 0.7, + "--enable-torch-compile", + "--cuda-graph-bs", + "128", + "--cuda-graph-max-bs", + "128", + "--tp-size", + "1", + ], + ) + + try: + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=1319, + max_new_tokens=512, + parallel=128, + host=f"http://{url.hostname}", + port=int(url.port), + ) + + metrics = run_eval_few_shot_gsm8k(args) + self.assertGreaterEqual(metrics["accuracy"], 0.62) + self.assertLessEqual(metrics["latency"], 150) + finally: + kill_process_tree(process.pid) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/ascend/test_ascend_npu_piecewise_graph_tp1_bf16.py b/test/srt/ascend/test_ascend_npu_piecewise_graph_tp1_bf16.py new file mode 100644 index 000000000000..6a1eb57ea870 --- /dev/null +++ b/test/srt/ascend/test_ascend_npu_piecewise_graph_tp1_bf16.py @@ -0,0 +1,61 @@ +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +DEFAULT_MODEL_NAME_FOR_TEST = "Qwen/Qwen2.5-7B-Instruct" + + +class TestAscendNpuPiecewiseGraph(CustomTestCase): + def test_gsm8k(self): + model = DEFAULT_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + url = urlparse(base_url) + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--attention-backend", + "ascend", + "--mem-fraction-static", + 0.7, + "--enable-piecewise-npu-graph-decode", + "--cuda-graph-bs", + "128", + "--cuda-graph-max-bs", + "128", + "--tp-size", + "1", + ], + ) + + try: + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=1319, + max_new_tokens=512, + parallel=128, + host=f"http://{url.hostname}", + port=int(url.port), + ) + + metrics = run_eval_few_shot_gsm8k(args) + self.assertGreaterEqual(metrics["accuracy"], 0.62) + self.assertLessEqual(metrics["latency"], 150) + finally: + kill_process_tree(process.pid) + + +if __name__ == "__main__": + unittest.main() From 11074d973a16112da7388aef82e8411178301f10 Mon Sep 17 00:00:00 2001 From: Eduard Shogulin Date: Wed, 19 Nov 2025 16:32:10 +0000 Subject: [PATCH 25/71] TorchAir compilation support: refactoring 2 --- python/sglang/srt/model_executor/npu_compile_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/model_executor/npu_compile_model_runner.py b/python/sglang/srt/model_executor/npu_compile_model_runner.py index 46db823df63d..b68bcdf6f3d4 100644 --- a/python/sglang/srt/model_executor/npu_compile_model_runner.py +++ b/python/sglang/srt/model_executor/npu_compile_model_runner.py @@ -43,6 +43,7 @@ class NPUCompileModelRunner: def __init__(self, model_runner: ModelRunner): self.model_runner = model_runner _, self.compile_bs = get_batch_sizes_to_capture(model_runner) + self.capture() def capture(self) -> None: # Reverse the order to enable better memory sharing across cuda graphs. From e06675b4c7a4afeecd7facde1d6d46753bb0dcd6 Mon Sep 17 00:00:00 2001 From: Eduard Shogulin Date: Fri, 21 Nov 2025 08:52:50 +0000 Subject: [PATCH 26/71] CompilationConfig comments fix + linter fix --- .../srt/compilation/compilation_config.py | 6 ++---- .../piecewise_npu_graph_compiler_backend.py | 21 ------------------- .../npu_compile_model_runner.py | 8 ++++--- 3 files changed, 7 insertions(+), 28 deletions(-) diff --git a/python/sglang/srt/compilation/compilation_config.py b/python/sglang/srt/compilation/compilation_config.py index 0acf483003ae..0d8df1fd85c4 100644 --- a/python/sglang/srt/compilation/compilation_config.py +++ b/python/sglang/srt/compilation/compilation_config.py @@ -1,18 +1,16 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py import json -from typing import List, Optional +from typing import List # TODO(Yuwei): support better compile config support class CompilationConfig: - splitting_ops: Optional[list[str]] = None - def __init__( self, capture_sizes: List[int] = [], compiler: str = "eager", - splitting_ops: list[str] = [], + splitting_ops: List[str] = [], ): self.traced_files = set() self.capture_sizes = capture_sizes diff --git a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py index 4ddf3e85db74..7bf409a1e8e0 100644 --- a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py +++ b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py @@ -127,27 +127,6 @@ class SplitItem: graph: torch.fx.GraphModule -class NpuAddRmsNormFuse: - def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset, v1, v2, v3): - output = torch.ops.npu.npu_add_rms_norm( - rms_norm_input, residual, rms_norm_weight, 1e-6 - ) - out0 = output[0] - out2 = output[2] - quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, v1, v2, v3) - return quantized_output, out2 - - def replacement( - rms_norm_input, residual, rms_norm_weight, scale, offset, v1, v2, v3 - ): - output = torch.ops.npu.npu_add_rms_norm_quant( - rms_norm_input, residual, rms_norm_weight, 1.0 / scale, offset, epsilon=1e-6 - ) - quantized_output = output[0] - out2 = output[2] - return quantized_output, out2 - - class PiecewiseNpuGraphCompilerBackend(NpuGraphCompilerBackend): graph: torch.fx.GraphModule diff --git a/python/sglang/srt/model_executor/npu_compile_model_runner.py b/python/sglang/srt/model_executor/npu_compile_model_runner.py index b68bcdf6f3d4..7f9466f26261 100644 --- a/python/sglang/srt/model_executor/npu_compile_model_runner.py +++ b/python/sglang/srt/model_executor/npu_compile_model_runner.py @@ -41,11 +41,13 @@ class NPUCompileModelRunner: def __init__(self, model_runner: ModelRunner): + print(f"NPUCompileModelRunner::__init__", flush=True) self.model_runner = model_runner _, self.compile_bs = get_batch_sizes_to_capture(model_runner) self.capture() def capture(self) -> None: + print(f"NPUCompileModelRunner::capture", flush=True) # Reverse the order to enable better memory sharing across cuda graphs. compile_range = ( tqdm.tqdm(list(reversed(self.compile_bs))) @@ -88,6 +90,7 @@ def replay( skip_attn_backend_init: bool = False, pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + print(f"NPUCompileModelRunner::replay", flush=True) if not skip_attn_backend_init: forward_batch.attn_backend.init_forward_metadata(forward_batch) @@ -241,7 +244,6 @@ def mark_tensor_static(model_input, is_cache=False): mark_tensor_static(forward_batch.token_to_kv_pool.kv_buffer, is_cache=True) def can_run(self, forward_batch: ForwardBatch): - return ( - forward_batch.forward_mode.is_decode() - and forward_batch.batch_size in self.compile_bs + return forward_batch.forward_mode.is_decode() and ( + forward_batch.batch_size in self.compile_bs ) From 0c09c2458ffcbd2ed06c158836ce03dbb19e1c36 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Fri, 21 Nov 2025 17:12:09 +0300 Subject: [PATCH 27/71] backend instantiation in get_compiler_backend --- .../srt/compilation/npu/npu_graph_compiler.py | 23 +++++++++++----- .../piecewise_npu_graph_compiler.py | 20 +++++++------- .../piecewise_npu_graph_compiler_backend.py | 3 +-- .../sglang/srt/model_executor/model_runner.py | 27 +++++++++---------- .../srt/model_executor/npu_graph_runner.py | 4 ++- .../piecewise_npu_graph_runner_decode.py | 13 +++++---- python/sglang/srt/utils/common.py | 24 ++++++++++++++++- 7 files changed, 74 insertions(+), 40 deletions(-) diff --git a/python/sglang/srt/compilation/npu/npu_graph_compiler.py b/python/sglang/srt/compilation/npu/npu_graph_compiler.py index 66e3e6a5ef4c..24a09a62f0a4 100644 --- a/python/sglang/srt/compilation/npu/npu_graph_compiler.py +++ b/python/sglang/srt/compilation/npu/npu_graph_compiler.py @@ -14,16 +14,27 @@ import torch -from sglang.srt.compilation.npu.npu_graph_compiler_backend import ( - NpuGraphCompilerBackend, -) +from sglang.srt.compilation.compilation_config import CompilationConfig +from sglang.srt.utils.common import get_compiler_backend class NpuGraphCompiler: - def __init__(self, model: torch.nn.Module, model_type: torch.dtype): + def __init__( + self, + model_runner, + model: torch.nn.Module, + compilation_config: CompilationConfig, + ): torch._dynamo.reset() - self.backend = NpuGraphCompilerBackend(model_type) + backend = get_compiler_backend( + ( + "npugraph_fused" + if compilation_config is None or compilation_config.compiler is None + else compilation_config.compiler + ), + model_runner.model_config.dtype, + ) self.compiled_callable = torch.compile( - model, fullgraph=True, dynamic=False, backend=self.backend + model, fullgraph=True, dynamic=False, backend=backend ) diff --git a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py index 4fdb42fa27bd..014837af3b95 100644 --- a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py +++ b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py @@ -19,9 +19,7 @@ from sglang.srt.compilation.compilation_config import CompilationConfig from sglang.srt.compilation.npu.compilation_context import CompilationContext -from sglang.srt.model_executor.compilation.piecewise_npu_graph_compiler_backend import ( - PiecewiseNpuGraphCompilerBackend, -) +from sglang.srt.utils.common import get_compiler_backend class PiecewiseNpuGraphCompiler: @@ -31,17 +29,21 @@ def __init__( model: torch.nn.Module, compilation_config: CompilationConfig, compilation_context: CompilationContext, - page_size: int, ): - self.backend = PiecewiseNpuGraphCompilerBackend( - model_runner, compilation_config, compilation_context, page_size + backend = get_compiler_backend( + ( + "piecewise" + if compilation_config.compiler is None + else compilation_config.compiler + ), + model_runner, + compilation_config, + compilation_context, ) - self.model = model - torch._dynamo.reset() torch.compiler.allow_in_graph(sys.intern) torch.compiler.allow_in_graph(pathlib.Path) self.compiled_callable = torch.compile( - self.model, fullgraph=True, dynamic=False, backend=self.backend + model, fullgraph=True, dynamic=False, backend=backend ) diff --git a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py index 7bf409a1e8e0..46b78388cee2 100644 --- a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py +++ b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py @@ -135,7 +135,6 @@ def __init__( model_runner, compilation_config: CompilationConfig, compilation_context: CompilationContext, - page_size: int, ): super().__init__(model_runner.model_config.dtype) @@ -143,7 +142,7 @@ def __init__( self.model_config = model_runner.model.config self.compilation_config = compilation_config - self.page_size = page_size + self.page_size = model_runner.page_size self.compilation_context = compilation_context self.split_gm = None diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index c7f9b175a960..5cec87ae815a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2067,23 +2067,22 @@ def init_device_graphs(self): f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" ) - if self.server_args.enable_piecewise_npu_graph_decode: - self.graph_runner = PiecewiseNPUGraphRunnerDecode( - self, self.server_args.compilation_config - ) - else: - graph_runners = defaultdict( - lambda: CudaGraphRunner, - { - "cpu": CPUGraphRunner, - "npu": ( + graph_runners = defaultdict( + lambda: CudaGraphRunner, + { + "cpu": CPUGraphRunner, + "npu": ( + PiecewiseNPUGraphRunnerDecode + if self.server_args.enable_piecewise_npu_graph_decode + else ( NPUCompileModelRunner if self.server_args.disable_cuda_graph else NPUGraphRunner - ), - }, - ) - self.graph_runner = graph_runners[self.device](self) + ) + ), + }, + ) + self.graph_runner = graph_runners[self.device](self) after_mem = get_available_gpu_memory(self.device, self.gpu_id) self.graph_mem_usage = before_mem - after_mem diff --git a/python/sglang/srt/model_executor/npu_graph_runner.py b/python/sglang/srt/model_executor/npu_graph_runner.py index feac5dcfa982..2ce19c17b4b6 100644 --- a/python/sglang/srt/model_executor/npu_graph_runner.py +++ b/python/sglang/srt/model_executor/npu_graph_runner.py @@ -103,7 +103,9 @@ def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): if self.enable_torch_compile and (not self.compile_bs or bs in self.compile_bs): self.model_runner.attn_backend.enable_torch_compile = True compiler = NpuGraphCompiler( - run_once_fn, self.model_runner.model_config.dtype + self.model_runner, + run_once_fn, + get_global_server_args().compilation_config, ) patch_dynamo_context_call() diff --git a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py index 76ec726e92bb..d3c83c7dd9f6 100644 --- a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py +++ b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py @@ -45,6 +45,7 @@ PPProxyTensors, enable_num_token_non_padded, ) +from sglang.srt.server_args import get_global_server_args from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin from sglang.srt.utils import get_available_gpu_memory, rank0_log @@ -90,9 +91,7 @@ def __init__( class PiecewiseNPUGraphRunnerDecode: """A PiecewiseNPUGraphRunnerDecode runs the forward pass of a model with npu graph and torch.compile.""" - def __init__( - self, model_runner: ModelRunner, compilation_config: CompilationConfig - ): + def __init__(self, model_runner: ModelRunner): model_runner.attn_backend.enable_piecewise_npu_graph_decode = True patch_dynamo_context() @@ -102,9 +101,11 @@ def __init__( # Parse args self.model_runner = model_runner + compilation_config = get_global_server_args().compilation_config if compilation_config is None: - compilation_config = CompilationConfig() - compilation_config.splitting_ops = ["atb._npu_paged_attention"] + compilation_config = CompilationConfig( + compiler="piecewise", splitting_ops=["atb._npu_paged_attention"] + ) self.compilation_config = compilation_config self.compilation_context = CompilationContext() @@ -413,7 +414,6 @@ def capture_one_batch_size( self.model_runner.model, self.compilation_config, self.compilation_context, - self.model_runner.page_size, ) patch_dynamo_context_call() @@ -527,7 +527,6 @@ def replay( pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> Union[LogitsProcessorOutput, PPProxyTensors]: self.replay_prepare(forward_batch, pp_proxy_tensors) - compiled_graph = self.graphs[self.bs] def init(): attn_backend = self.model_runner.attn_backend diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 51ebc7d90e2d..4008a70ec7c1 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -91,6 +91,7 @@ from torch.utils._contextlib import _DecoratorContextManager from typing_extensions import Literal +from sglang.srt.compilation.compilation_config import CompilationConfig from sglang.srt.environ import envs from sglang.srt.metrics.func_timer import enable_func_timer @@ -1887,11 +1888,32 @@ def get_npu_compiler_config(): return config -def get_compiler_backend(mode=None) -> str: +def get_compiler_backend( + mode=None, + model_runner=None, + compilation_config: CompilationConfig = None, + compilation_context=None, +) -> str: if hasattr(torch, "hpu") and torch.hpu.is_available(): return "hpu_backend" if hasattr(torch, "npu") and torch.npu.is_available(): + if mode == "piecewise": + from sglang.srt.model_executor.compilation.piecewise_npu_graph_compiler_backend import ( + PiecewiseNpuGraphCompilerBackend, + ) + + return PiecewiseNpuGraphCompilerBackend( + model_runner, compilation_config, compilation_context + ) + + if mode == "npugraph_fused": + from sglang.srt.compilation.npu.npu_graph_compiler_backend import ( + NpuGraphCompilerBackend, + ) + + return NpuGraphCompilerBackend(model_runner) + try: import torchair import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce From 7eefeee7922c0911056b0a655597e8241225ed6d Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Wed, 26 Nov 2025 13:43:40 +0300 Subject: [PATCH 28/71] linter fix --- python/sglang/srt/model_executor/npu_graph_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/npu_graph_runner.py b/python/sglang/srt/model_executor/npu_graph_runner.py index 4d1eb804706c..cc53b3a07fea 100644 --- a/python/sglang/srt/model_executor/npu_graph_runner.py +++ b/python/sglang/srt/model_executor/npu_graph_runner.py @@ -26,7 +26,7 @@ import torch import sglang.srt.model_executor.cuda_graph_runner -from sglang.srt.configs.model_config import AttentionArch, +from sglang.srt.configs.model_config import AttentionArch from sglang.srt.distributed.parallel_state import GroupCoordinator from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner From 8c63980be53619d0a923ccd23fdcd384b863e10a Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Wed, 26 Nov 2025 14:17:59 +0300 Subject: [PATCH 29/71] dynamo patch removing --- .../npu/npu_graph_compiler_backend.py | 4 -- .../srt/compilation/npu/patch_dynamo.py | 54 ------------------- .../srt/model_executor/npu_graph_runner.py | 32 ++--------- .../piecewise_npu_graph_runner_decode.py | 29 ++-------- 4 files changed, 8 insertions(+), 111 deletions(-) delete mode 100644 python/sglang/srt/compilation/npu/patch_dynamo.py diff --git a/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py b/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py index e7d51e48eead..bc16a315d7ed 100644 --- a/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py +++ b/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py @@ -15,7 +15,6 @@ from typing import Callable import torch -from torch._dynamo.eval_frame import DisableContext from sglang.srt.compilation.npu.pass_manager import PassManager from sglang.srt.compilation.npu.passes.w8a8_int8 import ( @@ -30,9 +29,6 @@ def __init__(self, model_type: torch.dtype): self.model_type = model_type def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: - DisableContext.compiled_function_args[DisableContext.batch_size] = ( - example_inputs - ) if self.model_type == torch.bfloat16: NpuGraphCompilerBackend.apply_passes(graph) return graph diff --git a/python/sglang/srt/compilation/npu/patch_dynamo.py b/python/sglang/srt/compilation/npu/patch_dynamo.py deleted file mode 100644 index 284582f86011..000000000000 --- a/python/sglang/srt/compilation/npu/patch_dynamo.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2025 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import annotations - -import torch -from torch._dynamo.decorators import skip -from torch._dynamo.eval_frame import DisableContext, innermost_fn - - -def patch_dynamo_context(): - setattr(torch._dynamo.eval_frame.DisableContext, "compiled_function_args", {}) - setattr(torch._dynamo.eval_frame.DisableContext, "compiled_function", {}) - setattr(torch._dynamo.eval_frame.DisableContext, "batch_size", None) - - -original_disable_context_call = None -original_disable = None - - -def decorators_disable(fn=None, recursive=True): - if recursive: - if fn is not None: - fn = innermost_fn(fn) - assert callable(fn) - - DisableContext.compiled_function[DisableContext.batch_size] = fn - return DisableContext()(fn) - return DisableContext() - else: - return skip(fn) - - -def patch_dynamo_context_call(): - global original_disable - original_disable = torch._dynamo.decorators.disable - torch._dynamo.decorators.disable = decorators_disable - - -def restore_dynamo_context_call(): - global original_disable - torch._dynamo.decorators.disable = original_disable - original_disable = None diff --git a/python/sglang/srt/model_executor/npu_graph_runner.py b/python/sglang/srt/model_executor/npu_graph_runner.py index cc53b3a07fea..ddab08b29c5e 100644 --- a/python/sglang/srt/model_executor/npu_graph_runner.py +++ b/python/sglang/srt/model_executor/npu_graph_runner.py @@ -44,18 +44,12 @@ if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner -from torch._dynamo.eval_frame import DisableContext from sglang.srt.compilation.custom_ops import ( _set_dp_buffer_len, _set_is_extend_in_batch, ) from sglang.srt.compilation.npu.npu_graph_compiler import NpuGraphCompiler -from sglang.srt.compilation.npu.patch_dynamo import ( - patch_dynamo_context, - patch_dynamo_context_call, - restore_dynamo_context_call, -) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors @@ -74,8 +68,6 @@ class NPUGraphRunner(CudaGraphRunner): """A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile.""" def __init__(self, model_runner: ModelRunner): - if model_runner.server_args.enable_torch_compile: - patch_dynamo_context() sglang.srt.model_executor.cuda_graph_runner.patch_model = patch_model_npu model_runner.attn_backend.enable_torch_compile = ( model_runner.server_args.enable_torch_compile @@ -123,32 +115,16 @@ def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): get_global_server_args().compilation_config, ) - patch_dynamo_context_call() - DisableContext.batch_size = bs - try: - # compilation - out = compiler.compiled_callable() - - # capture function and args - out = compiler.compiled_callable() - finally: - DisableContext.batch_size = None - restore_dynamo_context_call() - - assert bs in DisableContext.compiled_function - assert DisableContext.compiled_function[bs] - assert bs in DisableContext.compiled_function_args - assert DisableContext.compiled_function_args[bs] - - compiled_function = DisableContext.compiled_function[bs] - args = DisableContext.compiled_function_args[bs] + # compilation + out = compiler.compiled_callable() + with torch.npu.graph( graph, pool=pool, stream=stream, auto_dispatch_capture=True, ): - compiled_function(*args) + compiler.compiled_callable() else: self.model_runner.attn_backend.enable_torch_compile = False diff --git a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py index d3c83c7dd9f6..804e15d4eadf 100644 --- a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py +++ b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py @@ -25,11 +25,6 @@ from sglang.srt.compilation.compilation_config import CompilationConfig from sglang.srt.compilation.npu.compilation_context import CompilationContext -from sglang.srt.compilation.npu.patch_dynamo import ( - patch_dynamo_context, - patch_dynamo_context_call, - restore_dynamo_context_call, -) from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import graph_capture from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend @@ -55,7 +50,6 @@ import logging -from torch._dynamo.eval_frame import DisableContext logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -93,9 +87,6 @@ class PiecewiseNPUGraphRunnerDecode: def __init__(self, model_runner: ModelRunner): model_runner.attn_backend.enable_piecewise_npu_graph_decode = True - - patch_dynamo_context() - self.inference_counter = 1 self.init_forward_metadata_was_done = True @@ -416,24 +407,10 @@ def capture_one_batch_size( self.compilation_context, ) - patch_dynamo_context_call() - DisableContext.batch_size = bs - logits_output_or_pp_proxy_tensors = compiler.compiled_callable( forward_batch.input_ids, forward_batch.positions, forward_batch ) - try: - logits_output_or_pp_proxy_tensors = compiler.compiled_callable( - forward_batch.input_ids, forward_batch.positions, forward_batch - ) - finally: - DisableContext.batch_size = None - restore_dynamo_context_call() - - assert DisableContext.compiled_function - assert DisableContext.compiled_function_args - compiled_graph = CompiledGraph( bs, forward_batch, None, compiler.compiled_callable ) @@ -555,8 +532,10 @@ def init(): self.model_runner.attn_backend.graph_mode = True - DisableContext.compiled_function[self.bs]( - *DisableContext.compiled_function_args[self.bs] + compiled_graph = self.graphs[self.bs] + forward_batch = compiled_graph.forward_batch + compiled_graph.callable( + forward_batch.input_ids, forward_batch.positions, forward_batch ) output = self.output_buffers[self.bs] From 2e025683e2aa9c2b273dbf344052df7eae4ac206 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Thu, 27 Nov 2025 15:38:59 +0300 Subject: [PATCH 30/71] fix on main branch: compilation --- python/sglang/srt/_custom_ops.py | 52 +++++ .../srt/layers/attention/ascend_backend.py | 200 +++++++++--------- .../srt/model_executor/npu_graph_runner.py | 63 ++---- .../piecewise_npu_graph_runner_decode.py | 1 - python/sglang/srt/models/qwen3.py | 22 +- 5 files changed, 175 insertions(+), 163 deletions(-) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index e3734ba087fe..9f371cb74f52 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -9,9 +9,61 @@ logger = logging.getLogger(__name__) +import sgl_kernel_npu.norm.split_qkv_rmsnorm_rope + import sglang.srt.utils +@torch.library.custom_op("sglang::split_qkv_rmsnorm_rope", mutates_args=()) +def split_qkv_rmsnorm_rope( + input: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_hiddent_size: int, + kv_hidden_size: int, + head_dim: int, + eps: float, + q_bias: torch.Tensor, + k_bias: torch.Tensor, +) -> List[torch.Tensor]: + q, k, v = sgl_kernel_npu.norm.split_qkv_rmsnorm_rope.split_qkv_rmsnorm_rope( + input, + sin, + cos, + q_weight, + k_weight, + q_hiddent_size, + kv_hidden_size, + head_dim, + eps, + q_bias, + k_bias, + ) + return [q, k, v] + + +@split_qkv_rmsnorm_rope.register_fake +def split_qkv_rmsnorm_rope_fake( + input: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_hiddent_size: int, + kv_hidden_size: int, + head_dim: int, + eps: float, + q_bias: torch.Tensor, + k_bias: torch.Tensor, +) -> List[torch.Tensor]: + q = torch.empty((128, 4096), dtype=input.dtype, device=input.device) + k = torch.empty((128, 512), dtype=input.dtype, device=input.device) + v = torch.empty((128, 512), dtype=input.dtype, device=input.device) + return [q, k, v] + + @torch.library.custom_op("sglang::wait_cmo_stream", mutates_args=()) def wait_cmo_stream() -> None: if sglang.srt.utils.get_cmo_stream(): diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index ce88d6d6585a..1cbc92b1d1e0 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -153,13 +153,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.extend_seq_lens.cpu().int() ) self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int() - if ( - self.enable_torch_air_compile - and forward_batch.forward_mode.is_decode_or_idle() - ): - self.forward_metadata.seq_lens_cpu_list = ( - self.forward_metadata.seq_lens_cpu_int.tolist() - ) if ( not forward_batch.forward_mode.is_draft_extend_v2() and not forward_batch.forward_mode.is_draft_extend() @@ -660,51 +653,10 @@ def forward_decode_graph( layer, forward_batch.out_cache_loc, k, v ) - if not self.use_mla and ( - self.enable_torch_compile or self.enable_piecewise_npu_graph_decode - ): - k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id) - v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id) - query = q.reshape(-1, layer.tp_q_head_num, layer.qk_head_dim) - num_tokens = query.shape[0] - attn_output = torch.empty( - (num_tokens, layer.tp_q_head_num, layer.v_head_dim), - dtype=query.dtype, - device=query.device, - ) - - if self.forward_metadata.seq_lens_cpu_int is None: - actual_seq_len_kv = torch.from_numpy( - np.array(self.forward_metadata.seq_lens_cpu_list).astype(np.int32) - ) - else: - actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_int - - if ( - self.enable_piecewise_npu_graph_decode - and torch.compiler.is_dynamo_compiling() - ): - # input args for submodule forward - forward_batch.req_to_token_pool.req_to_token.add_( - forward_batch.req_to_token_pool.req_to_token - ) - forward_batch.req_pool_indices.add_(forward_batch.req_pool_indices) - forward_batch.seq_lens.add_(forward_batch.seq_lens) - - torch_npu._npu_paged_attention( - query=query, - key_cache=k_cache, - value_cache=v_cache, - num_heads=layer.tp_q_head_num, - num_kv_heads=layer.tp_k_head_num, - scale_value=layer.scaling, - block_table=self.forward_metadata.block_tables, - context_lens=actual_seq_len_kv, - out=attn_output, - ) - return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim) - else: - if not self.use_mla: + if not self.use_mla: + num_tokens = q.shape[0] + """PA will support bs Date: Thu, 27 Nov 2025 17:00:32 +0300 Subject: [PATCH 31/71] auto merge fix --- .../srt/layers/attention/ascend_backend.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index 8e7199183031..7180aa128c81 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -238,6 +238,14 @@ def __init__(self, model_runner: ModelRunner): self.ascend_attn_mask_builder.mix_mask_cache, ) + self.enable_torch_air_compile = ( + model_runner.server_args.disable_cuda_graph + and model_runner.server_args.enable_torch_compile + ) + if self.enable_torch_air_compile: + max_total_tokens = model_runner.max_total_num_tokens + self.max_seqlen_pad = max_total_tokens // model_runner.server_args.page_size + def get_verify_buffers_to_fill_after_draft(self): """ Return buffers for verify attention kernels that needs to be filled after draft. @@ -251,14 +259,6 @@ def update_verify_buffers_to_fill_after_draft( ): pass - self.enable_torch_air_compile = ( - model_runner.server_args.disable_cuda_graph - and model_runner.server_args.enable_torch_compile - ) - if self.enable_torch_air_compile: - max_total_tokens = model_runner.max_total_num_tokens - self.max_seqlen_pad = max_total_tokens // model_runner.server_args.page_size - def init_forward_metadata(self, forward_batch: ForwardBatch): """Init the metadata for a forward pass.""" tp_size = get_attention_tp_size() From f9891479bdc16ef6e24091997f8a4ef217f6a7ec Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Thu, 27 Nov 2025 18:15:02 +0300 Subject: [PATCH 32/71] tests suit update --- test/srt/run_suite.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 1a0f89bf5b73..d450de8adcf0 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -360,6 +360,9 @@ TestFile("ascend/test_ascend_hicache_mha.py", 400), TestFile("ascend/test_ascend_sampling_backend.py", 400), TestFile("ascend/test_ascend_tp1_bf16.py", 400), + TestFile("ascend/test_ascend_compile_graph_tp1_bf16.py", 400), + TestFile("ascend/test_ascend_npu_graph_compile_tp1_bf16.py", 400), + TestFile("ascend/test_ascend_npu_piecewise_graph_tp1_bf16.py", 400), ], "per-commit-2-npu-a2": [ TestFile("ascend/test_ascend_graph_tp2_bf16.py", 400), From bf1251dd13b5d9863d7fa5b4c31e152238f487d1 Mon Sep 17 00:00:00 2001 From: OrangeRedeng Date: Thu, 27 Nov 2025 18:16:02 +0300 Subject: [PATCH 33/71] Add npu_add_rms_norm_dynamic_quant fuse --- .../npu/npu_graph_compiler_backend.py | 2 ++ .../srt/compilation/npu/passes/w8a8_int8.py | 24 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py b/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py index bc16a315d7ed..24ef4b6a5cd5 100644 --- a/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py +++ b/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py @@ -20,6 +20,7 @@ from sglang.srt.compilation.npu.passes.w8a8_int8 import ( DivFuse, EraseCopy, + NpuAddRmsNormDynamicQuantFuse, NpuAddRmsNormQuantFuse, ) @@ -36,6 +37,7 @@ def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: def apply_passes(graph_module: torch.fx.GraphModule): passManager = PassManager(graph_module) passManager.add(NpuAddRmsNormQuantFuse) + passManager.add(NpuAddRmsNormDynamicQuantFuse) passManager.add(DivFuse) passManager.add(EraseCopy) passManager.apply() diff --git a/python/sglang/srt/compilation/npu/passes/w8a8_int8.py b/python/sglang/srt/compilation/npu/passes/w8a8_int8.py index ac2b86e7b171..d8353111e223 100644 --- a/python/sglang/srt/compilation/npu/passes/w8a8_int8.py +++ b/python/sglang/srt/compilation/npu/passes/w8a8_int8.py @@ -98,3 +98,27 @@ def replacement( quantized_output = output[0] out2 = output[2] return quantized_output, out2 + + +class NpuAddRmsNormDynamicQuantFuse: + def pattern(rms_norm_input, residual, rms_norm_weight): + output = torch.ops.npu.npu_add_rms_norm( + rms_norm_input, residual, rms_norm_weight, 1e-6 + ) + out0 = output[0] + out2 = output[2] + quantized_output = torch.ops.npu.npu_dynamic_quant(out0) + return quantized_output, out2, dynamic_scale + + def replacement(rms_norm_input, residual, rms_norm_weight): + output = torch.ops.npu.npu_add_rms_norm_dynamic_quant( + x1=rms_norm_input, + x2=residual, + gamma=rms_norm_weight, + epsilon=1e-6, + output_mask=[True, True], + ) + quantized_output = output[0] + out2 = output[2] + dynamic_scale = output[3] + return quantized_output, out2, dynamic_scale From e6eb29cab608b3f1fb6971eaeb17b71468eaa7b0 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Thu, 27 Nov 2025 18:57:34 +0300 Subject: [PATCH 34/71] NPU Graph compilation: attention architecture check --- python/sglang/srt/model_executor/model_runner.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ae38d8b9330d..a6d092b2baae 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -395,6 +395,17 @@ def __init__( else: self.piecewise_cuda_graph_runner = None + if _is_npu and ( + not self.server_args.disable_cuda_graph + and self.server_args.enable_torch_compile + and not self.model_config.attention_arch == AttentionArch.MLA + ): + log_info_on_rank0( + logger, + "Disable torch compile for NPU graph because attention architecture is not suitable", + ) + self.server_args.enable_torch_compile = False + def init_mindspore_runner(self): # Init the mindspore runner # for now, there is only some communication initialization work From caba95e45621190bdbb7900a9062fed98022f3c2 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Thu, 27 Nov 2025 19:09:16 +0300 Subject: [PATCH 35/71] Add npu_add_rms_norm_dynamic_quant fuse: quick fix --- python/sglang/srt/compilation/npu/passes/w8a8_int8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/compilation/npu/passes/w8a8_int8.py b/python/sglang/srt/compilation/npu/passes/w8a8_int8.py index d8353111e223..20c8dbfa6100 100644 --- a/python/sglang/srt/compilation/npu/passes/w8a8_int8.py +++ b/python/sglang/srt/compilation/npu/passes/w8a8_int8.py @@ -107,7 +107,7 @@ def pattern(rms_norm_input, residual, rms_norm_weight): ) out0 = output[0] out2 = output[2] - quantized_output = torch.ops.npu.npu_dynamic_quant(out0) + quantized_output, dynamic_scale = torch.ops.npu.npu_dynamic_quant(out0) return quantized_output, out2, dynamic_scale def replacement(rms_norm_input, residual, rms_norm_weight): From 3f878797b7fbdc4fdd0d47d8fb31d4779f7a6dec Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Fri, 28 Nov 2025 12:17:29 +0300 Subject: [PATCH 36/71] Qwen3 MoE compilation support for NPU --- python/sglang/srt/models/qwen3_moe.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 3fbe81257290..b99f18b6469a 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -71,6 +71,7 @@ is_flashinfer_available, is_non_idle_and_non_empty, is_npu, + supports_custom_op, ) Qwen3MoeConfig = None @@ -82,7 +83,13 @@ _is_npu = is_npu() if _is_npu: - from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope + if supports_custom_op() and ( + get_global_server_args().enable_torch_compile + or get_global_server_args().enable_piecewise_npu_graph_decode + ): + from sglang.srt._custom_ops import split_qkv_rmsnorm_rope + else: + from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope class Qwen3MoeSparseMoeBlock(nn.Module): From faea888e7570218a600841244fcb2910baddbb9d Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 2 Dec 2025 11:16:47 +0300 Subject: [PATCH 37/71] linter quick fix --- python/sglang/srt/model_executor/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 68fe99fc532c..e61438d5bb5a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -131,8 +131,8 @@ PPProxyTensors, ) from sglang.srt.model_executor.hook_manager import register_forward_hooks -from sglang.srt.model_executor.npu_compile_model_runner import NPUCompileModelRunner from sglang.srt.model_executor.input_buffers import GraphInputBuffers +from sglang.srt.model_executor.npu_compile_model_runner import NPUCompileModelRunner from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner from sglang.srt.model_executor.piecewise_cuda_graph_runner import ( PiecewiseCudaGraphRunner, From 85720d6a7a45960135a2998288491ba9b9a52aed Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Sat, 29 Nov 2025 11:56:30 +0300 Subject: [PATCH 38/71] SlitQkvRmsnormRopeFuse fuse --- .../srt/compilation/npu/npu_graph_compiler.py | 1 + .../npu/npu_graph_compiler_backend.py | 41 +++++- .../srt/compilation/npu/pass_manager.py | 8 +- .../sglang/srt/compilation/npu/passes/fp16.py | 128 ++++++++++++++++++ python/sglang/srt/models/qwen3.py | 60 ++------ 5 files changed, 183 insertions(+), 55 deletions(-) create mode 100644 python/sglang/srt/compilation/npu/passes/fp16.py diff --git a/python/sglang/srt/compilation/npu/npu_graph_compiler.py b/python/sglang/srt/compilation/npu/npu_graph_compiler.py index 24a09a62f0a4..dd4c059a8d10 100644 --- a/python/sglang/srt/compilation/npu/npu_graph_compiler.py +++ b/python/sglang/srt/compilation/npu/npu_graph_compiler.py @@ -35,6 +35,7 @@ def __init__( ), model_runner.model_config.dtype, ) + backend.init(model_runner.model_config) self.compiled_callable = torch.compile( model, fullgraph=True, dynamic=False, backend=backend ) diff --git a/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py b/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py index 24ef4b6a5cd5..aca189cb8102 100644 --- a/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py +++ b/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py @@ -17,12 +17,14 @@ import torch from sglang.srt.compilation.npu.pass_manager import PassManager +from sglang.srt.compilation.npu.passes.fp16 import SplitQkvRmsnormRopeFuse from sglang.srt.compilation.npu.passes.w8a8_int8 import ( DivFuse, EraseCopy, NpuAddRmsNormDynamicQuantFuse, NpuAddRmsNormQuantFuse, ) +from sglang.srt.layers.dp_attention import get_attention_tp_size class NpuGraphCompilerBackend: @@ -31,11 +33,46 @@ def __init__(self, model_type: torch.dtype): def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: if self.model_type == torch.bfloat16: - NpuGraphCompilerBackend.apply_passes(graph) + self.apply_passes(graph) return graph - def apply_passes(graph_module: torch.fx.GraphModule): + def init(self, config): + config = config.hf_config + + hidden_size = config.hidden_size + + num_heads = config.num_attention_heads + num_kv_heads = config.num_key_value_heads + + head_dim = getattr(config, "head_dim", None) + self.rms_norm_eps = config.rms_norm_eps + + total_num_heads = num_heads + attn_tp_size = get_attention_tp_size() + + assert total_num_heads % attn_tp_size == 0 + num_heads = total_num_heads // attn_tp_size + total_num_kv_heads = num_kv_heads + num_kv_heads = max(1, total_num_kv_heads // attn_tp_size) + + self.head_dim = head_dim or hidden_size // total_num_heads + self.q_size = num_heads * self.head_dim + self.kv_size = num_kv_heads * self.head_dim + + self.q_shape = (self.head_dim, self.q_size) + self.k_shape = (self.head_dim, self.kv_size) + + def apply_passes(self, graph_module: torch.fx.GraphModule): passManager = PassManager(graph_module) + passManager.add( + SplitQkvRmsnormRopeFuse, + q_size=self.q_size, + kv_size=self.kv_size, + head_dim=self.head_dim, + q_shape=self.q_shape, + k_shape=self.k_shape, + variance_epsilon=self.rms_norm_eps, + ) passManager.add(NpuAddRmsNormQuantFuse) passManager.add(NpuAddRmsNormDynamicQuantFuse) passManager.add(DivFuse) diff --git a/python/sglang/srt/compilation/npu/pass_manager.py b/python/sglang/srt/compilation/npu/pass_manager.py index e4da0bf8535d..52318e04b131 100644 --- a/python/sglang/srt/compilation/npu/pass_manager.py +++ b/python/sglang/srt/compilation/npu/pass_manager.py @@ -20,13 +20,13 @@ def __init__(self, graph_module: torch.fx.GraphModule): self.graph_module = graph_module self.passes = [] - def add(self, pass_): - self.passes.append(pass_) + def add(self, pass_, **kwargs): + self.passes.append((pass_, kwargs)) def apply(self): updated = False - for pass_ in self.passes: - pass_instance = pass_() + for pass_, kwargs in self.passes: + pass_instance = pass_(**kwargs) results = [] try: if callable(pass_instance): diff --git a/python/sglang/srt/compilation/npu/passes/fp16.py b/python/sglang/srt/compilation/npu/passes/fp16.py new file mode 100644 index 000000000000..f41e49e70717 --- /dev/null +++ b/python/sglang/srt/compilation/npu/passes/fp16.py @@ -0,0 +1,128 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch + + +class SplitQkvRmsnormRopeFuse: + instance = None + + def __init__( + self, + q_size: int, + kv_size: int, + head_dim: int, + q_shape, + k_shape, + variance_epsilon: float, + ): + self.q_size = q_size + self.kv_size = kv_size + self.head_dim = head_dim + self.q_shape = q_shape + self.k_shape = k_shape + self.variance_epsilon = variance_epsilon + + SplitQkvRmsnormRopeFuse.instance = self + + def pattern( + output_parallel, + q_norm_parameters_weight, + k_norm_parameters_weight, + positions, + cos_sin_cache, + ): + # pattern matching brokes if make static method as class method + self = SplitQkvRmsnormRopeFuse.instance + + split = output_parallel.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = split[0] + k = split[1] + v = split[2] + + q_by_head = q.reshape(-1, self.head_dim) + npu_rms_norm_q = torch.ops.npu.npu_rms_norm( + q_by_head, q_norm_parameters_weight, self.variance_epsilon + ) + q_by_head_1 = npu_rms_norm_q[0] + + k_by_head = k.reshape(-1, self.head_dim) + npu_rms_norm_k = torch.ops.npu.npu_rms_norm( + k_by_head, k_norm_parameters_weight, self.variance_epsilon + ) + k_by_head_1 = npu_rms_norm_k[0] + + q_1 = q_by_head_1.view(self.q_shape) + k_1 = k_by_head_1.view(self.k_shape) + + npu_mrope = torch.ops.npu.npu_mrope( + positions, + q_1, + k_1, + cos_sin_cache, + self.head_dim, + mrope_section=[0, 0, 0], + rotary_mode="half", + ) + query_out = npu_mrope[0] + key_out = npu_mrope[1] + + return v, query_out, key_out + + def replacement( + output_parallel, + q_norm_parameters_weight, + k_norm_parameters_weight, + positions, + cos_sin_cache, + ): + # pattern matching brokes if make static method as class method + self = SplitQkvRmsnormRopeFuse.instance + + flatten = positions.flatten() + cos_sin = cos_sin_cache.index_select(0, flatten) + + reshape = cos_sin.reshape(-1, 2, 64) + repeat = reshape.repeat(1, 1, 2) + chunk = repeat.chunk(2, dim=-2) + cos = chunk[0] + sin = chunk[1] + + cos_view = cos.view(-1, 1, 1, self.head_dim) + cos_contiguous = cos_view.contiguous() + + sin_view = sin.view(-1, 1, 1, self.head_dim) + sin_contiguous = sin_view.contiguous() + + split_qkv_rmsnorm_rope_default = ( + torch.ops.sglang.split_qkv_rmsnorm_rope.default( + output_parallel, + sin_contiguous, + cos_contiguous, + q_norm_parameters_weight, + k_norm_parameters_weight, + self.q_size, + self.kv_size, + self.head_dim, + self.variance_epsilon, + q_bias=None, + k_bias=None, + ) + ) + + q = split_qkv_rmsnorm_rope_default[0] + k = split_qkv_rmsnorm_rope_default[1] + v = split_qkv_rmsnorm_rope_default[2] + + return v, q, k diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index cceea3299141..5e8b2edf8cfd 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -32,20 +32,15 @@ from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, is_cuda, is_npu, supports_custom_op -if is_npu(): - if supports_custom_op() and ( +if ( + is_npu() + and supports_custom_op() + and ( get_global_server_args().enable_torch_compile or get_global_server_args().enable_piecewise_npu_graph_decode - ): - from sglang.srt._custom_ops import ( - get_cmo_stream, - split_qkv_rmsnorm_rope, - wait_cmo_stream, - ) - else: - from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope - - from sglang.srt.utils import get_cmo_stream, wait_cmo_stream + ) +): + from sglang.srt._custom_ops import get_cmo_stream, wait_cmo_stream else: from sglang.srt.utils import get_cmo_stream, wait_cmo_stream @@ -172,33 +167,6 @@ def _apply_qk_norm( k = k_by_head.view(k.shape) return q, k - def forward_prepare_native(self, positions, hidden_states): - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self._apply_qk_norm(q, k) - q, k = self.rotary_emb(positions, q, k) - return q, k, v - - def forward_prepare_npu(self, positions, hidden_states): - qkv, _ = self.qkv_proj(hidden_states) - - if self.attn.layer_id == 0: - self.rotary_emb.get_cos_sin_with_position(positions) - q, k, v = split_qkv_rmsnorm_rope( - qkv, - self.rotary_emb.position_sin, - self.rotary_emb.position_cos, - self.q_norm.weight, - self.k_norm.weight, - self.q_size, - self.kv_size, - self.head_dim, - self.q_norm.variance_epsilon, - q_bias=getattr(self.q_norm, "bias", None), - k_bias=getattr(self.k_norm, "bias", None), - ) - return q, k, v - def forward( self, positions: torch.Tensor, @@ -208,16 +176,10 @@ def forward( if get_global_server_args().rl_on_policy_target is not None: hidden_states = hidden_states.bfloat16() - if not _is_npu: - q, k, v = self.forward_prepare_native( - positions=positions, - hidden_states=hidden_states, - ) - else: - q, k, v = self.forward_prepare_npu( - positions=positions, - hidden_states=hidden_states, - ) + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) if get_global_server_args().rl_on_policy_target is not None: q = q.to(torch.bfloat16) From fd0e1e8cd8acdea02c9a0300805626252a3055e4 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 2 Dec 2025 14:46:43 +0300 Subject: [PATCH 39/71] headers quick fix --- python/sglang/srt/model_executor/npu_compile_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/npu_compile_model_runner.py b/python/sglang/srt/model_executor/npu_compile_model_runner.py index 7f9466f26261..0f3a1a19b3b8 100644 --- a/python/sglang/srt/model_executor/npu_compile_model_runner.py +++ b/python/sglang/srt/model_executor/npu_compile_model_runner.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 SGLang Team +# Copyright 2023-2025 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at From 4a61b7eca2bafb93dfe87fff9615445803c47918 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Wed, 3 Dec 2025 12:33:37 +0300 Subject: [PATCH 40/71] lint after merge + Piecewise Graph fix --- python/sglang/srt/_custom_ops.py | 2 +- .../model_executor/compilation/piecewise_npu_graph_compiler.py | 2 ++ .../srt/model_executor/piecewise_npu_graph_runner_decode.py | 2 -- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 38c674252031..3b4c53a7f4fd 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -4,7 +4,7 @@ import torch -from sglang.srt.utils import direct_register_custom_op, is_cuda, is_hip, is_hpu, is_npu +from sglang.srt.utils import direct_register_custom_op, is_cuda, is_hip logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py index 014837af3b95..d4d72612d5c6 100644 --- a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py +++ b/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py @@ -40,6 +40,8 @@ def __init__( compilation_config, compilation_context, ) + backend.init(model_runner.model_config) + torch._dynamo.reset() torch.compiler.allow_in_graph(sys.intern) torch.compiler.allow_in_graph(pathlib.Path) diff --git a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py index d67fc6408065..a2e4753e7f4b 100644 --- a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py +++ b/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py @@ -41,7 +41,6 @@ enable_num_token_non_padded, ) from sglang.srt.server_args import get_global_server_args -from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin from sglang.srt.utils import get_available_gpu_memory, rank0_log torch._dynamo.config.skip_nnmodule_hook_guards = True @@ -155,7 +154,6 @@ def __init__(self, model_runner: ModelRunner): self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32) - self.tbo_plugin = TboCudaGraphRunnerPlugin() self.block_tables = torch.full((160, 160), 0, dtype=torch.int32) From 17f0af5234f4a40f68b150eafcf195ed2da39216 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Wed, 3 Dec 2025 21:05:14 +0300 Subject: [PATCH 41/71] enable_torch_compile update rollback --- python/sglang/srt/model_executor/model_runner.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e61438d5bb5a..8690093d36b9 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -378,17 +378,6 @@ def __init__( self._model_update_group = {} self._weights_send_group = {} - if _is_npu and ( - not self.server_args.disable_cuda_graph - and self.server_args.enable_torch_compile - and not self.model_config.attention_arch == AttentionArch.MLA - ): - log_info_on_rank0( - logger, - "Disable torch compile for NPU graph because attention architecture is not suitable", - ) - self.server_args.enable_torch_compile = False - def init_mindspore_runner(self): # Init the mindspore runner # for now, there is only some communication initialization work From 752657c2dedbf414a6808c22231fc4632bb974af Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Sun, 7 Dec 2025 21:31:03 +0300 Subject: [PATCH 42/71] Merge fixes: moving in accordance with refactoring & cleanup --- .../sglang/srt/compilation/npu/passes/fp16.py | 2 +- .../compilation/npu_graph_backend.py | 0 .../piecewise_npu_graph_compiler.py | 0 .../piecewise_npu_graph_compiler_backend.py | 2 +- .../graph_runner}/npu_compile_model_runner.py | 0 .../piecewise_npu_graph_runner_decode.py | 6 +- python/sglang/srt/mem_cache/memory_pool.py | 183 +------------- .../sglang/srt/model_executor/model_runner.py | 10 +- python/sglang/srt/models/qwen3_moe.py | 232 +++++++++++++++--- python/sglang/srt/utils/common.py | 2 +- 10 files changed, 215 insertions(+), 222 deletions(-) rename python/sglang/srt/{model_executor => hardware_backend/npu/graph_runner}/compilation/npu_graph_backend.py (100%) rename python/sglang/srt/{model_executor => hardware_backend/npu/graph_runner}/compilation/piecewise_npu_graph_compiler.py (100%) rename python/sglang/srt/{model_executor => hardware_backend/npu/graph_runner}/compilation/piecewise_npu_graph_compiler_backend.py (98%) rename python/sglang/srt/{model_executor => hardware_backend/npu/graph_runner}/npu_compile_model_runner.py (100%) rename python/sglang/srt/{model_executor => hardware_backend/npu/graph_runner}/piecewise_npu_graph_runner_decode.py (99%) diff --git a/python/sglang/srt/compilation/npu/passes/fp16.py b/python/sglang/srt/compilation/npu/passes/fp16.py index c1f199132d43..fade634948cd 100644 --- a/python/sglang/srt/compilation/npu/passes/fp16.py +++ b/python/sglang/srt/compilation/npu/passes/fp16.py @@ -12,8 +12,8 @@ # limitations under the License. # ============================================================================== -import torch import sgl_kernel_npu +import torch class SplitQkvRmsnormRopeFuse: diff --git a/python/sglang/srt/model_executor/compilation/npu_graph_backend.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_backend.py similarity index 100% rename from python/sglang/srt/model_executor/compilation/npu_graph_backend.py rename to python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_backend.py diff --git a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler.py similarity index 100% rename from python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler.py rename to python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler.py diff --git a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py similarity index 98% rename from python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py rename to python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py index 46b78388cee2..8031122b65d2 100644 --- a/python/sglang/srt/model_executor/compilation/piecewise_npu_graph_compiler_backend.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py @@ -167,7 +167,7 @@ def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: ) npu_graph_backend = resolve_obj_by_qualname( - "sglang.srt.model_executor.compilation.npu_graph_backend.NPUGraphBackend" + "sglang.srt.hardware_backend.npu.graph_runner.compilation.npu_graph_backend.NPUGraphBackend" ) self.submod_names_compiled_only = [ diff --git a/python/sglang/srt/model_executor/npu_compile_model_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py similarity index 100% rename from python/sglang/srt/model_executor/npu_compile_model_runner.py rename to python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py diff --git a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py b/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py similarity index 99% rename from python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py rename to python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py index a2e4753e7f4b..8ca9ace090bc 100644 --- a/python/sglang/srt/model_executor/piecewise_npu_graph_runner_decode.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py @@ -27,11 +27,11 @@ from sglang.srt.compilation.npu.compilation_context import CompilationContext from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import graph_capture -from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend -from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.model_executor.compilation.piecewise_npu_graph_compiler import ( +from sglang.srt.hardware_backend.npu.attention.ascend_backend import AscendAttnBackend +from sglang.srt.hardware_backend.npu.graph_runner.compilation.piecewise_npu_graph_compiler import ( PiecewiseNpuGraphCompiler, ) +from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.model_executor.cuda_graph_runner import get_batch_sizes_to_capture from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 6ef667911253..99550e64adb2 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -50,14 +50,7 @@ set_mla_kv_buffer_triton, set_mla_kv_scale_buffer_triton, ) -from sglang.srt.server_args import get_global_server_args -from sglang.srt.utils import ( - get_bool_env_var, - is_cuda, - is_npu, - next_power_of_2, - supports_custom_op, -) +from sglang.srt.utils import is_cuda, is_npu, next_power_of_2 if TYPE_CHECKING: from sglang.srt.managers.cache_controller import LayerDoneCounter @@ -1310,180 +1303,6 @@ def set_kv_buffer( ) -class AscendTokenToKVPool(MHATokenToKVPool): - def __init__( - self, - size: int, - page_size: int, - dtype: torch.dtype, - head_num: int, - head_dim: int, - layer_num: int, - device: str, - enable_memory_saver: bool, - start_layer: Optional[int] = None, - end_layer: Optional[int] = None, - ): - self.is_npu = is_npu() - self.supports_custom_op = supports_custom_op() - self.enable_torch_air_compile = ( - get_global_server_args().disable_cuda_graph - and get_global_server_args().enable_torch_compile - ) - - self.use_fia = get_bool_env_var("ASCEND_USE_FIA", "False") - super().__init__( - size, - page_size, - dtype, - head_num, - head_dim, - layer_num, - device, - enable_memory_saver, - start_layer, - end_layer, - ) - - def _create_buffers(self): - with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): - # [size, head_num, head_dim] for each layer - # The padded slot 0 is used for writing dummy outputs from padded tokens. - # Continuous memory improves the efficiency of Ascend`s transmission backend, - # while other backends remain unchanged. - self.kv_buffer = torch.zeros( - ( - 2, - self.layer_num, - self.size // self.page_size + 1, - self.page_size, - self.head_num, - self.head_dim, - ), - dtype=self.store_dtype, - device=self.device, - ) - - self.k_buffer = self.kv_buffer[0] - self.v_buffer = self.kv_buffer[1] - - if ( - self.is_npu - and self.supports_custom_op - and self.enable_torch_air_compile - ): - self.k_buffer = [] - self.v_buffer = [] - for i in range(self.layer_num): - k_buffer_layer = self.kv_buffer[0][i] - v_buffer_layer = self.kv_buffer[1][i] - if self.use_fia: - k_buffer_layer = k_buffer_layer.view( - -1, 1, self.head_num, self.head_dim - ) - v_buffer_layer = v_buffer_layer.view( - -1, 1, self.head_num, self.head_dim - ) - self.k_buffer.append(k_buffer_layer) - self.v_buffer.append(v_buffer_layer) - else: - self.k_buffer = self.kv_buffer[0] - self.v_buffer = self.kv_buffer[1] - - # for disagg - def get_contiguous_buf_infos(self): - # layer_num x [seq_len, head_num, head_dim] - # layer_num x [page_num, page_size, head_num, head_dim] - kv_data_ptrs = [ - self.get_key_buffer(i).data_ptr() - for i in range(self.start_layer, self.start_layer + self.layer_num) - ] + [ - self.get_value_buffer(i).data_ptr() - for i in range(self.start_layer, self.start_layer + self.layer_num) - ] - kv_data_lens = [ - self.get_key_buffer(i).nbytes - for i in range(self.start_layer, self.start_layer + self.layer_num) - ] + [ - self.get_value_buffer(i).nbytes - for i in range(self.start_layer, self.start_layer + self.layer_num) - ] - kv_item_lens = [ - self.get_key_buffer(i)[0].nbytes - for i in range(self.start_layer, self.start_layer + self.layer_num) - ] + [ - self.get_value_buffer(i)[0].nbytes - for i in range(self.start_layer, self.start_layer + self.layer_num) - ] - return kv_data_ptrs, kv_data_lens, kv_item_lens - - def set_kv_buffer( - self, - layer: RadixAttention, - loc: torch.Tensor, - cache_k: torch.Tensor, - cache_v: torch.Tensor, - k_scale: Optional[float] = None, - v_scale: Optional[float] = None, - layer_id_override: Optional[int] = None, - ): - if layer_id_override is not None: - layer_id = layer_id_override - else: - layer_id = layer.layer_id - if cache_k.dtype != self.dtype: - if k_scale is not None: - cache_k.div_(k_scale) - if v_scale is not None: - cache_v.div_(v_scale) - cache_k = cache_k.to(self.dtype) - cache_v = cache_v.to(self.dtype) - - if self.store_dtype != self.dtype: - cache_k = cache_k.view(self.store_dtype) - cache_v = cache_v.view(self.store_dtype) - - if self.is_npu and self.supports_custom_op and self.enable_torch_air_compile: - if self.use_fia: - k_buffer_layer = self.k_buffer[layer_id - self.start_layer] - v_buffer_layer = self.v_buffer[layer_id - self.start_layer] - - torch_npu.npu_scatter_nd_update_( - k_buffer_layer, - loc.view(-1, 1), - cache_k.view(-1, 1, self.head_num, self.head_dim), - ) - torch_npu.npu_scatter_nd_update_( - v_buffer_layer, - loc.view(-1, 1), - cache_v.view(-1, 1, self.head_num, self.head_dim), - ) - else: - torch_npu._npu_reshape_and_cache( - key=cache_k, - value=cache_v, - key_cache=self.k_buffer[layer_id - self.start_layer].view( - -1, self.page_size, self.head_num, self.head_dim - ), - value_cache=self.v_buffer[layer_id - self.start_layer].view( - -1, self.page_size, self.head_num, self.head_dim - ), - slot_indices=loc, - ) - else: - torch_npu._npu_reshape_and_cache( - key=cache_k, - value=cache_v, - key_cache=self.k_buffer[layer_id - self.start_layer].view( - -1, self.page_size, self.head_num, self.head_dim - ), - value_cache=self.v_buffer[layer_id - self.start_layer].view( - -1, self.page_size, self.head_num, self.head_dim - ), - slot_indices=loc, - ) - - class MLATokenToKVPool(KVCache): def __init__( self, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b067f4a4d144..45b0f60a7604 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -78,7 +78,13 @@ set_global_expert_location_metadata, ) from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater +from sglang.srt.hardware_backend.npu.graph_runner.npu_compile_model_runner import ( + NPUCompileModelRunner, +) from sglang.srt.hardware_backend.npu.graph_runner.npu_graph_runner import NPUGraphRunner +from sglang.srt.hardware_backend.npu.graph_runner.piecewise_npu_graph_runner_decode import ( + PiecewiseNPUGraphRunnerDecode, +) from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.attention.attention_registry import ( ATTENTION_BACKENDS, @@ -130,13 +136,9 @@ ) from sglang.srt.model_executor.hook_manager import register_forward_hooks from sglang.srt.model_executor.input_buffers import GraphInputBuffers -from sglang.srt.model_executor.npu_compile_model_runner import NPUCompileModelRunner from sglang.srt.model_executor.piecewise_cuda_graph_runner import ( PiecewiseCudaGraphRunner, ) -from sglang.srt.model_executor.piecewise_npu_graph_runner_decode import ( - PiecewiseNPUGraphRunnerDecode, -) from sglang.srt.model_loader import get_model from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader from sglang.srt.model_loader.remote_instance_weight_loader_utils import ( diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index b99f18b6469a..9737ac7197a8 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -18,10 +18,12 @@ """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" import logging -from typing import Any, Dict, Iterable, List, Optional, Tuple +import math +from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar import torch from torch import nn +from transformers import PretrainedConfig from sglang.srt.distributed import ( get_moe_expert_parallel_world_size, @@ -71,9 +73,15 @@ is_flashinfer_available, is_non_idle_and_non_empty, is_npu, - supports_custom_op, ) +_is_cuda = is_cuda() + +if _is_cuda: + from sgl_kernel import fused_qk_norm_rope + +TConfig = TypeVar("TConfig", bound=PretrainedConfig) + Qwen3MoeConfig = None _is_flashinfer_available = is_flashinfer_available() @@ -83,13 +91,119 @@ _is_npu = is_npu() if _is_npu: - if supports_custom_op() and ( - get_global_server_args().enable_torch_compile - or get_global_server_args().enable_piecewise_npu_graph_decode - ): - from sglang.srt._custom_ops import split_qkv_rmsnorm_rope + from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope + + +def compute_yarn_parameters( + config: PretrainedConfig, +) -> tuple[float, float, float, float]: + """ + Refer to https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L197C1-L288C1 + Computes the inverse frequencies with NTK scaling. Please refer to the + [original paper](https://huggingface.co/papers/2309.00071) + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + Returns: + factor: float, the scaling factor for the RoPE embeddings + low: float, the lower bound of the dimension range + high: float, the upper bound of the dimension range + attention_factor: float, the post-processing scaling factor applied to the computed cos/sin + """ + + # The config does not contain rope_scaling, which means the model is not using yarn + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is None: + return 1.0, 0, 0, 1.0 + + base = config.rope_theta + partial_rotary_factor = ( + config.partial_rotary_factor + if hasattr(config, "partial_rotary_factor") + else 1.0 + ) + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + dim = int(head_dim * partial_rotary_factor) + factor = getattr(rope_scaling, "factor", 1.0) + attention_factor = rope_scaling.get("attention_factor") + mscale = rope_scaling.get("mscale") + mscale_all_dim = rope_scaling.get("mscale_all_dim") + + if "original_max_position_embeddings" in rope_scaling: + original_max_position_embeddings = rope_scaling[ + "original_max_position_embeddings" + ] + factor = config.max_position_embeddings / original_max_position_embeddings else: - from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope + original_max_position_embeddings = config.max_position_embeddings + + def get_mscale(scale, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + # Sets the attention factor as suggested in the paper + if attention_factor is None: + if mscale and mscale_all_dim: + attention_factor = float( + get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim) + ) + else: + attention_factor = get_mscale(factor) + + # Optional config options + # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) + beta_fast = rope_scaling.get("beta_fast") or 32 + beta_slow = rope_scaling.get("beta_slow") or 1 + + # Compute the inverse frequencies + def find_correction_dim(num_rotations, dim, base, max_position_embeddings): + """Inverse dimension formula to find the dimension based on the number of rotations""" + return ( + dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) + ) / (2 * math.log(base)) + + def find_correction_range( + low_rot, high_rot, dim, base, max_position_embeddings, truncate + ): + """Find dimension range bounds based on rotations""" + low = find_correction_dim(low_rot, dim, base, max_position_embeddings) + high = find_correction_dim(high_rot, dim, base, max_position_embeddings) + if truncate: + low = math.floor(low) + high = math.ceil(high) + return max(low, 0), min(high, dim - 1) + + truncate = rope_scaling.get("truncate", True) + low, high = find_correction_range( + beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate + ) + + # These parts are implemented in the fusedQKNormRopeKernel.cu + # # def linear_ramp_factor(min, max, dim): + # # if min == max: + # # max += 0.001 # Prevent singularity + + # # linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + # # ramp_func = torch.clamp(linear_func, 0, 1) + # # return ramp_func + + # # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs + # # to expand the possible context length. In other words, interpolation = apply scaling factor. + # # pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim) + # # inv_freq_extrapolation = 1.0 / pos_freqs + # # inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + # # # Get n-dimensional rotational scaling corrected for extrapolation + # # inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float) + # # inv_freq = ( + # # inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + # # + inv_freq_extrapolation * inv_freq_extrapolation_factor + # # ) + # # return inv_freq, attention_factor + return factor, low, high, attention_factor class Qwen3MoeSparseMoeBlock(nn.Module): @@ -293,6 +407,7 @@ def __init__( head_dim: Optional[int] = None, rms_norm_eps: float = 1e-06, attention_bias: bool = False, + config: Optional[TConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", dual_chunk_attention_config: Optional[dict[str, Any]] = None, @@ -304,6 +419,7 @@ def __init__( attn_tp_rank = get_attention_tp_rank() attn_tp_size = get_attention_tp_size() + self.config = config self.total_num_heads = num_heads assert self.total_num_heads % attn_tp_size == 0 self.num_heads = self.total_num_heads // attn_tp_size @@ -359,6 +475,14 @@ def __init__( self.compatible_with_fused_kv_buffer = ( False if isinstance(self.rotary_emb, MRotaryEmbedding) else True ) + self.compatible_with_fused_qk_norm_rope = ( + not isinstance(self.rotary_emb, MRotaryEmbedding) + ) and self.head_dim in (64, 128, 256) + self.use_fused_qk_norm_rope = ( + get_global_server_args().enable_fused_qk_norm_rope + and self.compatible_with_fused_qk_norm_rope + ) + self._used_fused_qk_norm_rope_last_call = False self.attn = RadixAttention( self.num_heads, @@ -386,6 +510,9 @@ def _apply_qk_norm( k_by_head = k.reshape(-1, self.head_dim) k_by_head = self.k_norm(k_by_head) current_stream.wait_stream(self.alt_stream) + q = q_by_head.view(q.shape) + k = k_by_head.view(k.shape) + return q, k else: q_by_head = q.reshape(-1, self.head_dim) q_by_head = self.q_norm(q_by_head) @@ -429,6 +556,7 @@ def forward_prepare_npu( q_bias=getattr(self.q_norm, "bias", None), k_bias=getattr(self.k_norm, "bias", None), ) + inner_state = q, k, v, forward_batch return None, forward_batch, inner_state @@ -439,26 +567,61 @@ def forward_prepare_native( forward_batch: ForwardBatch, ): qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self._apply_qk_norm(q, k) - q, k = self.rotary_emb( - positions, - q, - k, - fused_set_kv_buffer_arg=( - create_fused_set_kv_buffer_arg( - value=v, - layer=self.attn, - forward_batch=forward_batch, - ) - if enable_fused_set_kv_buffer(forward_batch) - and self.compatible_with_fused_kv_buffer - else None - ), - ) + + q, k, v = self.apply_qk_norm_rope(qkv, positions, forward_batch) + inner_state = q, k, v, forward_batch return None, forward_batch, inner_state + def apply_qk_norm_rope(self, qkv, positions, forward_batch): + use_fused = self.use_fused_qk_norm_rope and qkv.dtype == torch.bfloat16 + if use_fused: + theta = getattr(self.config, "rope_theta", 10000.0) + positions = ( + positions.view(-1).to(dtype=torch.int32, device=qkv.device).contiguous() + ) + factor, low, high, attention_factor = compute_yarn_parameters(self.config) + fused_qk_norm_rope( + qkv, + self.num_heads, + self.num_kv_heads, + self.num_kv_heads, + self.head_dim, + self.q_norm.variance_epsilon, + self.q_norm.weight, + self.k_norm.weight, + theta, + self.rotary_emb.is_neox_style, + positions, + factor, + low, + high, + attention_factor, + ) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + self._used_fused_qk_norm_rope_last_call = True + else: + # Fallback to non-fused QK Norm & RoPE implementation + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb( + positions, + q, + k, + fused_set_kv_buffer_arg=( + create_fused_set_kv_buffer_arg( + value=v, + layer=self.attn, + forward_batch=forward_batch, + ) + if enable_fused_set_kv_buffer(forward_batch) + and self.compatible_with_fused_kv_buffer + else None + ), + ) + self._used_fused_qk_norm_rope_last_call = False + return q, k, v + def forward_prepare( self, positions: torch.Tensor, @@ -484,12 +647,20 @@ def forward_core(self, intermediate_state): hidden_states, forward_batch, inner_state = intermediate_state if inner_state is None: return hidden_states + + q, k, v, fb = inner_state + + must_save_kv = self._used_fused_qk_norm_rope_last_call + save_kv_cache = must_save_kv or not ( + enable_fused_set_kv_buffer(forward_batch) + and self.compatible_with_fused_kv_buffer + ) attn_output = self.attn( - *inner_state, - save_kv_cache=not ( - enable_fused_set_kv_buffer(forward_batch) - and self.compatible_with_fused_kv_buffer - ), + q, + k, + v, + fb, + save_kv_cache=save_kv_cache, ) output, _ = self.o_proj(attn_output) return output @@ -542,6 +713,7 @@ def __init__( head_dim=head_dim, rms_norm_eps=rms_norm_eps, attention_bias=attention_bias, + config=config, quant_config=quant_config, prefix=add_prefix("self_attn", prefix), dual_chunk_attention_config=dual_chunk_attention_config, diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index ddb3fba439a0..65807dcbff37 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -1923,7 +1923,7 @@ def get_compiler_backend( if hasattr(torch, "npu") and torch.npu.is_available(): if mode == "piecewise": - from sglang.srt.model_executor.compilation.piecewise_npu_graph_compiler_backend import ( + from sglang.srt.hardware_backend.npu.graph_runner.compilation.piecewise_npu_graph_compiler_backend import ( PiecewiseNpuGraphCompilerBackend, ) From 3ac87be62d808b569481484de5e0a3c1019140bb Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Sun, 7 Dec 2025 21:34:26 +0300 Subject: [PATCH 43/71] Merge fixes: 1) updated cache prefetch support 2) ModelWeightParameter & ChannelQuantScaleParameter support --- .../custom_all_reduce_ops.py | 90 +------ python/sglang/srt/hardware_backend/npu/cmo.py | 8 + .../hardware_backend/npu/cmo_custom_ops.py | 39 +++ .../npu/quantization/linear_method_npu.py | 50 ++-- python/sglang/srt/layers/communicator.py | 13 +- python/sglang/srt/models/qwen3.py | 35 +-- python/sglang/srt/models/qwen3_moe.py | 225 ++---------------- 7 files changed, 136 insertions(+), 324 deletions(-) create mode 100644 python/sglang/srt/hardware_backend/npu/cmo_custom_ops.py diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_ops.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_ops.py index 3b4c53a7f4fd..3d7e3a56b795 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_ops.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_ops.py @@ -4,101 +4,13 @@ import torch -from sglang.srt.utils import direct_register_custom_op, is_cuda, is_hip +from sglang.srt.utils import is_cuda, is_hip logger = logging.getLogger(__name__) _is_cuda = is_cuda() _is_hip = is_hip() -import sgl_kernel_npu.norm.split_qkv_rmsnorm_rope - -import sglang.srt.utils - - -@torch.library.custom_op("sglang::split_qkv_rmsnorm_rope", mutates_args=()) -def split_qkv_rmsnorm_rope( - input: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - q_weight: torch.Tensor, - k_weight: torch.Tensor, - q_hiddent_size: int, - kv_hidden_size: int, - head_dim: int, - eps: float, - q_bias: torch.Tensor, - k_bias: torch.Tensor, -) -> List[torch.Tensor]: - q, k, v = sgl_kernel_npu.norm.split_qkv_rmsnorm_rope.split_qkv_rmsnorm_rope( - input, - sin, - cos, - q_weight, - k_weight, - q_hiddent_size, - kv_hidden_size, - head_dim, - eps, - q_bias, - k_bias, - ) - return [q, k, v] - - -@split_qkv_rmsnorm_rope.register_fake -def split_qkv_rmsnorm_rope_fake( - input: torch.Tensor, - sin: torch.Tensor, - cos: torch.Tensor, - q_weight: torch.Tensor, - k_weight: torch.Tensor, - q_hiddent_size: int, - kv_hidden_size: int, - head_dim: int, - eps: float, - q_bias: torch.Tensor, - k_bias: torch.Tensor, -) -> List[torch.Tensor]: - q = torch.empty((128, 4096), dtype=input.dtype, device=input.device) - k = torch.empty((128, 512), dtype=input.dtype, device=input.device) - v = torch.empty((128, 512), dtype=input.dtype, device=input.device) - return [q, k, v] - - -@torch.library.custom_op("sglang::wait_cmo_stream", mutates_args=()) -def wait_cmo_stream() -> None: - if sglang.srt.utils.get_cmo_stream(): - sglang.srt.utils.wait_cmo_stream() - - -@wait_cmo_stream.register_fake -def wait_cmo_stream_fake() -> None: - pass - - -def get_cmo_stream() -> bool: - return True - - -def prepare_weight_cache(handle: torch.Tensor, cache: List[torch.Tensor]) -> None: - sglang.srt.utils.prepare_weight_cache(handle, cache) - - -def prepare_weight_cache_register_fake( - handle: torch.Tensor, cache: List[torch.Tensor] -) -> None: - pass - - -direct_register_custom_op( - op_name="prepare_weight_cache", - op_func=prepare_weight_cache, - mutates_args=["handle"], - fake_impl=prepare_weight_cache_register_fake, -) - - IS_CUSTOM_AR_AVAILABLE = _is_cuda or _is_hip IS_QUICK_AR_AVAILABLE = _is_hip # TODO(zyksir): mscclpp is untested on AMD and therefore disabled. diff --git a/python/sglang/srt/hardware_backend/npu/cmo.py b/python/sglang/srt/hardware_backend/npu/cmo.py index 40f3b4f1696b..7b8d117735bc 100644 --- a/python/sglang/srt/hardware_backend/npu/cmo.py +++ b/python/sglang/srt/hardware_backend/npu/cmo.py @@ -1,5 +1,7 @@ import torch +from sglang.srt.layers.parameter import ModelWeightParameter + cmo_stream = None @@ -18,6 +20,12 @@ def set_cmo_stream(stream): cmo_stream = stream +def get_weight_cache(layer): + if isinstance(layer.weight, ModelWeightParameter): + return layer.weight_data + return layer.weight + + def prepare_weight_cache(handle, cache, PREFETCH_MAX_SIZE=1000000000): """ PREFETCH_MAX_SIZE: maximum size (bytes) for each prefetch operation. diff --git a/python/sglang/srt/hardware_backend/npu/cmo_custom_ops.py b/python/sglang/srt/hardware_backend/npu/cmo_custom_ops.py new file mode 100644 index 000000000000..0331ce4d7b2f --- /dev/null +++ b/python/sglang/srt/hardware_backend/npu/cmo_custom_ops.py @@ -0,0 +1,39 @@ +from typing import List + +import torch + +import sglang.srt.hardware_backend.npu.cmo +from sglang.srt.utils import direct_register_custom_op + + +@torch.library.custom_op("sglang::wait_cmo_stream", mutates_args=()) +def wait_cmo_stream() -> None: + if sglang.srt.hardware_backend.npu.cmo.get_cmo_stream(): + sglang.srt.hardware_backend.npu.cmo.wait_cmo_stream() + + +@wait_cmo_stream.register_fake +def wait_cmo_stream_fake() -> None: + pass + + +def get_cmo_stream() -> bool: + return True + + +def prepare_weight_cache(handle: torch.Tensor, cache: List[torch.Tensor]) -> None: + sglang.srt.hardware_backend.npu.cmo.prepare_weight_cache(handle, cache) + + +def prepare_weight_cache_register_fake( + handle: torch.Tensor, cache: List[torch.Tensor] +) -> None: + pass + + +direct_register_custom_op( + op_name="prepare_weight_cache", + op_func=prepare_weight_cache, + mutates_args=["handle"], + fake_impl=prepare_weight_cache_register_fake, +) diff --git a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py index 06a91fedd55d..fcbdc3d1e49b 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py @@ -41,10 +41,13 @@ def create_weights( weight_loader = extra_weight_attrs.get("weight_loader") output_size_per_partition = sum(output_partition_sizes) + weight_data = torch.empty( + (output_size_per_partition, input_size_per_partition), dtype=torch.int8 + ) + layer.__dict__["weight_data"] = weight_data + weight = ModelWeightParameter( - data=torch.empty( - (output_size_per_partition, input_size_per_partition), dtype=torch.int8 - ), + data=weight_data, input_dim=1, output_dim=0, weight_loader=weight_loader, @@ -79,8 +82,11 @@ def create_weights( input_offset.ignore_warning = True layer.register_parameter("input_offset", input_offset) + quant_bias_data = torch.empty(output_size_per_partition, dtype=torch.int32) + layer.__dict__["quant_bias_data"] = quant_bias_data + quant_bias = ChannelQuantScaleParameter( - data=torch.empty(output_size_per_partition, dtype=torch.int32), + data=quant_bias_data, output_dim=0, weight_loader=weight_loader, ) @@ -92,8 +98,12 @@ def create_weights( deq_scale_dtype = torch.int64 else: raise ValueError(f"Unsupported params_dtype: {params_dtype}") + + deq_scale_data = torch.empty(output_size_per_partition, dtype=deq_scale_dtype) + layer.__dict__["deq_scale_data"] = deq_scale_data + deq_scale = ChannelQuantScaleParameter( - data=torch.empty(output_size_per_partition, dtype=deq_scale_dtype), + data=deq_scale_data, output_dim=0, weight_loader=weight_loader, ) @@ -118,7 +128,7 @@ def apply( x = torch.ops.npu.npu_quantize( x, - layer.aclnn_input_scale_reciprocal, + aclnn_input_scale_reciprocal, layer.aclnn_input_offset, torch.qint8, -1, @@ -129,11 +139,11 @@ def apply( if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0: quant_bias = None else: - quant_bias = layer.quant_bias + quant_bias = layer.quant_bias_data return torch.ops.npu.npu_quant_matmul( x, - layer.weight, - layer.deq_scale, + layer.weight_data, + layer.deq_scale_data, bias=quant_bias, output_dtype=layer.params_dtype, ) @@ -141,6 +151,7 @@ def apply( def process_weights_after_loading(self, layer: torch.nn.Module): layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = npu_format_cast(layer.weight.data) + layer.weight_data = layer.weight.data layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data) @@ -184,18 +195,26 @@ def create_weights( weight_loader = extra_weight_attrs.get("weight_loader") output_size_per_partition = sum(output_partition_sizes) + weight_data = torch.empty( + (output_size_per_partition, input_size_per_partition), dtype=torch.int8 + ) + layer.__dict__["weight_data"] = weight_data + weight = ModelWeightParameter( - data=torch.empty( - (output_size_per_partition, input_size_per_partition), dtype=torch.int8 - ), + data=weight_data, input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight", weight) + weight_scale_data = torch.empty( + (output_size_per_partition, 1), dtype=params_dtype + ) + layer.__dict__["weight_scale_data"] = weight_scale_data + weight_scale = ChannelQuantScaleParameter( - data=torch.empty((output_size_per_partition, 1), dtype=params_dtype), + data=weight_scale_data, output_dim=0, weight_loader=weight_loader, ) @@ -218,8 +237,8 @@ def apply( quant_out, dynamic_scale = torch.ops.npu.npu_dynamic_quant(x) return torch.ops.npu.npu_quant_matmul( quant_out, - layer.weight, - layer.weight_scale, + layer.weight_data, + layer.weight_scale_data, pertoken_scale=dynamic_scale, bias=bias, output_dtype=original_dtype, @@ -228,6 +247,7 @@ def apply( def process_weights_after_loading(self, layer: torch.nn.Module): layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = npu_format_cast(layer.weight.data) + layer.weight_data = layer.weight.data layer.weight_scale.data = layer.weight_scale.data.flatten() layer.weight_offset.data = layer.weight_offset.data.flatten() diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 5d03cd84b035..0df8b3f80e72 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -1,4 +1,4 @@ -# Copyright 2023-2025 SGLang Team +# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -73,8 +73,11 @@ from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant -elif _is_npu: - from sglang.srt.hardware_backend.npu.cmo import prepare_weight_cache + +if _is_npu: + from sglang.srt.hardware_backend.npu.cmo_custom_ops import ( # noqa + prepare_weight_cache, + ) FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048 @@ -780,7 +783,9 @@ def _gather_hidden_states_and_residual( else: hidden_states = tensor_model_parallel_all_reduce(hidden_states) if _is_npu and context.cache is not None: - torch.ops.sglang.prepare_weight_cache(hidden_states, context.cache) + _ = torch.ops.sglang.prepare_weight_cache( + hidden_states, context.cache + ) hidden_states, residual = layernorm(hidden_states, residual) return hidden_states, residual diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 33a13c7da2b7..6163dded80fe 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -32,25 +32,29 @@ from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import add_prefix, is_cuda, is_npu, supports_custom_op -if ( - is_npu() - and supports_custom_op() - and ( - get_global_server_args().enable_torch_compile - or get_global_server_args().enable_piecewise_npu_graph_decode - ) -): - from python.sglang.srt.distributed.device_communicators.custom_all_reduce_ops import get_cmo_stream, wait_cmo_stream -else: - from sglang.srt.utils import get_cmo_stream, wait_cmo_stream -from sglang.srt.utils import add_prefix, is_cuda, is_npu - Qwen3Config = None logger = logging.getLogger(__name__) _is_cuda = is_cuda() _is_npu = is_npu() +if _is_npu: + if supports_custom_op() and ( + get_global_server_args().enable_torch_compile + or get_global_server_args().enable_piecewise_npu_graph_decode + ): + from sglang.srt.hardware_backend.npu.cmo import get_weight_cache + from sglang.srt.hardware_backend.npu.cmo_custom_ops import ( + get_cmo_stream, + wait_cmo_stream, + ) + else: + from sglang.srt.hardware_backend.npu.cmo import ( + get_cmo_stream, + get_weight_cache, + wait_cmo_stream, + ) + class Qwen3Attention(nn.Module): def __init__( @@ -282,7 +286,10 @@ def forward( residual, forward_batch, cache=( - [self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight] + [ + get_weight_cache(self.mlp.gate_up_proj), + get_weight_cache(self.mlp.down_proj), + ] if _is_npu else None ), diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 9737ac7197a8..3fbe81257290 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -18,12 +18,10 @@ """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" import logging -import math -from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn -from transformers import PretrainedConfig from sglang.srt.distributed import ( get_moe_expert_parallel_world_size, @@ -75,13 +73,6 @@ is_npu, ) -_is_cuda = is_cuda() - -if _is_cuda: - from sgl_kernel import fused_qk_norm_rope - -TConfig = TypeVar("TConfig", bound=PretrainedConfig) - Qwen3MoeConfig = None _is_flashinfer_available = is_flashinfer_available() @@ -94,118 +85,6 @@ from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope -def compute_yarn_parameters( - config: PretrainedConfig, -) -> tuple[float, float, float, float]: - """ - Refer to https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L197C1-L288C1 - Computes the inverse frequencies with NTK scaling. Please refer to the - [original paper](https://huggingface.co/papers/2309.00071) - Args: - config ([`~transformers.PretrainedConfig`]): - The model configuration. - Returns: - factor: float, the scaling factor for the RoPE embeddings - low: float, the lower bound of the dimension range - high: float, the upper bound of the dimension range - attention_factor: float, the post-processing scaling factor applied to the computed cos/sin - """ - - # The config does not contain rope_scaling, which means the model is not using yarn - rope_scaling = getattr(config, "rope_scaling", None) - if rope_scaling is None: - return 1.0, 0, 0, 1.0 - - base = config.rope_theta - partial_rotary_factor = ( - config.partial_rotary_factor - if hasattr(config, "partial_rotary_factor") - else 1.0 - ) - head_dim = getattr( - config, "head_dim", config.hidden_size // config.num_attention_heads - ) - dim = int(head_dim * partial_rotary_factor) - factor = getattr(rope_scaling, "factor", 1.0) - attention_factor = rope_scaling.get("attention_factor") - mscale = rope_scaling.get("mscale") - mscale_all_dim = rope_scaling.get("mscale_all_dim") - - if "original_max_position_embeddings" in rope_scaling: - original_max_position_embeddings = rope_scaling[ - "original_max_position_embeddings" - ] - factor = config.max_position_embeddings / original_max_position_embeddings - else: - original_max_position_embeddings = config.max_position_embeddings - - def get_mscale(scale, mscale=1): - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - # Sets the attention factor as suggested in the paper - if attention_factor is None: - if mscale and mscale_all_dim: - attention_factor = float( - get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dim) - ) - else: - attention_factor = get_mscale(factor) - - # Optional config options - # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) - beta_fast = rope_scaling.get("beta_fast") or 32 - beta_slow = rope_scaling.get("beta_slow") or 1 - - # Compute the inverse frequencies - def find_correction_dim(num_rotations, dim, base, max_position_embeddings): - """Inverse dimension formula to find the dimension based on the number of rotations""" - return ( - dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) - ) / (2 * math.log(base)) - - def find_correction_range( - low_rot, high_rot, dim, base, max_position_embeddings, truncate - ): - """Find dimension range bounds based on rotations""" - low = find_correction_dim(low_rot, dim, base, max_position_embeddings) - high = find_correction_dim(high_rot, dim, base, max_position_embeddings) - if truncate: - low = math.floor(low) - high = math.ceil(high) - return max(low, 0), min(high, dim - 1) - - truncate = rope_scaling.get("truncate", True) - low, high = find_correction_range( - beta_fast, beta_slow, dim, base, original_max_position_embeddings, truncate - ) - - # These parts are implemented in the fusedQKNormRopeKernel.cu - # # def linear_ramp_factor(min, max, dim): - # # if min == max: - # # max += 0.001 # Prevent singularity - - # # linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - # # ramp_func = torch.clamp(linear_func, 0, 1) - # # return ramp_func - - # # Note on variable naming: "interpolation" comes from the original technique, where we interpolate the position IDs - # # to expand the possible context length. In other words, interpolation = apply scaling factor. - # # pos_freqs = base ** (torch.arange(0, dim, 2).to(device=device, dtype=torch.float) / dim) - # # inv_freq_extrapolation = 1.0 / pos_freqs - # # inv_freq_interpolation = 1.0 / (factor * pos_freqs) - - # # # Get n-dimensional rotational scaling corrected for extrapolation - # # inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2).to(device=device, dtype=torch.float) - # # inv_freq = ( - # # inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) - # # + inv_freq_extrapolation * inv_freq_extrapolation_factor - # # ) - # # return inv_freq, attention_factor - return factor, low, high, attention_factor - - class Qwen3MoeSparseMoeBlock(nn.Module): def __init__( self, @@ -407,7 +286,6 @@ def __init__( head_dim: Optional[int] = None, rms_norm_eps: float = 1e-06, attention_bias: bool = False, - config: Optional[TConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", dual_chunk_attention_config: Optional[dict[str, Any]] = None, @@ -419,7 +297,6 @@ def __init__( attn_tp_rank = get_attention_tp_rank() attn_tp_size = get_attention_tp_size() - self.config = config self.total_num_heads = num_heads assert self.total_num_heads % attn_tp_size == 0 self.num_heads = self.total_num_heads // attn_tp_size @@ -475,14 +352,6 @@ def __init__( self.compatible_with_fused_kv_buffer = ( False if isinstance(self.rotary_emb, MRotaryEmbedding) else True ) - self.compatible_with_fused_qk_norm_rope = ( - not isinstance(self.rotary_emb, MRotaryEmbedding) - ) and self.head_dim in (64, 128, 256) - self.use_fused_qk_norm_rope = ( - get_global_server_args().enable_fused_qk_norm_rope - and self.compatible_with_fused_qk_norm_rope - ) - self._used_fused_qk_norm_rope_last_call = False self.attn = RadixAttention( self.num_heads, @@ -510,9 +379,6 @@ def _apply_qk_norm( k_by_head = k.reshape(-1, self.head_dim) k_by_head = self.k_norm(k_by_head) current_stream.wait_stream(self.alt_stream) - q = q_by_head.view(q.shape) - k = k_by_head.view(k.shape) - return q, k else: q_by_head = q.reshape(-1, self.head_dim) q_by_head = self.q_norm(q_by_head) @@ -556,7 +422,6 @@ def forward_prepare_npu( q_bias=getattr(self.q_norm, "bias", None), k_bias=getattr(self.k_norm, "bias", None), ) - inner_state = q, k, v, forward_batch return None, forward_batch, inner_state @@ -567,61 +432,26 @@ def forward_prepare_native( forward_batch: ForwardBatch, ): qkv, _ = self.qkv_proj(hidden_states) - - q, k, v = self.apply_qk_norm_rope(qkv, positions, forward_batch) - + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb( + positions, + q, + k, + fused_set_kv_buffer_arg=( + create_fused_set_kv_buffer_arg( + value=v, + layer=self.attn, + forward_batch=forward_batch, + ) + if enable_fused_set_kv_buffer(forward_batch) + and self.compatible_with_fused_kv_buffer + else None + ), + ) inner_state = q, k, v, forward_batch return None, forward_batch, inner_state - def apply_qk_norm_rope(self, qkv, positions, forward_batch): - use_fused = self.use_fused_qk_norm_rope and qkv.dtype == torch.bfloat16 - if use_fused: - theta = getattr(self.config, "rope_theta", 10000.0) - positions = ( - positions.view(-1).to(dtype=torch.int32, device=qkv.device).contiguous() - ) - factor, low, high, attention_factor = compute_yarn_parameters(self.config) - fused_qk_norm_rope( - qkv, - self.num_heads, - self.num_kv_heads, - self.num_kv_heads, - self.head_dim, - self.q_norm.variance_epsilon, - self.q_norm.weight, - self.k_norm.weight, - theta, - self.rotary_emb.is_neox_style, - positions, - factor, - low, - high, - attention_factor, - ) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - self._used_fused_qk_norm_rope_last_call = True - else: - # Fallback to non-fused QK Norm & RoPE implementation - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self._apply_qk_norm(q, k) - q, k = self.rotary_emb( - positions, - q, - k, - fused_set_kv_buffer_arg=( - create_fused_set_kv_buffer_arg( - value=v, - layer=self.attn, - forward_batch=forward_batch, - ) - if enable_fused_set_kv_buffer(forward_batch) - and self.compatible_with_fused_kv_buffer - else None - ), - ) - self._used_fused_qk_norm_rope_last_call = False - return q, k, v - def forward_prepare( self, positions: torch.Tensor, @@ -647,20 +477,12 @@ def forward_core(self, intermediate_state): hidden_states, forward_batch, inner_state = intermediate_state if inner_state is None: return hidden_states - - q, k, v, fb = inner_state - - must_save_kv = self._used_fused_qk_norm_rope_last_call - save_kv_cache = must_save_kv or not ( - enable_fused_set_kv_buffer(forward_batch) - and self.compatible_with_fused_kv_buffer - ) attn_output = self.attn( - q, - k, - v, - fb, - save_kv_cache=save_kv_cache, + *inner_state, + save_kv_cache=not ( + enable_fused_set_kv_buffer(forward_batch) + and self.compatible_with_fused_kv_buffer + ), ) output, _ = self.o_proj(attn_output) return output @@ -713,7 +535,6 @@ def __init__( head_dim=head_dim, rms_norm_eps=rms_norm_eps, attention_bias=attention_bias, - config=config, quant_config=quant_config, prefix=add_prefix("self_attn", prefix), dual_chunk_attention_config=dual_chunk_attention_config, From 105050fd9c0a644b2d8b58025784405488209c55 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Wed, 10 Dec 2025 20:33:31 +0300 Subject: [PATCH 44/71] cleanup --- python/sglang/srt/server_args.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 653a0c801004..dc3c92571fd3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -3606,11 +3606,6 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Optimize the model with torch.compile. Experimental feature.", ) - parser.add_argument( - "--enable-torch-air-compile", - action="store_true", - help="Optimize the model with Torch Ascend Intermediate Representation compilation. Experimental feature.", - ) parser.add_argument( "--enable-torch-compile-debug-mode", action="store_true", From 90caee290f24e815daab2025ee909b88f21f3711 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Wed, 10 Dec 2025 20:33:31 +0300 Subject: [PATCH 45/71] Cleanup & ComilationConfig usage update & master merge refactoring --- .../compilation}/compilation_context.py | 0 .../graph_runner/compilation/custom_ops.py | 55 +++++++++++++++++++ .../compilation/npu_graph_backend.py | 4 +- .../compilation}/npu_graph_compiler.py | 11 ++-- .../npu_graph_compiler_backend.py | 10 +++- .../graph_runner/compilation}/pass_manager.py | 14 ++++- .../graph_runner/compilation}/passes/fp16.py | 29 +++++----- .../compilation}/passes/w8a8_int8.py | 0 .../piecewise_npu_graph_compiler.py | 18 +++--- .../piecewise_npu_graph_compiler_backend.py | 8 ++- .../graph_runner/npu_compile_model_runner.py | 9 ++- .../npu/graph_runner/npu_graph_runner.py | 10 ++-- .../piecewise_npu_graph_runner_decode.py | 12 ++-- .../sglang/srt/hardware_backend/npu/utils.py | 40 ++++++++++++++ .../sglang/srt/model_executor/model_runner.py | 2 +- python/sglang/srt/server_args.py | 6 ++ python/sglang/srt/utils/common.py | 50 +++++++++++------ 17 files changed, 208 insertions(+), 70 deletions(-) rename python/sglang/srt/{compilation/npu => hardware_backend/npu/graph_runner/compilation}/compilation_context.py (100%) create mode 100644 python/sglang/srt/hardware_backend/npu/graph_runner/compilation/custom_ops.py rename python/sglang/srt/{compilation/npu => hardware_backend/npu/graph_runner/compilation}/npu_graph_compiler.py (83%) rename python/sglang/srt/{compilation/npu => hardware_backend/npu/graph_runner/compilation}/npu_graph_compiler_backend.py (89%) rename python/sglang/srt/{compilation/npu => hardware_backend/npu/graph_runner/compilation}/pass_manager.py (79%) rename python/sglang/srt/{compilation/npu => hardware_backend/npu/graph_runner/compilation}/passes/fp16.py (86%) rename python/sglang/srt/{compilation/npu => hardware_backend/npu/graph_runner/compilation}/passes/w8a8_int8.py (100%) diff --git a/python/sglang/srt/compilation/npu/compilation_context.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/compilation_context.py similarity index 100% rename from python/sglang/srt/compilation/npu/compilation_context.py rename to python/sglang/srt/hardware_backend/npu/graph_runner/compilation/compilation_context.py diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/custom_ops.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/custom_ops.py new file mode 100644 index 000000000000..1c4dfb5473b1 --- /dev/null +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/custom_ops.py @@ -0,0 +1,55 @@ +from typing import List + +import sgl_kernel_npu.norm.split_qkv_rmsnorm_rope +import torch + + +@torch.library.custom_op("sglang::split_qkv_rmsnorm_rope", mutates_args=()) +def split_qkv_rmsnorm_rope( + input: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_hidden_size: int, + kv_hiddem_size: int, + head_dim: int, + eps: float, + q_bias: torch.Tensor, + k_bias: torch.Tensor, +) -> List[torch.Tensor]: + q, k, v = sgl_kernel_npu.norm.split_qkv_rmsnorm_rope.split_qkv_rmsnorm_rope( + input, + sin, + cos, + q_weight, + k_weight, + q_hidden_size, + kv_hiddem_size, + head_dim, + eps, + q_bias, + k_bias, + ) + return [q, k, v] + + +@split_qkv_rmsnorm_rope.register_fake +def split_qkv_rmsnorm_rope( + input: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_hidden_size: int, + kv_hiddem_size: int, + head_dim: int, + eps: float, + q_bias: torch.Tensor, + k_bias: torch.Tensor, +) -> List[torch.Tensor]: + # TODO: generalize shape + q = torch.empty((128, 2048), dtype=input.dtype, device=input.device) + k = torch.empty((128, 256), dtype=input.dtype, device=input.device) + v = torch.empty((128, 256), dtype=input.dtype, device=input.device) + return [q, k, v] diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_backend.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_backend.py index 0481b0f6b163..07535210910c 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_backend.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_backend.py @@ -17,7 +17,9 @@ import torch import torch_npu -from sglang.srt.compilation.npu.compilation_context import CompilationContext +from sglang.srt.hardware_backend.npu.graph_runner.compilation.compilation_context import ( + CompilationContext, +) class NPUGraphBackend: diff --git a/python/sglang/srt/compilation/npu/npu_graph_compiler.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler.py similarity index 83% rename from python/sglang/srt/compilation/npu/npu_graph_compiler.py rename to python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler.py index dd4c059a8d10..fa8d9b5283dc 100644 --- a/python/sglang/srt/compilation/npu/npu_graph_compiler.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler.py @@ -27,15 +27,14 @@ def __init__( ): torch._dynamo.reset() + if compilation_config is None: + compilation_config = CompilationConfig(compiler="npugraph") + backend = get_compiler_backend( - ( - "npugraph_fused" - if compilation_config is None or compilation_config.compiler is None - else compilation_config.compiler - ), - model_runner.model_config.dtype, + compilation_config=compilation_config, model_runner=model_runner ) backend.init(model_runner.model_config) + self.compiled_callable = torch.compile( model, fullgraph=True, dynamic=False, backend=backend ) diff --git a/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py similarity index 89% rename from python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py rename to python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py index aca189cb8102..2ed6513b3088 100644 --- a/python/sglang/srt/compilation/npu/npu_graph_compiler_backend.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py @@ -16,9 +16,13 @@ import torch -from sglang.srt.compilation.npu.pass_manager import PassManager -from sglang.srt.compilation.npu.passes.fp16 import SplitQkvRmsnormRopeFuse -from sglang.srt.compilation.npu.passes.w8a8_int8 import ( +from sglang.srt.hardware_backend.npu.graph_runner.compilation.pass_manager import ( + PassManager, +) +from sglang.srt.hardware_backend.npu.graph_runner.compilation.passes.fp16 import ( + SplitQkvRmsnormRopeFuse, +) +from sglang.srt.hardware_backend.npu.graph_runner.compilation.passes.w8a8_int8 import ( DivFuse, EraseCopy, NpuAddRmsNormDynamicQuantFuse, diff --git a/python/sglang/srt/compilation/npu/pass_manager.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/pass_manager.py similarity index 79% rename from python/sglang/srt/compilation/npu/pass_manager.py rename to python/sglang/srt/hardware_backend/npu/graph_runner/compilation/pass_manager.py index 52318e04b131..ab1ecd806dae 100644 --- a/python/sglang/srt/compilation/npu/pass_manager.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/pass_manager.py @@ -12,8 +12,12 @@ # limitations under the License. # ============================================================================== +import logging + import torch +logger = logging.getLogger(__name__) + class PassManager: def __init__(self, graph_module: torch.fx.GraphModule): @@ -35,9 +39,15 @@ def apply(self): results = torch.fx.replace_pattern( self.graph_module, pass_.pattern, pass_.replacement ) - except: + + logger.debug( + f"PassManager::apply: pass_instance={type(pass_instance)}: results({len(results)})={results}" + ) + except Exception as e: # pass was not applied - pass + logger.debug( + f"PassManager::apply: pass_instance={type(pass_instance)}: ignored={e}" + ) if not updated: updated = len(results) != 0 diff --git a/python/sglang/srt/compilation/npu/passes/fp16.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/fp16.py similarity index 86% rename from python/sglang/srt/compilation/npu/passes/fp16.py rename to python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/fp16.py index fade634948cd..1d32ca842e9b 100644 --- a/python/sglang/srt/compilation/npu/passes/fp16.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/fp16.py @@ -12,9 +12,10 @@ # limitations under the License. # ============================================================================== -import sgl_kernel_npu import torch +import sglang.srt.hardware_backend.npu.graph_runner.compilation.custom_ops # noqa + class SplitQkvRmsnormRopeFuse: instance = None @@ -106,20 +107,18 @@ def replacement( sin_view = sin.view(-1, 1, 1, self.head_dim) sin_contiguous = sin_view.contiguous() - split_qkv_rmsnorm_rope_default = ( - sgl_kernel_npu.norm.split_qkv_rmsnorm_rope.default( - output_parallel, - sin_contiguous, - cos_contiguous, - q_norm_parameters_weight, - k_norm_parameters_weight, - self.q_size, - self.kv_size, - self.head_dim, - self.variance_epsilon, - q_bias=None, - k_bias=None, - ) + split_qkv_rmsnorm_rope_default = torch.ops.sglang.split_qkv_rmsnorm_rope( + output_parallel, + sin_contiguous, + cos_contiguous, + q_norm_parameters_weight, + k_norm_parameters_weight, + self.q_size, + self.kv_size, + self.head_dim, + self.variance_epsilon, + q_bias=None, + k_bias=None, ) q = split_qkv_rmsnorm_rope_default[0] diff --git a/python/sglang/srt/compilation/npu/passes/w8a8_int8.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/w8a8_int8.py similarity index 100% rename from python/sglang/srt/compilation/npu/passes/w8a8_int8.py rename to python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/w8a8_int8.py diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler.py index d4d72612d5c6..5477f0c4131b 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler.py @@ -18,7 +18,9 @@ import torch from sglang.srt.compilation.compilation_config import CompilationConfig -from sglang.srt.compilation.npu.compilation_context import CompilationContext +from sglang.srt.hardware_backend.npu.graph_runner.compilation.compilation_context import ( + CompilationContext, +) from sglang.srt.utils.common import get_compiler_backend @@ -30,15 +32,13 @@ def __init__( compilation_config: CompilationConfig, compilation_context: CompilationContext, ): + if compilation_config is None: + compilation_config = CompilationConfig(compiler="piecewise") + backend = get_compiler_backend( - ( - "piecewise" - if compilation_config.compiler is None - else compilation_config.compiler - ), - model_runner, - compilation_config, - compilation_context, + model_runner=model_runner, + compilation_config=compilation_config, + compilation_context=compilation_context, ) backend.init(model_runner.model_config) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py index 8031122b65d2..be5556b2d6eb 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py @@ -20,11 +20,13 @@ import torch from sglang.srt.compilation.compilation_config import CompilationConfig -from sglang.srt.compilation.npu.compilation_context import CompilationContext -from sglang.srt.compilation.npu.npu_graph_compiler_backend import ( +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.hardware_backend.npu.graph_runner.compilation.compilation_context import ( + CompilationContext, +) +from sglang.srt.hardware_backend.npu.graph_runner.compilation.npu_graph_compiler_backend import ( NpuGraphCompilerBackend, ) -from sglang.srt.distributed import get_tensor_model_parallel_world_size logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py index 0f3a1a19b3b8..d3c1352dc4f6 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py @@ -37,17 +37,16 @@ ForwardBatch, PPProxyTensors, ) +from sglang.srt.server_args import get_global_server_args class NPUCompileModelRunner: def __init__(self, model_runner: ModelRunner): - print(f"NPUCompileModelRunner::__init__", flush=True) self.model_runner = model_runner _, self.compile_bs = get_batch_sizes_to_capture(model_runner) self.capture() def capture(self) -> None: - print(f"NPUCompileModelRunner::capture", flush=True) # Reverse the order to enable better memory sharing across cuda graphs. compile_range = ( tqdm.tqdm(list(reversed(self.compile_bs))) @@ -55,7 +54,11 @@ def capture(self) -> None: else reversed(self.compile_bs) ) - backend = get_compiler_backend("reduce-overhead") + backend = get_compiler_backend( + mode="reduce-overhead", + compilation_config=get_global_server_args().compilation_config, + ) + compile_forward = torch.compile( torch.no_grad()(self.model_runner.model.forward), fullgraph=True, diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py index a884b9812aa4..950f8c3871b0 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py @@ -48,7 +48,9 @@ _set_dp_buffer_len, _set_is_extend_in_batch, ) -from sglang.srt.compilation.npu.npu_graph_compiler import NpuGraphCompiler +from sglang.srt.hardware_backend.npu.graph_runner.compilation.npu_graph_compiler import ( + NpuGraphCompiler, +) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors @@ -108,9 +110,9 @@ def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): if self.enable_torch_compile and (not self.compile_bs or bs in self.compile_bs): self.model_runner.attn_backend.enable_torch_compile = True compiler = NpuGraphCompiler( - self.model_runner, - run_once_fn, - get_global_server_args().compilation_config, + model_runner=self.model_runner, + model=run_once_fn, + compilation_config=get_global_server_args().compilation_config, ) # compilation diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py b/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py index 8ca9ace090bc..a885f08e9aaf 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py @@ -24,10 +24,12 @@ import tqdm from sglang.srt.compilation.compilation_config import CompilationConfig -from sglang.srt.compilation.npu.compilation_context import CompilationContext from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import graph_capture from sglang.srt.hardware_backend.npu.attention.ascend_backend import AscendAttnBackend +from sglang.srt.hardware_backend.npu.graph_runner.compilation.compilation_context import ( + CompilationContext, +) from sglang.srt.hardware_backend.npu.graph_runner.compilation.piecewise_npu_graph_compiler import ( PiecewiseNpuGraphCompiler, ) @@ -398,10 +400,10 @@ def capture_one_batch_size( self.model_runner.attn_backend.graph_mode = True compiler = PiecewiseNpuGraphCompiler( - self.model_runner, - self.model_runner.model, - self.compilation_config, - self.compilation_context, + model_runner=self.model_runner, + model=self.model_runner.model, + compilation_config=self.compilation_config, + compilation_context=self.compilation_context, ) logits_output_or_pp_proxy_tensors = compiler.compiled_callable( diff --git a/python/sglang/srt/hardware_backend/npu/utils.py b/python/sglang/srt/hardware_backend/npu/utils.py index 97c126db54a1..115e749be2b9 100644 --- a/python/sglang/srt/hardware_backend/npu/utils.py +++ b/python/sglang/srt/hardware_backend/npu/utils.py @@ -58,6 +58,46 @@ def set_default_server_args(args: "ServerArgs"): else: args.hicache_mem_layout = "page_first_direct" + if args.enable_piecewise_npu_graph_decode and args.enable_torch_air_compile: + raise ValueError( + "Cannot enable both --enable-piecewise-npu-graph-decode and --enable-torch-air-compile" + ) + + if args.compilation_config: + if args.compilation_config.compiler == "npugraph": + args.enable_torch_compile = True + + if args.disable_cuda_graph: + raise ValueError( + f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --disable-cuda-graph" + ) + + if args.enable_piecewise_npu_graph_decode: + raise ValueError( + f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-piecewise-npu-graph-decode" + ) + + if args.enable_torch_air_compile: + raise ValueError( + f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-torch-air-compile" + ) + + if args.compilation_config.compiler == "piecewise": + args.enable_piecewise_npu_graph_decode = True + + if args.enable_torch_air_compile: + raise ValueError( + f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-torch-air-compile" + ) + + if args.compilation_config.compiler == "torchair": + args.enable_torch_air_compile = True + + if args.enable_piecewise_npu_graph_decode: + raise ValueError( + f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-piecewise-npu-graph-decode" + ) + @_call_once def init_npu_backend(): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fd7ab8f2bf6c..baed481161d4 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2460,7 +2460,7 @@ def init_device_graphs(self): if self.server_args.enable_piecewise_npu_graph_decode else ( NPUCompileModelRunner - if self.server_args.disable_cuda_graph + if self.server_args.enable_torch_air_compile else NPUGraphRunner ) ), diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index dc3c92571fd3..7dd6d04a873e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -514,6 +514,7 @@ class ServerArgs: tbo_token_distribution_threshold: float = 0.48 enable_torch_compile: bool = False enable_piecewise_cuda_graph: bool = False + enable_torch_air_compile: bool = False enable_torch_compile_debug_mode: bool = False torch_compile_max_bs: int = 32 piecewise_cuda_graph_max_tokens: int = 4096 @@ -3611,6 +3612,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable debug mode for torch compile", ) + parser.add_argument( + "--enable-torch-air-compile", + action="store_true", + help="Optimize the model with Torch Ascend Intermediate Representation compilation. Experimental feature.", + ) parser.add_argument( "--enable-piecewise-cuda-graph", action="store_true", diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 73a7d0f31feb..a018750225f9 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -1912,7 +1912,7 @@ def get_npu_compiler_config(): def get_compiler_backend( - mode=None, + mode: str = None, model_runner=None, compilation_config: CompilationConfig = None, compilation_context=None, @@ -1921,7 +1921,10 @@ def get_compiler_backend( return "hpu_backend" if hasattr(torch, "npu") and torch.npu.is_available(): - if mode == "piecewise": + if compilation_config is None: + compilation_config = CompilationConfig(compiler="torchair") + + if compilation_config.compiler == "piecewise": from sglang.srt.hardware_backend.npu.graph_runner.compilation.piecewise_npu_graph_compiler_backend import ( PiecewiseNpuGraphCompilerBackend, ) @@ -1930,27 +1933,38 @@ def get_compiler_backend( model_runner, compilation_config, compilation_context ) - if mode == "npugraph_fused": - from sglang.srt.compilation.npu.npu_graph_compiler_backend import ( + if compilation_config.compiler == "npugraph": + from sglang.srt.hardware_backend.npu.graph_runner.compilation.npu_graph_compiler_backend import ( NpuGraphCompilerBackend, ) return NpuGraphCompilerBackend(model_runner) - try: - import torchair - import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce - from torchair.configs.compiler_config import CompilerConfig - except ImportError as e: - raise ImportError( - "NPU detected, but torchair package is not installed. " - "Please install torchair for torch.compile support on NPU." - ) - compiler_config = CompilerConfig() - # TODO(iforgetmyname): Change this default value once torch_npu version 7.2.0 - compiler_config.mode = "max-autotune" if mode is None else mode - npu_backend = torchair.get_npu_backend(compiler_config=compiler_config) - return npu_backend + if compilation_config.compiler == "torchair": + try: + import torchair + import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce + from torchair.configs.compiler_config import CompilerConfig + except ImportError as e: + raise ImportError( + "NPU detected, but torchair package is not installed. " + "Please install torchair for torch.compile support on NPU." + ) + compiler_config = CompilerConfig() + + # TODO(iforgetmyname): Change this default value once torch_npu version 7.2.0 + # compiler_config.mode = "max-autotune" if mode is None else mode + + predefined_config = get_npu_compiler_config() + for k, v in predefined_config.items(): + setattr(compiler_config.experimental_config, k, v) + + npu_backend = torchair.get_npu_backend(compiler_config=compiler_config) + return npu_backend + + raise ValueError( + f"unrecognized compiler backend '{compilation_config.compiler}'" + ) return "inductor" From a5e87f6fb80109a9b6ecd35f8b569c72c06cc620 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Thu, 11 Dec 2025 16:57:22 +0300 Subject: [PATCH 46/71] Compilation backends: model type quick fix --- .../graph_runner/compilation/npu_graph_compiler_backend.py | 4 ++-- .../compilation/piecewise_npu_graph_compiler_backend.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py index 2ed6513b3088..816e696559a5 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py @@ -32,8 +32,8 @@ class NpuGraphCompilerBackend: - def __init__(self, model_type: torch.dtype): - self.model_type = model_type + def __init__(self, model_runner): + self.model_type = model_runner.model_config.dtype def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: if self.model_type == torch.bfloat16: diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py index be5556b2d6eb..bd0a98215ed8 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py @@ -138,7 +138,7 @@ def __init__( compilation_config: CompilationConfig, compilation_context: CompilationContext, ): - super().__init__(model_runner.model_config.dtype) + super().__init__(model_runner) self.model_runner = model_runner self.model_config = model_runner.model.config From daf81b289e7c0779a629b1a6864f01c849bb05bb Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Fri, 12 Dec 2025 00:35:52 +0300 Subject: [PATCH 47/71] TorchAir compilation backend: Ascend attention backend quick fix --- .../srt/hardware_backend/npu/attention/ascend_backend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index 92c0195966f3..f1c5f1ea7b73 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -257,8 +257,7 @@ def __init__(self, model_runner: ModelRunner): self.ringmla_mask = self.ascend_attn_mask_builder.ringmla_mask self.enable_torch_air_compile = ( - model_runner.server_args.disable_cuda_graph - and model_runner.server_args.enable_torch_compile + model_runner.server_args.enable_torch_air_compile ) if self.enable_torch_air_compile: max_total_tokens = model_runner.max_total_num_tokens From ea25b3f528f148924cad9ee600d2801bfbd0fd2c Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Fri, 12 Dec 2025 09:25:32 +0300 Subject: [PATCH 48/71] torchair compilation test fix --- test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py b/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py index 2ecc97a95edf..433e82e4f9a0 100644 --- a/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py +++ b/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py @@ -39,10 +39,9 @@ def setUpClass(cls): "--attention-backend", "ascend", "--disable-radix-cache", - "--enable-torch-compile", + "--enable-torch-air-compile", "--watchdog-timeout", 30000, - "--disable-cuda-graph", ] def test_a_gsm8k(self): From 40389ddb4c7bac30f386828e271168977a0503e4 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Fri, 12 Dec 2025 20:22:44 +0300 Subject: [PATCH 49/71] Capturing compiled code issue: fix - dynamo patching --- .../compilation/npu_graph_compiler_backend.py | 4 ++ .../graph_runner/compilation/patch_dynamo.py | 54 +++++++++++++++++++ .../npu/graph_runner/npu_graph_runner.py | 33 ++++++++++-- .../piecewise_npu_graph_runner_decode.py | 27 ++++++++-- 4 files changed, 111 insertions(+), 7 deletions(-) create mode 100644 python/sglang/srt/hardware_backend/npu/graph_runner/compilation/patch_dynamo.py diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py index 816e696559a5..63e26bd4259a 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py @@ -15,6 +15,7 @@ from typing import Callable import torch +from torch._dynamo.eval_frame import DisableContext from sglang.srt.hardware_backend.npu.graph_runner.compilation.pass_manager import ( PassManager, @@ -36,6 +37,9 @@ def __init__(self, model_runner): self.model_type = model_runner.model_config.dtype def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: + DisableContext.compiled_function_args[DisableContext.batch_size] = ( + example_inputs + ) if self.model_type == torch.bfloat16: self.apply_passes(graph) return graph diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/patch_dynamo.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/patch_dynamo.py new file mode 100644 index 000000000000..a5883bf8d4f3 --- /dev/null +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/patch_dynamo.py @@ -0,0 +1,54 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +import torch +from torch._dynamo.decorators import skip +from torch._dynamo.eval_frame import DisableContext, innermost_fn + + +def patch_dynamo_context(): + setattr(torch._dynamo.eval_frame.DisableContext, "compiled_function_args", {}) + setattr(torch._dynamo.eval_frame.DisableContext, "compiled_function", {}) + setattr(torch._dynamo.eval_frame.DisableContext, "batch_size", None) + + +original_disable_context_call = None +original_disable = None + + +def decorators_disable(fn=None, recursive=True): + if recursive: + if fn is not None: + fn = innermost_fn(fn) + assert callable(fn) + + DisableContext.compiled_function[DisableContext.batch_size] = fn + return DisableContext()(fn) + return DisableContext() + else: + return skip(fn) + + +def patch_dynamo_context_call(): + global original_disable + original_disable = torch._dynamo.decorators.disable + torch._dynamo.decorators.disable = decorators_disable + + +def restore_dynamo_context_call(): + global original_disable + torch._dynamo.decorators.disable = original_disable + original_disable = None \ No newline at end of file diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py index 950f8c3871b0..de1b320622c2 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py @@ -44,6 +44,8 @@ if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner +from torch._dynamo.eval_frame import DisableContext + from sglang.srt.compilation.custom_ops import ( _set_dp_buffer_len, _set_is_extend_in_batch, @@ -51,6 +53,11 @@ from sglang.srt.hardware_backend.npu.graph_runner.compilation.npu_graph_compiler import ( NpuGraphCompiler, ) +from sglang.srt.hardware_backend.npu.graph_runner.compilation.patch_dynamo import ( + patch_dynamo_context, + patch_dynamo_context_call, + restore_dynamo_context_call, +) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors @@ -69,6 +76,9 @@ class NPUGraphRunner(CudaGraphRunner): """A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile.""" def __init__(self, model_runner: ModelRunner): + if model_runner.server_args.enable_torch_compile: + patch_dynamo_context() + sglang.srt.model_executor.cuda_graph_runner.patch_model = patch_model_npu model_runner.attn_backend.enable_torch_compile = ( model_runner.server_args.enable_torch_compile @@ -115,8 +125,25 @@ def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): compilation_config=get_global_server_args().compilation_config, ) - # compilation - out = compiler.compiled_callable() + patch_dynamo_context_call() + DisableContext.batch_size = bs + try: + # compilation + out = compiler.compiled_callable() + + # capture function and args + out = compiler.compiled_callable() + finally: + DisableContext.batch_size = None + restore_dynamo_context_call() + + assert bs in DisableContext.compiled_function + assert DisableContext.compiled_function[bs] + assert bs in DisableContext.compiled_function_args + assert DisableContext.compiled_function_args[bs] + + compiled_function = DisableContext.compiled_function[bs] + args = DisableContext.compiled_function_args[bs] with torch.npu.graph( graph, @@ -124,7 +151,7 @@ def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): stream=stream, auto_dispatch_capture=True, ): - compiler.compiled_callable() + compiled_function(*args) else: self.model_runner.attn_backend.enable_torch_compile = False diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py b/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py index a885f08e9aaf..4e4e2085aff4 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py @@ -22,6 +22,7 @@ import torch import torch._dynamo.config import tqdm +from torch._dynamo.eval_frame import DisableContext from sglang.srt.compilation.compilation_config import CompilationConfig from sglang.srt.distributed import get_tensor_model_parallel_rank @@ -30,6 +31,11 @@ from sglang.srt.hardware_backend.npu.graph_runner.compilation.compilation_context import ( CompilationContext, ) +from sglang.srt.hardware_backend.npu.graph_runner.compilation.patch_dynamo import ( + patch_dynamo_context, + patch_dynamo_context_call, + restore_dynamo_context_call, +) from sglang.srt.hardware_backend.npu.graph_runner.compilation.piecewise_npu_graph_compiler import ( PiecewiseNpuGraphCompiler, ) @@ -87,6 +93,7 @@ class PiecewiseNPUGraphRunnerDecode: def __init__(self, model_runner: ModelRunner): model_runner.attn_backend.enable_piecewise_npu_graph_decode = True + patch_dynamo_context() self.inference_counter = 1 self.init_forward_metadata_was_done = True @@ -406,6 +413,9 @@ def capture_one_batch_size( compilation_context=self.compilation_context, ) + patch_dynamo_context_call() + DisableContext.batch_size = bs + logits_output_or_pp_proxy_tensors = compiler.compiled_callable( forward_batch.input_ids, forward_batch.positions, forward_batch ) @@ -414,6 +424,17 @@ def capture_one_batch_size( bs, forward_batch, None, compiler.compiled_callable ) + try: + logits_output_or_pp_proxy_tensors = compiler.compiled_callable( + forward_batch.input_ids, forward_batch.positions, forward_batch + ) + finally: + DisableContext.batch_size = None + restore_dynamo_context_call() + + assert DisableContext.compiled_function + assert DisableContext.compiled_function_args + torch._dynamo.reset() gc.collect() @@ -531,10 +552,8 @@ def init(): self.model_runner.attn_backend.graph_mode = True - compiled_graph = self.graphs[self.bs] - forward_batch = compiled_graph.forward_batch - compiled_graph.callable( - forward_batch.input_ids, forward_batch.positions, forward_batch + DisableContext.compiled_function[self.bs]( + *DisableContext.compiled_function_args[self.bs] ) output = self.output_buffers[self.bs] From f5424d84782259425a8a79fb5dc420724bdf1b0b Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Fri, 12 Dec 2025 20:23:08 +0300 Subject: [PATCH 50/71] comments fix --- .../npu/attention/ascend_backend.py | 10 ++++----- .../graph_runner/compilation/patch_dynamo.py | 2 +- .../sglang/srt/hardware_backend/npu/utils.py | 16 +++++++------- .../sglang/srt/model_executor/model_runner.py | 2 +- python/sglang/srt/server_args.py | 4 ++-- python/sglang/srt/utils/common.py | 22 +++++++++++++------ .../test_ascend_compile_graph_tp1_bf16.py | 2 +- 7 files changed, 33 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index 23f3ccfd3656..037f646f7e7c 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -258,10 +258,10 @@ def __init__(self, model_runner: ModelRunner): if self.use_mla: self.ringmla_mask = self.ascend_attn_mask_builder.ringmla_mask - self.enable_torch_air_compile = ( - model_runner.server_args.enable_torch_air_compile + self.enable_torch_npugraph_ex_compile = ( + model_runner.server_args.enable_torch_npugraph_ex_compile ) - if self.enable_torch_air_compile: + if self.enable_torch_npugraph_ex_compile: max_total_tokens = model_runner.max_total_num_tokens self.max_seqlen_pad = max_total_tokens // model_runner.server_args.page_size @@ -293,7 +293,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ) if ( - self.enable_torch_air_compile + self.enable_torch_npugraph_ex_compile and forward_batch.forward_mode.is_decode_or_idle() ): bs = forward_batch.input_ids.size(0) @@ -1245,7 +1245,7 @@ def forward_decode( block_table=self.forward_metadata.block_tables, actual_seq_lengths_kv=( self.forward_metadata.seq_lens_cpu_list - if self.enable_torch_air_compile + if self.enable_torch_npugraph_ex_compile else self.forward_metadata.seq_lens_cpu_int ), scale=layer.scaling, diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/patch_dynamo.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/patch_dynamo.py index a5883bf8d4f3..284582f86011 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/patch_dynamo.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/patch_dynamo.py @@ -51,4 +51,4 @@ def patch_dynamo_context_call(): def restore_dynamo_context_call(): global original_disable torch._dynamo.decorators.disable = original_disable - original_disable = None \ No newline at end of file + original_disable = None diff --git a/python/sglang/srt/hardware_backend/npu/utils.py b/python/sglang/srt/hardware_backend/npu/utils.py index 6817067758ca..889ea513d7fd 100644 --- a/python/sglang/srt/hardware_backend/npu/utils.py +++ b/python/sglang/srt/hardware_backend/npu/utils.py @@ -58,9 +58,9 @@ def set_default_server_args(args: "ServerArgs"): else: args.hicache_mem_layout = "page_first_direct" - if args.enable_piecewise_npu_graph_decode and args.enable_torch_air_compile: + if args.enable_piecewise_npu_graph_decode and args.enable_torch_npugraph_ex_compile: raise ValueError( - "Cannot enable both --enable-piecewise-npu-graph-decode and --enable-torch-air-compile" + "Cannot enable both --enable-piecewise-npu-graph-decode and --enable-torch-npugraph-ex-compile" ) if args.compilation_config: @@ -77,21 +77,21 @@ def set_default_server_args(args: "ServerArgs"): f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-piecewise-npu-graph-decode" ) - if args.enable_torch_air_compile: + if args.enable_torch_npugraph_ex_compile: raise ValueError( - f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-torch-air-compile" + f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-torch-npugraph-ex-compile" ) if args.compilation_config.compiler == "piecewise": args.enable_piecewise_npu_graph_decode = True - if args.enable_torch_air_compile: + if args.enable_torch_npugraph_ex_compile: raise ValueError( - f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-torch-air-compile" + f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-torch-npugraph-ex-compile" ) - if args.compilation_config.compiler == "torchair": - args.enable_torch_air_compile = True + if args.compilation_config.compiler == "npugraph_ex": + args.enable_torch_npugraph_ex_compile = True if args.enable_piecewise_npu_graph_decode: raise ValueError( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3ecdf3d553da..f2add2e196d7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2511,7 +2511,7 @@ def init_device_graphs(self): if self.server_args.enable_piecewise_npu_graph_decode else ( NPUCompileModelRunner - if self.server_args.enable_torch_air_compile + if self.server_args.enable_torch_npugraph_ex_compile else NPUGraphRunner ) ), diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7bc1d9b92634..f268f630c5c7 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -535,7 +535,7 @@ class ServerArgs: tbo_token_distribution_threshold: float = 0.48 enable_torch_compile: bool = False enable_piecewise_cuda_graph: bool = False - enable_torch_air_compile: bool = False + enable_torch_npugraph_ex_compile: bool = False enable_torch_compile_debug_mode: bool = False torch_compile_max_bs: int = 32 piecewise_cuda_graph_max_tokens: int = 4096 @@ -3804,7 +3804,7 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Enable debug mode for torch compile", ) parser.add_argument( - "--enable-torch-air-compile", + "--enable-torch-npugraph-ex-compile", action="store_true", help="Optimize the model with Torch Ascend Intermediate Representation compilation. Experimental feature.", ) diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 0bab47c4da68..6a91478868c9 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -104,6 +104,17 @@ from sglang.srt.server_args import ServerArgs +if hasattr(torch, "npu") and torch.npu.is_available(): + try: + import torchair + import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce + from torchair.configs.compiler_config import CompilerConfig + + torchair_package_installed = True + except ImportError as e: + torchair_package_installed = False + + logger = logging.getLogger(__name__) show_time_cost = False @@ -1955,6 +1966,7 @@ def get_npu_compiler_config(): return config +@lru_cache(maxsize=1) def get_compiler_backend( mode: str = None, model_runner=None, @@ -1966,7 +1978,7 @@ def get_compiler_backend( if hasattr(torch, "npu") and torch.npu.is_available(): if compilation_config is None: - compilation_config = CompilationConfig(compiler="torchair") + compilation_config = CompilationConfig(compiler="npugraph_ex") if compilation_config.compiler == "piecewise": from sglang.srt.hardware_backend.npu.graph_runner.compilation.piecewise_npu_graph_compiler_backend import ( @@ -1984,12 +1996,8 @@ def get_compiler_backend( return NpuGraphCompilerBackend(model_runner) - if compilation_config.compiler == "torchair": - try: - import torchair - import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce - from torchair.configs.compiler_config import CompilerConfig - except ImportError as e: + if compilation_config.compiler == "npugraph_ex": + if not torchair_package_installed: raise ImportError( "NPU detected, but torchair package is not installed. " "Please install torchair for torch.compile support on NPU." diff --git a/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py b/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py index 433e82e4f9a0..e14969d4449a 100644 --- a/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py +++ b/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py @@ -39,7 +39,7 @@ def setUpClass(cls): "--attention-backend", "ascend", "--disable-radix-cache", - "--enable-torch-air-compile", + "--enable-torch-npugraph-ex-compile", "--watchdog-timeout", 30000, ] From 55a1e069b858ce3b77bc13bda1498d5d5a881d25 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Fri, 12 Dec 2025 20:16:38 +0300 Subject: [PATCH 51/71] Documentation --- docs/platforms/ascend_npu_pass_development.md | 30 +++++++++++++++++++ docs/platforms/ascend_npu_support.rst | 1 + 2 files changed, 31 insertions(+) create mode 100644 docs/platforms/ascend_npu_pass_development.md diff --git a/docs/platforms/ascend_npu_pass_development.md b/docs/platforms/ascend_npu_pass_development.md new file mode 100644 index 000000000000..251d10c85e98 --- /dev/null +++ b/docs/platforms/ascend_npu_pass_development.md @@ -0,0 +1,30 @@ +## How to transform model instances with PyTorch FX Toolkit in SGLang for NPU + +### PassManager +`PassManager` is implemented here: [PassManager](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/pass_manager.py) + + +You can explore `PassManager` usage in [`NpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py) compiler backend. [`PiecewiseNpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py) compiler backed uses `PassManager` too via `NpuGraphCompilerBackend` inheritance. + +### Pass development +There are two approaches to develop passes for SGLang NPU PassManager: + +1. Matches all possible non-overlapping sets of operators and their data dependencies with `torch.fx.replace_pattern` api. +Pass example: [NpuAddRmsNormQuantFuse](https://github.com/eshoguli/sglang/blob/3365d711fd5aa0d6191c32769163320fe41e27f2/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/w8a8_int8.py#L82). +You can find details on official FX toolkit web site: https://docs.pytorch.org/docs/stable/fx.html#subgraph-rewriting-with-replace-pattern + +2. Direct Graph Manipulation. +Pass example: [EraseCopy](https://github.com/eshoguli/sglang/blob/3365d711fd5aa0d6191c32769163320fe41e27f2/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/w8a8_int8.py#L28). +You can find details on official FX toolkit web site: https://docs.pytorch.org/docs/stable/fx.html#direct-graph-manipulation + +### Compiler backend update +After pass development you should create `PassManager` instance, add the pass and call `apply` method: +``` +def apply_passes(self, graph_module: torch.fx.GraphModule): + passManager = PassManager(graph_module) + passManager.add(NpuAddRmsNormQuantFuse) + passManager.apply() + graph_module.recompile() +``` + +You can explore [`NpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py) as example. diff --git a/docs/platforms/ascend_npu_support.rst b/docs/platforms/ascend_npu_support.rst index 1437515f8acc..09a42c07f214 100644 --- a/docs/platforms/ascend_npu_support.rst +++ b/docs/platforms/ascend_npu_support.rst @@ -6,4 +6,5 @@ Ascend NPUs ascend_npu.md ascend_npu_deepseek_example.md + ascend_npu_pass_development.md ascend_npu_qwen3_examples.md From fd28ac6db2c4d797ff02e5d3f31a09f4831027c9 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Mon, 15 Dec 2025 10:13:46 +0300 Subject: [PATCH 52/71] cleanup & fuse quick fix: compilation & piecewise --- .../srt/hardware_backend/npu/attention/ascend_backend.py | 1 - .../npu/graph_runner/compilation/npu_graph_compiler.py | 3 ++- .../graph_runner/compilation/npu_graph_compiler_backend.py | 6 +++--- .../compilation/piecewise_npu_graph_compiler.py | 3 ++- .../hardware_backend/npu/graph_runner/npu_graph_runner.py | 1 + .../npu/graph_runner/piecewise_npu_graph_runner_decode.py | 1 + 6 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index 037f646f7e7c..26a660c6ef77 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -222,7 +222,6 @@ class AscendAttnBackend(AttentionBackend): def __init__(self, model_runner: ModelRunner): super().__init__() - self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.enable_piecewise_npu_graph_decode = ( model_runner.server_args.enable_piecewise_npu_graph_decode ) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler.py index fa8d9b5283dc..f28059f56a82 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler.py @@ -24,6 +24,7 @@ def __init__( model_runner, model: torch.nn.Module, compilation_config: CompilationConfig, + batch_size: int, ): torch._dynamo.reset() @@ -33,7 +34,7 @@ def __init__( backend = get_compiler_backend( compilation_config=compilation_config, model_runner=model_runner ) - backend.init(model_runner.model_config) + backend.init(model_runner.model_config, batch_size) self.compiled_callable = torch.compile( model, fullgraph=True, dynamic=False, backend=backend diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py index 63e26bd4259a..79da1b9278fd 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py @@ -44,7 +44,7 @@ def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: self.apply_passes(graph) return graph - def init(self, config): + def init(self, config, batch_size: int): config = config.hf_config hidden_size = config.hidden_size @@ -67,8 +67,8 @@ def init(self, config): self.q_size = num_heads * self.head_dim self.kv_size = num_kv_heads * self.head_dim - self.q_shape = (self.head_dim, self.q_size) - self.k_shape = (self.head_dim, self.kv_size) + self.q_shape = (batch_size, self.q_size) + self.k_shape = (batch_size, self.kv_size) def apply_passes(self, graph_module: torch.fx.GraphModule): passManager = PassManager(graph_module) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler.py index 5477f0c4131b..e46e3c88c573 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler.py @@ -31,6 +31,7 @@ def __init__( model: torch.nn.Module, compilation_config: CompilationConfig, compilation_context: CompilationContext, + batch_size: int, ): if compilation_config is None: compilation_config = CompilationConfig(compiler="piecewise") @@ -40,7 +41,7 @@ def __init__( compilation_config=compilation_config, compilation_context=compilation_context, ) - backend.init(model_runner.model_config) + backend.init(model_runner.model_config, batch_size) torch._dynamo.reset() torch.compiler.allow_in_graph(sys.intern) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py index de1b320622c2..1e18346f4999 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py @@ -123,6 +123,7 @@ def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): model_runner=self.model_runner, model=run_once_fn, compilation_config=get_global_server_args().compilation_config, + batch_size=bs, ) patch_dynamo_context_call() diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py b/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py index 4e4e2085aff4..ddae98181cf1 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py @@ -411,6 +411,7 @@ def capture_one_batch_size( model=self.model_runner.model, compilation_config=self.compilation_config, compilation_context=self.compilation_context, + batch_size=bs, ) patch_dynamo_context_call() From 97d654e9735fd64384e7e532a3c5ff49325d46d0 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Mon, 15 Dec 2025 14:41:39 +0300 Subject: [PATCH 53/71] TorchAir support: inference fix & refactoring --- .../graph_runner/npu_compile_model_runner.py | 17 ++++++++++++----- python/sglang/srt/hardware_backend/npu/utils.py | 3 +++ python/sglang/srt/utils/common.py | 6 +----- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py index d3c1352dc4f6..9941a29d5cb2 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py @@ -31,7 +31,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.model_executor.cuda_graph_runner import get_batch_sizes_to_capture +from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, @@ -40,11 +40,9 @@ from sglang.srt.server_args import get_global_server_args -class NPUCompileModelRunner: +class NPUCompileModelRunner(CudaGraphRunner): def __init__(self, model_runner: ModelRunner): - self.model_runner = model_runner - _, self.compile_bs = get_batch_sizes_to_capture(model_runner) - self.capture() + super().__init__(model_runner) def capture(self) -> None: # Reverse the order to enable better memory sharing across cuda graphs. @@ -54,6 +52,15 @@ def capture(self) -> None: else reversed(self.compile_bs) ) + # warm up before dynamic shape compilation + bs = 1 + num_tokens = bs * self.num_tokens_per_bs + forward_batch = self.prepare_forward_batch(bs, num_tokens) + forward_batch.attn_backend.init_forward_metadata(forward_batch) + self.model_runner.model.forward( + forward_batch.input_ids, forward_batch.positions, forward_batch + ) + backend = get_compiler_backend( mode="reduce-overhead", compilation_config=get_global_server_args().compilation_config, diff --git a/python/sglang/srt/hardware_backend/npu/utils.py b/python/sglang/srt/hardware_backend/npu/utils.py index 889ea513d7fd..f6471901d3c3 100644 --- a/python/sglang/srt/hardware_backend/npu/utils.py +++ b/python/sglang/srt/hardware_backend/npu/utils.py @@ -98,6 +98,9 @@ def set_default_server_args(args: "ServerArgs"): f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-piecewise-npu-graph-decode" ) + if args.enable_torch_npugraph_ex_compile: + args.enable_torch_compile = True + @_call_once def init_npu_backend(): diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 6a91478868c9..25254ec246a7 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -2005,11 +2005,7 @@ def get_compiler_backend( compiler_config = CompilerConfig() # TODO(iforgetmyname): Change this default value once torch_npu version 7.2.0 - # compiler_config.mode = "max-autotune" if mode is None else mode - - predefined_config = get_npu_compiler_config() - for k, v in predefined_config.items(): - setattr(compiler_config.experimental_config, k, v) + compiler_config.mode = "max-autotune" if mode is None else mode npu_backend = torchair.get_npu_backend(compiler_config=compiler_config) return npu_backend From 3ce92e8f9039c59493da6702a2a87a2d251b444a Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Mon, 15 Dec 2025 21:14:44 +0300 Subject: [PATCH 54/71] Comment fixes + refactoring --- .../compilation/npu_graph_compiler_backend.py | 6 +- .../{passes/w8a8_int8.py => passes.py} | 113 ++++++++++++++++ .../graph_runner/compilation/passes/fp16.py | 128 ------------------ python/sglang/srt/utils/common.py | 2 - .../test_ascend_npu_graph_compile_tp1_bf16.py | 2 - ...est_ascend_npu_piecewise_graph_tp1_bf16.py | 2 - 6 files changed, 115 insertions(+), 138 deletions(-) rename python/sglang/srt/hardware_backend/npu/graph_runner/compilation/{passes/w8a8_int8.py => passes.py} (56%) delete mode 100644 python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/fp16.py diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py index 79da1b9278fd..9e9c12214c68 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py @@ -20,14 +20,12 @@ from sglang.srt.hardware_backend.npu.graph_runner.compilation.pass_manager import ( PassManager, ) -from sglang.srt.hardware_backend.npu.graph_runner.compilation.passes.fp16 import ( - SplitQkvRmsnormRopeFuse, -) -from sglang.srt.hardware_backend.npu.graph_runner.compilation.passes.w8a8_int8 import ( +from sglang.srt.hardware_backend.npu.graph_runner.compilation.passes import ( DivFuse, EraseCopy, NpuAddRmsNormDynamicQuantFuse, NpuAddRmsNormQuantFuse, + SplitQkvRmsnormRopeFuse, ) from sglang.srt.layers.dp_attention import get_attention_tp_size diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/w8a8_int8.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes.py similarity index 56% rename from python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/w8a8_int8.py rename to python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes.py index 20c8dbfa6100..3c9efb93a6d0 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/w8a8_int8.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes.py @@ -14,6 +14,8 @@ import torch +import sglang.srt.hardware_backend.npu.graph_runner.compilation.custom_ops # noqa + class DivFuse: def pattern(x): @@ -122,3 +124,114 @@ def replacement(rms_norm_input, residual, rms_norm_weight): out2 = output[2] dynamic_scale = output[3] return quantized_output, out2, dynamic_scale + + +class SplitQkvRmsnormRopeFuse: + instance = None + + def __init__( + self, + q_size: int, + kv_size: int, + head_dim: int, + q_shape, + k_shape, + variance_epsilon: float, + ): + self.q_size = q_size + self.kv_size = kv_size + self.head_dim = head_dim + self.q_shape = q_shape + self.k_shape = k_shape + self.variance_epsilon = variance_epsilon + + SplitQkvRmsnormRopeFuse.instance = self + + def pattern( + output_parallel, + q_norm_parameters_weight, + k_norm_parameters_weight, + positions, + cos_sin_cache, + ): + # pattern matching brokes if make static method as class method + self = SplitQkvRmsnormRopeFuse.instance + + split = output_parallel.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = split[0] + k = split[1] + v = split[2] + + q_by_head = q.reshape(-1, self.head_dim) + npu_rms_norm_q = torch.ops.npu.npu_rms_norm( + q_by_head, q_norm_parameters_weight, self.variance_epsilon + ) + q_by_head_1 = npu_rms_norm_q[0] + + k_by_head = k.reshape(-1, self.head_dim) + npu_rms_norm_k = torch.ops.npu.npu_rms_norm( + k_by_head, k_norm_parameters_weight, self.variance_epsilon + ) + k_by_head_1 = npu_rms_norm_k[0] + + q_1 = q_by_head_1.view(self.q_shape) + k_1 = k_by_head_1.view(self.k_shape) + + npu_mrope = torch.ops.npu.npu_mrope( + positions, + q_1, + k_1, + cos_sin_cache, + self.head_dim, + mrope_section=[0, 0, 0], + rotary_mode="half", + ) + query_out = npu_mrope[0] + key_out = npu_mrope[1] + + return v, query_out, key_out + + def replacement( + output_parallel, + q_norm_parameters_weight, + k_norm_parameters_weight, + positions, + cos_sin_cache, + ): + # pattern matching brokes if make static method as class method + self = SplitQkvRmsnormRopeFuse.instance + + flatten = positions.flatten() + cos_sin = cos_sin_cache.index_select(0, flatten) + + reshape = cos_sin.reshape(-1, 2, 64) + repeat = reshape.repeat(1, 1, 2) + chunk = repeat.chunk(2, dim=-2) + cos = chunk[0] + sin = chunk[1] + + cos_view = cos.view(-1, 1, 1, self.head_dim) + cos_contiguous = cos_view.contiguous() + + sin_view = sin.view(-1, 1, 1, self.head_dim) + sin_contiguous = sin_view.contiguous() + + split_qkv_rmsnorm_rope_default = torch.ops.sglang.split_qkv_rmsnorm_rope( + output_parallel, + sin_contiguous, + cos_contiguous, + q_norm_parameters_weight, + k_norm_parameters_weight, + self.q_size, + self.kv_size, + self.head_dim, + self.variance_epsilon, + q_bias=None, + k_bias=None, + ) + + q = split_qkv_rmsnorm_rope_default[0] + k = split_qkv_rmsnorm_rope_default[1] + v = split_qkv_rmsnorm_rope_default[2] + + return v, q, k diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/fp16.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/fp16.py deleted file mode 100644 index 1d32ca842e9b..000000000000 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/fp16.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2025 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import torch - -import sglang.srt.hardware_backend.npu.graph_runner.compilation.custom_ops # noqa - - -class SplitQkvRmsnormRopeFuse: - instance = None - - def __init__( - self, - q_size: int, - kv_size: int, - head_dim: int, - q_shape, - k_shape, - variance_epsilon: float, - ): - self.q_size = q_size - self.kv_size = kv_size - self.head_dim = head_dim - self.q_shape = q_shape - self.k_shape = k_shape - self.variance_epsilon = variance_epsilon - - SplitQkvRmsnormRopeFuse.instance = self - - def pattern( - output_parallel, - q_norm_parameters_weight, - k_norm_parameters_weight, - positions, - cos_sin_cache, - ): - # pattern matching brokes if make static method as class method - self = SplitQkvRmsnormRopeFuse.instance - - split = output_parallel.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = split[0] - k = split[1] - v = split[2] - - q_by_head = q.reshape(-1, self.head_dim) - npu_rms_norm_q = torch.ops.npu.npu_rms_norm( - q_by_head, q_norm_parameters_weight, self.variance_epsilon - ) - q_by_head_1 = npu_rms_norm_q[0] - - k_by_head = k.reshape(-1, self.head_dim) - npu_rms_norm_k = torch.ops.npu.npu_rms_norm( - k_by_head, k_norm_parameters_weight, self.variance_epsilon - ) - k_by_head_1 = npu_rms_norm_k[0] - - q_1 = q_by_head_1.view(self.q_shape) - k_1 = k_by_head_1.view(self.k_shape) - - npu_mrope = torch.ops.npu.npu_mrope( - positions, - q_1, - k_1, - cos_sin_cache, - self.head_dim, - mrope_section=[0, 0, 0], - rotary_mode="half", - ) - query_out = npu_mrope[0] - key_out = npu_mrope[1] - - return v, query_out, key_out - - def replacement( - output_parallel, - q_norm_parameters_weight, - k_norm_parameters_weight, - positions, - cos_sin_cache, - ): - # pattern matching brokes if make static method as class method - self = SplitQkvRmsnormRopeFuse.instance - - flatten = positions.flatten() - cos_sin = cos_sin_cache.index_select(0, flatten) - - reshape = cos_sin.reshape(-1, 2, 64) - repeat = reshape.repeat(1, 1, 2) - chunk = repeat.chunk(2, dim=-2) - cos = chunk[0] - sin = chunk[1] - - cos_view = cos.view(-1, 1, 1, self.head_dim) - cos_contiguous = cos_view.contiguous() - - sin_view = sin.view(-1, 1, 1, self.head_dim) - sin_contiguous = sin_view.contiguous() - - split_qkv_rmsnorm_rope_default = torch.ops.sglang.split_qkv_rmsnorm_rope( - output_parallel, - sin_contiguous, - cos_contiguous, - q_norm_parameters_weight, - k_norm_parameters_weight, - self.q_size, - self.kv_size, - self.head_dim, - self.variance_epsilon, - q_bias=None, - k_bias=None, - ) - - q = split_qkv_rmsnorm_rope_default[0] - k = split_qkv_rmsnorm_rope_default[1] - v = split_qkv_rmsnorm_rope_default[2] - - return v, q, k diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 25254ec246a7..615ccb2341d6 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -2003,8 +2003,6 @@ def get_compiler_backend( "Please install torchair for torch.compile support on NPU." ) compiler_config = CompilerConfig() - - # TODO(iforgetmyname): Change this default value once torch_npu version 7.2.0 compiler_config.mode = "max-autotune" if mode is None else mode npu_backend = torchair.get_npu_backend(compiler_config=compiler_config) diff --git a/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py b/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py index 85d1c2e83fbf..03dd0cdf80f7 100644 --- a/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py +++ b/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py @@ -32,8 +32,6 @@ def test_gsm8k(self): "--enable-torch-compile", "--cuda-graph-bs", "128", - "--cuda-graph-max-bs", - "128", "--tp-size", "1", ], diff --git a/test/srt/ascend/test_ascend_npu_piecewise_graph_tp1_bf16.py b/test/srt/ascend/test_ascend_npu_piecewise_graph_tp1_bf16.py index 6a1eb57ea870..e3fdd84969ed 100644 --- a/test/srt/ascend/test_ascend_npu_piecewise_graph_tp1_bf16.py +++ b/test/srt/ascend/test_ascend_npu_piecewise_graph_tp1_bf16.py @@ -32,8 +32,6 @@ def test_gsm8k(self): "--enable-piecewise-npu-graph-decode", "--cuda-graph-bs", "128", - "--cuda-graph-max-bs", - "128", "--tp-size", "1", ], From f4dfef3b50da5e200d90ee7e18bb0f9b668d4ddf Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 16 Dec 2025 09:08:25 +0300 Subject: [PATCH 55/71] Piecewise Graph Runner refactoring --- .../piecewise_npu_graph_runner_decode.py | 192 ++---------------- 1 file changed, 14 insertions(+), 178 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py b/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py index ddae98181cf1..3afeb0c6b1ee 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py @@ -40,16 +40,17 @@ PiecewiseNpuGraphCompiler, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.model_executor.cuda_graph_runner import get_batch_sizes_to_capture +from sglang.srt.model_executor.cuda_graph_runner import ( + CudaGraphRunner, +) from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, - ForwardMode, PPProxyTensors, enable_num_token_non_padded, ) from sglang.srt.server_args import get_global_server_args -from sglang.srt.utils import get_available_gpu_memory, rank0_log +from sglang.srt.utils import get_available_gpu_memory torch._dynamo.config.skip_nnmodule_hook_guards = True torch._dynamo.config.automatic_dynamic_shapes = False @@ -63,7 +64,6 @@ if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner -from sglang.srt.model_executor.cuda_graph_runner import model_capture_mode torch.cuda.CUDAGraph = torch.npu.NPUGraph torch.cuda.synchronize = torch.npu.synchronize @@ -88,13 +88,12 @@ def __init__( self.callable = callable -class PiecewiseNPUGraphRunnerDecode: +class PiecewiseNPUGraphRunnerDecode(CudaGraphRunner): """A PiecewiseNPUGraphRunnerDecode runs the forward pass of a model with npu graph and torch.compile.""" def __init__(self, model_runner: ModelRunner): model_runner.attn_backend.enable_piecewise_npu_graph_decode = True patch_dynamo_context() - self.inference_counter = 1 self.init_forward_metadata_was_done = True # Parse args @@ -110,143 +109,16 @@ def __init__(self, model_runner: ModelRunner): self.graphs = {} self.output_buffers = {} self.enable_torch_compile = model_runner.server_args.enable_torch_compile - self.disable_padding = model_runner.server_args.disable_cuda_graph_padding - self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.enable_dp_attention = model_runner.server_args.enable_dp_attention - # self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm - self.enable_two_batch_overlap = ( - model_runner.server_args.enable_two_batch_overlap - ) - self.speculative_algorithm = model_runner.server_args.speculative_algorithm - self.tp_size = model_runner.server_args.tp_size - self.dp_size = model_runner.server_args.dp_size - self.pp_size = model_runner.server_args.pp_size - - # Batch sizes to capture - self.capture_bs, _ = get_batch_sizes_to_capture(model_runner) - rank0_log(f"Capture npu graph bs {self.capture_bs}") - self.capture_forward_mode: int = ForwardMode.DECODE - self.capture_hidden_mode: int = CaptureHiddenMode.NULL - self.num_tokens_per_bs = 1 - if model_runner.spec_algorithm.is_eagle(): - if self.model_runner.is_draft_worker: - raise RuntimeError("This should not happen") - else: - self.capture_forward_mode = ForwardMode.TARGET_VERIFY - self.num_tokens_per_bs = ( - self.model_runner.server_args.speculative_num_draft_tokens - ) - - # Attention backend - self.max_bs = max(self.capture_bs) - self.max_num_token = self.max_bs * self.num_tokens_per_bs - self.model_runner.attn_backend.init_cuda_graph_state( - self.max_bs, self.max_num_token - ) - self.seq_len_fill_value = ( - self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() - ) - # FIXME(lsyin): leave it here for now, I don't know whether it is necessary - self.encoder_len_fill_value = 0 - self.seq_lens_cpu = torch.full( - (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 - ) # Graph inputs with torch.device(self.model_runner.device): - self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) - self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) - self.seq_lens = torch.full( - (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 - ) - self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32) - self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) - self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32) - self.block_tables = torch.full((160, 160), 0, dtype=torch.int32) - # pipeline parallelism - if self.pp_size > 1: - self.pp_proxy_tensors = { - "hidden_states": torch.zeros( - (self.max_bs, self.model_runner.model_config.hidden_size), - dtype=torch.bfloat16, - ), - "residual": torch.zeros( - (self.max_bs, self.model_runner.model_config.hidden_size), - dtype=torch.bfloat16, - ), - } - - # Speculative_inference - if model_runner.spec_algorithm.is_eagle3(): - self.model_runner.model.set_eagle3_layers_to_capture() - - if self.is_encoder_decoder: - # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch - self.encoder_lens = torch.full( - (self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32 - ) - else: - self.encoder_lens = None - - if self.enable_dp_attention: # or self.enable_sp_layernorm: - # TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer - self.gathered_buffer = torch.zeros( - ( - self.max_bs * self.dp_size * self.num_tokens_per_bs, - self.model_runner.model_config.hidden_size, - ), - dtype=self.model_runner.dtype, - ) - self.global_num_tokens_gpu = torch.zeros( - (self.dp_size,), dtype=torch.int32 - ) - - try: - with model_capture_mode(): - self.capture() - except RuntimeError as e: - raise Exception( - f"Graph compilation failed: {e}\n{NPU_GRAPH_CAPTURE_FAILED_MSG}" - ) - - def can_run(self, forward_batch: ForwardBatch): - if self.enable_dp_attention: # or self.enable_sp_layernorm: - total_global_tokens = sum(forward_batch.global_num_tokens_cpu) - - is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( - total_global_tokens in self.graphs - if self.disable_padding - else total_global_tokens <= self.max_bs - ) - else: - is_bs_supported = ( - forward_batch.batch_size in self.graphs - if self.disable_padding - else forward_batch.batch_size <= self.max_bs - ) - - # NOTE: npu graph cannot handle mixed batch (encoder_len = 0) - # If mixed batch cannot be supported, then encoder_lens can be removed in npu graph - # because the full_text_row_masked_out_mask tensor will always be ones - is_encoder_lens_supported = ( - torch.all(forward_batch.encoder_lens > 0) - if self.is_encoder_decoder - else True - ) - - is_tbo_supported = ( - forward_batch.can_run_tbo if self.enable_two_batch_overlap else True - ) - - can_run_value = ( - is_bs_supported and is_encoder_lens_supported and is_tbo_supported - ) - return can_run_value + super().__init__(model_runner) - def capture(self, forward_batch_: ForwardBatch = None, bs_: int = None): + def capture(self, forward_batch_: ForwardBatch = None, bs_: int = None) -> None: with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream @@ -450,10 +322,14 @@ def replay_prepare( raw_num_token = raw_bs * self.num_tokens_per_bs # Pad - if self.enable_dp_attention: - index = bisect.bisect_left( - self.capture_bs, sum(forward_batch.global_num_tokens_cpu) + if self.require_mlp_tp_gather: + max_num_tokens = max(forward_batch.global_num_tokens_cpu) + max_batch_size = ( + max_num_tokens / self.num_tokens_per_bs + if self.model_runner.spec_algorithm.is_eagle() + else max_num_tokens ) + index = bisect.bisect_left(self.capture_bs, max_batch_size) else: index = bisect.bisect_left(self.capture_bs, raw_bs) @@ -575,43 +451,3 @@ def init(): ) return result - - def get_spec_info(self, num_tokens: int): - spec_info = None - if self.model_runner.spec_algorithm.is_eagle(): - from sglang.srt.speculative.eagle_utils import EagleVerifyInput - - if self.model_runner.is_draft_worker: - raise RuntimeError("This should not happen.") - else: - spec_info = EagleVerifyInput( - draft_token=None, - custom_mask=torch.ones( - (num_tokens * self.model_runner.model_config.context_len), - dtype=torch.bool, - device=self.model_runner.device, - ), - positions=None, - retrive_index=None, - retrive_next_token=None, - retrive_next_sibling=None, - retrive_cum_len=None, - spec_steps=self.model_runner.server_args.speculative_num_steps, - topk=self.model_runner.server_args.speculative_eagle_topk, - draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, - capture_hidden_mode=CaptureHiddenMode.FULL, - seq_lens_sum=None, - seq_lens_cpu=None, - ) - - return spec_info - - -NPU_GRAPH_CAPTURE_FAILED_MSG = ( - "Possible solutions:\n" - "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" - "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n" - "3. disable torch compile by not using --enable-torch-compile\n" - "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n" - "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" -) From 123e36cf743e0a2ae19520c96fe3f99590aa9c91 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 16 Dec 2025 09:32:11 +0300 Subject: [PATCH 56/71] PiecewiseGraph runner quick fix --- .../piecewise_npu_graph_runner_decode.py | 46 +++++-------------- 1 file changed, 11 insertions(+), 35 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py b/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py index 3afeb0c6b1ee..a535995b4b4a 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py @@ -40,14 +40,11 @@ PiecewiseNpuGraphCompiler, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.model_executor.cuda_graph_runner import ( - CudaGraphRunner, -) +from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, PPProxyTensors, - enable_num_token_non_padded, ) from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import get_available_gpu_memory @@ -109,7 +106,6 @@ def __init__(self, model_runner: ModelRunner): self.graphs = {} self.output_buffers = {} self.enable_torch_compile = model_runner.server_args.enable_torch_compile - self.enable_dp_attention = model_runner.server_args.enable_dp_attention # Graph inputs with torch.device(self.model_runner.device): @@ -118,6 +114,14 @@ def __init__(self, model_runner: ModelRunner): super().__init__(model_runner) + def can_run(self, forward_batch: ForwardBatch): + return ( + (self.pp_size <= 1) + and (not self.is_encoder_decoder) + and (not self.enable_two_batch_overlap) + and super().can_run(forward_batch) + ) + def capture(self, forward_batch_: ForwardBatch = None, bs_: int = None) -> None: with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream @@ -179,15 +183,6 @@ def init_forward_batch( input_ids = torch.zeros((bs,), dtype=torch.int64) mrope_positions = torch.zeros((3, self.max_num_token), dtype=torch.int64) - assert self.is_encoder_decoder == False - encoder_lens = None - num_token_non_padded = None - - assert self.pp_size <= 1 - assert self.enable_dp_attention == False - global_num_tokens = None - gathered_buffer = None - spec_info = self.get_spec_info(num_tokens) if self.capture_hidden_mode != CaptureHiddenMode.FULL: self.capture_hidden_mode = ( @@ -205,10 +200,10 @@ def init_forward_batch( attn_backend=attn_backend, out_cache_loc=out_cache_loc, seq_lens_sum=seq_lens.sum(), - encoder_lens=encoder_lens, + encoder_lens=None, return_logprob=False, positions=positions, - global_num_tokens_gpu=global_num_tokens, + global_num_tokens_gpu=None, mrope_positions=mrope_positions, spec_algorithm=self.model_runner.spec_algorithm, spec_info=spec_info, @@ -235,13 +230,6 @@ def init_forward_batch( forward_batch.req_pool_indices[i] = 1 forward_batch.seq_lens_sum = sum(forward_batch.seq_lens) - if self.enable_dp_attention: # or self.enable_sp_layernorm: - assert False - assert self.pp_size <= 1 - assert self.enable_dp_attention == False - assert enable_num_token_non_padded(self.model_runner.server_args) == False - assert self.enable_two_batch_overlap == False - attn_backend.init_forward_metadata(forward_batch) self.init_forward_metadata_attn_backend(bs, attn_backend, forward_batch) @@ -372,23 +360,11 @@ def replay_prepare( dim = pp_proxy_tensors[key].shape[0] self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key]) - if self.is_encoder_decoder: - assert False - if forward_batch.mrope_positions is not None: compiled_graph.forward_batch.mrope_positions[:, :raw_num_token].copy_( forward_batch.mrope_positions ) - if self.enable_dp_attention: - assert False - - if enable_num_token_non_padded(self.model_runner.server_args): - assert False - - if self.enable_two_batch_overlap: - assert False - # Store fields self.raw_bs = raw_bs self.raw_num_token = raw_num_token From ebcc846d5d3d7cf8da7e0505ab6191de0a33be35 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 16 Dec 2025 17:31:26 +0300 Subject: [PATCH 57/71] linter fixes --- .../hardware_backend/npu/attention/ascend_backend.py | 4 +--- .../npu/graph_runner/npu_compile_model_runner.py | 1 - .../npu/graph_runner/npu_graph_runner.py | 12 +++++++++--- python/sglang/srt/utils/common.py | 3 ++- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index 98797a6cc37d..3cc1f866cfb1 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -267,9 +267,7 @@ def __init__(self, model_runner: ModelRunner): if self.use_mla: self.ringmla_mask = self.ascend_attn_mask_builder.ringmla_mask - self.enable_torchair_compile = ( - model_runner.server_args.enable_torchair_compile - ) + self.enable_torchair_compile = model_runner.server_args.enable_torchair_compile if self.enable_torchair_compile: max_total_tokens = model_runner.max_total_num_tokens self.max_seqlen_pad = max_total_tokens // model_runner.server_args.page_size diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py index e627beaf5264..f44ffd21545a 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py @@ -37,7 +37,6 @@ ForwardBatch, PPProxyTensors, ) -from sglang.srt.server_args import get_global_server_args class NPUCompileModelRunner(CudaGraphRunner): diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py index 1b2b7dc6f4c3..28c07f0a6ae9 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py @@ -25,8 +25,8 @@ import numpy as np import torch -import sglang.srt.model_executor.cuda_graph_runner import sglang +import sglang.srt.model_executor.cuda_graph_runner from sglang.srt.configs.model_config import AttentionArch, is_deepseek_nsa from sglang.srt.distributed.parallel_state import GroupCoordinator from sglang.srt.layers.dp_attention import get_attention_tp_size @@ -136,7 +136,11 @@ def _init_dp_gathered_buffer( def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): compilation_config = get_global_server_args().compilation_config - if self.enable_torch_compile and (not self.compile_bs or bs in self.compile_bs) and (compilation_config.compiler != "npugraph_ex"): + if ( + self.enable_torch_compile + and (not self.compile_bs or bs in self.compile_bs) + and (compilation_config.compiler != "npugraph_ex") + ): self.model_runner.attn_backend.enable_torch_compile = True compiler = NpuGraphCompiler( model_runner=self.model_runner, @@ -177,7 +181,9 @@ def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): self.model_runner.attn_backend.enable_torch_compile = False if self.enable_torch_compile: - skip_guard_context = torch.compiler.set_stance(skip_guard_eval_unsafe=True) + skip_guard_context = torch.compiler.set_stance( + skip_guard_eval_unsafe=True + ) else: skip_guard_context = empty_context() diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index b0f6f47c07d6..c993dd4da693 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -1944,6 +1944,7 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]: return major, minor + @lru_cache(maxsize=1) def get_compiler_backend( mode: str = None, @@ -1965,7 +1966,7 @@ def get_compiler_backend( compiler_config.mode = "max-autotune" if mode is None else mode npu_backend = torchair.get_npu_backend(compiler_config=compiler_config) return npu_backend - + if compilation_config.compiler == "npugraph_ex": if not torchair_package_installed: raise ImportError( From 132581a0290e273dfbf971d6e10d25fb68e041aa Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 16 Dec 2025 19:36:15 +0300 Subject: [PATCH 58/71] fix after merge torchair --- .../npu/attention/ascend_backend.py | 3 +- .../graph_runner/npu_compile_model_runner.py | 257 ------------------ .../npu/graph_runner/npu_graph_runner.py | 14 +- .../sglang/srt/hardware_backend/npu/utils.py | 29 ++ .../sglang/srt/model_executor/model_runner.py | 9 +- .../test_ascend_compile_graph_tp1_bf16.py | 2 +- 6 files changed, 40 insertions(+), 274 deletions(-) delete mode 100644 python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index 3cc1f866cfb1..ee83104843cb 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -251,7 +251,6 @@ def __init__(self, model_runner: ModelRunner): self.req_to_token = model_runner.req_to_token_pool.req_to_token self.graph_mode = False self.use_fia = get_bool_env_var("ASCEND_USE_FIA", "False") - self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.speculative_num_draft_tokens = ( model_runner.server_args.speculative_num_draft_tokens ) @@ -1292,7 +1291,7 @@ def forward_decode( topk_indices, ) - if self.graph_mode and (not self.enable_torch_compile): + if self.graph_mode and (not self.enable_torchair_compile): return self.forward_decode_graph( q, k, diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py deleted file mode 100644 index f44ffd21545a..000000000000 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_compile_model_runner.py +++ /dev/null @@ -1,257 +0,0 @@ -# Copyright 2023-2025 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Run the model with torch air backend""" - -from __future__ import annotations - -import inspect -import logging -from typing import TYPE_CHECKING, Callable, Optional, Union - -import torch -import tqdm - -from sglang.srt.distributed import get_tensor_model_parallel_rank -from sglang.srt.utils import get_available_gpu_memory, get_compiler_backend - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from sglang.srt.model_executor.model_runner import ModelRunner - -from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner -from sglang.srt.model_executor.forward_batch_info import ( - CaptureHiddenMode, - ForwardBatch, - PPProxyTensors, -) - - -class NPUCompileModelRunner(CudaGraphRunner): - def __init__(self, model_runner: ModelRunner): - super().__init__(model_runner) - - def capture(self) -> None: - # Reverse the order to enable better memory sharing across cuda graphs. - compile_range = ( - tqdm.tqdm(list(reversed(self.compile_bs))) - if get_tensor_model_parallel_rank() == 0 - else reversed(self.compile_bs) - ) - - # warm up before dynamic shape compilation - bs = 1 - num_tokens = bs * self.num_tokens_per_bs - forward_batch = self.prepare_forward_batch(bs, num_tokens) - forward_batch.attn_backend.init_forward_metadata(forward_batch) - self.model_runner.model.forward( - forward_batch.input_ids, forward_batch.positions, forward_batch - ) - - backend = get_compiler_backend( - mode="reduce-overhead", - ) - - compile_forward = torch.compile( - torch.no_grad()(self.model_runner.model.forward), - fullgraph=True, - dynamic=True, - backend=backend, - ) - - self.model_runner.model.compile_forward = compile_forward - - @torch.compile(dynamic=True, backend=get_compiler_backend()) - def run_for_init(input): - return input + 1 - - run_for_init(torch.zeros([1]).to(self.model_runner.device)) - - for i, bs in enumerate(compile_range): - if get_tensor_model_parallel_rank() == 0: - avail_mem = get_available_gpu_memory( - self.model_runner.device, - self.model_runner.gpu_id, - empty_cache=False, - ) - compile_range.set_description( - f"Compiling batches ({bs=} {avail_mem=:.2f} GB)" - ) - - self.warm_up(bs, compile_forward) - - def replay( - self, - forward_batch: ForwardBatch, - skip_attn_backend_init: bool = False, - pp_proxy_tensors: Optional[PPProxyTensors] = None, - ) -> Union[LogitsProcessorOutput, PPProxyTensors]: - print(f"NPUCompileModelRunner::replay", flush=True) - if not skip_attn_backend_init: - forward_batch.attn_backend.init_forward_metadata(forward_batch) - - kwargs = {} - if pp_proxy_tensors is not None: - kwargs["pp_proxy_tensors"] = pp_proxy_tensors - - with torch.no_grad(): - return self.model_runner.model.compile_forward( - forward_batch.input_ids, - forward_batch.positions, - forward_batch, - **kwargs, - ) - - def prepare_forward_batch(self, bs: int, num_tokens: int) -> ForwardBatch: - # Graph inputs - with torch.device(self.model_runner.device): - input_ids = torch.zeros((num_tokens,), dtype=torch.int64) - req_pool_indices = torch.zeros((bs,), dtype=torch.int64) - seq_lens = torch.full((bs,), self.seq_len_fill_value, dtype=torch.int64) - out_cache_loc = torch.zeros((num_tokens,), dtype=torch.int32) - positions = torch.zeros((num_tokens,), dtype=torch.int64) - num_token_non_padded = torch.tensor(num_tokens, dtype=torch.int32) - - if self.is_encoder_decoder: - encoder_lens = self.encoder_lens[:bs] - else: - encoder_lens = None - mrope_positions = None - - # pipeline parallelism - if self.pp_size > 1: - pp_proxy_tensors = PPProxyTensors( - {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()} - ) - - if self.require_mlp_tp_gather: - global_num_tokens = torch.tensor( - [ - num_tokens // self.dp_size + (i < (num_tokens % self.dp_size)) - for i in range(self.dp_size) - ], - dtype=torch.int64, - device=input_ids.device, - ) - elif self.require_attn_tp_gather: - global_num_tokens = torch.tensor( - [num_tokens], dtype=torch.int64, device=input_ids.device - ) - else: - global_num_tokens = None - gathered_buffer = None - - spec_info = self.get_spec_info(num_tokens) - if self.capture_hidden_mode != CaptureHiddenMode.FULL: - self.capture_hidden_mode = ( - spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL - ) - - forward_batch = ForwardBatch( - forward_mode=self.capture_forward_mode, - batch_size=bs, - input_ids=input_ids, - req_pool_indices=req_pool_indices, - seq_lens=seq_lens, - seq_lens_cpu=seq_lens.cpu(), - req_to_token_pool=self.model_runner.req_to_token_pool, - token_to_kv_pool=self.model_runner.token_to_kv_pool, - attn_backend=self.model_runner.attn_backend, - out_cache_loc=out_cache_loc, - seq_lens_sum=seq_lens.sum().item(), - encoder_lens=encoder_lens, - return_logprob=False, - positions=positions, - global_num_tokens_gpu=global_num_tokens, - mrope_positions=mrope_positions, - spec_algorithm=self.model_runner.spec_algorithm, - spec_info=spec_info, - capture_hidden_mode=self.capture_hidden_mode, - num_token_non_padded=num_token_non_padded, - global_forward_mode=None, - mm_inputs=[None] * bs, - lora_ids=[None] * bs, - global_num_tokens_cpu=[num_tokens], - ) - return forward_batch - - def warm_up(self, bs: int, forward: Callable): - num_tokens = bs * self.num_tokens_per_bs - forward_batch = self.prepare_forward_batch(bs, num_tokens) - forward_batch.attn_backend.init_forward_metadata(forward_batch) - - # Run and compile - def run_once(): - # Clean intermediate result cache for DP attention - forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None - - kwargs = {} - if ( - self.pp_size > 1 - and "pp_proxy_tensors" in inspect.signature(forward).parameters - ): - kwargs["pp_proxy_tensors"] = forward_batch.pp_proxy_tensors - self.mark_static(forward_batch, kwargs.get("pp_proxy_tensors")) - - with torch.no_grad(): - logits_output_or_pp_proxy_tensors = forward( - forward_batch.input_ids, - forward_batch.positions, - forward_batch, - **kwargs, - ) - return logits_output_or_pp_proxy_tensors - - torch.npu.synchronize() - self.model_runner.tp_group.barrier() - run_once() - - def mark_static( - self, forward_batch: ForwardBatch, pp_proxy_tensors: PPProxyTensors = None - ): - def mark_tensor_static(model_input, is_cache=False): - if model_input is not None: - if isinstance(model_input, torch.Tensor): - torch._dynamo.mark_static(model_input) - elif is_cache: - for buffer_per_layer in model_input: - torch._dynamo.mark_static(buffer_per_layer) - elif isinstance(model_input, PPProxyTensors): - for pp_out in model_input.tensors.items(): - torch._dynamo.mark_static(pp_out) - elif isinstance(model_input, tuple): - for value in model_input: - torch._dynamo.mark_static(value) - else: - raise ValueError( - f"Unsupported type with mark static: {type(model_input)}" - ) - - mark_tensor_static(pp_proxy_tensors) - mark_tensor_static(forward_batch.input_ids) - mark_tensor_static(forward_batch.positions) - mark_tensor_static(forward_batch.input_embeds) - mark_tensor_static(forward_batch.out_cache_loc) - mark_tensor_static(forward_batch.attn_backend.forward_metadata.block_tables) - try: - mark_tensor_static(forward_batch.token_to_kv_pool.k_buffer, is_cache=True) - mark_tensor_static(forward_batch.token_to_kv_pool.v_buffer, is_cache=True) - except AttributeError as e: - mark_tensor_static(forward_batch.token_to_kv_pool.kv_buffer, is_cache=True) - - def can_run(self, forward_batch: ForwardBatch): - return forward_batch.forward_mode.is_decode() and ( - forward_batch.batch_size in self.compile_bs - ) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py index 28c07f0a6ae9..4c5e37e4edd9 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py @@ -77,7 +77,11 @@ def patch_model_npu( tp_group: GroupCoordinator, ): compilation_config = get_global_server_args().compilation_config - if enable_compile and compilation_config.compiler == "npugraph_ex": + if ( + enable_compile + and (compilation_config is not None) + and (compilation_config.compiler == "npugraph_ex") + ): backend = get_compiler_backend(compilation_config=compilation_config) yield torch.compile( torch.no_grad()(model.forward), @@ -100,6 +104,7 @@ def __init__(self, model_runner: ModelRunner): model_runner.attn_backend.enable_torch_compile = ( model_runner.server_args.enable_torch_compile ) + self.enable_torchair_compile = model_runner.server_args.enable_torchair_compile super().__init__(model_runner) self.update_attr_name = None @@ -138,10 +143,9 @@ def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): compilation_config = get_global_server_args().compilation_config if ( self.enable_torch_compile + and (not self.enable_torchair_compile) and (not self.compile_bs or bs in self.compile_bs) - and (compilation_config.compiler != "npugraph_ex") ): - self.model_runner.attn_backend.enable_torch_compile = True compiler = NpuGraphCompiler( model_runner=self.model_runner, model=run_once_fn, @@ -178,9 +182,7 @@ def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): compiled_function(*args) else: - self.model_runner.attn_backend.enable_torch_compile = False - - if self.enable_torch_compile: + if self.enable_torchair_compile: skip_guard_context = torch.compiler.set_stance( skip_guard_eval_unsafe=True ) diff --git a/python/sglang/srt/hardware_backend/npu/utils.py b/python/sglang/srt/hardware_backend/npu/utils.py index 64deddf77024..8250e9bbbad5 100644 --- a/python/sglang/srt/hardware_backend/npu/utils.py +++ b/python/sglang/srt/hardware_backend/npu/utils.py @@ -64,6 +64,25 @@ def set_default_server_args(args: "ServerArgs"): "Cannot enable both --enable-piecewise-npu-graph-decode and --enable-torchair-compile" ) + if args.enable_piecewise_npu_graph_decode and args.enable_torch_compile: + raise ValueError( + "Cannot enable both --enable-piecewise-npu-graph-decode and --enable-torch-compile" + ) + + if args.enable_torchair_compile and args.enable_torch_compile: + raise ValueError( + "Cannot enable both --enable-torchair-compile and --enable-torch-compile" + ) + + if args.disable_cuda_graph and ( + args.enable_piecewise_npu_graph_decode + or args.enable_torch_compile + or args.enable_torchair_compile + ): + raise ValueError( + f"--enable-piecewise-npu-graph-decode or --enable-torch-compile or --enable-torchair-compile is not appropriate for --disable-cuda-graph" + ) + if args.compilation_config: if args.compilation_config.compiler == "npugraph": args.enable_torch_compile = True @@ -86,6 +105,11 @@ def set_default_server_args(args: "ServerArgs"): if args.compilation_config.compiler == "piecewise": args.enable_piecewise_npu_graph_decode = True + if args.disable_cuda_graph: + raise ValueError( + f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --disable-cuda-graph" + ) + if args.enable_torchair_compile: raise ValueError( f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-torchair-compile" @@ -94,6 +118,11 @@ def set_default_server_args(args: "ServerArgs"): if args.compilation_config.compiler == "npugraph_ex": args.enable_torchair_compile = True + if args.disable_cuda_graph: + raise ValueError( + f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --disable-cuda-graph" + ) + if args.enable_piecewise_npu_graph_decode: raise ValueError( f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-piecewise-npu-graph-decode" diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8a1597144c23..90c25f0f9c60 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -79,9 +79,6 @@ set_global_expert_location_metadata, ) from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater -from sglang.srt.hardware_backend.npu.graph_runner.npu_compile_model_runner import ( - NPUCompileModelRunner, -) from sglang.srt.hardware_backend.npu.graph_runner.npu_graph_runner import NPUGraphRunner from sglang.srt.hardware_backend.npu.graph_runner.piecewise_npu_graph_runner_decode import ( PiecewiseNPUGraphRunnerDecode, @@ -2488,11 +2485,7 @@ def init_device_graphs(self): "npu": ( PiecewiseNPUGraphRunnerDecode if self.server_args.enable_piecewise_npu_graph_decode - else ( - NPUCompileModelRunner - if self.server_args.enable_torchair_compile - else NPUGraphRunner - ) + else NPUGraphRunner ), }, ) diff --git a/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py b/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py index deaf4a5e0387..4ff92ce3e101 100644 --- a/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py +++ b/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py @@ -39,7 +39,7 @@ def setUpClass(cls): "--attention-backend", "ascend", "--disable-radix-cache", - "--enable-torch-compile", + "--enable-torchair-compile", "--watchdog-timeout", 30000, ] From cd0770ceefd89cca2074d2ab9b8567df51c6a5a4 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Wed, 17 Dec 2025 18:45:31 +0300 Subject: [PATCH 59/71] Piecewise Graph temporary removal --- docs/platforms/ascend_npu_pass_development.md | 2 +- .../srt/compilation/compilation_config.py | 5 - .../npu/attention/ascend_backend.py | 14 - .../compilation/npu_graph_backend.py | 64 --- .../piecewise_npu_graph_compiler.py | 52 --- .../piecewise_npu_graph_compiler_backend.py | 275 ----------- .../piecewise_npu_graph_runner_decode.py | 429 ------------------ .../sglang/srt/hardware_backend/npu/utils.py | 39 +- .../sglang/srt/model_executor/model_runner.py | 9 +- python/sglang/srt/models/qwen3.py | 5 +- python/sglang/srt/server_args.py | 18 - python/sglang/srt/utils/common.py | 10 - ...est_ascend_npu_piecewise_graph_tp1_bf16.py | 59 --- test/srt/run_suite.py | 1 - 14 files changed, 5 insertions(+), 977 deletions(-) delete mode 100644 python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_backend.py delete mode 100644 python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler.py delete mode 100644 python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py delete mode 100644 python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py delete mode 100644 test/srt/ascend/test_ascend_npu_piecewise_graph_tp1_bf16.py diff --git a/docs/platforms/ascend_npu_pass_development.md b/docs/platforms/ascend_npu_pass_development.md index 251d10c85e98..355896b0a7d6 100644 --- a/docs/platforms/ascend_npu_pass_development.md +++ b/docs/platforms/ascend_npu_pass_development.md @@ -4,7 +4,7 @@ `PassManager` is implemented here: [PassManager](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/pass_manager.py) -You can explore `PassManager` usage in [`NpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py) compiler backend. [`PiecewiseNpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py) compiler backed uses `PassManager` too via `NpuGraphCompilerBackend` inheritance. +You can explore `PassManager` usage in [`NpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py) compiler backend. ### Pass development There are two approaches to develop passes for SGLang NPU PassManager: diff --git a/python/sglang/srt/compilation/compilation_config.py b/python/sglang/srt/compilation/compilation_config.py index d3558f452d4c..b05aea82435d 100644 --- a/python/sglang/srt/compilation/compilation_config.py +++ b/python/sglang/srt/compilation/compilation_config.py @@ -11,13 +11,11 @@ def __init__( capture_sizes: List[int] = [], compiler: str = "eager", enable_debug_mode: bool = False, - splitting_ops: List[str] = [], ): self.traced_files = set() self.capture_sizes = capture_sizes self.compiler = compiler self.enable_debug_mode = enable_debug_mode - self.splitting_ops = splitting_ops def add_traced_file(self, file_path: str): self.traced_files.add(file_path) @@ -35,6 +33,3 @@ def from_cli(cls, args) -> "CompilationConfig": def get_enable_debug_mode(self): return self.enable_debug_mode - - def get_splitting_ops(self): - return self.splitting_ops diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index ee83104843cb..3889dfbee618 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -231,9 +231,6 @@ class AscendAttnBackend(AttentionBackend): def __init__(self, model_runner: ModelRunner): super().__init__() - self.enable_piecewise_npu_graph_decode = ( - model_runner.server_args.enable_piecewise_npu_graph_decode - ) self.forward_metadata = None self.device = model_runner.device self.page_size = model_runner.page_size @@ -1169,17 +1166,6 @@ def forward_decode_graph( else: actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_int - if ( - self.enable_piecewise_npu_graph_decode - and torch.compiler.is_dynamo_compiling() - ): - # input args for submodule forward - forward_batch.req_to_token_pool.req_to_token.add_( - forward_batch.req_to_token_pool.req_to_token - ) - forward_batch.req_pool_indices.add_(forward_batch.req_pool_indices) - forward_batch.seq_lens.add_(forward_batch.seq_lens) - torch_npu._npu_paged_attention( query=query, key_cache=k_cache, diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_backend.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_backend.py deleted file mode 100644 index 07535210910c..000000000000 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_backend.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2025 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from typing import Any - -import torch -import torch_npu - -from sglang.srt.hardware_backend.npu.graph_runner.compilation.compilation_context import ( - CompilationContext, -) - - -class NPUGraphBackend: - def __init__( - self, - model_runner, - graph: torch.fx.GraphModule, - compilation_context: CompilationContext, - ): - self.model_runner = model_runner - self.graph = graph - self.compilation_context = compilation_context - - self.captured = False - self.output = None - self.npu_graph = None - - def __call__(self, *args) -> Any: - if not self.captured: - if not self.compilation_context.stream: - self.compilation_context.stream = torch_npu.npu.Stream() - - torch.cuda.synchronize() - - self.npu_graph = torch_npu.npu.NPUGraph() - with torch.npu.graph( - self.npu_graph, - stream=self.compilation_context.stream, - pool=self.compilation_context.graph_memory_pool, - ): - - self.output = self.graph.forward(*args) - - if not self.compilation_context.graph_memory_pool: - self.compilation_context.graph_memory_pool = self.npu_graph.pool() - - self.npu_graph.replay() - self.captured = True - else: - self.npu_graph.replay() - - return self.output diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler.py deleted file mode 100644 index e46e3c88c573..000000000000 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2023-2025 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import pathlib -import sys - -import torch - -from sglang.srt.compilation.compilation_config import CompilationConfig -from sglang.srt.hardware_backend.npu.graph_runner.compilation.compilation_context import ( - CompilationContext, -) -from sglang.srt.utils.common import get_compiler_backend - - -class PiecewiseNpuGraphCompiler: - def __init__( - self, - model_runner, - model: torch.nn.Module, - compilation_config: CompilationConfig, - compilation_context: CompilationContext, - batch_size: int, - ): - if compilation_config is None: - compilation_config = CompilationConfig(compiler="piecewise") - - backend = get_compiler_backend( - model_runner=model_runner, - compilation_config=compilation_config, - compilation_context=compilation_context, - ) - backend.init(model_runner.model_config, batch_size) - - torch._dynamo.reset() - torch.compiler.allow_in_graph(sys.intern) - torch.compiler.allow_in_graph(pathlib.Path) - - self.compiled_callable = torch.compile( - model, fullgraph=True, dynamic=False, backend=backend - ) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py deleted file mode 100644 index bd0a98215ed8..000000000000 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/piecewise_npu_graph_compiler_backend.py +++ /dev/null @@ -1,275 +0,0 @@ -# Copyright 2023-2025 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import dataclasses -import importlib -import logging -from typing import Any, Callable - -import torch - -from sglang.srt.compilation.compilation_config import CompilationConfig -from sglang.srt.distributed import get_tensor_model_parallel_world_size -from sglang.srt.hardware_backend.npu.graph_runner.compilation.compilation_context import ( - CompilationContext, -) -from sglang.srt.hardware_backend.npu.graph_runner.compilation.npu_graph_compiler_backend import ( - NpuGraphCompilerBackend, -) - -logger = logging.getLogger(__name__) - - -class Submodule(torch.nn.Module): - block_tables = None - - def __init__(self, page_size, model_config): - self.page_size = page_size - self.config = model_config - - tp_size = get_tensor_model_parallel_world_size() - assert self.config.num_attention_heads % tp_size == 0 - self.num_heads = self.config.num_attention_heads // tp_size - - self.total_num_kv_heads = self.config.num_key_value_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - - self.hidden_size = self.config.hidden_size - self.head_dim = getattr( - self.config, "head_dim", self.hidden_size // self.config.num_attention_heads - ) - - self.scaling = self.head_dim**-0.5 - - def forward_with_calculation( - self, - l_args_2_req_to_token_pool_req_to_token, - l_args_2_req_pool_indices, - l_args_2_seq_lens, - query_2, - l_args_2_token_to_kv_pool_k_buffer_0_, - l_args_2_token_to_kv_pool_v_buffer_0_, - l_args_2_attn_backend_forward_metadata_block_tables, - l_args_2_attn_backend_forward_metadata_seq_lens_cpu_int, - output, - ): - Submodule.block_tables = ( - l_args_2_req_to_token_pool_req_to_token[ - l_args_2_req_pool_indices, : l_args_2_seq_lens.max() - ][:, :: self.page_size] - // self.page_size - ) - _npu_paged_attention = torch.ops.atb._npu_paged_attention( - query=query_2, - key_cache=l_args_2_token_to_kv_pool_k_buffer_0_, - value_cache=l_args_2_token_to_kv_pool_v_buffer_0_, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - scale_value=self.scaling, - block_table=Submodule.block_tables, - context_lens=l_args_2_attn_backend_forward_metadata_seq_lens_cpu_int, - out=output, - ) - - def forward( - self, - l_args_2_req_to_token_pool_req_to_token, - l_args_2_req_pool_indices, - l_args_2_seq_lens, - query_2, - l_args_2_token_to_kv_pool_k_buffer_0_, - l_args_2_token_to_kv_pool_v_buffer_0_, - l_args_2_attn_backend_forward_metadata_block_tables, - l_args_2_attn_backend_forward_metadata_seq_lens_cpu_int, - output, - ): - _npu_paged_attention = torch.ops.atb._npu_paged_attention( - query=query_2, - key_cache=l_args_2_token_to_kv_pool_k_buffer_0_, - value_cache=l_args_2_token_to_kv_pool_v_buffer_0_, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - scale_value=self.scaling, - block_table=Submodule.block_tables, - context_lens=l_args_2_attn_backend_forward_metadata_seq_lens_cpu_int, - out=output, - ) - - -def resolve_obj_by_qualname(qualname: str) -> Any: - module_name, obj_name = qualname.rsplit(".", 1) - module = importlib.import_module(module_name) - return getattr(module, obj_name) - - -@dataclasses.dataclass -class SplitItem: - submod_name: str - graph_id: int - is_compiled_only: bool - graph: torch.fx.GraphModule - - -class PiecewiseNpuGraphCompilerBackend(NpuGraphCompilerBackend): - graph: torch.fx.GraphModule - - def __init__( - self, - model_runner, - compilation_config: CompilationConfig, - compilation_context: CompilationContext, - ): - super().__init__(model_runner) - - self.model_runner = model_runner - self.model_config = model_runner.model.config - - self.compilation_config = compilation_config - self.page_size = model_runner.page_size - self.compilation_context = compilation_context - - self.split_gm = None - self.piecewise_graphs = None - - self.callables = {} - self.callables_by_branch = {} - - def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: - example_inputs_len = len(example_inputs) - if example_inputs_len in self.callables: - callable = self.callables[example_inputs_len] - return callable - - super().__call__(graph, example_inputs) - - self.graph = graph - self.split_gm, self.piecewise_graphs = ( - PiecewiseNpuGraphCompilerBackend.split_graph( - self.graph, self.compilation_config.splitting_ops - ) - ) - - npu_graph_backend = resolve_obj_by_qualname( - "sglang.srt.hardware_backend.npu.graph_runner.compilation.npu_graph_backend.NPUGraphBackend" - ) - - self.submod_names_compiled_only = [ - item.submod_name for item in self.piecewise_graphs if item.is_compiled_only - ] - - named_modules = self.split_gm.named_modules() - submod = Submodule(self.page_size, self.model_config) - use_forward = False - for name, graph_module in named_modules: - if not name: - continue - - graph = getattr(self.split_gm, name) - if name in self.submod_names_compiled_only: - if use_forward: - self.split_gm.__dict__[name] = submod.forward - else: - self.split_gm.__dict__[name] = submod.forward_with_calculation - use_forward = True - else: - self.split_gm.__dict__[name] = npu_graph_backend( - self.model_runner, graph, self.compilation_context - ) - - self.split_gm(*example_inputs) - self.callables[example_inputs_len] = self.split_gm.forward - return self.split_gm.forward - - def split_graph( - graph: torch.fx.GraphModule, ops: list[str] - ) -> tuple[torch.fx.GraphModule, list[SplitItem]]: - subgraph_id = 0 - node_to_subgraph_id = {} - graphs_for_compilation = [] - - node_index = 0 - node_index_max = len(graph.graph.nodes) - - nodes = list(graph.graph.nodes) - - counter = 1 - ops_count = 3 - ops_step = ops_count + 1 - while node_index < node_index_max: - if ( - (node_index + ops_count) < node_index_max - and nodes[node_index + ops_count].op == "call_function" - and str(nodes[node_index + ops_count].target) in ops - ): - subgraph_id += 1 - graphs_for_compilation.append(subgraph_id) - - for submodule_node_index in range(node_index, node_index + ops_step): - submodule_node = nodes[submodule_node_index] - node_to_subgraph_id[submodule_node] = subgraph_id - counter = counter + 1 - node_index += ops_step - - subgraph_id += 1 - else: - node = nodes[node_index] - if node.op in ("output", "placeholder"): - node_index += 1 - elif node.op == "call_function" and str(node.target) in ops: - subgraph_id += 1 - graphs_for_compilation.append(subgraph_id) - - node_to_subgraph_id[node] = subgraph_id - node_index += 1 - - subgraph_id += 1 - else: - node_to_subgraph_id[node] = subgraph_id - node_index += 1 - counter += 1 - - split_gm = torch.fx.passes.split_module.split_module( - graph, - None, - lambda node: node_to_subgraph_id[node], - keep_original_order=True, - ) - - names = [name for (name, module) in split_gm.named_modules()] - - outputs = [] - for name in names: - if "." in name or name == "": - # recursive child module or the root module - continue - - module = getattr(split_gm, name) - - graph_id = int(name.replace("submod_", "")) - outputs.append( - SplitItem(name, graph_id, (graph_id in graphs_for_compilation), module) - ) - - # sort by intetger graph_id, rather than string name - outputs.sort(key=lambda x: x.graph_id) - - return split_gm, outputs diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py b/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py deleted file mode 100644 index a535995b4b4a..000000000000 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/piecewise_npu_graph_runner_decode.py +++ /dev/null @@ -1,429 +0,0 @@ -# Copyright 2025 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Run the model with npu graph and torch.compile.""" - -from __future__ import annotations - -import bisect -import gc -from typing import TYPE_CHECKING, Callable, Optional, Union - -import torch -import torch._dynamo.config -import tqdm -from torch._dynamo.eval_frame import DisableContext - -from sglang.srt.compilation.compilation_config import CompilationConfig -from sglang.srt.distributed import get_tensor_model_parallel_rank -from sglang.srt.distributed.parallel_state import graph_capture -from sglang.srt.hardware_backend.npu.attention.ascend_backend import AscendAttnBackend -from sglang.srt.hardware_backend.npu.graph_runner.compilation.compilation_context import ( - CompilationContext, -) -from sglang.srt.hardware_backend.npu.graph_runner.compilation.patch_dynamo import ( - patch_dynamo_context, - patch_dynamo_context_call, - restore_dynamo_context_call, -) -from sglang.srt.hardware_backend.npu.graph_runner.compilation.piecewise_npu_graph_compiler import ( - PiecewiseNpuGraphCompiler, -) -from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner -from sglang.srt.model_executor.forward_batch_info import ( - CaptureHiddenMode, - ForwardBatch, - PPProxyTensors, -) -from sglang.srt.server_args import get_global_server_args -from sglang.srt.utils import get_available_gpu_memory - -torch._dynamo.config.skip_nnmodule_hook_guards = True -torch._dynamo.config.automatic_dynamic_shapes = False -torch._dynamo.config.guard_nn_modules = False - -import logging - -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from sglang.srt.model_executor.model_runner import ModelRunner - - -torch.cuda.CUDAGraph = torch.npu.NPUGraph -torch.cuda.synchronize = torch.npu.synchronize -torch.cuda.graph = torch.npu.graph -torch.cuda.stream = torch.npu.stream -torch.cuda.Stream = torch.npu.Stream -torch.cuda.current_stream = torch.npu.current_stream -torch.cuda.graph_pool_handle = torch.npu.graph_pool_handle - - -class CompiledGraph: - def __init__( - self, - bs: int, - forward_batch: ForwardBatch, - attn_backend: AscendAttnBackend, - callable, - ): - self.bs = bs - self.forward_batch = forward_batch - self.attn_backend = attn_backend - self.callable = callable - - -class PiecewiseNPUGraphRunnerDecode(CudaGraphRunner): - """A PiecewiseNPUGraphRunnerDecode runs the forward pass of a model with npu graph and torch.compile.""" - - def __init__(self, model_runner: ModelRunner): - model_runner.attn_backend.enable_piecewise_npu_graph_decode = True - patch_dynamo_context() - self.init_forward_metadata_was_done = True - - # Parse args - self.model_runner = model_runner - compilation_config = get_global_server_args().compilation_config - if compilation_config is None: - compilation_config = CompilationConfig( - compiler="piecewise", splitting_ops=["atb._npu_paged_attention"] - ) - self.compilation_config = compilation_config - self.compilation_context = CompilationContext() - - self.graphs = {} - self.output_buffers = {} - self.enable_torch_compile = model_runner.server_args.enable_torch_compile - - # Graph inputs - with torch.device(self.model_runner.device): - self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32) - self.block_tables = torch.full((160, 160), 0, dtype=torch.int32) - - super().__init__(model_runner) - - def can_run(self, forward_batch: ForwardBatch): - return ( - (self.pp_size <= 1) - and (not self.is_encoder_decoder) - and (not self.enable_two_batch_overlap) - and super().can_run(forward_batch) - ) - - def capture(self, forward_batch_: ForwardBatch = None, bs_: int = None) -> None: - with graph_capture() as graph_capture_context: - self.stream = graph_capture_context.stream - - self.model_runner.tp_group.barrier() - - avail_mem = get_available_gpu_memory( - self.model_runner.device, self.model_runner.gpu_id, empty_cache=False - ) - - # Reverse the order to enable better memory sharing across cuda graphs. - capture_range = ( - tqdm.tqdm(list(reversed(self.capture_bs))) - if get_tensor_model_parallel_rank() == 0 - else reversed(self.capture_bs) - ) - - for bs in capture_range: - if get_tensor_model_parallel_rank() == 0: - avail_mem = get_available_gpu_memory( - self.model_runner.device, - self.model_runner.gpu_id, - empty_cache=False, - ) - capture_range.set_description( - f"Capturing batches ({avail_mem=:.2f} GB)" - ) - - (compiled_graph, output_buffers) = self.capture_one_batch_size( - bs, self.model_runner.model.forward, forward_batch_=forward_batch_ - ) - self.graphs[bs] = compiled_graph - self.output_buffers[bs] = output_buffers - - def init_forward_metadata_attn_backend( - self, bs: int, attn_backend: AscendAttnBackend, forward_batch: ForwardBatch - ): - attn_backend.forward_metadata.block_tables = self.block_tables - - seq_lens_cpu_int = forward_batch.seq_lens_cpu_int - seq_lens_cpu_int[ - : attn_backend.forward_metadata.seq_lens_cpu_int.shape[0] - ].copy_(attn_backend.forward_metadata.seq_lens_cpu_int) - attn_backend.forward_metadata.seq_lens_cpu_int = seq_lens_cpu_int - - def init_forward_batch( - self, bs: int, attn_backend: AscendAttnBackend, forward_batch_: ForwardBatch - ) -> ForwardBatch: - if forward_batch_: - return forward_batch_ - - num_tokens = bs * self.num_tokens_per_bs - - with torch.device(self.model_runner.device): - req_pool_indices = torch.zeros((bs,), dtype=torch.int32) - seq_lens = torch.full((bs,), self.seq_len_fill_value, dtype=torch.int32) - out_cache_loc = torch.zeros((bs,), dtype=torch.int32) - positions = torch.zeros((bs,), dtype=torch.int64) - input_ids = torch.zeros((bs,), dtype=torch.int64) - mrope_positions = torch.zeros((3, self.max_num_token), dtype=torch.int64) - - spec_info = self.get_spec_info(num_tokens) - if self.capture_hidden_mode != CaptureHiddenMode.FULL: - self.capture_hidden_mode = ( - spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL - ) - - forward_batch = ForwardBatch( - forward_mode=self.capture_forward_mode, - batch_size=bs, - input_ids=input_ids, - req_pool_indices=req_pool_indices, - seq_lens=seq_lens, - req_to_token_pool=self.model_runner.req_to_token_pool, - token_to_kv_pool=self.model_runner.token_to_kv_pool, - attn_backend=attn_backend, - out_cache_loc=out_cache_loc, - seq_lens_sum=seq_lens.sum(), - encoder_lens=None, - return_logprob=False, - positions=positions, - global_num_tokens_gpu=None, - mrope_positions=mrope_positions, - spec_algorithm=self.model_runner.spec_algorithm, - spec_info=spec_info, - capture_hidden_mode=self.capture_hidden_mode, - num_token_non_padded=self.num_token_non_padded, - global_forward_mode=self.capture_forward_mode, - ) - - seq_lens_cpu_int = torch.zeros((bs,), dtype=torch.int32, device="cpu") - forward_batch.seq_lens_cpu_int = seq_lens_cpu_int - - seq_lens_cpu = torch.full((bs,), 1, dtype=torch.int32, device="cpu") - forward_batch.seq_lens_cpu = seq_lens_cpu - - for i in range(bs): - forward_batch.global_forward_mode = None - forward_batch.input_ids[i] = 323 - forward_batch.num_token_non_padded = None - forward_batch.out_cache_loc[i] = 134 - forward_batch.positions[i] = 6 - forward_batch.seq_lens[i] = 7 - forward_batch.seq_lens_cpu[i] = 7 - forward_batch.seq_lens_cpu_int[i] = 7 - forward_batch.req_pool_indices[i] = 1 - forward_batch.seq_lens_sum = sum(forward_batch.seq_lens) - - attn_backend.init_forward_metadata(forward_batch) - - self.init_forward_metadata_attn_backend(bs, attn_backend, forward_batch) - - # Clean intermediate result cache for DP attention - forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None - return forward_batch - - def capture_one_batch_size( - self, - bs: int, - forward: Callable, - forward_batch_: ForwardBatch = None, - compile: bool = True, - ): - attn_backend = self.model_runner.attn_backend - attn_backend.init_cuda_graph_state(bs, self.max_num_token) - - self.model_runner.attn_backend = attn_backend - - for _ in range(2): - forward_batch = self.init_forward_batch(bs, attn_backend, forward_batch_) - - torch.cuda.synchronize() - self.model_runner.tp_group.barrier() - - self.model_runner.attn_backend.graph_mode = True - self.model_runner.model( - forward_batch.input_ids, forward_batch.positions, forward_batch - ) - - forward_batch = self.init_forward_batch(bs, attn_backend, forward_batch_) - - self.compilation_context.stream = self.stream - self.model_runner.attn_backend.graph_mode = True - - compiler = PiecewiseNpuGraphCompiler( - model_runner=self.model_runner, - model=self.model_runner.model, - compilation_config=self.compilation_config, - compilation_context=self.compilation_context, - batch_size=bs, - ) - - patch_dynamo_context_call() - DisableContext.batch_size = bs - - logits_output_or_pp_proxy_tensors = compiler.compiled_callable( - forward_batch.input_ids, forward_batch.positions, forward_batch - ) - - compiled_graph = CompiledGraph( - bs, forward_batch, None, compiler.compiled_callable - ) - - try: - logits_output_or_pp_proxy_tensors = compiler.compiled_callable( - forward_batch.input_ids, forward_batch.positions, forward_batch - ) - finally: - DisableContext.batch_size = None - restore_dynamo_context_call() - - assert DisableContext.compiled_function - assert DisableContext.compiled_function_args - - torch._dynamo.reset() - gc.collect() - - return (compiled_graph, logits_output_or_pp_proxy_tensors) - - def replay_prepare( - self, - forward_batch: ForwardBatch, - pp_proxy_tensors: Optional[PPProxyTensors] = None, - ): - raw_bs = forward_batch.batch_size - raw_num_token = raw_bs * self.num_tokens_per_bs - - # Pad - if self.require_mlp_tp_gather: - max_num_tokens = max(forward_batch.global_num_tokens_cpu) - max_batch_size = ( - max_num_tokens / self.num_tokens_per_bs - if self.model_runner.spec_algorithm.is_eagle() - else max_num_tokens - ) - index = bisect.bisect_left(self.capture_bs, max_batch_size) - else: - index = bisect.bisect_left(self.capture_bs, raw_bs) - - bs = self.capture_bs[index] - compiled_graph = self.graphs[bs] - - compiled_graph.forward_batch.input_ids[ - : forward_batch.input_ids.shape[0] - ].copy_(forward_batch.input_ids) - forward_batch.input_ids = compiled_graph.forward_batch.input_ids - - compiled_graph.forward_batch.seq_lens[: forward_batch.seq_lens.shape[0]].copy_( - forward_batch.seq_lens - ) - forward_batch.seq_lens = compiled_graph.forward_batch.seq_lens - - compiled_graph.forward_batch.req_pool_indices[ - : forward_batch.req_pool_indices.shape[0] - ].copy_(forward_batch.req_pool_indices) - forward_batch.req_pool_indices = compiled_graph.forward_batch.req_pool_indices - - compiled_graph.forward_batch.out_cache_loc[ - : forward_batch.out_cache_loc.shape[0] - ].copy_(forward_batch.out_cache_loc) - forward_batch.out_cache_loc = compiled_graph.forward_batch.out_cache_loc - - compiled_graph.forward_batch.positions[ - : forward_batch.positions.shape[0] - ].copy_(forward_batch.positions) - forward_batch.positions = compiled_graph.forward_batch.positions - - if forward_batch.seq_lens_cpu is not None: - compiled_graph.forward_batch.seq_lens_cpu[ - : forward_batch.seq_lens_cpu.shape[0] - ].copy_(forward_batch.seq_lens_cpu) - forward_batch.seq_lens_cpu = compiled_graph.forward_batch.seq_lens_cpu - - if pp_proxy_tensors: - for key in self.pp_proxy_tensors.keys(): - dim = pp_proxy_tensors[key].shape[0] - self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key]) - - if forward_batch.mrope_positions is not None: - compiled_graph.forward_batch.mrope_positions[:, :raw_num_token].copy_( - forward_batch.mrope_positions - ) - - # Store fields - self.raw_bs = raw_bs - self.raw_num_token = raw_num_token - self.bs = bs - - def replay( - self, - forward_batch: ForwardBatch, - skip_attn_backend_init: bool = False, - pp_proxy_tensors: Optional[PPProxyTensors] = None, - ) -> Union[LogitsProcessorOutput, PPProxyTensors]: - self.replay_prepare(forward_batch, pp_proxy_tensors) - - def init(): - attn_backend = self.model_runner.attn_backend - forward_batch.attn_backend = attn_backend - - compiled_graph: CompiledGraph = self.graphs[self.bs] - - attn_backend = self.model_runner.attn_backend - if not self.init_forward_metadata_was_done: - attn_backend.init_forward_metadata(forward_batch) - self.init_forward_metadata_was_done = True - else: - if forward_batch.extend_seq_lens is not None: - attn_backend.forward_metadata.extend_seq_lens_cpu_int = ( - forward_batch.extend_seq_lens.cpu().int() - ) - attn_backend.forward_metadata.seq_lens_cpu_int = ( - forward_batch.seq_lens_cpu.int() - ) - - self.init_forward_metadata_attn_backend( - self.bs, attn_backend, compiled_graph.forward_batch - ) - - init() - - self.model_runner.attn_backend.graph_mode = True - - DisableContext.compiled_function[self.bs]( - *DisableContext.compiled_function_args[self.bs] - ) - - output = self.output_buffers[self.bs] - - if isinstance(output, LogitsProcessorOutput): - result = LogitsProcessorOutput( - next_token_logits=output.next_token_logits[: self.raw_num_token], - hidden_states=( - output.hidden_states[: self.raw_num_token] - if output.hidden_states is not None - else None - ), - ) - else: - assert isinstance(output, PPProxyTensors) - result = PPProxyTensors( - {k: v[: self.bs] for k, v in output.tensors.items()} - ) - - return result diff --git a/python/sglang/srt/hardware_backend/npu/utils.py b/python/sglang/srt/hardware_backend/npu/utils.py index 48eebb74e3c3..a9c3036e9ecf 100644 --- a/python/sglang/srt/hardware_backend/npu/utils.py +++ b/python/sglang/srt/hardware_backend/npu/utils.py @@ -82,52 +82,22 @@ def set_default_server_args(args: "ServerArgs"): else: args.hicache_mem_layout = "page_first_direct" - if args.enable_piecewise_npu_graph_decode and args.enable_torchair_compile: - raise ValueError( - "Cannot enable both --enable-piecewise-npu-graph-decode and --enable-torchair-compile" - ) - - if args.enable_piecewise_npu_graph_decode and args.enable_torch_compile: - raise ValueError( - "Cannot enable both --enable-piecewise-npu-graph-decode and --enable-torch-compile" - ) - if args.enable_torchair_compile and args.enable_torch_compile: raise ValueError( "Cannot enable both --enable-torchair-compile and --enable-torch-compile" ) if args.disable_cuda_graph and ( - args.enable_piecewise_npu_graph_decode - or args.enable_torch_compile - or args.enable_torchair_compile + args.enable_torch_compile or args.enable_torchair_compile ): raise ValueError( - f"--enable-piecewise-npu-graph-decode or --enable-torch-compile or --enable-torchair-compile is not appropriate for --disable-cuda-graph" + f"--enable-torch-compile or --enable-torchair-compile is not appropriate for --disable-cuda-graph" ) if args.compilation_config: if args.compilation_config.compiler == "npugraph": args.enable_torch_compile = True - if args.disable_cuda_graph: - raise ValueError( - f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --disable-cuda-graph" - ) - - if args.enable_piecewise_npu_graph_decode: - raise ValueError( - f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-piecewise-npu-graph-decode" - ) - - if args.enable_torchair_compile: - raise ValueError( - f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-torchair-compile" - ) - - if args.compilation_config.compiler == "piecewise": - args.enable_piecewise_npu_graph_decode = True - if args.disable_cuda_graph: raise ValueError( f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --disable-cuda-graph" @@ -146,11 +116,6 @@ def set_default_server_args(args: "ServerArgs"): f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --disable-cuda-graph" ) - if args.enable_piecewise_npu_graph_decode: - raise ValueError( - f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-piecewise-npu-graph-decode" - ) - if args.enable_torchair_compile: args.enable_torch_compile = True diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5a7b6c1ad0d6..197de03c5860 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -80,9 +80,6 @@ ) from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater from sglang.srt.hardware_backend.npu.graph_runner.npu_graph_runner import NPUGraphRunner -from sglang.srt.hardware_backend.npu.graph_runner.piecewise_npu_graph_runner_decode import ( - PiecewiseNPUGraphRunnerDecode, -) from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.attention.attention_registry import ( ATTENTION_BACKENDS, @@ -2535,11 +2532,7 @@ def init_device_graphs(self): lambda: CudaGraphRunner, { "cpu": CPUGraphRunner, - "npu": ( - PiecewiseNPUGraphRunnerDecode - if self.server_args.enable_piecewise_npu_graph_decode - else NPUGraphRunner - ), + "npu": NPUGraphRunner, }, ) self.graph_runner = graph_runners[self.device](self) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index c7c5474bc5a0..eb531acbdf6d 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -39,10 +39,7 @@ _is_npu = is_npu() if _is_npu: - if supports_custom_op() and ( - get_global_server_args().enable_torch_compile - or get_global_server_args().enable_piecewise_npu_graph_decode - ): + if supports_custom_op() and get_global_server_args().enable_torch_compile: from sglang.srt.hardware_backend.npu.cmo import get_weight_cache from sglang.srt.hardware_backend.npu.cmo_custom_ops import ( get_cmo_stream, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 77ba3db2e804..cccf44925a3a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -531,7 +531,6 @@ class ServerArgs: cuda_graph_bs: Optional[List[int]] = None disable_cuda_graph: bool = False disable_cuda_graph_padding: bool = False - enable_piecewise_npu_graph_decode: bool = False enable_profile_cuda_graph: bool = False enable_cudagraph_gc: bool = False enable_layerwise_nvtx_marker: bool = False @@ -1535,7 +1534,6 @@ def _handle_attention_backend_compatibility(self): "Cuda graph is disabled because of using torch native attention backend" ) self.disable_cuda_graph = True - self.enable_piecewise_npu_graph_decode = False if self.attention_backend == "flex_attention": logger.warning( @@ -1822,7 +1820,6 @@ def _handle_a2a_moe(self): if self.deepep_mode == "normal": logger.warning("Cuda graph is disabled because deepep_mode=`normal`") self.disable_cuda_graph = True - self.enable_piecewise_npu_graph_decode = False self.ep_size = self.tp_size logger.warning( f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." @@ -2174,10 +2171,6 @@ def _handle_pd_disaggregation(self): self.disaggregation_prefill_pp = self.pp_size self.validate_disagg_tp_size(self.tp_size, self.disaggregation_decode_tp) - if self.enable_piecewise_npu_graph_decode: - self.enable_piecewise_npu_graph_decode = False - logger.warning("NPU piecewise graph is disabled for decode server") - if not self.enable_piecewise_cuda_graph: self.disable_cuda_graph = True logger.warning( @@ -2390,12 +2383,6 @@ def _handle_other_validations(self): "Torch compile is disabled because custom ops are not supported" ) - if self.enable_piecewise_npu_graph_decode: - self.enable_piecewise_npu_graph_decode = False - logger.warning( - "Piecewise graph decode is disabled because custom ops are not supported" - ) - if self.enable_torchair_compile: self.enable_torchair_compile = False logger.warning( @@ -3889,11 +3876,6 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Disable cuda graph.", ) - parser.add_argument( - "--enable-piecewise-npu-graph-decode", - action="store_true", - help="Optimize the model with piecewise npu graph for decode.", - ) parser.add_argument( "--disable-cuda-graph-padding", action="store_true", diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index c993dd4da693..308a5a3cf6df 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -1950,7 +1950,6 @@ def get_compiler_backend( mode: str = None, model_runner=None, compilation_config: CompilationConfig = None, - compilation_context=None, ) -> str: if hasattr(torch, "hpu") and torch.hpu.is_available(): return "hpu_backend" @@ -1979,15 +1978,6 @@ def get_compiler_backend( npu_backend = torchair.get_npu_backend(compiler_config=compiler_config) return npu_backend - if compilation_config.compiler == "piecewise": - from sglang.srt.hardware_backend.npu.graph_runner.compilation.piecewise_npu_graph_compiler_backend import ( - PiecewiseNpuGraphCompilerBackend, - ) - - return PiecewiseNpuGraphCompilerBackend( - model_runner, compilation_config, compilation_context - ) - if compilation_config.compiler == "npugraph": from sglang.srt.hardware_backend.npu.graph_runner.compilation.npu_graph_compiler_backend import ( NpuGraphCompilerBackend, diff --git a/test/srt/ascend/test_ascend_npu_piecewise_graph_tp1_bf16.py b/test/srt/ascend/test_ascend_npu_piecewise_graph_tp1_bf16.py deleted file mode 100644 index e3fdd84969ed..000000000000 --- a/test/srt/ascend/test_ascend_npu_piecewise_graph_tp1_bf16.py +++ /dev/null @@ -1,59 +0,0 @@ -import unittest -from types import SimpleNamespace -from urllib.parse import urlparse - -from sglang.srt.utils import kill_process_tree -from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k -from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - popen_launch_server, -) - -DEFAULT_MODEL_NAME_FOR_TEST = "Qwen/Qwen2.5-7B-Instruct" - - -class TestAscendNpuPiecewiseGraph(CustomTestCase): - def test_gsm8k(self): - model = DEFAULT_MODEL_NAME_FOR_TEST - base_url = DEFAULT_URL_FOR_TEST - url = urlparse(base_url) - process = popen_launch_server( - model, - base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--attention-backend", - "ascend", - "--mem-fraction-static", - 0.7, - "--enable-piecewise-npu-graph-decode", - "--cuda-graph-bs", - "128", - "--tp-size", - "1", - ], - ) - - try: - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=1319, - max_new_tokens=512, - parallel=128, - host=f"http://{url.hostname}", - port=int(url.port), - ) - - metrics = run_eval_few_shot_gsm8k(args) - self.assertGreaterEqual(metrics["accuracy"], 0.62) - self.assertLessEqual(metrics["latency"], 150) - finally: - kill_process_tree(process.pid) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index e62de39b847d..81042f689d0c 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -370,7 +370,6 @@ TestFile("ascend/test_ascend_tp1_bf16.py", 400), TestFile("ascend/test_ascend_compile_graph_tp1_bf16.py", 400), TestFile("ascend/test_ascend_npu_graph_compile_tp1_bf16.py", 400), - TestFile("ascend/test_ascend_npu_piecewise_graph_tp1_bf16.py", 400), ], "per-commit-2-npu-a2": [ TestFile("ascend/test_ascend_graph_tp2_bf16.py", 400), From bdc4b433bda117020ca85bffc65e3af78b2d4714 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Wed, 17 Dec 2025 19:42:01 +0300 Subject: [PATCH 60/71] fix import comment --- .../npu/graph_runner/compilation/custom_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/custom_ops.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/custom_ops.py index 1c4dfb5473b1..1ade7dae46f2 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/custom_ops.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/custom_ops.py @@ -1,6 +1,6 @@ from typing import List -import sgl_kernel_npu.norm.split_qkv_rmsnorm_rope +import sgl_kernel_npu.norm.split_qkv_rmsnorm_rope as sgl_kernel_npu import torch @@ -18,7 +18,7 @@ def split_qkv_rmsnorm_rope( q_bias: torch.Tensor, k_bias: torch.Tensor, ) -> List[torch.Tensor]: - q, k, v = sgl_kernel_npu.norm.split_qkv_rmsnorm_rope.split_qkv_rmsnorm_rope( + q, k, v = sgl_kernel_npu.split_qkv_rmsnorm_rope( input, sin, cos, From 499c1851184d84466e8661ccf69fcddcf5ada981 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Thu, 18 Dec 2025 12:20:52 +0300 Subject: [PATCH 61/71] refactoring: command line arg renaming --- .../npu/attention/ascend_backend.py | 10 +++++---- .../npu/graph_runner/npu_graph_runner.py | 8 ++++--- .../sglang/srt/hardware_backend/npu/utils.py | 18 +++++++-------- python/sglang/srt/server_args.py | 22 ++++++++++++++----- .../test_ascend_compile_graph_tp1_bf16.py | 2 +- 5 files changed, 38 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index 3889dfbee618..f37794f382b0 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -263,8 +263,10 @@ def __init__(self, model_runner: ModelRunner): if self.use_mla: self.ringmla_mask = self.ascend_attn_mask_builder.ringmla_mask - self.enable_torchair_compile = model_runner.server_args.enable_torchair_compile - if self.enable_torchair_compile: + self.enable_npu_torchair_compile = ( + model_runner.server_args.enable_npu_torchair_compile + ) + if self.enable_npu_torchair_compile: max_total_tokens = model_runner.max_total_num_tokens self.max_seqlen_pad = max_total_tokens // model_runner.server_args.page_size @@ -296,7 +298,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ) if ( - self.enable_torchair_compile + self.enable_npu_torchair_compile and forward_batch.forward_mode.is_decode_or_idle() ): bs = forward_batch.input_ids.size(0) @@ -1277,7 +1279,7 @@ def forward_decode( topk_indices, ) - if self.graph_mode and (not self.enable_torchair_compile): + if self.graph_mode and (not self.enable_npu_torchair_compile): return self.forward_decode_graph( q, k, diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py index 4c5e37e4edd9..d40797e54850 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py @@ -104,7 +104,9 @@ def __init__(self, model_runner: ModelRunner): model_runner.attn_backend.enable_torch_compile = ( model_runner.server_args.enable_torch_compile ) - self.enable_torchair_compile = model_runner.server_args.enable_torchair_compile + self.enable_npu_torchair_compile = ( + model_runner.server_args.enable_npu_torchair_compile + ) super().__init__(model_runner) self.update_attr_name = None @@ -143,7 +145,7 @@ def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): compilation_config = get_global_server_args().compilation_config if ( self.enable_torch_compile - and (not self.enable_torchair_compile) + and (not self.enable_npu_torchair_compile) and (not self.compile_bs or bs in self.compile_bs) ): compiler = NpuGraphCompiler( @@ -182,7 +184,7 @@ def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): compiled_function(*args) else: - if self.enable_torchair_compile: + if self.enable_npu_torchair_compile: skip_guard_context = torch.compiler.set_stance( skip_guard_eval_unsafe=True ) diff --git a/python/sglang/srt/hardware_backend/npu/utils.py b/python/sglang/srt/hardware_backend/npu/utils.py index a9c3036e9ecf..c0913f592552 100644 --- a/python/sglang/srt/hardware_backend/npu/utils.py +++ b/python/sglang/srt/hardware_backend/npu/utils.py @@ -82,16 +82,16 @@ def set_default_server_args(args: "ServerArgs"): else: args.hicache_mem_layout = "page_first_direct" - if args.enable_torchair_compile and args.enable_torch_compile: + if args.enable_npu_torchair_compile and args.enable_torch_compile: raise ValueError( - "Cannot enable both --enable-torchair-compile and --enable-torch-compile" + "Cannot enable both --enable-npu-torchair-compile and --enable-torch-compile" ) if args.disable_cuda_graph and ( - args.enable_torch_compile or args.enable_torchair_compile + args.enable_torch_compile or args.enable_npu_torchair_compile ): raise ValueError( - f"--enable-torch-compile or --enable-torchair-compile is not appropriate for --disable-cuda-graph" + f"--enable-torch-compile or --enable-npu-torchair-compile is not appropriate for --disable-cuda-graph" ) if args.compilation_config: @@ -103,21 +103,21 @@ def set_default_server_args(args: "ServerArgs"): f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --disable-cuda-graph" ) - if args.enable_torchair_compile: + if args.enable_npu_torchair_compile: raise ValueError( - f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-torchair-compile" + f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-npu-torchair-compile" ) if args.compilation_config.compiler == "npugraph_ex": - args.enable_torchair_compile = True + args.enable_npu_torchair_compile = True if args.disable_cuda_graph: raise ValueError( f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --disable-cuda-graph" ) - if args.enable_torchair_compile: - args.enable_torch_compile = True + if args.enable_npu_torchair_compile: + args.enable_torch_compile = True @_call_once diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index cccf44925a3a..78471621f12c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -552,7 +552,7 @@ class ServerArgs: tbo_token_distribution_threshold: float = 0.48 enable_torch_compile: bool = False enable_piecewise_cuda_graph: bool = False - enable_torchair_compile: bool = False + enable_npu_torchair_compile: bool = False enable_torch_compile_debug_mode: bool = False torch_compile_max_bs: int = 32 piecewise_cuda_graph_max_tokens: int = 4096 @@ -2376,6 +2376,18 @@ def _handle_other_validations(self): self.disable_cuda_graph = True self.skip_server_warmup = True + if not is_npu() and ( + self.enable_npu_torchair_compile + or ( + self.compilation_config is not None + and self.compilation_config.compiler == "npugraph_ex" + ) + ): + self.enable_npu_torchair_compile = False + logger.warning( + "NPU TorchAir compile is disabled, the argument is appropriate for NPU only" + ) + if is_npu() and not supports_custom_op(): if self.enable_torch_compile: self.enable_torch_compile = False @@ -2383,10 +2395,10 @@ def _handle_other_validations(self): "Torch compile is disabled because custom ops are not supported" ) - if self.enable_torchair_compile: - self.enable_torchair_compile = False + if self.enable_npu_torchair_compile: + self.enable_npu_torchair_compile = False logger.warning( - "TorchAir compile is disabled because custom ops are not supported" + "NPU TorchAir compile is disabled because custom ops are not supported" ) def _handle_remote_instance_weight_loader_start_seed_via_transfer_engine(self): @@ -3988,7 +4000,7 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Enable debug mode for torch compile", ) parser.add_argument( - "--enable-torchair-compile", + "--enable-npu-torchair-compile", action="store_true", help="Optimize the model with Torch Ascend Intermediate Representation compilation. Experimental feature.", ) diff --git a/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py b/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py index 4ff92ce3e101..bec22888c7aa 100644 --- a/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py +++ b/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py @@ -39,7 +39,7 @@ def setUpClass(cls): "--attention-backend", "ascend", "--disable-radix-cache", - "--enable-torchair-compile", + "--enable-npu-torchair-compile", "--watchdog-timeout", 30000, ] From 4785e12c5ea491b0dc76cd5b72c7d6c70a3ea7b9 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Thu, 18 Dec 2025 16:18:45 +0300 Subject: [PATCH 62/71] refactoring: server args text update --- python/sglang/srt/server_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 78471621f12c..8f359b992aeb 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2385,7 +2385,7 @@ def _handle_other_validations(self): ): self.enable_npu_torchair_compile = False logger.warning( - "NPU TorchAir compile is disabled, the argument is appropriate for NPU only" + "The option --enable-npu-torchair-compile is ignored, this option is available for Ascend NPU only" ) if is_npu() and not supports_custom_op(): @@ -4002,7 +4002,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--enable-npu-torchair-compile", action="store_true", - help="Optimize the model with Torch Ascend Intermediate Representation compilation. Experimental feature.", + help="Optimize the model with Torch Ascend Intermediate Representation compilation. This is only available for Ascend NPU. Experimental feature.", ) parser.add_argument( "--enable-piecewise-cuda-graph", From 48095c40fff5b78d14967e6c288801047ef33fbf Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 23 Dec 2025 19:16:15 +0300 Subject: [PATCH 63/71] refactoring: custom ops & CompilationConfig loading movements --- python/sglang/srt/compilation/compilation_config.py | 6 ------ .../{compilation => hardware_backend/npu}/custom_ops.py | 0 .../hardware_backend/npu/graph_runner/npu_graph_runner.py | 2 +- python/sglang/srt/server_args.py | 7 ++++++- 4 files changed, 7 insertions(+), 8 deletions(-) rename python/sglang/srt/{compilation => hardware_backend/npu}/custom_ops.py (100%) diff --git a/python/sglang/srt/compilation/compilation_config.py b/python/sglang/srt/compilation/compilation_config.py index 055bc2bb41da..967855ea3e91 100644 --- a/python/sglang/srt/compilation/compilation_config.py +++ b/python/sglang/srt/compilation/compilation_config.py @@ -1,6 +1,5 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.10.0/vllm/compilation/compilation_config.py -import json from typing import List @@ -33,10 +32,5 @@ def get_traced_files(self): def get_capture_sizes(self): return self.capture_sizes - @classmethod - def from_cli(cls, args) -> "CompilationConfig": - args_dict = json.loads(args) - return CompilationConfig(**args_dict) - def get_enable_debug_mode(self): return self.enable_debug_mode diff --git a/python/sglang/srt/compilation/custom_ops.py b/python/sglang/srt/hardware_backend/npu/custom_ops.py similarity index 100% rename from python/sglang/srt/compilation/custom_ops.py rename to python/sglang/srt/hardware_backend/npu/custom_ops.py diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py index d40797e54850..67dd2420ac81 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py @@ -53,7 +53,7 @@ from torch._dynamo.eval_frame import DisableContext -from sglang.srt.compilation.custom_ops import ( +from sglang.srt.hardware_backend.npu.custom_ops import ( _set_dp_buffer_len, _set_is_extend_in_batch, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index bf4e2a44525e..3c491347b33d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2423,6 +2423,11 @@ def _handle_other_validations(self): self.preferred_sampling_params ) + if self.compilation_config: + if isinstance(self.compilation_config, str): + args_dict = json.loads(self.compilation_config) + self.compilation_config = CompilationConfig(**args_dict) + def _handle_two_batch_overlap(self): if self.enable_two_batch_overlap and self.moe_a2a_backend == "none": raise ValueError( @@ -3359,7 +3364,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--compilation-config", - type=CompilationConfig.from_cli, + type=str, default=None, help="Compilation config.", ) From 7fa0424d49422eb1df6f0ead3d8df825ca071984 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 23 Dec 2025 19:43:51 +0300 Subject: [PATCH 64/71] server args quick fix for NPU --- python/sglang/srt/server_args.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3c491347b33d..3b9af02523b7 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -782,6 +782,11 @@ def _handle_missing_default_values(self): elif self.speculative_draft_model_quantization == "unquant": self.speculative_draft_model_quantization = None + if self.compilation_config: + if isinstance(self.compilation_config, str): + args_dict = json.loads(self.compilation_config) + self.compilation_config = CompilationConfig(**args_dict) + def _handle_hpu_backends(self): if self.device == "hpu": self.attention_backend = "torch_native" @@ -2423,11 +2428,6 @@ def _handle_other_validations(self): self.preferred_sampling_params ) - if self.compilation_config: - if isinstance(self.compilation_config, str): - args_dict = json.loads(self.compilation_config) - self.compilation_config = CompilationConfig(**args_dict) - def _handle_two_batch_overlap(self): if self.enable_two_batch_overlap and self.moe_a2a_backend == "none": raise ValueError( From d91ea445f42a5e386d258fd4ff95da657d71c3d5 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 30 Dec 2025 17:14:25 +0300 Subject: [PATCH 65/71] linear method fix --- .../npu/quantization/linear_method_npu.py | 11 ++++++++++- test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py index fcbdc3d1e49b..75abca9dba6f 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py @@ -154,7 +154,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module): layer.weight_data = layer.weight.data layer.weight_scale.data = torch.flatten(layer.weight_scale.data) + layer.weight_scale_data = layer.weight_scale.data layer.weight_offset.data = torch.flatten(layer.weight_offset.data) + layer.weight_offset_data = layer.weight_offset.data expanding_factor = layer.weight.data.shape[0] layer.aclnn_input_scale = torch.nn.Parameter( @@ -220,8 +222,13 @@ def create_weights( ) layer.register_parameter("weight_scale", weight_scale) + weight_offset_data = torch.empty( + (output_size_per_partition, 1), dtype=params_dtype + ) + layer.__dict__["weight_offset_data"] = weight_offset_data + weight_offset = ChannelQuantScaleParameter( - data=torch.empty((output_size_per_partition, 1), dtype=params_dtype), + data=weight_offset_data, output_dim=0, weight_loader=weight_loader, ) @@ -250,4 +257,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module): layer.weight_data = layer.weight.data layer.weight_scale.data = layer.weight_scale.data.flatten() + layer.weight_scale_data = layer.weight_scale.data layer.weight_offset.data = layer.weight_offset.data.flatten() + layer.weight_offset_data = layer.weight_offset.data diff --git a/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py b/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py index bec22888c7aa..deaf4a5e0387 100644 --- a/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py +++ b/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py @@ -39,7 +39,7 @@ def setUpClass(cls): "--attention-backend", "ascend", "--disable-radix-cache", - "--enable-npu-torchair-compile", + "--enable-torch-compile", "--watchdog-timeout", 30000, ] From c5e8ba0cbcc2994d6244f04f6cc43643b7b1ceb4 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Mon, 12 Jan 2026 14:47:14 +0300 Subject: [PATCH 66/71] main merge fix --- .../npu/graph_runner/npu_graph_runner.py | 3 +-- python/sglang/srt/models/qwen3.py | 4 ++-- python/sglang/srt/server_args.py | 14 -------------- 3 files changed, 3 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py index 67dd2420ac81..62cd16dc7a17 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py @@ -37,7 +37,6 @@ get_bool_env_var, get_compiler_backend, is_npu, - supports_custom_op, ) is_npu = is_npu() @@ -131,7 +130,7 @@ def _create_device_graph(self): def _init_dp_gathered_buffer( self, global_dp_buffer_len: int, local_dp_buffer_len: int, dp_max_padding: bool ): - if supports_custom_op() and get_global_server_args().enable_torch_compile: + if get_global_server_args().enable_torch_compile: _set_dp_buffer_len( global_dp_buffer_len, local_dp_buffer_len, dp_max_padding ) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 8e05ddf24a02..5a7ec53195aa 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -30,7 +30,7 @@ from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.utils import apply_qk_norm from sglang.srt.server_args import get_global_server_args -from sglang.srt.utils import add_prefix, is_cuda, is_npu, supports_custom_op +from sglang.srt.utils import add_prefix, is_cuda, is_npu Qwen3Config = None @@ -39,7 +39,7 @@ _is_npu = is_npu() if _is_npu: - if supports_custom_op() and get_global_server_args().enable_torch_compile: + if get_global_server_args().enable_torch_compile: from sglang.srt.hardware_backend.npu.cmo import get_weight_cache from sglang.srt.hardware_backend.npu.cmo_custom_ops import ( get_cmo_stream, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 616787baf1cd..2ab2f2dbb9c7 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -33,7 +33,6 @@ from sglang.srt.layers.attention.fla.chunk_delta_h import CHUNK_SIZE as FLA_CHUNK_SIZE from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.parser.reasoning_parser import ReasoningParser -from sglang.srt.utils import supports_custom_op from sglang.srt.utils.common import ( LORA_TARGET_ALL_MODULES, SUPPORTED_LORA_TARGET_MODULES, @@ -2535,19 +2534,6 @@ def _handle_other_validations(self): "The option --enable-npu-torchair-compile is ignored, this option is available for Ascend NPU only" ) - if is_npu() and not supports_custom_op(): - if self.enable_torch_compile: - self.enable_torch_compile = False - logger.warning( - "Torch compile is disabled because custom ops are not supported" - ) - - if self.enable_npu_torchair_compile: - self.enable_npu_torchair_compile = False - logger.warning( - "NPU TorchAir compile is disabled because custom ops are not supported" - ) - # Validate limit_mm_per_prompt modalities if self.limit_mm_data_per_request: if isinstance(self.limit_mm_data_per_request, str): From 145f25261eeefc0e1fc53b63d300e0c779a9c4cf Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Mon, 12 Jan 2026 15:15:57 +0300 Subject: [PATCH 67/71] Comments: compilation config arg documantation was extened --- python/sglang/srt/server_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 2ab2f2dbb9c7..20ce884c2e37 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -3531,7 +3531,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--compilation-config", type=str, default=None, - help="Compilation config.", + help="Represents JSON serialized instance of 'CompilationConfig' class to provide compilation details.", ) # Speculative decoding From 8db134aaaf63f71dd86e59745a269db874a40404 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 13 Jan 2026 17:11:56 +0300 Subject: [PATCH 68/71] tests & NPUGraphRunner fix --- .../srt/hardware_backend/npu/graph_runner/npu_graph_runner.py | 1 + test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py index 62cd16dc7a17..6640e7c656b3 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py @@ -146,6 +146,7 @@ def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): self.enable_torch_compile and (not self.enable_npu_torchair_compile) and (not self.compile_bs or bs in self.compile_bs) + and (bs >= get_attention_tp_size()) ): compiler = NpuGraphCompiler( model_runner=self.model_runner, diff --git a/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py b/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py index deaf4a5e0387..bec22888c7aa 100644 --- a/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py +++ b/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py @@ -39,7 +39,7 @@ def setUpClass(cls): "--attention-backend", "ascend", "--disable-radix-cache", - "--enable-torch-compile", + "--enable-npu-torchair-compile", "--watchdog-timeout", 30000, ] From ac81122d7f41ea5c4168c5fc109ded2d0505fb19 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Wed, 14 Jan 2026 14:28:14 +0300 Subject: [PATCH 69/71] tests improvements: bs is not defined & both options are possible --- python/sglang/srt/hardware_backend/npu/utils.py | 5 ----- test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py | 4 ---- 2 files changed, 9 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/utils.py b/python/sglang/srt/hardware_backend/npu/utils.py index c0913f592552..99b8e12fe9d1 100644 --- a/python/sglang/srt/hardware_backend/npu/utils.py +++ b/python/sglang/srt/hardware_backend/npu/utils.py @@ -82,11 +82,6 @@ def set_default_server_args(args: "ServerArgs"): else: args.hicache_mem_layout = "page_first_direct" - if args.enable_npu_torchair_compile and args.enable_torch_compile: - raise ValueError( - "Cannot enable both --enable-npu-torchair-compile and --enable-torch-compile" - ) - if args.disable_cuda_graph and ( args.enable_torch_compile or args.enable_npu_torchair_compile ): diff --git a/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py b/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py index 03dd0cdf80f7..cba52a7d1aa5 100644 --- a/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py +++ b/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py @@ -30,10 +30,6 @@ def test_gsm8k(self): "--mem-fraction-static", 0.7, "--enable-torch-compile", - "--cuda-graph-bs", - "128", - "--tp-size", - "1", ], ) From a61b8d6f478599b05c92bb17ec45608a00e229a3 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Thu, 15 Jan 2026 09:16:03 +0300 Subject: [PATCH 70/71] merge fix --- .../graph_runner/compilation/patch_dynamo.py | 2 +- .../npu/quantization/linear_method_npu.py | 42 +++++++++++++++---- .../modelslim/schemes/modelslim_w8a8_int8.py | 31 ++++++++++---- .../test_ascend_npu_graph_compile_tp1_bf16.py | 2 + test/srt/test_embed_interpolate_unittest.py | 5 ++- 5 files changed, 64 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/patch_dynamo.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/patch_dynamo.py index 284582f86011..7fcf14713c54 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/patch_dynamo.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/patch_dynamo.py @@ -29,7 +29,7 @@ def patch_dynamo_context(): original_disable = None -def decorators_disable(fn=None, recursive=True): +def decorators_disable(fn=None, recursive=True, **kwargs): if recursive: if fn is not None: fn = innermost_fn(fn) diff --git a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py index 3a99f6ac7c3b..635f606c53a5 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py @@ -3,7 +3,9 @@ import torch from sglang.srt.hardware_backend.npu.utils import npu_format_cast +from sglang.srt.layers.linear import MergedColumnParallelLinear, QKVParallelLinear from sglang.srt.layers.quantization.base_config import LinearMethodBase +from sglang.srt.server_args import get_global_server_args if TYPE_CHECKING: from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -23,21 +25,33 @@ class NPUW8A8Int8LinearMethod(_NPULinearMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module): layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = npu_format_cast(layer.weight.data) + layer.weight_data = layer.weight.data layer.weight_scale.data = layer.weight_scale.data.flatten() + layer.weight_scale_data = layer.weight_scale.data # Compressed-tensors format doesn't have this field if hasattr(layer, "weight_offset"): layer.weight_offset.data = layer.weight_offset.data.flatten() + layer.weight_offset_data = layer.weight_offset.data expanding_factor = layer.weight.data.shape[0] layer.aclnn_input_scale = torch.nn.Parameter( layer.input_scale.data.repeat(expanding_factor).to(device="npu"), requires_grad=False, ) - layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter( - layer.input_scale.data.repeat(expanding_factor).to(device="npu"), - requires_grad=False, - ) + prev_layer_fuse_reciprocal = isinstance( + layer, MergedColumnParallelLinear + ) or isinstance(layer, QKVParallelLinear) + if get_global_server_args().enable_torch_compile and prev_layer_fuse_reciprocal: + layer.aclnn_input_scale_reciprocal = torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor).to(device="npu"), + requires_grad=False, + ) + else: + layer.aclnn_input_scale_reciprocal = 1.0 / torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor).to(device="npu"), + requires_grad=False, + ) layer.aclnn_input_offset = torch.nn.Parameter( layer.input_offset.data.repeat(expanding_factor).to(device="npu"), requires_grad=False, @@ -53,9 +67,16 @@ def apply( original_dtype = x.dtype if original_dtype != torch.int8: + aclnn_input_scale_reciprocal = layer.aclnn_input_scale_reciprocal + if get_global_server_args().enable_torch_compile and ( + isinstance(layer, MergedColumnParallelLinear) + or isinstance(layer, QKVParallelLinear) + ): + aclnn_input_scale_reciprocal = 1.0 / aclnn_input_scale_reciprocal + x = torch.ops.npu.npu_quantize( x, - layer.aclnn_input_scale_reciprocal, + aclnn_input_scale_reciprocal, layer.aclnn_input_offset, torch.qint8, -1, @@ -66,13 +87,13 @@ def apply( if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0: quant_bias = None else: - quant_bias = layer.quant_bias + quant_bias = layer.quant_bias_data return torch.ops.npu.npu_quant_matmul( x, - layer.weight, - layer.deq_scale, + layer.weight_data, + layer.deq_scale_data, bias=quant_bias, - output_dtype=original_dtype, + output_dtype=layer.params_dtype, ) @@ -81,11 +102,14 @@ class NPUW8A8Int8DynamicLinearMethod(_NPULinearMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module): layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = npu_format_cast(layer.weight.data) + layer.weight_data = layer.weight.data layer.weight_scale.data = layer.weight_scale.data.flatten() + layer.weight_scale_data = layer.weight_scale.data # Compressed-tensors format doesn't have this field if hasattr(layer, "weight_offset"): layer.weight_offset.data = layer.weight_offset.data.flatten() + layer.weight_offset_data = layer.weight_offset.data def apply( self, diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w8a8_int8.py b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w8a8_int8.py index 16c62d551fa3..f5b7fe698121 100644 --- a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w8a8_int8.py @@ -46,25 +46,35 @@ def create_weights( weight_loader = extra_weight_attrs.get("weight_loader") output_size_per_partition = sum(output_partition_sizes) + weight_data = torch.empty( + (output_size_per_partition, input_size_per_partition), dtype=torch.int8 + ) + layer.__dict__["weight_data"] = weight_data weight = ModelWeightParameter( - data=torch.empty( - (output_size_per_partition, input_size_per_partition), dtype=torch.int8 - ), + data=weight_data, input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight", weight) + weight_scale_data = torch.empty( + (output_size_per_partition, 1), dtype=params_dtype + ) + layer.__dict__["weight_scale_data"] = weight_scale_data weight_scale = ChannelQuantScaleParameter( - data=torch.empty((output_size_per_partition, 1), dtype=params_dtype), + data=weight_scale_data, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight_scale", weight_scale) + weight_offset_data = torch.empty( + (output_size_per_partition, 1), dtype=params_dtype + ) + layer.__dict__["weight_offset_data"] = weight_offset_data weight_offset = ChannelQuantScaleParameter( - data=torch.empty((output_size_per_partition, 1), dtype=params_dtype), + data=weight_offset_data, output_dim=0, weight_loader=weight_loader, ) @@ -85,8 +95,10 @@ def create_weights( input_offset.ignore_warning = True layer.register_parameter("input_offset", input_offset) + quant_bias_data = torch.empty(output_size_per_partition, dtype=torch.int32) + layer.__dict__["quant_bias_data"] = quant_bias_data quant_bias = ChannelQuantScaleParameter( - data=torch.empty(output_size_per_partition, dtype=torch.int32), + data=quant_bias_data, output_dim=0, weight_loader=weight_loader, ) @@ -98,8 +110,13 @@ def create_weights( deq_scale_dtype = torch.int64 else: raise ValueError(f"Unsupported params_dtype: {params_dtype}") + + deq_scale_data = torch.empty( + output_size_per_partition, dtype=deq_scale_dtype + ) + layer.__dict__["deq_scale_data"] = deq_scale_data deq_scale = ChannelQuantScaleParameter( - data=torch.empty(output_size_per_partition, dtype=deq_scale_dtype), + data=deq_scale_data, output_dim=0, weight_loader=weight_loader, ) diff --git a/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py b/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py index cba52a7d1aa5..c726a31dd58d 100644 --- a/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py +++ b/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py @@ -30,6 +30,8 @@ def test_gsm8k(self): "--mem-fraction-static", 0.7, "--enable-torch-compile", + "--watchdog-timeout", + 30000, ], ) diff --git a/test/srt/test_embed_interpolate_unittest.py b/test/srt/test_embed_interpolate_unittest.py index cb09935bcc7d..b1848ba3ae2f 100644 --- a/test/srt/test_embed_interpolate_unittest.py +++ b/test/srt/test_embed_interpolate_unittest.py @@ -12,7 +12,6 @@ LinearMethodBase, UnquantizedLinearMethod, ) -from sglang.srt.models.qwen3_vl import Qwen3VLMoeVisionModel from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler @@ -65,6 +64,10 @@ def test_embed_interpolate(self): server_args=sarg, model_config=mconf, ) + + # in real pipeline, a model is imported after command line argument parsing + from sglang.srt.models.qwen3_vl import Qwen3VLMoeVisionModel + model = Qwen3VLMoeVisionModel( mconf, quant_config=None, From 467b543e22180c8d6567b8c0612a45e88cf62964 Mon Sep 17 00:00:00 2001 From: Edward Shogulin Date: Tue, 20 Jan 2026 16:46:40 +0300 Subject: [PATCH 71/71] comments: command line arg validation & custom ops --- .../srt/compilation/compilation_config.py | 2 +- .../hardware_backend/npu/cmo_custom_ops.py | 19 +++----- .../graph_runner/compilation/custom_ops.py | 45 +++++++++++-------- .../npu/graph_runner/compilation/passes.py | 18 ++++---- .../sglang/srt/hardware_backend/npu/utils.py | 5 +++ python/sglang/srt/layers/communicator.py | 4 +- 6 files changed, 48 insertions(+), 45 deletions(-) diff --git a/python/sglang/srt/compilation/compilation_config.py b/python/sglang/srt/compilation/compilation_config.py index b574d19e489e..dd9dfafe8de0 100644 --- a/python/sglang/srt/compilation/compilation_config.py +++ b/python/sglang/srt/compilation/compilation_config.py @@ -18,7 +18,7 @@ def decorator(op_func: Callable): class CompilationConfig: def __init__( self, - capture_sizes: List[int] = [], + capture_sizes: List[int] = None, compiler: str = "eager", enable_debug_mode: bool = False, ): diff --git a/python/sglang/srt/hardware_backend/npu/cmo_custom_ops.py b/python/sglang/srt/hardware_backend/npu/cmo_custom_ops.py index 0331ce4d7b2f..a3f7d0c82c06 100644 --- a/python/sglang/srt/hardware_backend/npu/cmo_custom_ops.py +++ b/python/sglang/srt/hardware_backend/npu/cmo_custom_ops.py @@ -3,7 +3,7 @@ import torch import sglang.srt.hardware_backend.npu.cmo -from sglang.srt.utils import direct_register_custom_op +from sglang.srt.utils.custom_op import register_custom_op @torch.library.custom_op("sglang::wait_cmo_stream", mutates_args=()) @@ -21,19 +21,10 @@ def get_cmo_stream() -> bool: return True -def prepare_weight_cache(handle: torch.Tensor, cache: List[torch.Tensor]) -> None: - sglang.srt.hardware_backend.npu.cmo.prepare_weight_cache(handle, cache) - - -def prepare_weight_cache_register_fake( - handle: torch.Tensor, cache: List[torch.Tensor] -) -> None: +def prepare_weight_cache_fake(handle: torch.Tensor, cache: List[torch.Tensor]) -> None: pass -direct_register_custom_op( - op_name="prepare_weight_cache", - op_func=prepare_weight_cache, - mutates_args=["handle"], - fake_impl=prepare_weight_cache_register_fake, -) +@register_custom_op(fake_impl=prepare_weight_cache_fake) +def prepare_weight_cache(handle: torch.Tensor, cache: List[torch.Tensor]) -> None: + sglang.srt.hardware_backend.npu.cmo.prepare_weight_cache(handle, cache) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/custom_ops.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/custom_ops.py index 1ade7dae46f2..f85efe08ba87 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/custom_ops.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/custom_ops.py @@ -3,9 +3,10 @@ import sgl_kernel_npu.norm.split_qkv_rmsnorm_rope as sgl_kernel_npu import torch +from sglang.srt.utils.custom_op import register_custom_op -@torch.library.custom_op("sglang::split_qkv_rmsnorm_rope", mutates_args=()) -def split_qkv_rmsnorm_rope( + +def split_qkv_rmsnorm_rope_fake( input: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor, @@ -18,23 +19,20 @@ def split_qkv_rmsnorm_rope( q_bias: torch.Tensor, k_bias: torch.Tensor, ) -> List[torch.Tensor]: - q, k, v = sgl_kernel_npu.split_qkv_rmsnorm_rope( - input, - sin, - cos, - q_weight, - k_weight, - q_hidden_size, - kv_hiddem_size, - head_dim, - eps, - q_bias, - k_bias, + # TODO: generalize shape + q = torch.empty( + (input.shape[0], q_hidden_size), dtype=input.dtype, device=input.device + ) + k = torch.empty( + (input.shape[0], kv_hiddem_size), dtype=input.dtype, device=input.device + ) + v = torch.empty( + (input.shape[0], kv_hiddem_size), dtype=input.dtype, device=input.device ) return [q, k, v] -@split_qkv_rmsnorm_rope.register_fake +@register_custom_op(fake_impl=split_qkv_rmsnorm_rope_fake) def split_qkv_rmsnorm_rope( input: torch.Tensor, sin: torch.Tensor, @@ -48,8 +46,17 @@ def split_qkv_rmsnorm_rope( q_bias: torch.Tensor, k_bias: torch.Tensor, ) -> List[torch.Tensor]: - # TODO: generalize shape - q = torch.empty((128, 2048), dtype=input.dtype, device=input.device) - k = torch.empty((128, 256), dtype=input.dtype, device=input.device) - v = torch.empty((128, 256), dtype=input.dtype, device=input.device) + q, k, v = sgl_kernel_npu.split_qkv_rmsnorm_rope( + input, + sin, + cos, + q_weight, + k_weight, + q_hidden_size, + kv_hiddem_size, + head_dim, + eps, + q_bias, + k_bias, + ) return [q, k, v] diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes.py index 3c9efb93a6d0..c033cd2df9d7 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes.py @@ -14,7 +14,9 @@ import torch -import sglang.srt.hardware_backend.npu.graph_runner.compilation.custom_ops # noqa +from sglang.srt.hardware_backend.npu.graph_runner.compilation.custom_ops import ( + split_qkv_rmsnorm_rope, +) class DivFuse: @@ -37,19 +39,19 @@ def __call__(self, graph_module: torch.fx.GraphModule): for node in list(module.graph.nodes): if node.type == torch.nn.parameter.Parameter: continue - if node.target == "copy_": + + node_target_str = str(node.target) + + if node_target_str == "copy_": copy_node = node prepare_weight_cache_default_node = None continue - if ( - copy_node - and node.target == torch.ops.sglang.prepare_weight_cache.default - ): + if copy_node and node_target_str == "sglang.prepare_weight_cache": prepare_weight_cache_default_node = node continue - if copy_node and node.target == torch.ops.npu.npu_add_rms_norm_quant: + if copy_node and node_target_str == "npu.npu_add_rms_norm_quant": arg = copy_node.args[1] if prepare_weight_cache_default_node is not None: @@ -216,7 +218,7 @@ def replacement( sin_view = sin.view(-1, 1, 1, self.head_dim) sin_contiguous = sin_view.contiguous() - split_qkv_rmsnorm_rope_default = torch.ops.sglang.split_qkv_rmsnorm_rope( + split_qkv_rmsnorm_rope_default = split_qkv_rmsnorm_rope( output_parallel, sin_contiguous, cos_contiguous, diff --git a/python/sglang/srt/hardware_backend/npu/utils.py b/python/sglang/srt/hardware_backend/npu/utils.py index 99b8e12fe9d1..2f0479ae2904 100644 --- a/python/sglang/srt/hardware_backend/npu/utils.py +++ b/python/sglang/srt/hardware_backend/npu/utils.py @@ -90,6 +90,11 @@ def set_default_server_args(args: "ServerArgs"): ) if args.compilation_config: + if not args.enable_torch_compile and not args.enable_npu_torchair_compile: + raise ValueError( + f"--compilation-config must be used only with --enable-torch-compile or --enable-npu-torchair-compile" + ) + if args.compilation_config.compiler == "npugraph": args.enable_torch_compile = True diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 3543bcbdf573..d032ce31280a 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -811,9 +811,7 @@ def _gather_hidden_states_and_residual( else: hidden_states = tensor_model_parallel_all_reduce(hidden_states) if _is_npu and context.cache is not None: - _ = torch.ops.sglang.prepare_weight_cache( - hidden_states, context.cache - ) + _ = prepare_weight_cache(hidden_states, context.cache) hidden_states, residual = layernorm(hidden_states, residual) return hidden_states, residual