diff --git a/docs/platforms/ascend_npu_pass_development.md b/docs/platforms/ascend_npu_pass_development.md new file mode 100644 index 000000000000..355896b0a7d6 --- /dev/null +++ b/docs/platforms/ascend_npu_pass_development.md @@ -0,0 +1,30 @@ +## How to transform model instances with PyTorch FX Toolkit in SGLang for NPU + +### PassManager +`PassManager` is implemented here: [PassManager](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/pass_manager.py) + + +You can explore `PassManager` usage in [`NpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py) compiler backend. + +### Pass development +There are two approaches to develop passes for SGLang NPU PassManager: + +1. Matches all possible non-overlapping sets of operators and their data dependencies with `torch.fx.replace_pattern` api. +Pass example: [NpuAddRmsNormQuantFuse](https://github.com/eshoguli/sglang/blob/3365d711fd5aa0d6191c32769163320fe41e27f2/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/w8a8_int8.py#L82). +You can find details on official FX toolkit web site: https://docs.pytorch.org/docs/stable/fx.html#subgraph-rewriting-with-replace-pattern + +2. Direct Graph Manipulation. +Pass example: [EraseCopy](https://github.com/eshoguli/sglang/blob/3365d711fd5aa0d6191c32769163320fe41e27f2/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/w8a8_int8.py#L28). +You can find details on official FX toolkit web site: https://docs.pytorch.org/docs/stable/fx.html#direct-graph-manipulation + +### Compiler backend update +After pass development you should create `PassManager` instance, add the pass and call `apply` method: +``` +def apply_passes(self, graph_module: torch.fx.GraphModule): + passManager = PassManager(graph_module) + passManager.add(NpuAddRmsNormQuantFuse) + passManager.apply() + graph_module.recompile() +``` + +You can explore [`NpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py) as example. diff --git a/docs/platforms/ascend_npu_support.rst b/docs/platforms/ascend_npu_support.rst index 26306c7c1aad..11da03d85969 100644 --- a/docs/platforms/ascend_npu_support.rst +++ b/docs/platforms/ascend_npu_support.rst @@ -8,6 +8,7 @@ Ascend NPUs ascend_npu_support_models.md ascend_npu_support_features.md ascend_npu_deepseek_example.md + ascend_npu_pass_development.md ascend_npu_qwen3_examples.md ascend_contribution_guide.md ascend_npu_best_practice.md diff --git a/python/sglang/srt/compilation/compilation_config.py b/python/sglang/srt/compilation/compilation_config.py index 0388bbedac06..dd9dfafe8de0 100644 --- a/python/sglang/srt/compilation/compilation_config.py +++ b/python/sglang/srt/compilation/compilation_config.py @@ -18,7 +18,7 @@ def decorator(op_func: Callable): class CompilationConfig: def __init__( self, - capture_sizes: List[int], + capture_sizes: List[int] = None, compiler: str = "eager", enable_debug_mode: bool = False, ): diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py index fa82245ed31e..7eb7b3562016 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_backend.py @@ -222,7 +222,6 @@ def __init__(self, model_runner: ModelRunner): self.req_to_token = model_runner.req_to_token_pool.req_to_token self.graph_mode = False self.use_fia = get_bool_env_var("ASCEND_USE_FIA", "False") - self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.speculative_num_draft_tokens = ( model_runner.server_args.speculative_num_draft_tokens ) @@ -238,6 +237,13 @@ def __init__(self, model_runner: ModelRunner): if self.use_mla: self.ringmla_mask = self.ascend_attn_mask_builder.ringmla_mask + self.enable_npu_torchair_compile = ( + model_runner.server_args.enable_npu_torchair_compile + ) + if self.enable_npu_torchair_compile: + max_total_tokens = model_runner.max_total_num_tokens + self.max_seqlen_pad = max_total_tokens // model_runner.server_args.page_size + def get_verify_buffers_to_fill_after_draft(self): """ Return buffers for verify attention kernels that needs to be filled after draft. @@ -257,12 +263,29 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): seq_lens_max = forward_batch.seq_lens.max() if forward_batch.forward_mode.is_target_verify(): seq_lens_max += self.speculative_num_draft_tokens - self.forward_metadata.block_tables = ( + + block_tables = ( forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, :seq_lens_max ][:, :: self.page_size] // self.page_size ) + + if ( + self.enable_npu_torchair_compile + and forward_batch.forward_mode.is_decode_or_idle() + ): + bs = forward_batch.input_ids.size(0) + device = forward_batch.input_ids.device + self.forward_metadata.block_tables = torch.full( + (bs, self.max_seqlen_pad), -1, dtype=torch.int32, device=device + ) + self.forward_metadata.block_tables[:, : block_tables.size(1)].copy_( + block_tables + ) + else: + self.forward_metadata.block_tables = block_tables + if forward_batch.extend_seq_lens is not None: self.forward_metadata.extend_seq_lens = forward_batch.extend_seq_lens self.forward_metadata.extend_seq_lens_cpu_int = ( @@ -1283,7 +1306,7 @@ def forward_decode( topk_indices, ) - if self.graph_mode and (not self.enable_torch_compile): + if self.graph_mode and (not self.enable_npu_torchair_compile): return self.forward_decode_graph( q, k, diff --git a/python/sglang/srt/hardware_backend/npu/cmo.py b/python/sglang/srt/hardware_backend/npu/cmo.py index 40f3b4f1696b..7b8d117735bc 100644 --- a/python/sglang/srt/hardware_backend/npu/cmo.py +++ b/python/sglang/srt/hardware_backend/npu/cmo.py @@ -1,5 +1,7 @@ import torch +from sglang.srt.layers.parameter import ModelWeightParameter + cmo_stream = None @@ -18,6 +20,12 @@ def set_cmo_stream(stream): cmo_stream = stream +def get_weight_cache(layer): + if isinstance(layer.weight, ModelWeightParameter): + return layer.weight_data + return layer.weight + + def prepare_weight_cache(handle, cache, PREFETCH_MAX_SIZE=1000000000): """ PREFETCH_MAX_SIZE: maximum size (bytes) for each prefetch operation. diff --git a/python/sglang/srt/hardware_backend/npu/cmo_custom_ops.py b/python/sglang/srt/hardware_backend/npu/cmo_custom_ops.py new file mode 100644 index 000000000000..a3f7d0c82c06 --- /dev/null +++ b/python/sglang/srt/hardware_backend/npu/cmo_custom_ops.py @@ -0,0 +1,30 @@ +from typing import List + +import torch + +import sglang.srt.hardware_backend.npu.cmo +from sglang.srt.utils.custom_op import register_custom_op + + +@torch.library.custom_op("sglang::wait_cmo_stream", mutates_args=()) +def wait_cmo_stream() -> None: + if sglang.srt.hardware_backend.npu.cmo.get_cmo_stream(): + sglang.srt.hardware_backend.npu.cmo.wait_cmo_stream() + + +@wait_cmo_stream.register_fake +def wait_cmo_stream_fake() -> None: + pass + + +def get_cmo_stream() -> bool: + return True + + +def prepare_weight_cache_fake(handle: torch.Tensor, cache: List[torch.Tensor]) -> None: + pass + + +@register_custom_op(fake_impl=prepare_weight_cache_fake) +def prepare_weight_cache(handle: torch.Tensor, cache: List[torch.Tensor]) -> None: + sglang.srt.hardware_backend.npu.cmo.prepare_weight_cache(handle, cache) diff --git a/python/sglang/srt/hardware_backend/npu/custom_ops.py b/python/sglang/srt/hardware_backend/npu/custom_ops.py new file mode 100644 index 000000000000..21ac61f22f66 --- /dev/null +++ b/python/sglang/srt/hardware_backend/npu/custom_ops.py @@ -0,0 +1,52 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import List, Optional + +import torch + +import sglang.srt.layers.dp_attention + + +@torch.library.custom_op("sglang::_set_dp_buffer_len", mutates_args=()) +def _set_dp_buffer_len( + global_dp_buffer_len: Optional[int], + num_tokens: Optional[int], + is_max_len: bool, + global_num_tokens: Optional[List[int]] = None, +) -> None: + global set_dp_buffer_len_original + sglang.srt.layers.dp_attention.set_dp_buffer_len( + global_dp_buffer_len, num_tokens, is_max_len, global_num_tokens + ) + + +@_set_dp_buffer_len.register_fake +def _set_dp_buffer_len_fake( + global_dp_buffer_len: Optional[int], + num_tokens: Optional[int], + is_max_len: bool, + global_num_tokens: Optional[List[int]] = None, +) -> None: + pass + + +@torch.library.custom_op("sglang::_set_is_extend_in_batch", mutates_args=()) +def _set_is_extend_in_batch(is_extend_in_batch: bool) -> None: + sglang.srt.layers.dp_attention.set_is_extend_in_batch(is_extend_in_batch) + + +@_set_is_extend_in_batch.register_fake +def _set_is_extend_in_batch_fake(is_extend_in_batch: bool) -> None: + pass diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/compilation_context.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/compilation_context.py new file mode 100644 index 000000000000..11a01cb5c877 --- /dev/null +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/compilation_context.py @@ -0,0 +1,20 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch_npu + + +class CompilationContext: + graph_memory_pool = None + stream: torch_npu.npu.Stream = None diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/custom_ops.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/custom_ops.py new file mode 100644 index 000000000000..f85efe08ba87 --- /dev/null +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/custom_ops.py @@ -0,0 +1,62 @@ +from typing import List + +import sgl_kernel_npu.norm.split_qkv_rmsnorm_rope as sgl_kernel_npu +import torch + +from sglang.srt.utils.custom_op import register_custom_op + + +def split_qkv_rmsnorm_rope_fake( + input: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_hidden_size: int, + kv_hiddem_size: int, + head_dim: int, + eps: float, + q_bias: torch.Tensor, + k_bias: torch.Tensor, +) -> List[torch.Tensor]: + # TODO: generalize shape + q = torch.empty( + (input.shape[0], q_hidden_size), dtype=input.dtype, device=input.device + ) + k = torch.empty( + (input.shape[0], kv_hiddem_size), dtype=input.dtype, device=input.device + ) + v = torch.empty( + (input.shape[0], kv_hiddem_size), dtype=input.dtype, device=input.device + ) + return [q, k, v] + + +@register_custom_op(fake_impl=split_qkv_rmsnorm_rope_fake) +def split_qkv_rmsnorm_rope( + input: torch.Tensor, + sin: torch.Tensor, + cos: torch.Tensor, + q_weight: torch.Tensor, + k_weight: torch.Tensor, + q_hidden_size: int, + kv_hiddem_size: int, + head_dim: int, + eps: float, + q_bias: torch.Tensor, + k_bias: torch.Tensor, +) -> List[torch.Tensor]: + q, k, v = sgl_kernel_npu.split_qkv_rmsnorm_rope( + input, + sin, + cos, + q_weight, + k_weight, + q_hidden_size, + kv_hiddem_size, + head_dim, + eps, + q_bias, + k_bias, + ) + return [q, k, v] diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler.py new file mode 100644 index 000000000000..f28059f56a82 --- /dev/null +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler.py @@ -0,0 +1,41 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch + +from sglang.srt.compilation.compilation_config import CompilationConfig +from sglang.srt.utils.common import get_compiler_backend + + +class NpuGraphCompiler: + def __init__( + self, + model_runner, + model: torch.nn.Module, + compilation_config: CompilationConfig, + batch_size: int, + ): + torch._dynamo.reset() + + if compilation_config is None: + compilation_config = CompilationConfig(compiler="npugraph") + + backend = get_compiler_backend( + compilation_config=compilation_config, model_runner=model_runner + ) + backend.init(model_runner.model_config, batch_size) + + self.compiled_callable = torch.compile( + model, fullgraph=True, dynamic=False, backend=backend + ) diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py new file mode 100644 index 000000000000..9e9c12214c68 --- /dev/null +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py @@ -0,0 +1,87 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Callable + +import torch +from torch._dynamo.eval_frame import DisableContext + +from sglang.srt.hardware_backend.npu.graph_runner.compilation.pass_manager import ( + PassManager, +) +from sglang.srt.hardware_backend.npu.graph_runner.compilation.passes import ( + DivFuse, + EraseCopy, + NpuAddRmsNormDynamicQuantFuse, + NpuAddRmsNormQuantFuse, + SplitQkvRmsnormRopeFuse, +) +from sglang.srt.layers.dp_attention import get_attention_tp_size + + +class NpuGraphCompilerBackend: + def __init__(self, model_runner): + self.model_type = model_runner.model_config.dtype + + def __call__(self, graph: torch.fx.GraphModule, example_inputs) -> Callable: + DisableContext.compiled_function_args[DisableContext.batch_size] = ( + example_inputs + ) + if self.model_type == torch.bfloat16: + self.apply_passes(graph) + return graph + + def init(self, config, batch_size: int): + config = config.hf_config + + hidden_size = config.hidden_size + + num_heads = config.num_attention_heads + num_kv_heads = config.num_key_value_heads + + head_dim = getattr(config, "head_dim", None) + self.rms_norm_eps = config.rms_norm_eps + + total_num_heads = num_heads + attn_tp_size = get_attention_tp_size() + + assert total_num_heads % attn_tp_size == 0 + num_heads = total_num_heads // attn_tp_size + total_num_kv_heads = num_kv_heads + num_kv_heads = max(1, total_num_kv_heads // attn_tp_size) + + self.head_dim = head_dim or hidden_size // total_num_heads + self.q_size = num_heads * self.head_dim + self.kv_size = num_kv_heads * self.head_dim + + self.q_shape = (batch_size, self.q_size) + self.k_shape = (batch_size, self.kv_size) + + def apply_passes(self, graph_module: torch.fx.GraphModule): + passManager = PassManager(graph_module) + passManager.add( + SplitQkvRmsnormRopeFuse, + q_size=self.q_size, + kv_size=self.kv_size, + head_dim=self.head_dim, + q_shape=self.q_shape, + k_shape=self.k_shape, + variance_epsilon=self.rms_norm_eps, + ) + passManager.add(NpuAddRmsNormQuantFuse) + passManager.add(NpuAddRmsNormDynamicQuantFuse) + passManager.add(DivFuse) + passManager.add(EraseCopy) + passManager.apply() + graph_module.recompile() diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/pass_manager.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/pass_manager.py new file mode 100644 index 000000000000..ab1ecd806dae --- /dev/null +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/pass_manager.py @@ -0,0 +1,56 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import logging + +import torch + +logger = logging.getLogger(__name__) + + +class PassManager: + def __init__(self, graph_module: torch.fx.GraphModule): + self.graph_module = graph_module + self.passes = [] + + def add(self, pass_, **kwargs): + self.passes.append((pass_, kwargs)) + + def apply(self): + updated = False + for pass_, kwargs in self.passes: + pass_instance = pass_(**kwargs) + results = [] + try: + if callable(pass_instance): + results = pass_instance(self.graph_module) + else: + results = torch.fx.replace_pattern( + self.graph_module, pass_.pattern, pass_.replacement + ) + + logger.debug( + f"PassManager::apply: pass_instance={type(pass_instance)}: results({len(results)})={results}" + ) + except Exception as e: + # pass was not applied + logger.debug( + f"PassManager::apply: pass_instance={type(pass_instance)}: ignored={e}" + ) + + if not updated: + updated = len(results) != 0 + + if updated: + self.graph_module.recompile() diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes.py new file mode 100644 index 000000000000..c033cd2df9d7 --- /dev/null +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes.py @@ -0,0 +1,239 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch + +from sglang.srt.hardware_backend.npu.graph_runner.compilation.custom_ops import ( + split_qkv_rmsnorm_rope, +) + + +class DivFuse: + def pattern(x): + y = 1.0 / x + z = 1.0 / y + return z + + def replacement(x): + return x + + +class EraseCopy: + def __call__(self, graph_module: torch.fx.GraphModule): + copy_node = None + prepare_weight_cache_default_node = None + + results = [] + for module in graph_module.modules(): + for node in list(module.graph.nodes): + if node.type == torch.nn.parameter.Parameter: + continue + + node_target_str = str(node.target) + + if node_target_str == "copy_": + copy_node = node + prepare_weight_cache_default_node = None + continue + + if copy_node and node_target_str == "sglang.prepare_weight_cache": + prepare_weight_cache_default_node = node + continue + + if copy_node and node_target_str == "npu.npu_add_rms_norm_quant": + arg = copy_node.args[1] + + if prepare_weight_cache_default_node is not None: + prepare_weight_cache_default_node.args = ( + arg, + prepare_weight_cache_default_node.args[1], + ) + + node.args = ( + node.args[0], + arg, + node.args[2], + node.args[3], + node.args[4], + ) + + module.graph.erase_node(copy_node) + + result = ( + arg, + copy_node, + prepare_weight_cache_default_node, + ) + results.append(result) + + copy_node = None + prepare_weight_cache_default_node = None + + return results + + +class NpuAddRmsNormQuantFuse: + def pattern(rms_norm_input, residual, rms_norm_weight, scale, offset, v1, v2, v3): + output = torch.ops.npu.npu_add_rms_norm( + rms_norm_input, residual, rms_norm_weight, 1e-6 + ) + out0 = output[0] + out2 = output[2] + quantized_output = torch.ops.npu.npu_quantize(out0, scale, offset, v1, v2, v3) + return quantized_output, out2 + + def replacement( + rms_norm_input, residual, rms_norm_weight, scale, offset, v1, v2, v3 + ): + output = torch.ops.npu.npu_add_rms_norm_quant( + rms_norm_input, residual, rms_norm_weight, 1.0 / scale, offset, epsilon=1e-6 + ) + quantized_output = output[0] + out2 = output[2] + return quantized_output, out2 + + +class NpuAddRmsNormDynamicQuantFuse: + def pattern(rms_norm_input, residual, rms_norm_weight): + output = torch.ops.npu.npu_add_rms_norm( + rms_norm_input, residual, rms_norm_weight, 1e-6 + ) + out0 = output[0] + out2 = output[2] + quantized_output, dynamic_scale = torch.ops.npu.npu_dynamic_quant(out0) + return quantized_output, out2, dynamic_scale + + def replacement(rms_norm_input, residual, rms_norm_weight): + output = torch.ops.npu.npu_add_rms_norm_dynamic_quant( + x1=rms_norm_input, + x2=residual, + gamma=rms_norm_weight, + epsilon=1e-6, + output_mask=[True, True], + ) + quantized_output = output[0] + out2 = output[2] + dynamic_scale = output[3] + return quantized_output, out2, dynamic_scale + + +class SplitQkvRmsnormRopeFuse: + instance = None + + def __init__( + self, + q_size: int, + kv_size: int, + head_dim: int, + q_shape, + k_shape, + variance_epsilon: float, + ): + self.q_size = q_size + self.kv_size = kv_size + self.head_dim = head_dim + self.q_shape = q_shape + self.k_shape = k_shape + self.variance_epsilon = variance_epsilon + + SplitQkvRmsnormRopeFuse.instance = self + + def pattern( + output_parallel, + q_norm_parameters_weight, + k_norm_parameters_weight, + positions, + cos_sin_cache, + ): + # pattern matching brokes if make static method as class method + self = SplitQkvRmsnormRopeFuse.instance + + split = output_parallel.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = split[0] + k = split[1] + v = split[2] + + q_by_head = q.reshape(-1, self.head_dim) + npu_rms_norm_q = torch.ops.npu.npu_rms_norm( + q_by_head, q_norm_parameters_weight, self.variance_epsilon + ) + q_by_head_1 = npu_rms_norm_q[0] + + k_by_head = k.reshape(-1, self.head_dim) + npu_rms_norm_k = torch.ops.npu.npu_rms_norm( + k_by_head, k_norm_parameters_weight, self.variance_epsilon + ) + k_by_head_1 = npu_rms_norm_k[0] + + q_1 = q_by_head_1.view(self.q_shape) + k_1 = k_by_head_1.view(self.k_shape) + + npu_mrope = torch.ops.npu.npu_mrope( + positions, + q_1, + k_1, + cos_sin_cache, + self.head_dim, + mrope_section=[0, 0, 0], + rotary_mode="half", + ) + query_out = npu_mrope[0] + key_out = npu_mrope[1] + + return v, query_out, key_out + + def replacement( + output_parallel, + q_norm_parameters_weight, + k_norm_parameters_weight, + positions, + cos_sin_cache, + ): + # pattern matching brokes if make static method as class method + self = SplitQkvRmsnormRopeFuse.instance + + flatten = positions.flatten() + cos_sin = cos_sin_cache.index_select(0, flatten) + + reshape = cos_sin.reshape(-1, 2, 64) + repeat = reshape.repeat(1, 1, 2) + chunk = repeat.chunk(2, dim=-2) + cos = chunk[0] + sin = chunk[1] + + cos_view = cos.view(-1, 1, 1, self.head_dim) + cos_contiguous = cos_view.contiguous() + + sin_view = sin.view(-1, 1, 1, self.head_dim) + sin_contiguous = sin_view.contiguous() + + split_qkv_rmsnorm_rope_default = split_qkv_rmsnorm_rope( + output_parallel, + sin_contiguous, + cos_contiguous, + q_norm_parameters_weight, + k_norm_parameters_weight, + self.q_size, + self.kv_size, + self.head_dim, + self.variance_epsilon, + q_bias=None, + k_bias=None, + ) + + q = split_qkv_rmsnorm_rope_default[0] + k = split_qkv_rmsnorm_rope_default[1] + v = split_qkv_rmsnorm_rope_default[2] + + return v, q, k diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/patch_dynamo.py b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/patch_dynamo.py new file mode 100644 index 000000000000..7fcf14713c54 --- /dev/null +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/patch_dynamo.py @@ -0,0 +1,54 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import annotations + +import torch +from torch._dynamo.decorators import skip +from torch._dynamo.eval_frame import DisableContext, innermost_fn + + +def patch_dynamo_context(): + setattr(torch._dynamo.eval_frame.DisableContext, "compiled_function_args", {}) + setattr(torch._dynamo.eval_frame.DisableContext, "compiled_function", {}) + setattr(torch._dynamo.eval_frame.DisableContext, "batch_size", None) + + +original_disable_context_call = None +original_disable = None + + +def decorators_disable(fn=None, recursive=True, **kwargs): + if recursive: + if fn is not None: + fn = innermost_fn(fn) + assert callable(fn) + + DisableContext.compiled_function[DisableContext.batch_size] = fn + return DisableContext()(fn) + return DisableContext() + else: + return skip(fn) + + +def patch_dynamo_context_call(): + global original_disable + original_disable = torch._dynamo.decorators.disable + torch._dynamo.decorators.disable = decorators_disable + + +def restore_dynamo_context_call(): + global original_disable + torch._dynamo.decorators.disable = original_disable + original_disable = None diff --git a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py index d089fcc635d5..6640e7c656b3 100644 --- a/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py +++ b/python/sglang/srt/hardware_backend/npu/graph_runner/npu_graph_runner.py @@ -26,10 +26,12 @@ import torch import sglang +import sglang.srt.model_executor.cuda_graph_runner from sglang.srt.configs.model_config import AttentionArch, is_deepseek_nsa from sglang.srt.distributed.parallel_state import GroupCoordinator from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner +from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( empty_context, get_bool_env_var, @@ -48,6 +50,20 @@ if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner +from torch._dynamo.eval_frame import DisableContext + +from sglang.srt.hardware_backend.npu.custom_ops import ( + _set_dp_buffer_len, + _set_is_extend_in_batch, +) +from sglang.srt.hardware_backend.npu.graph_runner.compilation.npu_graph_compiler import ( + NpuGraphCompiler, +) +from sglang.srt.hardware_backend.npu.graph_runner.compilation.patch_dynamo import ( + patch_dynamo_context, + patch_dynamo_context_call, + restore_dynamo_context_call, +) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors @@ -59,8 +75,13 @@ def patch_model_npu( num_tokens: int, tp_group: GroupCoordinator, ): - if enable_compile: - backend = get_compiler_backend("npugraph_ex") + compilation_config = get_global_server_args().compilation_config + if ( + enable_compile + and (compilation_config is not None) + and (compilation_config.compiler == "npugraph_ex") + ): + backend = get_compiler_backend(compilation_config=compilation_config) yield torch.compile( torch.no_grad()(model.forward), fullgraph=True, @@ -75,7 +96,17 @@ class NPUGraphRunner(CudaGraphRunner): """A NPUGraphRunner runs the forward pass of a model with npu graph and torch.compile.""" def __init__(self, model_runner: ModelRunner): + if model_runner.server_args.enable_torch_compile: + patch_dynamo_context() + sglang.srt.model_executor.cuda_graph_runner.patch_model = patch_model_npu + model_runner.attn_backend.enable_torch_compile = ( + model_runner.server_args.enable_torch_compile + ) + self.enable_npu_torchair_compile = ( + model_runner.server_args.enable_npu_torchair_compile + ) + super().__init__(model_runner) self.update_attr_name = None self.update_attr_type = None @@ -96,19 +127,77 @@ def _init_arch_map(self): def _create_device_graph(self): return torch.npu.NPUGraph() - def _capture_graph(self, graph, pool, stream, run_once_fn): - if self.enable_torch_compile: - skip_guard_context = torch.compiler.set_stance(skip_guard_eval_unsafe=True) + def _init_dp_gathered_buffer( + self, global_dp_buffer_len: int, local_dp_buffer_len: int, dp_max_padding: bool + ): + if get_global_server_args().enable_torch_compile: + _set_dp_buffer_len( + global_dp_buffer_len, local_dp_buffer_len, dp_max_padding + ) + _set_is_extend_in_batch(False) else: - skip_guard_context = empty_context() + super()._init_dp_gathered_buffer( + global_dp_buffer_len, local_dp_buffer_len, dp_max_padding + ) - with skip_guard_context, torch.npu.graph( - graph, - pool=pool, - stream=stream, - auto_dispatch_capture=True, + def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): + compilation_config = get_global_server_args().compilation_config + if ( + self.enable_torch_compile + and (not self.enable_npu_torchair_compile) + and (not self.compile_bs or bs in self.compile_bs) + and (bs >= get_attention_tp_size()) ): - out = run_once_fn() + compiler = NpuGraphCompiler( + model_runner=self.model_runner, + model=run_once_fn, + compilation_config=compilation_config, + batch_size=bs, + ) + + patch_dynamo_context_call() + DisableContext.batch_size = bs + try: + # compilation + out = compiler.compiled_callable() + + # capture function and args + out = compiler.compiled_callable() + finally: + DisableContext.batch_size = None + restore_dynamo_context_call() + + assert bs in DisableContext.compiled_function + assert DisableContext.compiled_function[bs] + assert bs in DisableContext.compiled_function_args + assert DisableContext.compiled_function_args[bs] + + compiled_function = DisableContext.compiled_function[bs] + args = DisableContext.compiled_function_args[bs] + + with torch.npu.graph( + graph, + pool=pool, + stream=stream, + auto_dispatch_capture=True, + ): + compiled_function(*args) + + else: + if self.enable_npu_torchair_compile: + skip_guard_context = torch.compiler.set_stance( + skip_guard_eval_unsafe=True + ) + else: + skip_guard_context = empty_context() + + with skip_guard_context, torch.npu.graph( + graph, + pool=pool, + stream=stream, + auto_dispatch_capture=True, + ): + out = run_once_fn() return out def _get_update_attr_name(self, model_runner, forward_batch): diff --git a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py index 3a99f6ac7c3b..635f606c53a5 100644 --- a/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py +++ b/python/sglang/srt/hardware_backend/npu/quantization/linear_method_npu.py @@ -3,7 +3,9 @@ import torch from sglang.srt.hardware_backend.npu.utils import npu_format_cast +from sglang.srt.layers.linear import MergedColumnParallelLinear, QKVParallelLinear from sglang.srt.layers.quantization.base_config import LinearMethodBase +from sglang.srt.server_args import get_global_server_args if TYPE_CHECKING: from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -23,21 +25,33 @@ class NPUW8A8Int8LinearMethod(_NPULinearMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module): layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = npu_format_cast(layer.weight.data) + layer.weight_data = layer.weight.data layer.weight_scale.data = layer.weight_scale.data.flatten() + layer.weight_scale_data = layer.weight_scale.data # Compressed-tensors format doesn't have this field if hasattr(layer, "weight_offset"): layer.weight_offset.data = layer.weight_offset.data.flatten() + layer.weight_offset_data = layer.weight_offset.data expanding_factor = layer.weight.data.shape[0] layer.aclnn_input_scale = torch.nn.Parameter( layer.input_scale.data.repeat(expanding_factor).to(device="npu"), requires_grad=False, ) - layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter( - layer.input_scale.data.repeat(expanding_factor).to(device="npu"), - requires_grad=False, - ) + prev_layer_fuse_reciprocal = isinstance( + layer, MergedColumnParallelLinear + ) or isinstance(layer, QKVParallelLinear) + if get_global_server_args().enable_torch_compile and prev_layer_fuse_reciprocal: + layer.aclnn_input_scale_reciprocal = torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor).to(device="npu"), + requires_grad=False, + ) + else: + layer.aclnn_input_scale_reciprocal = 1.0 / torch.nn.Parameter( + layer.input_scale.data.repeat(expanding_factor).to(device="npu"), + requires_grad=False, + ) layer.aclnn_input_offset = torch.nn.Parameter( layer.input_offset.data.repeat(expanding_factor).to(device="npu"), requires_grad=False, @@ -53,9 +67,16 @@ def apply( original_dtype = x.dtype if original_dtype != torch.int8: + aclnn_input_scale_reciprocal = layer.aclnn_input_scale_reciprocal + if get_global_server_args().enable_torch_compile and ( + isinstance(layer, MergedColumnParallelLinear) + or isinstance(layer, QKVParallelLinear) + ): + aclnn_input_scale_reciprocal = 1.0 / aclnn_input_scale_reciprocal + x = torch.ops.npu.npu_quantize( x, - layer.aclnn_input_scale_reciprocal, + aclnn_input_scale_reciprocal, layer.aclnn_input_offset, torch.qint8, -1, @@ -66,13 +87,13 @@ def apply( if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0: quant_bias = None else: - quant_bias = layer.quant_bias + quant_bias = layer.quant_bias_data return torch.ops.npu.npu_quant_matmul( x, - layer.weight, - layer.deq_scale, + layer.weight_data, + layer.deq_scale_data, bias=quant_bias, - output_dtype=original_dtype, + output_dtype=layer.params_dtype, ) @@ -81,11 +102,14 @@ class NPUW8A8Int8DynamicLinearMethod(_NPULinearMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module): layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = npu_format_cast(layer.weight.data) + layer.weight_data = layer.weight.data layer.weight_scale.data = layer.weight_scale.data.flatten() + layer.weight_scale_data = layer.weight_scale.data # Compressed-tensors format doesn't have this field if hasattr(layer, "weight_offset"): layer.weight_offset.data = layer.weight_offset.data.flatten() + layer.weight_offset_data = layer.weight_offset.data def apply( self, diff --git a/python/sglang/srt/hardware_backend/npu/utils.py b/python/sglang/srt/hardware_backend/npu/utils.py index 478c73d24429..2f0479ae2904 100644 --- a/python/sglang/srt/hardware_backend/npu/utils.py +++ b/python/sglang/srt/hardware_backend/npu/utils.py @@ -82,6 +82,43 @@ def set_default_server_args(args: "ServerArgs"): else: args.hicache_mem_layout = "page_first_direct" + if args.disable_cuda_graph and ( + args.enable_torch_compile or args.enable_npu_torchair_compile + ): + raise ValueError( + f"--enable-torch-compile or --enable-npu-torchair-compile is not appropriate for --disable-cuda-graph" + ) + + if args.compilation_config: + if not args.enable_torch_compile and not args.enable_npu_torchair_compile: + raise ValueError( + f"--compilation-config must be used only with --enable-torch-compile or --enable-npu-torchair-compile" + ) + + if args.compilation_config.compiler == "npugraph": + args.enable_torch_compile = True + + if args.disable_cuda_graph: + raise ValueError( + f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --disable-cuda-graph" + ) + + if args.enable_npu_torchair_compile: + raise ValueError( + f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --enable-npu-torchair-compile" + ) + + if args.compilation_config.compiler == "npugraph_ex": + args.enable_npu_torchair_compile = True + + if args.disable_cuda_graph: + raise ValueError( + f"compilation_config.compiler '{args.compilation_config.compiler}' is not appropriate for --disable-cuda-graph" + ) + + if args.enable_npu_torchair_compile: + args.enable_torch_compile = True + @_call_once def init_npu_backend(): diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 1636ed706474..80ffd63fbb24 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -77,8 +77,11 @@ from aiter.ops.triton.fused_fp8_quant import fused_rms_fp8_group_quant from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant -elif _is_npu: - from sglang.srt.hardware_backend.npu.cmo import prepare_weight_cache + +if _is_npu: + from sglang.srt.hardware_backend.npu.cmo_custom_ops import ( # noqa + prepare_weight_cache, + ) FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048 diff --git a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w8a8_int8.py b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w8a8_int8.py index 16c62d551fa3..f5b7fe698121 100644 --- a/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/modelslim/schemes/modelslim_w8a8_int8.py @@ -46,25 +46,35 @@ def create_weights( weight_loader = extra_weight_attrs.get("weight_loader") output_size_per_partition = sum(output_partition_sizes) + weight_data = torch.empty( + (output_size_per_partition, input_size_per_partition), dtype=torch.int8 + ) + layer.__dict__["weight_data"] = weight_data weight = ModelWeightParameter( - data=torch.empty( - (output_size_per_partition, input_size_per_partition), dtype=torch.int8 - ), + data=weight_data, input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight", weight) + weight_scale_data = torch.empty( + (output_size_per_partition, 1), dtype=params_dtype + ) + layer.__dict__["weight_scale_data"] = weight_scale_data weight_scale = ChannelQuantScaleParameter( - data=torch.empty((output_size_per_partition, 1), dtype=params_dtype), + data=weight_scale_data, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight_scale", weight_scale) + weight_offset_data = torch.empty( + (output_size_per_partition, 1), dtype=params_dtype + ) + layer.__dict__["weight_offset_data"] = weight_offset_data weight_offset = ChannelQuantScaleParameter( - data=torch.empty((output_size_per_partition, 1), dtype=params_dtype), + data=weight_offset_data, output_dim=0, weight_loader=weight_loader, ) @@ -85,8 +95,10 @@ def create_weights( input_offset.ignore_warning = True layer.register_parameter("input_offset", input_offset) + quant_bias_data = torch.empty(output_size_per_partition, dtype=torch.int32) + layer.__dict__["quant_bias_data"] = quant_bias_data quant_bias = ChannelQuantScaleParameter( - data=torch.empty(output_size_per_partition, dtype=torch.int32), + data=quant_bias_data, output_dim=0, weight_loader=weight_loader, ) @@ -98,8 +110,13 @@ def create_weights( deq_scale_dtype = torch.int64 else: raise ValueError(f"Unsupported params_dtype: {params_dtype}") + + deq_scale_data = torch.empty( + output_size_per_partition, dtype=deq_scale_dtype + ) + layer.__dict__["deq_scale_data"] = deq_scale_data deq_scale = ChannelQuantScaleParameter( - data=torch.empty(output_size_per_partition, dtype=deq_scale_dtype), + data=deq_scale_data, output_dim=0, weight_loader=weight_loader, ) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 8fbdf3160dce..8cf203b144da 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -291,6 +291,7 @@ def forward_npu( fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """A PyTorch-npu implementation of forward().""" + assert ( fused_set_kv_buffer_arg is None ), "fused_set_kv_buffer_arg is not supported for npu implementation" diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index b0b2ede6dbde..ec441a10bb5e 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 SGLang Team +# Copyright 2023-2025 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -532,7 +532,7 @@ def _capture_one_stream(stream_idx: Optional[int] = None): if self.enable_profile_cuda_graph: self._post_process_after_profile(prof) - def _capture_graph(self, graph, pool, stream, run_once_fn): + def _capture_graph(self, graph, pool, stream, run_once_fn, bs: int): memory_saver_adapter = TorchMemorySaverAdapter.create( enable=self.model_runner.server_args.enable_memory_saver and get_bool_env_var("SGLANG_MEMORY_SAVER_CUDA_GRAPH") @@ -549,6 +549,12 @@ def _capture_graph(self, graph, pool, stream, run_once_fn): def _create_device_graph(self): return torch.cuda.CUDAGraph() + def _init_dp_gathered_buffer( + self, global_dp_buffer_len: int, local_dp_buffer_len: int, dp_max_padding: bool + ): + set_dp_buffer_len(global_dp_buffer_len, local_dp_buffer_len, dp_max_padding) + set_is_extend_in_batch(False) + def capture_one_batch_size( self, bs: int, forward: Callable, stream_idx: Optional[int] = None ): @@ -696,12 +702,12 @@ def capture_one_batch_size( def run_once(): # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None - set_dp_buffer_len( + + self._init_dp_gathered_buffer( global_dp_buffer_len, num_tokens, forward_batch.dp_padding_mode.is_max_len(), ) - set_is_extend_in_batch(False) kwargs = {} if ( @@ -732,7 +738,7 @@ def run_once(): # Set graph pool id globally to be able to use symmetric memory set_graph_pool_id(get_global_graph_memory_pool()) out = self._capture_graph( - graph, get_global_graph_memory_pool(), stream, run_once + graph, get_global_graph_memory_pool(), stream, run_once, bs ) return graph, out diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d15a2f61125d..d096e2a57185 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -573,6 +573,7 @@ def initialize(self, min_per_gpu_memory: float): self.init_attention_backend() self.kernel_warmup() self.init_device_graphs() + elif self.device in ["npu", "cpu"]: self.init_attention_backend() self.init_device_graphs() @@ -1985,10 +1986,17 @@ def init_device_graphs(self): # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models return - if self.server_args.model_impl.lower() == ModelImpl.MINDSPORE: + if self.device not in ["cpu", "npu"] and self.server_args.disable_cuda_graph: + return + + if ( + self.device == "npu" + and self.server_args.disable_cuda_graph + and not self.server_args.enable_torch_compile + ): return - if self.device != "cpu" and self.server_args.disable_cuda_graph: + if self.server_args.model_impl.lower() == ModelImpl.MINDSPORE: return if self.device == "cpu" and not self.server_args.enable_torch_compile: @@ -1999,6 +2007,7 @@ def init_device_graphs(self): logger.info( f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" ) + graph_runners = defaultdict( lambda: CudaGraphRunner, { diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 89871ad57db4..b90f1cf7664c 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -39,9 +39,18 @@ _is_npu = is_npu() if _is_npu: - from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope - - from sglang.srt.hardware_backend.npu.cmo import get_cmo_stream, wait_cmo_stream + if get_global_server_args().enable_torch_compile: + from sglang.srt.hardware_backend.npu.cmo import get_weight_cache + from sglang.srt.hardware_backend.npu.cmo_custom_ops import ( + get_cmo_stream, + wait_cmo_stream, + ) + else: + from sglang.srt.hardware_backend.npu.cmo import ( + get_cmo_stream, + get_weight_cache, + wait_cmo_stream, + ) class Qwen3Attention(nn.Module): @@ -138,7 +147,15 @@ def __init__( ) self.alt_stream = alt_stream - def forward_prepare_native(self, positions, hidden_states): + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + if get_global_server_args().rl_on_policy_target is not None: + hidden_states = hidden_states.bfloat16() + qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = apply_qk_norm( @@ -150,48 +167,6 @@ def forward_prepare_native(self, positions, hidden_states): alt_stream=self.alt_stream, ) q, k = self.rotary_emb(positions, q, k) - return q, k, v - - def forward_prepare_npu(self, positions, hidden_states, forward_batch): - qkv, _ = self.qkv_proj(hidden_states) - - if self.attn.layer_id == forward_batch.token_to_kv_pool.start_layer: - self.rotary_emb.get_cos_sin_with_position(positions) - q, k, v = split_qkv_rmsnorm_rope( - qkv, - self.rotary_emb.position_sin, - self.rotary_emb.position_cos, - self.q_size, - self.kv_size, - self.head_dim, - eps=self.q_norm.variance_epsilon, - q_weight=self.q_norm.weight, - k_weight=self.k_norm.weight, - q_bias=getattr(self.q_norm, "bias", None), - k_bias=getattr(self.k_norm, "bias", None), - ) - return q, k, v - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - forward_batch: ForwardBatch, - ) -> torch.Tensor: - if get_global_server_args().rl_on_policy_target is not None: - hidden_states = hidden_states.bfloat16() - - if not _is_npu: - q, k, v = self.forward_prepare_native( - positions=positions, - hidden_states=hidden_states, - ) - else: - q, k, v = self.forward_prepare_npu( - positions=positions, - hidden_states=hidden_states, - forward_batch=forward_batch, - ) if get_global_server_args().rl_on_policy_target is not None: q = q.to(torch.bfloat16) @@ -298,7 +273,10 @@ def forward( residual, forward_batch, cache=( - [self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight] + [ + get_weight_cache(self.mlp.gate_up_proj), + get_weight_cache(self.mlp.down_proj), + ] if _is_npu else None ), diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index cef7baedeec5..9cf9765ca589 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -26,6 +26,7 @@ import tempfile from typing import Any, Callable, Dict, List, Literal, Optional, Union +from sglang.srt.compilation.compilation_config import CompilationConfig from sglang.srt.connector import ConnectorType from sglang.srt.environ import ToolStrictLevel, envs from sglang.srt.function_call.function_call_parser import FunctionCallParser @@ -577,6 +578,7 @@ class ServerArgs: tbo_token_distribution_threshold: float = 0.48 enable_torch_compile: bool = False enable_piecewise_cuda_graph: bool = False + enable_npu_torchair_compile: bool = False enable_torch_compile_debug_mode: bool = False torch_compile_max_bs: int = 32 piecewise_cuda_graph_max_tokens: Optional[int] = None @@ -641,6 +643,7 @@ class ServerArgs: # FIXME: hack to reduce ITL when decode bs is small disaggregation_decode_polling_interval: int = 1 + compilation_config: Optional[CompilationConfig] = None # Encode prefill disaggregation encoder_only: bool = False language_only: bool = False @@ -852,6 +855,11 @@ def _handle_missing_default_values(self): elif self.speculative_draft_model_quantization == "unquant": self.speculative_draft_model_quantization = None + if self.compilation_config: + if isinstance(self.compilation_config, str): + args_dict = json.loads(self.compilation_config) + self.compilation_config = CompilationConfig(**args_dict) + def _handle_hpu_backends(self): if self.device == "hpu": self.attention_backend = "torch_native" @@ -2542,6 +2550,18 @@ def _handle_other_validations(self): self.disable_cuda_graph = True self.skip_server_warmup = True + if not is_npu() and ( + self.enable_npu_torchair_compile + or ( + self.compilation_config is not None + and self.compilation_config.compiler == "npugraph_ex" + ) + ): + self.enable_npu_torchair_compile = False + logger.warning( + "The option --enable-npu-torchair-compile is ignored, this option is available for Ascend NPU only" + ) + # Validate limit_mm_per_prompt modalities if self.limit_mm_data_per_request: if isinstance(self.limit_mm_data_per_request, str): @@ -3580,6 +3600,13 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Disable FlashInfer autotuning.", ) + parser.add_argument( + "--compilation-config", + type=str, + default=None, + help="Represents JSON serialized instance of 'CompilationConfig' class to provide compilation details.", + ) + # Speculative decoding parser.add_argument( "--speculative-algorithm", @@ -4253,6 +4280,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable debug mode for torch compile", ) + parser.add_argument( + "--enable-npu-torchair-compile", + action="store_true", + help="Optimize the model with Torch Ascend Intermediate Representation compilation. This is only available for Ascend NPU. Experimental feature.", + ) parser.add_argument( "--enable-piecewise-cuda-graph", action="store_true", diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index fed2a33955c0..814bcc8cee44 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -91,6 +91,7 @@ from torch.utils._contextlib import _DecoratorContextManager from typing_extensions import Literal +from sglang.srt.compilation.compilation_config import CompilationConfig from sglang.srt.environ import envs from sglang.srt.metrics.func_timer import enable_func_timer @@ -100,6 +101,17 @@ from sglang.srt.server_args import ServerArgs +if hasattr(torch, "npu") and torch.npu.is_available(): + try: + import torchair + import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce + from torchair.configs.compiler_config import CompilerConfig + + torchair_package_installed = True + except ImportError as e: + torchair_package_installed = False + + logger = logging.getLogger(__name__) @@ -1991,27 +2003,49 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]: return major, minor -def get_compiler_backend(mode=None) -> str: +@lru_cache(maxsize=1) +def get_compiler_backend( + mode: str = None, + model_runner=None, + compilation_config: CompilationConfig = None, +) -> str: if hasattr(torch, "hpu") and torch.hpu.is_available(): return "hpu_backend" if hasattr(torch, "npu") and torch.npu.is_available(): - try: - import torchair - import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce - from torchair.configs.compiler_config import CompilerConfig - except ImportError as e: - raise ImportError( - "NPU detected, but torchair package is not installed. " - "Please install torchair for torch.compile support on NPU." - ) - compiler_config = CompilerConfig() - compiler_config.mode = "max-autotune" - if mode == "npugraph_ex": + if compilation_config is None: + if not torchair_package_installed: + raise ImportError( + "NPU detected, but torchair package is not installed. " + "Please install torchair for torch.compile support on NPU." + ) + compiler_config = CompilerConfig() + compiler_config.mode = "max-autotune" if mode is None else mode + npu_backend = torchair.get_npu_backend(compiler_config=compiler_config) + return npu_backend + + if compilation_config.compiler == "npugraph_ex": + if not torchair_package_installed: + raise ImportError( + "NPU detected, but torchair package is not installed. " + "Please install torchair for torch.compile support on NPU." + ) + compiler_config = CompilerConfig() compiler_config.mode = "reduce-overhead" compiler_config.debug.run_eagerly = True - npu_backend = torchair.get_npu_backend(compiler_config=compiler_config) - return npu_backend + npu_backend = torchair.get_npu_backend(compiler_config=compiler_config) + return npu_backend + + if compilation_config.compiler == "npugraph": + from sglang.srt.hardware_backend.npu.graph_runner.compilation.npu_graph_compiler_backend import ( + NpuGraphCompilerBackend, + ) + + return NpuGraphCompilerBackend(model_runner) + + raise ValueError( + f"unrecognized compiler backend '{compilation_config.compiler}'" + ) return "inductor" diff --git a/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py b/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py index deaf4a5e0387..bec22888c7aa 100644 --- a/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py +++ b/test/srt/ascend/test_ascend_compile_graph_tp1_bf16.py @@ -39,7 +39,7 @@ def setUpClass(cls): "--attention-backend", "ascend", "--disable-radix-cache", - "--enable-torch-compile", + "--enable-npu-torchair-compile", "--watchdog-timeout", 30000, ] diff --git a/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py b/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py new file mode 100644 index 000000000000..c726a31dd58d --- /dev/null +++ b/test/srt/ascend/test_ascend_npu_graph_compile_tp1_bf16.py @@ -0,0 +1,57 @@ +import unittest +from types import SimpleNamespace +from urllib.parse import urlparse + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +DEFAULT_MODEL_NAME_FOR_TEST = "Qwen/Qwen2.5-7B-Instruct" + + +class TestAscendNpuGraphCompile(CustomTestCase): + def test_gsm8k(self): + model = DEFAULT_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + url = urlparse(base_url) + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--attention-backend", + "ascend", + "--mem-fraction-static", + 0.7, + "--enable-torch-compile", + "--watchdog-timeout", + 30000, + ], + ) + + try: + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=1319, + max_new_tokens=512, + parallel=128, + host=f"http://{url.hostname}", + port=int(url.port), + ) + + metrics = run_eval_few_shot_gsm8k(args) + self.assertGreaterEqual(metrics["accuracy"], 0.62) + self.assertLessEqual(metrics["latency"], 150) + finally: + kill_process_tree(process.pid) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index e8e0e5874b8f..c0ac38da1c42 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -141,6 +141,7 @@ TestFile("ascend/test_ascend_sampling_backend.py", 400), TestFile("ascend/test_ascend_tp1_bf16.py", 400), TestFile("ascend/test_ascend_compile_graph_tp1_bf16.py", 400), + TestFile("ascend/test_ascend_npu_graph_compile_tp1_bf16.py", 400), TestFile("ascend/test_ascend_w8a8_quantization.py", 400), TestFile("test_embed_interpolate_unittest.py", 400), ], diff --git a/test/srt/test_embed_interpolate_unittest.py b/test/srt/test_embed_interpolate_unittest.py index cb09935bcc7d..b1848ba3ae2f 100644 --- a/test/srt/test_embed_interpolate_unittest.py +++ b/test/srt/test_embed_interpolate_unittest.py @@ -12,7 +12,6 @@ LinearMethodBase, UnquantizedLinearMethod, ) -from sglang.srt.models.qwen3_vl import Qwen3VLMoeVisionModel from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler @@ -65,6 +64,10 @@ def test_embed_interpolate(self): server_args=sarg, model_config=mconf, ) + + # in real pipeline, a model is imported after command line argument parsing + from sglang.srt.models.qwen3_vl import Qwen3VLMoeVisionModel + model = Qwen3VLMoeVisionModel( mconf, quant_config=None,