Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
bbd5eee
Update backend.py
Vladimir221 Oct 28, 2025
e860194
Create npu_piecewise_backend.py
Vladimir221 Oct 28, 2025
b455f9f
Create weak_ref_tensor_npu.cpp
Vladimir221 Oct 28, 2025
5046dea
Update weak_ref_tensor_jit.py
Vladimir221 Oct 28, 2025
13aa142
Update piecewise_cuda_graph_runner.py
Vladimir221 Oct 28, 2025
418c789
NPU only support prefill compilation with 'eager' backend, added chec…
Vladimir221 Oct 28, 2025
0495e07
Update npu_piecewise_backend.py
Vladimir221 Oct 29, 2025
edcc306
Update weak_ref_tensor_jit.py
Vladimir221 Oct 29, 2025
fcca834
Merge branch 'main' into vkh/piecewise_graph_npu_support
ping1jing2 Oct 29, 2025
9bd648a
Merge branch 'main' into vkh/piecewise_graph_npu_support
Vladimir221 Oct 30, 2025
cb6df05
Merge branch 'main' into vkh/piecewise_graph_npu_support
Vladimir221 Nov 5, 2025
acee70e
Update npu_piecewise_backend.py
Vladimir221 Nov 6, 2025
bddb0a5
Update piecewise_cuda_graph_runner.py
Vladimir221 Nov 6, 2025
8bf2689
Create test_piecewise_graph_prefill.py
Vladimir221 Nov 6, 2025
891187c
Merge branch 'main' into vkh/piecewise_graph_npu_support
Vladimir221 Nov 6, 2025
b8c7395
Rename test_piecewise_graph_prefill.py to test_ascend_piecewise_graph…
Vladimir221 Nov 7, 2025
9a788bb
Update run_suite.py
Vladimir221 Nov 7, 2025
876dd7e
Merge branch 'main' into vkh/piecewise_graph_npu_support
Vladimir221 Nov 7, 2025
2a90301
Merge branch 'main' into vkh/piecewise_graph_npu_support
Vladimir221 Nov 10, 2025
52c5103
Merge branch 'main' into vkh/piecewise_graph_npu_support
ping1jing2 Nov 11, 2025
97a3271
Merge branch 'main' into vkh/piecewise_graph_npu_support
Vladimir221 Nov 11, 2025
de22de8
Merge branch 'main' into vkh/piecewise_graph_npu_support
Vladimir221 Nov 13, 2025
f18ef26
Delete python/sglang/srt/compilation/weak_ref_tensor_jit.py
Vladimir221 Nov 26, 2025
932ad83
Update backend.py
Vladimir221 Nov 26, 2025
7c7a3d3
Merge branch 'main' into vkh/piecewise_graph_npu_support
Vladimir221 Nov 26, 2025
274fc4e
Update cuda_piecewise_backend.py
Vladimir221 Nov 26, 2025
b16e645
Update npu_piecewise_backend.py
Vladimir221 Nov 26, 2025
9422f3d
Create weak_ref_tensor.py
Vladimir221 Nov 26, 2025
19019b1
Delete python/sglang/srt/compilation/weak_ref_tensor_npu.cpp
Vladimir221 Nov 26, 2025
6345a15
Merge branch 'main' into vkh/piecewise_graph_npu_support
Vladimir221 Nov 26, 2025
437839a
Merge branch 'main' into vkh/piecewise_graph_npu_support
ping1jing2 Nov 27, 2025
d4a1600
Update weak_ref_tensor.py
Vladimir221 Nov 27, 2025
d3ece7b
Merge branch 'main' into vkh/piecewise_graph_npu_support
Vladimir221 Nov 27, 2025
d5bd011
Update test_utils.py
Vladimir221 Nov 28, 2025
cd12434
Update test_ascend_piecewise_graph_prefill.py
Vladimir221 Nov 28, 2025
2241346
Merge branch 'main' into vkh/piecewise_graph_npu_support
Vladimir221 Nov 28, 2025
5bbc9bd
Merge branch 'main' into vkh/piecewise_graph_npu_support
Vladimir221 Dec 2, 2025
3d2a5ec
Merge branch 'main' into vkh/piecewise_graph_npu_support
Vladimir221 Dec 3, 2025
a9cdfd5
Merge branch 'main' into vkh/piecewise_graph_npu_support
ping1jing2 Dec 3, 2025
088d290
Update test_utils.py
Vladimir221 Dec 4, 2025
f57573f
Merge branch 'main' into vkh/piecewise_graph_npu_support
Vladimir221 Dec 4, 2025
4b4007a
Merge branch 'main' into vkh/piecewise_graph_npu_support
ping1jing2 Dec 9, 2025
9f0ede7
Update test_ascend_piecewise_graph_prefill.py
Vladimir221 Dec 9, 2025
de6b526
Merge branch 'main' into vkh/piecewise_graph_npu_support
ping1jing2 Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions python/sglang/srt/compilation/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
from sglang.srt.compilation.compilation_counter import compilation_counter
from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor
from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend
from sglang.srt.compilation.npu_piecewise_backend import NPUPiecewiseBackend
from sglang.srt.compilation.pass_manager import PostGradPassManager
from sglang.srt.utils.common import rank0_log
from sglang.srt.utils.common import is_npu, rank0_log

logger = logging.getLogger(__name__)

Expand All @@ -44,6 +45,32 @@ def make_compiler(config: CompilationConfig):
raise ValueError(f"Unknown compiler: {config.compiler}")


def make_backend(
graph: fx.GraphModule,
compile_config: CompilationConfig,
inductor_config: dict[str, Any],
graph_pool: Any,
piecewise_compile_index: int,
total_piecewise_compiles: int,
sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
sglang_backend,
):

backend_cls = CUDAPiecewiseBackend if not is_npu() else NPUPiecewiseBackend
return backend_cls(
graph,
compile_config,
inductor_config,
graph_pool,
piecewise_compile_index,
total_piecewise_compiles,
sym_shape_indices,
compiled_graph_for_general_shape,
sglang_backend,
)


class CompilerManager:
def __init__(
self,
Expand Down Expand Up @@ -302,7 +329,7 @@ def call_module(
)
)

self.module.__dict__[target] = CUDAPiecewiseBackend(
self.module.__dict__[target] = make_backend(
submod,
self.compile_config,
self.inductor_config,
Expand Down
20 changes: 2 additions & 18 deletions python/sglang/srt/compilation/cuda_piecewise_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,19 @@
import dataclasses
import logging
from contextlib import ExitStack
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional
from unittest.mock import patch

import torch
import torch.fx as fx
from sgl_kernel import weak_ref_tensor

from sglang.srt.compilation.compilation_config import CompilationConfig
from sglang.srt.compilation.compilation_counter import compilation_counter
from sglang.srt.compilation.weak_ref_tensor import weak_ref_tensors

logger = logging.getLogger(__name__)


def weak_ref_tensors(
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
"""
if isinstance(tensors, torch.Tensor):
return weak_ref_tensor(tensors)
if isinstance(tensors, list):
return [weak_ref_tensor(t) for t in tensors]
if isinstance(tensors, tuple):
return tuple(weak_ref_tensor(t) for t in tensors)
raise ValueError("Invalid type for tensors")


@dataclasses.dataclass
class ConcreteSizeEntry:
runtime_shape: int
Expand Down
109 changes: 109 additions & 0 deletions python/sglang/srt/compilation/npu_piecewise_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from contextlib import ExitStack
from typing import Any, Callable
from unittest.mock import patch

import torch
import torch.fx as fx

from sglang.srt.compilation.compilation_config import CompilationConfig
from sglang.srt.compilation.compilation_counter import compilation_counter
from sglang.srt.compilation.cuda_piecewise_backend import (
CUDAPiecewiseBackend,
weak_ref_tensors,
)


class NPUPiecewiseBackend(CUDAPiecewiseBackend):
def __init__(
self,
graph: fx.GraphModule,
compile_config: CompilationConfig,
inductor_config: dict[str, Any],
graph_pool: Any,
piecewise_compile_index: int,
total_piecewise_compiles: int,
sym_shape_indices: list[int],
compiled_graph_for_general_shape: Callable,
sglang_backend,
):
super().__init__(
graph,
compile_config,
inductor_config,
graph_pool,
piecewise_compile_index,
total_piecewise_compiles,
sym_shape_indices,
compiled_graph_for_general_shape,
sglang_backend,
)

def __call__(self, *args):
runtime_shape = args[self.sym_shape_indices[0]]
if runtime_shape not in self.concrete_size_entries:
# we don't need to do anything for this shape
return self.compiled_graph_for_general_shape(*args)

entry = self.concrete_size_entries[runtime_shape]

if entry.runnable is None:
entry.runnable = self.compiled_graph_for_general_shape

if entry.cudagraph is None:
if entry.num_finished_warmup < 1: # noqa
entry.num_finished_warmup += 1
return entry.runnable(*args)

if self.compile_config.get_enable_debug_mode():
input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
npugraph = torch.npu.NPUGraph()

with ExitStack() as stack:
if not self.is_first_graph:
# during every model forward, we will capture
# many pieces of cudagraphs (roughly one per layer).
# running gc again and again across layers will
# make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(patch("torch.npu.empty_cache", lambda: None))

# mind-exploding: carefully manage the reference and memory.
with torch.npu.graph(npugraph, pool=self.graph_pool):
# `output` is managed by pytorch's cudagraph pool
output = entry.runnable(*args)
if self.is_last_graph:
# by converting it to weak ref,
# the original `output` will immediately be released
# to save memory. It is only safe to do this for
# the last graph, because the output of the last graph
# will not be used by any other cuda graph.
output = weak_ref_tensors(output)

# here we always use weak ref for the output
# to save memory
entry.output = weak_ref_tensors(output)
entry.cudagraph = npugraph

compilation_counter.num_cudagraph_captured += 1

# important: we need to return the output, rather than
# the weak ref of the output, so that pytorch can correctly
# manage the memory during cuda graph capture
return output

if self.compile_config.get_enable_debug_mode():
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
"Input addresses for cudagraphs are different during replay."
f" Expected {entry.input_addresses}, got {new_input_addresses}"
)
entry.cudagraph.replay()
return entry.output
28 changes: 28 additions & 0 deletions python/sglang/srt/compilation/weak_ref_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any, Union

import torch

from sglang.srt.utils.common import is_cuda, is_npu

if is_cuda():
from sgl_kernel import weak_ref_tensor
elif is_npu():
from torch_npu._C import _weak_ref_tensor as weak_ref_tensor
else:
raise NotImplementedError("weak_ref_tensor is implemented only for CUDA and NPU.")


def weak_ref_tensors(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

@Vladimir221 Vladimir221 Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can align it with this, but in this case we will get code duplication for weak_ref_tensors function. Moreover NPUPiecewiseBackend is based on CUDAPiecewiseBackend to not duplicate backend class initialization, so in this case I import CUDAPiecewiseBackend class in npu_piecewise_backend.py file if the host machine doesn't have sgl_kernel package (only sgl_kernel_npu package) the import error will occur. So to make it unified I'll need to remove NPUPiecewiseBackend inheritance from CUDAPiecewiseBackend and to duplicate code from CUDAPiecewiseBackend.__init__() method. If you guess this is more proper way I can align the code with your suggestion

tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
"""
if isinstance(tensors, torch.Tensor):
return weak_ref_tensor(tensors)
if isinstance(tensors, list):
return [weak_ref_tensor(t) for t in tensors]
if isinstance(tensors, tuple):
return tuple(weak_ref_tensor(t) for t in tensors)
raise ValueError("Invalid type for tensors")
20 changes: 10 additions & 10 deletions python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
ForwardMode,
PPProxyTensors,
)
from sglang.srt.utils import get_available_gpu_memory, log_info_on_rank0
from sglang.srt.utils import get_available_gpu_memory, is_npu, log_info_on_rank0

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -303,7 +303,7 @@ def warmup_torch_compile(self):
seq_lens=torch.tensor([num_tokens], device=self.device),
next_token_logits_buffer=None,
orig_seq_lens=torch.tensor([num_tokens], device=self.device),
seq_lens_cpu=torch.tensor([num_tokens]),
seq_lens_cpu=torch.tensor([num_tokens], device="cpu"),
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,
Expand All @@ -316,9 +316,9 @@ def warmup_torch_compile(self):
extend_seq_lens=torch.tensor([num_tokens], device=self.device),
extend_prefix_lens=torch.tensor([num_tokens], device=self.device),
extend_start_loc=torch.tensor([0], device=self.device),
extend_prefix_lens_cpu=torch.tensor([num_tokens]),
extend_seq_lens_cpu=torch.tensor([num_tokens]),
extend_logprob_start_lens_cpu=torch.tensor([num_tokens]),
extend_prefix_lens_cpu=torch.tensor([num_tokens], device="cpu"),
extend_seq_lens_cpu=torch.tensor([num_tokens], device="cpu"),
extend_logprob_start_lens_cpu=torch.tensor([num_tokens], device="cpu"),
positions=torch.arange(num_tokens, device=self.device),
global_num_tokens_gpu=None,
global_num_tokens_for_logprob_gpu=None,
Expand Down Expand Up @@ -347,7 +347,7 @@ def warmup_torch_compile(self):
)

def _cache_loc_dtype(self):
return torch.int64
return torch.int64 if not is_npu() else torch.int32

def can_run(self, forward_batch: ForwardBatch):
num_tokens = len(forward_batch.input_ids)
Expand Down Expand Up @@ -432,7 +432,7 @@ def capture_one_batch_size(self, num_tokens: int):
seq_lens=torch.tensor([num_tokens], device=self.device),
next_token_logits_buffer=None,
orig_seq_lens=torch.tensor([num_tokens], device=self.device),
seq_lens_cpu=torch.tensor([num_tokens]),
seq_lens_cpu=torch.tensor([num_tokens], device="cpu"),
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,
Expand All @@ -445,9 +445,9 @@ def capture_one_batch_size(self, num_tokens: int):
extend_seq_lens=torch.tensor([num_tokens], device=self.device),
extend_prefix_lens=torch.tensor([num_tokens], device=self.device),
extend_start_loc=torch.tensor([0], device=self.device),
extend_prefix_lens_cpu=torch.tensor([num_tokens]),
extend_seq_lens_cpu=torch.tensor([num_tokens]),
extend_logprob_start_lens_cpu=torch.tensor([num_tokens]),
extend_prefix_lens_cpu=torch.tensor([num_tokens], device="cpu"),
extend_seq_lens_cpu=torch.tensor([num_tokens], device="cpu"),
extend_logprob_start_lens_cpu=torch.tensor([num_tokens], device="cpu"),
positions=positions,
global_num_tokens_gpu=None,
global_num_tokens_for_logprob_gpu=None,
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,9 @@ def __post_init__(self):
self._handle_cpu_backends()
self._handle_npu_backends()

# Handle compilation config
self._handle_compilation_cfg()

# Apply model-specific adjustments.
self._handle_model_specific_adjustments()

Expand Down Expand Up @@ -951,6 +954,15 @@ def _handle_cpu_backends(self):
self.attention_backend = "intel_amx"
self.sampling_backend = "pytorch"

def _handle_compilation_cfg(self):
# NPU platform
if is_npu() and self.piecewise_cuda_graph_compiler != "eager":
logger.warning(
"At this moment Ascend platform only support prefill graph compilation with "
"piecewise_cuda_graph_compiler='eager', change piecewise_cuda_graph_compiler to 'eager'."
)
self.piecewise_cuda_graph_compiler = "eager"

def _handle_npu_backends(self):
if self.device == "npu":
from sglang.srt.hardware_backend.npu.utils import set_default_server_args
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1996,7 +1996,7 @@ def direct_register_custom_op(

try:
my_lib.define(op_name + schema_str)
my_lib.impl(op_name, op_func, "CUDA")
my_lib.impl(op_name, op_func, "CUDA" if not is_npu() else "PrivateUse1")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this PrivateUse1 used for?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PrivateUse1 is PyTorch provided reserved dispatch key to integrate a new backend living outside pytorch/pytorch and to dispatch PyTorch functionality to custom backend kernels. Backend for NPU operators is registered via this key (https://docs.pytorch.org/tutorials/advanced/privateuseone.html)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @Vladimir221 will there CUDA device and NPU device exist in the same node? if not you can register for CUDA/NPU at the same time

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @Vladimir221 will there CUDA device and NPU device exist in the same node? if not you can register for CUDA/NPU at the same time

Do you suggest to register implementations of custom op functions for both dispatch keys and remove if statement?

my_lib.impl(op_name, op_func, "CUDA")
my_lib.impl(op_name, op_func, "PrivateUse1")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you suggest to register implementations of custom op functions for both dispatch keys and remove if statement?

From my view, yes. It might save an if-else cost

if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)
except RuntimeError as error:
Expand Down
17 changes: 8 additions & 9 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,21 +1201,20 @@ def run_bench_one_batch(model, other_args):

try:
stdout, stderr = process.communicate()
output = stdout.decode()
error = stderr.decode()
output = stdout.decode(errors="backslashreplace")
error = stderr.decode(errors="backslashreplace")
print(f"Output: {output}", flush=True)
print(f"Error: {error}", flush=True)

# Return prefill_latency, decode_throughput, decode_latency
prefill_line = output.split("\n")[-9]
decode_line = output.split("\n")[-3]
pattern = (
r"latency: (?P<latency>\d+\.\d+).*?throughput:\s*(?P<throughput>\d+\.\d+)"
)
match = re.search(pattern, prefill_line)
pattern = r"Benchmark[\s\S]*Total"
match = re.search(pattern, output)
bench_output = match[0] if match else ""
pattern = r".*?latency: (?P<latency>\d+\.\d+).*?throughput:\s*(?P<throughput>\d+\.\d+)"
match = re.search(r"Prefill." + pattern, bench_output)
if match:
prefill_latency = float(match.group("latency"))
match = re.search(pattern, decode_line)
match = re.search(r"Decode." + pattern, bench_output)
if match:
decode_latency = float(match.group("latency"))
decode_throughput = float(match.group("throughput"))
Expand Down
Loading
Loading