diff --git a/python/sglang/srt/compilation/backend.py b/python/sglang/srt/compilation/backend.py index 9892e4725898..08c1814157c4 100644 --- a/python/sglang/srt/compilation/backend.py +++ b/python/sglang/srt/compilation/backend.py @@ -19,8 +19,9 @@ from sglang.srt.compilation.compilation_counter import compilation_counter from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend +from sglang.srt.compilation.npu_piecewise_backend import NPUPiecewiseBackend from sglang.srt.compilation.pass_manager import PostGradPassManager -from sglang.srt.utils.common import rank0_log +from sglang.srt.utils.common import is_npu, rank0_log logger = logging.getLogger(__name__) @@ -44,6 +45,32 @@ def make_compiler(config: CompilationConfig): raise ValueError(f"Unknown compiler: {config.compiler}") +def make_backend( + graph: fx.GraphModule, + compile_config: CompilationConfig, + inductor_config: dict[str, Any], + graph_pool: Any, + piecewise_compile_index: int, + total_piecewise_compiles: int, + sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + sglang_backend, +): + + backend_cls = CUDAPiecewiseBackend if not is_npu() else NPUPiecewiseBackend + return backend_cls( + graph, + compile_config, + inductor_config, + graph_pool, + piecewise_compile_index, + total_piecewise_compiles, + sym_shape_indices, + compiled_graph_for_general_shape, + sglang_backend, + ) + + class CompilerManager: def __init__( self, @@ -302,7 +329,7 @@ def call_module( ) ) - self.module.__dict__[target] = CUDAPiecewiseBackend( + self.module.__dict__[target] = make_backend( submod, self.compile_config, self.inductor_config, diff --git a/python/sglang/srt/compilation/cuda_piecewise_backend.py b/python/sglang/srt/compilation/cuda_piecewise_backend.py index 2e45d34d3e64..dd8b358b7b30 100644 --- a/python/sglang/srt/compilation/cuda_piecewise_backend.py +++ b/python/sglang/srt/compilation/cuda_piecewise_backend.py @@ -3,35 +3,19 @@ import dataclasses import logging from contextlib import ExitStack -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional from unittest.mock import patch import torch import torch.fx as fx -from sgl_kernel import weak_ref_tensor from sglang.srt.compilation.compilation_config import CompilationConfig from sglang.srt.compilation.compilation_counter import compilation_counter +from sglang.srt.compilation.weak_ref_tensor import weak_ref_tensors logger = logging.getLogger(__name__) -def weak_ref_tensors( - tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] -) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: - """ - Convenience function to create weak references to tensors, - for single tensor, list of tensors or tuple of tensors. - """ - if isinstance(tensors, torch.Tensor): - return weak_ref_tensor(tensors) - if isinstance(tensors, list): - return [weak_ref_tensor(t) for t in tensors] - if isinstance(tensors, tuple): - return tuple(weak_ref_tensor(t) for t in tensors) - raise ValueError("Invalid type for tensors") - - @dataclasses.dataclass class ConcreteSizeEntry: runtime_shape: int diff --git a/python/sglang/srt/compilation/npu_piecewise_backend.py b/python/sglang/srt/compilation/npu_piecewise_backend.py new file mode 100644 index 000000000000..dc97bd5c3f74 --- /dev/null +++ b/python/sglang/srt/compilation/npu_piecewise_backend.py @@ -0,0 +1,109 @@ +from contextlib import ExitStack +from typing import Any, Callable +from unittest.mock import patch + +import torch +import torch.fx as fx + +from sglang.srt.compilation.compilation_config import CompilationConfig +from sglang.srt.compilation.compilation_counter import compilation_counter +from sglang.srt.compilation.cuda_piecewise_backend import ( + CUDAPiecewiseBackend, + weak_ref_tensors, +) + + +class NPUPiecewiseBackend(CUDAPiecewiseBackend): + def __init__( + self, + graph: fx.GraphModule, + compile_config: CompilationConfig, + inductor_config: dict[str, Any], + graph_pool: Any, + piecewise_compile_index: int, + total_piecewise_compiles: int, + sym_shape_indices: list[int], + compiled_graph_for_general_shape: Callable, + sglang_backend, + ): + super().__init__( + graph, + compile_config, + inductor_config, + graph_pool, + piecewise_compile_index, + total_piecewise_compiles, + sym_shape_indices, + compiled_graph_for_general_shape, + sglang_backend, + ) + + def __call__(self, *args): + runtime_shape = args[self.sym_shape_indices[0]] + if runtime_shape not in self.concrete_size_entries: + # we don't need to do anything for this shape + return self.compiled_graph_for_general_shape(*args) + + entry = self.concrete_size_entries[runtime_shape] + + if entry.runnable is None: + entry.runnable = self.compiled_graph_for_general_shape + + if entry.cudagraph is None: + if entry.num_finished_warmup < 1: # noqa + entry.num_finished_warmup += 1 + return entry.runnable(*args) + + if self.compile_config.get_enable_debug_mode(): + input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + entry.input_addresses = input_addresses + npugraph = torch.npu.NPUGraph() + + with ExitStack() as stack: + if not self.is_first_graph: + # during every model forward, we will capture + # many pieces of cudagraphs (roughly one per layer). + # running gc again and again across layers will + # make the cudagraph capture very slow. + # therefore, we only run gc for the first graph, + # and disable gc for the rest of the graphs. + stack.enter_context(patch("gc.collect", lambda: None)) + stack.enter_context(patch("torch.npu.empty_cache", lambda: None)) + + # mind-exploding: carefully manage the reference and memory. + with torch.npu.graph(npugraph, pool=self.graph_pool): + # `output` is managed by pytorch's cudagraph pool + output = entry.runnable(*args) + if self.is_last_graph: + # by converting it to weak ref, + # the original `output` will immediately be released + # to save memory. It is only safe to do this for + # the last graph, because the output of the last graph + # will not be used by any other cuda graph. + output = weak_ref_tensors(output) + + # here we always use weak ref for the output + # to save memory + entry.output = weak_ref_tensors(output) + entry.cudagraph = npugraph + + compilation_counter.num_cudagraph_captured += 1 + + # important: we need to return the output, rather than + # the weak ref of the output, so that pytorch can correctly + # manage the memory during cuda graph capture + return output + + if self.compile_config.get_enable_debug_mode(): + # check if the input addresses are the same + new_input_addresses = [ + x.data_ptr() for x in args if isinstance(x, torch.Tensor) + ] + assert new_input_addresses == entry.input_addresses, ( + "Input addresses for cudagraphs are different during replay." + f" Expected {entry.input_addresses}, got {new_input_addresses}" + ) + entry.cudagraph.replay() + return entry.output diff --git a/python/sglang/srt/compilation/weak_ref_tensor.py b/python/sglang/srt/compilation/weak_ref_tensor.py new file mode 100644 index 000000000000..83cb7e64b0b9 --- /dev/null +++ b/python/sglang/srt/compilation/weak_ref_tensor.py @@ -0,0 +1,28 @@ +from typing import Any, Union + +import torch + +from sglang.srt.utils.common import is_cuda, is_npu + +if is_cuda(): + from sgl_kernel import weak_ref_tensor +elif is_npu(): + from torch_npu._C import _weak_ref_tensor as weak_ref_tensor +else: + raise NotImplementedError("weak_ref_tensor is implemented only for CUDA and NPU.") + + +def weak_ref_tensors( + tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] +) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: + """ + Convenience function to create weak references to tensors, + for single tensor, list of tensors or tuple of tensors. + """ + if isinstance(tensors, torch.Tensor): + return weak_ref_tensor(tensors) + if isinstance(tensors, list): + return [weak_ref_tensor(t) for t in tensors] + if isinstance(tensors, tuple): + return tuple(weak_ref_tensor(t) for t in tensors) + raise ValueError("Invalid type for tensors") diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index 689c21cabcea..3d54b0c7fd4a 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -51,7 +51,7 @@ ForwardMode, PPProxyTensors, ) -from sglang.srt.utils import get_available_gpu_memory, log_info_on_rank0 +from sglang.srt.utils import get_available_gpu_memory, is_npu, log_info_on_rank0 logger = logging.getLogger(__name__) @@ -303,7 +303,7 @@ def warmup_torch_compile(self): seq_lens=torch.tensor([num_tokens], device=self.device), next_token_logits_buffer=None, orig_seq_lens=torch.tensor([num_tokens], device=self.device), - seq_lens_cpu=torch.tensor([num_tokens]), + seq_lens_cpu=torch.tensor([num_tokens], device="cpu"), req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, attn_backend=self.model_runner.attn_backend, @@ -316,9 +316,9 @@ def warmup_torch_compile(self): extend_seq_lens=torch.tensor([num_tokens], device=self.device), extend_prefix_lens=torch.tensor([num_tokens], device=self.device), extend_start_loc=torch.tensor([0], device=self.device), - extend_prefix_lens_cpu=torch.tensor([num_tokens]), - extend_seq_lens_cpu=torch.tensor([num_tokens]), - extend_logprob_start_lens_cpu=torch.tensor([num_tokens]), + extend_prefix_lens_cpu=torch.tensor([num_tokens], device="cpu"), + extend_seq_lens_cpu=torch.tensor([num_tokens], device="cpu"), + extend_logprob_start_lens_cpu=torch.tensor([num_tokens], device="cpu"), positions=torch.arange(num_tokens, device=self.device), global_num_tokens_gpu=None, global_num_tokens_for_logprob_gpu=None, @@ -347,7 +347,7 @@ def warmup_torch_compile(self): ) def _cache_loc_dtype(self): - return torch.int64 + return torch.int64 if not is_npu() else torch.int32 def can_run(self, forward_batch: ForwardBatch): num_tokens = len(forward_batch.input_ids) @@ -432,7 +432,7 @@ def capture_one_batch_size(self, num_tokens: int): seq_lens=torch.tensor([num_tokens], device=self.device), next_token_logits_buffer=None, orig_seq_lens=torch.tensor([num_tokens], device=self.device), - seq_lens_cpu=torch.tensor([num_tokens]), + seq_lens_cpu=torch.tensor([num_tokens], device="cpu"), req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, attn_backend=self.model_runner.attn_backend, @@ -445,9 +445,9 @@ def capture_one_batch_size(self, num_tokens: int): extend_seq_lens=torch.tensor([num_tokens], device=self.device), extend_prefix_lens=torch.tensor([num_tokens], device=self.device), extend_start_loc=torch.tensor([0], device=self.device), - extend_prefix_lens_cpu=torch.tensor([num_tokens]), - extend_seq_lens_cpu=torch.tensor([num_tokens]), - extend_logprob_start_lens_cpu=torch.tensor([num_tokens]), + extend_prefix_lens_cpu=torch.tensor([num_tokens], device="cpu"), + extend_seq_lens_cpu=torch.tensor([num_tokens], device="cpu"), + extend_logprob_start_lens_cpu=torch.tensor([num_tokens], device="cpu"), positions=positions, global_num_tokens_gpu=None, global_num_tokens_for_logprob_gpu=None, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8fc9062ca6ea..25b8cbf15e3d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -639,6 +639,9 @@ def __post_init__(self): self._handle_cpu_backends() self._handle_npu_backends() + # Handle compilation config + self._handle_compilation_cfg() + # Apply model-specific adjustments. self._handle_model_specific_adjustments() @@ -951,6 +954,15 @@ def _handle_cpu_backends(self): self.attention_backend = "intel_amx" self.sampling_backend = "pytorch" + def _handle_compilation_cfg(self): + # NPU platform + if is_npu() and self.piecewise_cuda_graph_compiler != "eager": + logger.warning( + "At this moment Ascend platform only support prefill graph compilation with " + "piecewise_cuda_graph_compiler='eager', change piecewise_cuda_graph_compiler to 'eager'." + ) + self.piecewise_cuda_graph_compiler = "eager" + def _handle_npu_backends(self): if self.device == "npu": from sglang.srt.hardware_backend.npu.utils import set_default_server_args diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 3aa7140cb9f1..d9bd01a1fb27 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -1996,7 +1996,7 @@ def direct_register_custom_op( try: my_lib.define(op_name + schema_str) - my_lib.impl(op_name, op_func, "CUDA") + my_lib.impl(op_name, op_func, "CUDA" if not is_npu() else "PrivateUse1") if fake_impl is not None: my_lib._register_fake(op_name, fake_impl) except RuntimeError as error: diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 2c3e821be7e4..a20d192b5436 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -1201,21 +1201,20 @@ def run_bench_one_batch(model, other_args): try: stdout, stderr = process.communicate() - output = stdout.decode() - error = stderr.decode() + output = stdout.decode(errors="backslashreplace") + error = stderr.decode(errors="backslashreplace") print(f"Output: {output}", flush=True) print(f"Error: {error}", flush=True) # Return prefill_latency, decode_throughput, decode_latency - prefill_line = output.split("\n")[-9] - decode_line = output.split("\n")[-3] - pattern = ( - r"latency: (?P\d+\.\d+).*?throughput:\s*(?P\d+\.\d+)" - ) - match = re.search(pattern, prefill_line) + pattern = r"Benchmark[\s\S]*Total" + match = re.search(pattern, output) + bench_output = match[0] if match else "" + pattern = r".*?latency: (?P\d+\.\d+).*?throughput:\s*(?P\d+\.\d+)" + match = re.search(r"Prefill." + pattern, bench_output) if match: prefill_latency = float(match.group("latency")) - match = re.search(pattern, decode_line) + match = re.search(r"Decode." + pattern, bench_output) if match: decode_latency = float(match.group("latency")) decode_throughput = float(match.group("throughput")) diff --git a/test/srt/ascend/test_ascend_piecewise_graph_prefill.py b/test/srt/ascend/test_ascend_piecewise_graph_prefill.py new file mode 100644 index 000000000000..9e43ca60f74b --- /dev/null +++ b/test/srt/ascend/test_ascend_piecewise_graph_prefill.py @@ -0,0 +1,89 @@ +import unittest +from urllib.parse import urlparse + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + SimpleNamespace, + popen_launch_server, + run_bench_one_batch, +) + +MODEL = "Qwen/Qwen2.5-7B-Instruct" +GSM8K_EXP_ACCURACY = 0.84 +EXP_PREFILL_LATENCY = 0.045 +TOKENS_TO_CAPTURE = [i for i in range(128, 4096, 128)] + + +class TestPiecewiseGraphPrefillCorrectness(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = MODEL + cls.base_url = DEFAULT_URL_FOR_TEST + cls.url = urlparse(DEFAULT_URL_FOR_TEST) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--mem-fraction-static", + 0.8, + "--attention-backend", + "ascend", + "--cuda-graph-bs", + 128, + "--enable-piecewise-cuda-graph", + "--piecewise-cuda-graph-tokens", + TOKENS_TO_CAPTURE, + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + print(f"##=== Testing accuracy: {self.model} ===##") + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=1319, + max_new_tokens=512, + parallel=128, + host=f"http://{self.url.hostname}", + port=int(self.url.port), + ) + + metrics = run_eval_few_shot_gsm8k(args) + self.assertGreaterEqual( + metrics["accuracy"], + GSM8K_EXP_ACCURACY, + ) + + +class TestPiecewiseGraphPrefillBenchmark(CustomTestCase): + + def test_latency(self): + print(f"##=== Testing prefill latency: {MODEL} ===##") + prefill_latency, _, _ = run_bench_one_batch( + MODEL, + other_args=[ + "--trust-remote-code", + "--mem-fraction-static", + 0.8, + "--attention-backend", + "ascend", + "--enable-piecewise-cuda-graph", + "--piecewise-cuda-graph-tokens", + TOKENS_TO_CAPTURE, + ], + ) + self.assertLess(prefill_latency, EXP_PREFILL_LATENCY) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 7812054e5c48..52389e8c21f8 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -361,6 +361,7 @@ suite_ascend = { "per-commit-1-npu-a2": [ TestFile("ascend/test_ascend_graph_tp1_bf16.py", 400), + TestFile("ascend/test_ascend_piecewise_graph_prefill.py", 400), TestFile("ascend/test_ascend_hicache_mha.py", 400), TestFile("ascend/test_ascend_sampling_backend.py", 400), TestFile("ascend/test_ascend_tp1_bf16.py", 400),