-
Notifications
You must be signed in to change notification settings - Fork 5.4k
[Ascend]Support of piecewise graph compilation for prefill on NPU #12287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bbd5eee
e860194
b455f9f
5046dea
13aa142
418c789
0495e07
edcc306
fcca834
9bd648a
cb6df05
acee70e
bddb0a5
8bf2689
891187c
b8c7395
9a788bb
876dd7e
2a90301
52c5103
97a3271
de22de8
f18ef26
932ad83
7c7a3d3
274fc4e
b16e645
9422f3d
19019b1
6345a15
437839a
d4a1600
d3ece7b
d5bd011
cd12434
2241346
5bbc9bd
3d2a5ec
a9cdfd5
088d290
f57573f
4b4007a
9f0ede7
de6b526
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| from contextlib import ExitStack | ||
| from typing import Any, Callable | ||
| from unittest.mock import patch | ||
|
|
||
| import torch | ||
| import torch.fx as fx | ||
|
|
||
| from sglang.srt.compilation.compilation_config import CompilationConfig | ||
| from sglang.srt.compilation.compilation_counter import compilation_counter | ||
| from sglang.srt.compilation.cuda_piecewise_backend import ( | ||
| CUDAPiecewiseBackend, | ||
| weak_ref_tensors, | ||
| ) | ||
|
|
||
|
|
||
| class NPUPiecewiseBackend(CUDAPiecewiseBackend): | ||
| def __init__( | ||
| self, | ||
| graph: fx.GraphModule, | ||
| compile_config: CompilationConfig, | ||
| inductor_config: dict[str, Any], | ||
| graph_pool: Any, | ||
| piecewise_compile_index: int, | ||
| total_piecewise_compiles: int, | ||
| sym_shape_indices: list[int], | ||
| compiled_graph_for_general_shape: Callable, | ||
| sglang_backend, | ||
| ): | ||
| super().__init__( | ||
| graph, | ||
| compile_config, | ||
| inductor_config, | ||
| graph_pool, | ||
| piecewise_compile_index, | ||
| total_piecewise_compiles, | ||
| sym_shape_indices, | ||
| compiled_graph_for_general_shape, | ||
| sglang_backend, | ||
| ) | ||
|
|
||
| def __call__(self, *args): | ||
| runtime_shape = args[self.sym_shape_indices[0]] | ||
| if runtime_shape not in self.concrete_size_entries: | ||
| # we don't need to do anything for this shape | ||
| return self.compiled_graph_for_general_shape(*args) | ||
|
|
||
| entry = self.concrete_size_entries[runtime_shape] | ||
|
|
||
| if entry.runnable is None: | ||
| entry.runnable = self.compiled_graph_for_general_shape | ||
|
|
||
| if entry.cudagraph is None: | ||
| if entry.num_finished_warmup < 1: # noqa | ||
| entry.num_finished_warmup += 1 | ||
| return entry.runnable(*args) | ||
|
|
||
| if self.compile_config.get_enable_debug_mode(): | ||
| input_addresses = [ | ||
| x.data_ptr() for x in args if isinstance(x, torch.Tensor) | ||
| ] | ||
| entry.input_addresses = input_addresses | ||
| npugraph = torch.npu.NPUGraph() | ||
|
|
||
| with ExitStack() as stack: | ||
| if not self.is_first_graph: | ||
| # during every model forward, we will capture | ||
| # many pieces of cudagraphs (roughly one per layer). | ||
| # running gc again and again across layers will | ||
| # make the cudagraph capture very slow. | ||
| # therefore, we only run gc for the first graph, | ||
| # and disable gc for the rest of the graphs. | ||
| stack.enter_context(patch("gc.collect", lambda: None)) | ||
| stack.enter_context(patch("torch.npu.empty_cache", lambda: None)) | ||
|
|
||
| # mind-exploding: carefully manage the reference and memory. | ||
| with torch.npu.graph(npugraph, pool=self.graph_pool): | ||
| # `output` is managed by pytorch's cudagraph pool | ||
| output = entry.runnable(*args) | ||
| if self.is_last_graph: | ||
| # by converting it to weak ref, | ||
| # the original `output` will immediately be released | ||
| # to save memory. It is only safe to do this for | ||
| # the last graph, because the output of the last graph | ||
| # will not be used by any other cuda graph. | ||
| output = weak_ref_tensors(output) | ||
|
|
||
| # here we always use weak ref for the output | ||
| # to save memory | ||
| entry.output = weak_ref_tensors(output) | ||
| entry.cudagraph = npugraph | ||
|
|
||
| compilation_counter.num_cudagraph_captured += 1 | ||
|
|
||
| # important: we need to return the output, rather than | ||
| # the weak ref of the output, so that pytorch can correctly | ||
| # manage the memory during cuda graph capture | ||
| return output | ||
|
|
||
| if self.compile_config.get_enable_debug_mode(): | ||
| # check if the input addresses are the same | ||
| new_input_addresses = [ | ||
| x.data_ptr() for x in args if isinstance(x, torch.Tensor) | ||
| ] | ||
| assert new_input_addresses == entry.input_addresses, ( | ||
| "Input addresses for cudagraphs are different during replay." | ||
| f" Expected {entry.input_addresses}, got {new_input_addresses}" | ||
| ) | ||
| entry.cudagraph.replay() | ||
| return entry.output |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| from typing import Any, Union | ||
|
|
||
| import torch | ||
|
|
||
| from sglang.srt.utils.common import is_cuda, is_npu | ||
|
|
||
| if is_cuda(): | ||
| from sgl_kernel import weak_ref_tensor | ||
| elif is_npu(): | ||
| from torch_npu._C import _weak_ref_tensor as weak_ref_tensor | ||
| else: | ||
| raise NotImplementedError("weak_ref_tensor is implemented only for CUDA and NPU.") | ||
|
|
||
|
|
||
| def weak_ref_tensors( | ||
| tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] | ||
| ) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: | ||
| """ | ||
| Convenience function to create weak references to tensors, | ||
| for single tensor, list of tensors or tuple of tensors. | ||
| """ | ||
| if isinstance(tensors, torch.Tensor): | ||
| return weak_ref_tensor(tensors) | ||
| if isinstance(tensors, list): | ||
| return [weak_ref_tensor(t) for t in tensors] | ||
| if isinstance(tensors, tuple): | ||
| return tuple(weak_ref_tensor(t) for t in tensors) | ||
| raise ValueError("Invalid type for tensors") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1996,7 +1996,7 @@ def direct_register_custom_op( | |
|
|
||
| try: | ||
| my_lib.define(op_name + schema_str) | ||
| my_lib.impl(op_name, op_func, "CUDA") | ||
| my_lib.impl(op_name, op_func, "CUDA" if not is_npu() else "PrivateUse1") | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hi @Vladimir221 will there CUDA device and NPU device exist in the same node? if not you can register for CUDA/NPU at the same time
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Do you suggest to register implementations of custom op functions for both dispatch keys and remove if statement? my_lib.impl(op_name, op_func, "CUDA")
my_lib.impl(op_name, op_func, "PrivateUse1")
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
From my view, yes. It might save an if-else cost |
||
| if fake_impl is not None: | ||
| my_lib._register_fake(op_name, fake_impl) | ||
| except RuntimeError as error: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can it unified with https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/compilation/cuda_piecewise_backend.py#L19
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I can align it with this, but in this case we will get code duplication for
weak_ref_tensorsfunction. MoreoverNPUPiecewiseBackendis based onCUDAPiecewiseBackendto not duplicate backend class initialization, so in this case I importCUDAPiecewiseBackendclass innpu_piecewise_backend.pyfile if the host machine doesn't havesgl_kernelpackage (onlysgl_kernel_npupackage) the import error will occur. So to make it unified I'll need to removeNPUPiecewiseBackendinheritance fromCUDAPiecewiseBackendand to duplicate code fromCUDAPiecewiseBackend.__init__()method. If you guess this is more proper way I can align the code with your suggestion