-
Notifications
You must be signed in to change notification settings - Fork 5.1k
[NPU]ACLGraph Compilation support and PassManager with AddRmsNorm & Quantize fuse. TorchAir compiler backend support. #11104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
98 commits
Select commit
Hold shift + click to select a range
4ce70e6
NPU Graph Compilation support and PassManager with AddRmsNorm & Quantize
eshoguli b974460
pre-commit & refactoring
eshoguli 7a7bde7
pre-commit
qyqc731 d8e2dc3
Merge branch 'main' into eshogulin/pass_manager
eshoguli 9bb7751
Merge branch 'main' into eshogulin/pass_manager: fix - custom_ops.py
eshoguli 7048005
cleanup & refactoring
eshoguli eb240d9
Pass Manager fix
eshoguli 29c1d89
Compilation: refactoring
eshoguli 3e98d17
NPU Piecewise Graph
eshoguli 3d9516a
rollback
eshoguli 2c1b6fe
linter
eshoguli 55016b0
refactoring
eshoguli fbff08d
refactoring
eshoguli 3e5db77
Compilation: refactoring
eshoguli 99d4497
Merge branch 'main' into eshogulin/pass_manager
eshoguli 30da7fe
model_type check
eshoguli 36ef7e7
Merge branch 'main' into eshogulin/pass_manager
eshoguli 1808479
PiecewiseNpuGraphCompilerBackend quick fix
bcfc2c5
CompilationConfig reusage
a6a159d
--torch-compile-max-bs support
c08d076
TorchAir compilation support
XDaoHong 73f2ee9
runner selection fix: model forward usage
eshoguli 2f97641
add test for torchair
XDaoHong 7154cf4
TorchAir compilation support: refactoring
eshoguli dfaee00
NPU Piecewise Graph: refactoring
eshoguli bec1b28
Merge branch 'main' into eshogulin/pass_manager
eshoguli 253c14d
linter fix after merge commit
eshoguli 85d808e
NPUGraph compilation (fp16) & NPU Piecewise Graph tests
eshoguli 11074d9
TorchAir compilation support: refactoring 2
eshoguli 51ac4b4
Merge branch 'main' into eshogulin/pass_manager
ping1jing2 e06675b
CompilationConfig comments fix + linter fix
eshoguli 0c09c24
backend instantiation in get_compiler_backend
eshoguli 00a0b9b
Merge branch 'main' into eshogulin/pass_manager
eshoguli 0b31746
Merge remote-tracking branch 'sglang/main' into eshogulin/pass_manage…
eshoguli 3b5c83b
Merge branch 'main' into eshogulin/pass_manager
eshoguli 7eefeee
linter fix
eshoguli 8c63980
dynamo patch removing
eshoguli 2e02568
fix on main branch: compilation
eshoguli 966bbf4
Merge branch 'main' into eshogulin/pass_manager
eshoguli 14092b3
auto merge fix
eshoguli f989147
tests suit update
eshoguli bf1251d
Add npu_add_rms_norm_dynamic_quant fuse
OrangeRedeng 317174b
Merge branch 'eshogulin/pass_manager' of https://github.com/eshoguli/…
OrangeRedeng e6eb29c
NPU Graph compilation: attention architecture check
eshoguli caba95e
Add npu_add_rms_norm_dynamic_quant fuse: quick fix
eshoguli 3f87879
Qwen3 MoE compilation support for NPU
eshoguli a2046c3
Merge branch 'main' into eshogulin/pass_manager
eshoguli faea888
linter quick fix
eshoguli 85720d6
SlitQkvRmsnormRopeFuse fuse
eshoguli fd0e1e8
headers quick fix
eshoguli 58966d6
Merge branch 'main' into eshogulin/pass_manager
ping1jing2 a85ab1f
Merge branch 'main' into eshogulin/pass_manager
eshoguli 4a61b7e
lint after merge + Piecewise Graph fix
eshoguli 17f0af5
enable_torch_compile update rollback
eshoguli ad76e3c
Merge branch 'main' into eshogulin/pass_manager
eshoguli 752657c
Merge fixes: moving in accordance with refactoring & cleanup
eshoguli 3ac87be
Merge fixes: 1) updated cache prefetch support 2) ModelWeightParamete…
eshoguli b79fc0b
Merge branch 'main' into eshogulin/pass_manager
ping1jing2 105050f
cleanup
eshoguli 90caee2
Cleanup & ComilationConfig usage update & master merge refactoring
eshoguli a5e87f6
Compilation backends: model type quick fix
eshoguli daf81b2
TorchAir compilation backend: Ascend attention backend quick fix
eshoguli ea25b3f
torchair compilation test fix
eshoguli 3365d71
Merge branch 'main' into eshogulin/pass_manager
eshoguli 40389dd
Capturing compiled code issue: fix - dynamo patching
eshoguli f5424d8
comments fix
eshoguli 55a1e06
Documentation
eshoguli fd28ac6
cleanup & fuse quick fix: compilation & piecewise
eshoguli 97d654e
TorchAir support: inference fix & refactoring
eshoguli 3ce92e8
Comment fixes + refactoring
eshoguli f4dfef3
Piecewise Graph Runner refactoring
eshoguli 123e36c
PiecewiseGraph runner quick fix
eshoguli b3e2fe8
enable_torch_npugraph_ex_compile
XDaoHong ebcc846
linter fixes
eshoguli 132581a
fix after merge torchair
eshoguli 81a392c
Merge branch 'main' into eshogulin/pass_manager
eshoguli cd0770c
Piecewise Graph temporary removal
eshoguli bdc4b43
fix import comment
eshoguli 499c185
refactoring: command line arg renaming
eshoguli 4785e12
refactoring: server args text update
eshoguli 8be580b
Merge remote-tracking branch 'sglang/main' into eshogulin/pass_manager
eshoguli 5f1d3bc
Merge remote-tracking branch 'sglang' into eshogulin/pass_manager
eshoguli 48095c4
refactoring: custom ops & CompilationConfig loading movements
eshoguli 7fa0424
server args quick fix for NPU
eshoguli 8adb35d
Merge remote-tracking branch 'sglang/main' into eshogulin/pass_manage…
eshoguli 038fc19
Merge remote-tracking branch 'sglang/main' into eshogulin/pass_manage…
eshoguli ef10cec
Merge remote-tracking branch 'sglang/main' into eshogulin/pass_manager
eshoguli d91ea44
linear method fix
eshoguli dde39ad
Merge remote-tracking branch 'sglang/main' into eshogulin/pass_manage…
eshoguli c5e8ba0
main merge fix
eshoguli 145f252
Comments: compilation config arg documantation was extened
eshoguli 8db134a
tests & NPUGraphRunner fix
eshoguli ac81122
tests improvements: bs is not defined & both options are possible
eshoguli 4727ead
Merge remote-tracking branch 'sglang/main' into eshogulin/pass_manage…
eshoguli a61b8d6
merge fix
eshoguli 01d892d
Merge branch 'main' into eshogulin/pass_manager
ping1jing2 467b543
comments: command line arg validation & custom ops
eshoguli 0849202
Merge branch 'main' into eshogulin/pass_manager
ping1jing2 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| ## How to transform model instances with PyTorch FX Toolkit in SGLang for NPU | ||
|
|
||
| ### PassManager | ||
| `PassManager` is implemented here: [PassManager](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/pass_manager.py) | ||
|
|
||
|
|
||
| You can explore `PassManager` usage in [`NpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py) compiler backend. | ||
|
|
||
| ### Pass development | ||
| There are two approaches to develop passes for SGLang NPU PassManager: | ||
|
|
||
| 1. Matches all possible non-overlapping sets of operators and their data dependencies with `torch.fx.replace_pattern` api. | ||
| Pass example: [NpuAddRmsNormQuantFuse](https://github.com/eshoguli/sglang/blob/3365d711fd5aa0d6191c32769163320fe41e27f2/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/w8a8_int8.py#L82). | ||
| You can find details on official FX toolkit web site: https://docs.pytorch.org/docs/stable/fx.html#subgraph-rewriting-with-replace-pattern | ||
|
|
||
| 2. Direct Graph Manipulation. | ||
| Pass example: [EraseCopy](https://github.com/eshoguli/sglang/blob/3365d711fd5aa0d6191c32769163320fe41e27f2/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/passes/w8a8_int8.py#L28). | ||
| You can find details on official FX toolkit web site: https://docs.pytorch.org/docs/stable/fx.html#direct-graph-manipulation | ||
|
|
||
| ### Compiler backend update | ||
| After pass development you should create `PassManager` instance, add the pass and call `apply` method: | ||
| ``` | ||
| def apply_passes(self, graph_module: torch.fx.GraphModule): | ||
| passManager = PassManager(graph_module) | ||
| passManager.add(NpuAddRmsNormQuantFuse) | ||
| passManager.apply() | ||
| graph_module.recompile() | ||
| ``` | ||
|
|
||
| You can explore [`NpuGraphCompilerBackend`](https://github.com/eshoguli/sglang/blob/eshogulin/pass_manager/python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler_backend.py) as example. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| from typing import List | ||
|
|
||
| import torch | ||
|
|
||
| import sglang.srt.hardware_backend.npu.cmo | ||
| from sglang.srt.utils.custom_op import register_custom_op | ||
|
|
||
|
|
||
| @torch.library.custom_op("sglang::wait_cmo_stream", mutates_args=()) | ||
| def wait_cmo_stream() -> None: | ||
| if sglang.srt.hardware_backend.npu.cmo.get_cmo_stream(): | ||
| sglang.srt.hardware_backend.npu.cmo.wait_cmo_stream() | ||
|
|
||
|
|
||
| @wait_cmo_stream.register_fake | ||
| def wait_cmo_stream_fake() -> None: | ||
| pass | ||
|
|
||
|
|
||
| def get_cmo_stream() -> bool: | ||
| return True | ||
|
|
||
|
|
||
| def prepare_weight_cache_fake(handle: torch.Tensor, cache: List[torch.Tensor]) -> None: | ||
| pass | ||
|
|
||
|
|
||
| @register_custom_op(fake_impl=prepare_weight_cache_fake) | ||
| def prepare_weight_cache(handle: torch.Tensor, cache: List[torch.Tensor]) -> None: | ||
| sglang.srt.hardware_backend.npu.cmo.prepare_weight_cache(handle, cache) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| # Copyright 2025 SGLang Team | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| from typing import List, Optional | ||
|
|
||
| import torch | ||
|
|
||
| import sglang.srt.layers.dp_attention | ||
eshoguli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @torch.library.custom_op("sglang::_set_dp_buffer_len", mutates_args=()) | ||
| def _set_dp_buffer_len( | ||
| global_dp_buffer_len: Optional[int], | ||
| num_tokens: Optional[int], | ||
| is_max_len: bool, | ||
| global_num_tokens: Optional[List[int]] = None, | ||
| ) -> None: | ||
| global set_dp_buffer_len_original | ||
| sglang.srt.layers.dp_attention.set_dp_buffer_len( | ||
| global_dp_buffer_len, num_tokens, is_max_len, global_num_tokens | ||
| ) | ||
|
|
||
|
|
||
| @_set_dp_buffer_len.register_fake | ||
| def _set_dp_buffer_len_fake( | ||
| global_dp_buffer_len: Optional[int], | ||
| num_tokens: Optional[int], | ||
| is_max_len: bool, | ||
| global_num_tokens: Optional[List[int]] = None, | ||
| ) -> None: | ||
| pass | ||
|
|
||
|
|
||
| @torch.library.custom_op("sglang::_set_is_extend_in_batch", mutates_args=()) | ||
| def _set_is_extend_in_batch(is_extend_in_batch: bool) -> None: | ||
| sglang.srt.layers.dp_attention.set_is_extend_in_batch(is_extend_in_batch) | ||
|
|
||
|
|
||
| @_set_is_extend_in_batch.register_fake | ||
| def _set_is_extend_in_batch_fake(is_extend_in_batch: bool) -> None: | ||
| pass | ||
20 changes: 20 additions & 0 deletions
20
python/sglang/srt/hardware_backend/npu/graph_runner/compilation/compilation_context.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| # Copyright 2023-2025 SGLang Team | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| import torch_npu | ||
|
|
||
|
|
||
| class CompilationContext: | ||
| graph_memory_pool = None | ||
| stream: torch_npu.npu.Stream = None |
62 changes: 62 additions & 0 deletions
62
python/sglang/srt/hardware_backend/npu/graph_runner/compilation/custom_ops.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| from typing import List | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we move this file to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done, thanks for comment |
||
|
|
||
| import sgl_kernel_npu.norm.split_qkv_rmsnorm_rope as sgl_kernel_npu | ||
| import torch | ||
|
|
||
| from sglang.srt.utils.custom_op import register_custom_op | ||
|
|
||
|
|
||
| def split_qkv_rmsnorm_rope_fake( | ||
| input: torch.Tensor, | ||
| sin: torch.Tensor, | ||
| cos: torch.Tensor, | ||
| q_weight: torch.Tensor, | ||
| k_weight: torch.Tensor, | ||
| q_hidden_size: int, | ||
| kv_hiddem_size: int, | ||
| head_dim: int, | ||
| eps: float, | ||
| q_bias: torch.Tensor, | ||
| k_bias: torch.Tensor, | ||
| ) -> List[torch.Tensor]: | ||
| # TODO: generalize shape | ||
| q = torch.empty( | ||
| (input.shape[0], q_hidden_size), dtype=input.dtype, device=input.device | ||
| ) | ||
| k = torch.empty( | ||
| (input.shape[0], kv_hiddem_size), dtype=input.dtype, device=input.device | ||
| ) | ||
| v = torch.empty( | ||
| (input.shape[0], kv_hiddem_size), dtype=input.dtype, device=input.device | ||
| ) | ||
| return [q, k, v] | ||
|
|
||
|
|
||
| @register_custom_op(fake_impl=split_qkv_rmsnorm_rope_fake) | ||
| def split_qkv_rmsnorm_rope( | ||
| input: torch.Tensor, | ||
| sin: torch.Tensor, | ||
| cos: torch.Tensor, | ||
| q_weight: torch.Tensor, | ||
| k_weight: torch.Tensor, | ||
| q_hidden_size: int, | ||
| kv_hiddem_size: int, | ||
| head_dim: int, | ||
| eps: float, | ||
| q_bias: torch.Tensor, | ||
| k_bias: torch.Tensor, | ||
| ) -> List[torch.Tensor]: | ||
| q, k, v = sgl_kernel_npu.split_qkv_rmsnorm_rope( | ||
| input, | ||
| sin, | ||
| cos, | ||
| q_weight, | ||
| k_weight, | ||
| q_hidden_size, | ||
| kv_hiddem_size, | ||
| head_dim, | ||
| eps, | ||
| q_bias, | ||
| k_bias, | ||
| ) | ||
| return [q, k, v] | ||
41 changes: 41 additions & 0 deletions
41
python/sglang/srt/hardware_backend/npu/graph_runner/compilation/npu_graph_compiler.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| # Copyright 2025 SGLang Team | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
|
|
||
| import torch | ||
|
|
||
| from sglang.srt.compilation.compilation_config import CompilationConfig | ||
| from sglang.srt.utils.common import get_compiler_backend | ||
|
|
||
|
|
||
| class NpuGraphCompiler: | ||
| def __init__( | ||
| self, | ||
| model_runner, | ||
| model: torch.nn.Module, | ||
| compilation_config: CompilationConfig, | ||
| batch_size: int, | ||
| ): | ||
| torch._dynamo.reset() | ||
|
|
||
| if compilation_config is None: | ||
| compilation_config = CompilationConfig(compiler="npugraph") | ||
|
|
||
| backend = get_compiler_backend( | ||
| compilation_config=compilation_config, model_runner=model_runner | ||
| ) | ||
| backend.init(model_runner.model_config, batch_size) | ||
|
|
||
| self.compiled_callable = torch.compile( | ||
| model, fullgraph=True, dynamic=False, backend=backend | ||
| ) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.