Skip to content

[Ascend]Support of piecewise graph compilation for prefill on NPU#12287

Merged
ispobock merged 44 commits intosgl-project:mainfrom
Vladimir221:vkh/piecewise_graph_npu_support
Dec 11, 2025
Merged

[Ascend]Support of piecewise graph compilation for prefill on NPU#12287
ispobock merged 44 commits intosgl-project:mainfrom
Vladimir221:vkh/piecewise_graph_npu_support

Conversation

@Vladimir221
Copy link
Copy Markdown
Contributor

@Vladimir221 Vladimir221 commented Oct 28, 2025

Motivation

Compilation of model forward at prefill speeds up inference time as already was showed at this PR: #10062 which enabled this feature for CUDA devices, current PR enables this feature for NPU devices

Modifications

Added:

  • choosing of backend for PiecewiseCompileInterpreter based on platform
  • backend for piecewise prefill compilation for NPU
  • implementation of weak_ref_tensor for NPU
  • choosing of weak_ref_tensor implementation based on platform
  • NPU only support prefill compilation with 'eager' backend, so added checking of piecewise_cuda_graph_compiler option
  • test for piecewise graph on NPU

Changed:

  • device argument to PrivateUse1 in direct_register_custom_op function if platform is NPU
  • seq_lens_cpu, extend_seq_lens_cpu, extend_prefix_lens_cpu, extend_logprob_start_lens_cpu should be allocated at cpu, so changed device type for these tensors in warmup_and_capture and capture_one_batch_size methods of PiecewiseCudaGraphRunner class
  • _cache_loc_dtype method of PiecewiseCudaGraphRunner class, NPU supports int32 type for out_cache_loc

Accuracy Tests

GSM 8K Llama-3.1-8B
Ascend 910B, tp-size=1, concurrency=128

# Without Piecewise Cuda Graph
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [01:07<00:00, 19.47it/s]
Accuracy: 0.760
Invalid: 0.001
Latency: 67.817 s
Output throughput: 1769.183 token/s
# With Piecewise Cuda Graph
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1319/1319 [00:56<00:00, 23.21it/s]
Accuracy: 0.759
Invalid: 0.001
Latency: 56.924 s
Output throughput: 2079.687 token/s

Benchmarking and Profiling

Profiling

image

Checklist

Vladimir221 and others added 6 commits October 28, 2025 18:47
Added choosing of backend for PiecewiseCompileInterpreter based on platform
Added backend for piecewise prefill compilation for NPU
Added implementation of weak_ref_tensor for npu
Added choosing of weak_ref_tensor implementation based on platform
seq_lens_cpu and extend_seq_lens_cpu should be allocated on cpu
…king of piecewise_cuda_graph_compiler option, changed device arg to PrivateUse1 in direct_register_custom_op function if platform is NPU
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Vladimir221, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the model inference capabilities by extending piecewise graph compilation to NPU devices. The core objective is to accelerate the prefill phase of inference on NPU hardware, mirroring the performance gains observed on CUDA. This is achieved through the introduction of NPU-specific backend logic, optimized memory management using NPU-compatible weak tensor references, and precise device placement for CPU-bound tensors. Additionally, the changes include a mechanism to enforce the 'eager' compilation backend for NPU, ensuring compatibility with current NPU capabilities.

Highlights

  • NPU Piecewise Compilation Support: Introduces comprehensive support for piecewise graph compilation on NPU (Neural Processing Unit) devices, extending the performance benefits previously available only for CUDA.
  • Dynamic Backend Selection: Adds a new make_backend function that intelligently selects between CUDAPiecewiseBackend and the newly implemented NPUPiecewiseBackend based on the detected platform.
  • NPU-Specific Weak Reference Tensors: Implements weak_ref_tensor functionality tailored for NPU devices in both Python and C++, crucial for efficient memory management during graph capture on NPU.
  • CPU Tensor Allocation Correction: Ensures that specific tensors, seq_lens_cpu and extend_seq_lens_cpu, are explicitly allocated on the CPU within the graph runner to prevent device mismatches and ensure correct operation.
  • NPU Compilation Backend Enforcement: Enforces the use of the 'eager' backend for prefill graph compilation when operating on NPU devices, aligning with the currently supported compilation modes for this platform.
  • Custom Operation Device Handling: Modifies the direct_register_custom_op function to correctly register custom operations for 'PrivateUse1' (NPU) devices when the NPU platform is detected.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request extends piecewise graph compilation support to NPU devices, mirroring the existing functionality for CUDA to improve prefill inference performance. The implementation introduces an NPU-specific backend and a JIT-compiled weak_ref_tensor operator. The changes are well-structured, but I've identified a couple of issues in the new NPU backend where CUDA-specific code was left behind, and a minor improvement for the JIT compilation script. Addressing these points will enhance the correctness and maintainability of the NPU support.

@ping1jing2 ping1jing2 changed the title Support of piecewise graph compilation for prefill on NPU [Ascend]Support of piecewise graph compilation for prefill on NPU Oct 29, 2025
Copy link
Copy Markdown
Contributor

@ssshinigami ssshinigami left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

try:
my_lib.define(op_name + schema_str)
my_lib.impl(op_name, op_func, "CUDA")
my_lib.impl(op_name, op_func, "CUDA" if not is_npu() else "PrivateUse1")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this PrivateUse1 used for?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PrivateUse1 is PyTorch provided reserved dispatch key to integrate a new backend living outside pytorch/pytorch and to dispatch PyTorch functionality to custom backend kernels. Backend for NPU operators is registered via this key (https://docs.pytorch.org/tutorials/advanced/privateuseone.html)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @Vladimir221 will there CUDA device and NPU device exist in the same node? if not you can register for CUDA/NPU at the same time

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @Vladimir221 will there CUDA device and NPU device exist in the same node? if not you can register for CUDA/NPU at the same time

Do you suggest to register implementations of custom op functions for both dispatch keys and remove if statement?

my_lib.impl(op_name, op_func, "CUDA")
my_lib.impl(op_name, op_func, "PrivateUse1")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you suggest to register implementations of custom op functions for both dispatch keys and remove if statement?

From my view, yes. It might save an if-else cost

@Oasis-Git
Copy link
Copy Markdown
Collaborator

Oasis-Git commented Nov 3, 2025

@Vladimir221 Thanks for your contribution. LGTM. I am wondering whether we should add related unit test?

@Vladimir221
Copy link
Copy Markdown
Contributor Author

Vladimir221 commented Nov 6, 2025

@Vladimir221 Thanks for your contribution. LGTM. I am wondering whether we should add related unit test?

@Oasis-Git Added a new one test into ascend directory

@ping1jing2 ping1jing2 marked this pull request as ready for review November 27, 2025 02:37
@ispobock
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

Copy link
Copy Markdown
Contributor

@ssshinigami ssshinigami left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ping1jing2 ping1jing2 self-assigned this Nov 28, 2025
raise NotImplementedError("weak_ref_tensor is implemented only for CUDA and NPU.")


def weak_ref_tensors(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

@Vladimir221 Vladimir221 Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can align it with this, but in this case we will get code duplication for weak_ref_tensors function. Moreover NPUPiecewiseBackend is based on CUDAPiecewiseBackend to not duplicate backend class initialization, so in this case I import CUDAPiecewiseBackend class in npu_piecewise_backend.py file if the host machine doesn't have sgl_kernel package (only sgl_kernel_npu package) the import error will occur. So to make it unified I'll need to remove NPUPiecewiseBackend inheritance from CUDAPiecewiseBackend and to duplicate code from CUDAPiecewiseBackend.__init__() method. If you guess this is more proper way I can align the code with your suggestion

@ping1jing2
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@ping1jing2
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@ping1jing2
Copy link
Copy Markdown
Collaborator

/rerun-failed-ci

@ispobock ispobock merged commit 27032ce into sgl-project:main Dec 11, 2025
307 of 331 checks passed
Prozac614 pushed a commit to Prozac614/sglang that referenced this pull request Dec 17, 2025
YChange01 pushed a commit to YChange01/sglang that referenced this pull request Jan 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants