-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Multi platform Plugin #21388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi platform Plugin #21388
Changes from all commits
c64da9e
40b9d26
c7bb8ab
b75ad74
da9afbd
2132d76
c3ffa19
4b25ba4
a850974
0730a45
34130e4
0ba5f75
70adaed
01650d9
6c89c24
3954351
42f4542
1c05d10
4679dc1
908a221
58e4c07
9f853a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -150,6 +150,7 @@ | |
| ) | ||
| from sglang.srt.model_loader.utils import set_default_torch_dtype | ||
| from sglang.srt.model_loader.weight_utils import default_weight_loader | ||
| from sglang.srt.platforms import current_platform | ||
| from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo | ||
| from sglang.srt.server_args import ( | ||
| ServerArgs, | ||
|
|
@@ -207,6 +208,8 @@ | |
| from sglang.srt.hardware_backend.npu.utils import init_npu_backend | ||
|
|
||
| init_npu_backend() | ||
| elif current_platform.is_out_of_tree(): | ||
| current_platform.init_backend() | ||
|
|
||
| MLA_ATTENTION_BACKENDS = [ | ||
| "aiter", | ||
|
|
@@ -702,6 +705,7 @@ def initialize(self, pre_model_load_memory: float): | |
| # Init routed experts capturer | ||
| self.init_routed_experts_capturer() | ||
|
|
||
| # TODO: Refactor device-specific init branches into platform interface (separate PR). | ||
| # Must be called BEFORE init_device_graphs() so CUDA graph capture | ||
| # runs with aux hidden state capture enabled. | ||
| self.init_aux_hidden_state_capture() | ||
|
|
@@ -714,6 +718,13 @@ def initialize(self, pre_model_load_memory: float): | |
| elif self.device in ["npu", "cpu"]: | ||
| self.init_attention_backend() | ||
| self.init_device_graphs() | ||
| elif current_platform.is_out_of_tree(): | ||
| self.init_attention_backend() | ||
| if current_platform.support_cuda_graph(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. kind of on the same note as another comment, but cuda graph we just mean FULL-style CUDA graph capture right? might be worth renaming... hmmm
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Totally understand the naming confusion. However, renaming touches 85+ sites across the codebase, which is a pretty
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah I am thinking that that is its own PR. please leave a todo
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added TODO comment at the OOT branch in initialize(). |
||
| self.init_device_graphs() | ||
| else: | ||
| self.graph_runner = None | ||
| self.graph_mem_usage = 0 | ||
| else: | ||
| self.graph_runner = None | ||
| self.graph_mem_usage = 0 | ||
|
|
@@ -1483,7 +1494,14 @@ def model_load_weights(model, iter): | |
| self.server_args.load_format = load_format | ||
| self.load_config = load_config | ||
|
|
||
| if recapture_cuda_graph and (self.device == "cuda" or self.device == "musa"): | ||
| if recapture_cuda_graph and ( | ||
| self.device == "cuda" | ||
| or self.device == "musa" | ||
| or ( | ||
| current_platform.is_out_of_tree() | ||
|
alexnails marked this conversation as resolved.
|
||
| and current_platform.support_cuda_graph() | ||
| ) | ||
| ): | ||
| self.init_device_graphs() | ||
|
|
||
| logger.info("Update weights end.") | ||
|
|
@@ -2532,23 +2550,29 @@ def init_device_graphs(self): | |
| tic = time.perf_counter() | ||
| before_mem = get_available_gpu_memory(self.device, self.gpu_id) | ||
| graph_backend = defaultdict( | ||
| lambda: "cuda graph", | ||
| lambda: f"{current_platform.device_name} graph", | ||
| { | ||
| "cuda": "cuda graph", | ||
| "musa": "cuda graph", | ||
| "cpu": "cpu graph", | ||
| "npu": "npu graph", | ||
| }, | ||
| ) | ||
| logger.info( | ||
| f"Capture {graph_backend[self.device]} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" | ||
| ) | ||
| graph_runners = defaultdict( | ||
| lambda: CudaGraphRunner, | ||
| { | ||
| "cpu": CPUGraphRunner, | ||
| "npu": NPUGraphRunner, | ||
| }, | ||
| ) | ||
| self.graph_runner = graph_runners[self.device](self) | ||
| if current_platform.is_out_of_tree(): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as Alex said, we need to clean this up in the next PR |
||
| GraphRunnerCls = current_platform.get_graph_runner_cls() | ||
| self.graph_runner = GraphRunnerCls(self) | ||
| else: | ||
| graph_runners = defaultdict( | ||
| lambda: CudaGraphRunner, | ||
| { | ||
| "cpu": CPUGraphRunner, | ||
| "npu": NPUGraphRunner, | ||
| }, | ||
| ) | ||
| self.graph_runner = graph_runners[self.device](self) | ||
|
|
||
| after_mem = get_available_gpu_memory(self.device, self.gpu_id) | ||
| self.graph_mem_usage = before_mem - after_mem | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -282,7 +282,63 @@ def _init_pools(self: ModelRunner): | |
|
|
||
| # Initialize token_to_kv_pool | ||
| is_nsa_model = is_deepseek_nsa(self.model_config.hf_config) | ||
| if self.server_args.attention_backend == "ascend" and not self.mambaish_config: | ||
|
|
||
| # Check out-of-tree platform (plugin system) first | ||
| from sglang.srt.platforms import current_platform | ||
|
|
||
| if current_platform.is_out_of_tree() and not self.mambaish_config: | ||
| if self.use_mla_backend and is_nsa_model: | ||
| PoolCls = current_platform.get_nsa_kv_pool_cls() | ||
| self.token_to_kv_pool = PoolCls( | ||
| self.max_total_num_tokens, | ||
| page_size=self.page_size, | ||
| dtype=self.kv_cache_dtype, | ||
| kv_lora_rank=self.model_config.kv_lora_rank, | ||
| qk_rope_head_dim=self.model_config.qk_rope_head_dim, | ||
| layer_num=self.num_effective_layers, | ||
| device=self.device, | ||
| kv_cache_dim=self.calculate_mla_kv_cache_dim(), | ||
| enable_memory_saver=self.server_args.enable_memory_saver, | ||
| start_layer=self.start_layer, | ||
| end_layer=self.end_layer, | ||
| index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config), | ||
| ) | ||
| elif self.use_mla_backend: | ||
| PoolCls = current_platform.get_mla_kv_pool_cls() | ||
| self.token_to_kv_pool = PoolCls( | ||
| self.max_total_num_tokens, | ||
| page_size=self.page_size, | ||
| dtype=self.kv_cache_dtype, | ||
| kv_lora_rank=self.model_config.kv_lora_rank, | ||
| qk_rope_head_dim=self.model_config.qk_rope_head_dim, | ||
| index_head_dim=( | ||
| self.model_config.index_head_dim if is_nsa_model else None | ||
| ), | ||
| layer_num=self.num_effective_layers, | ||
| device=self.device, | ||
| enable_memory_saver=self.server_args.enable_memory_saver, | ||
| start_layer=self.start_layer, | ||
| end_layer=self.end_layer, | ||
| ) | ||
| else: | ||
| PoolCls = current_platform.get_mha_kv_pool_cls() | ||
| self.token_to_kv_pool = PoolCls( | ||
| self.max_total_num_tokens, | ||
| page_size=self.page_size, | ||
| dtype=self.kv_cache_dtype, | ||
| head_num=self.model_config.get_num_kv_heads( | ||
| get_attention_tp_size() | ||
| ), | ||
| head_dim=self.model_config.head_dim, | ||
| layer_num=self.num_effective_layers, | ||
| device=self.device, | ||
| enable_memory_saver=self.server_args.enable_memory_saver, | ||
| start_layer=self.start_layer, | ||
| end_layer=self.end_layer, | ||
| ) | ||
|
Comment on lines
+289
to
+338
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic for initializing the different KV pool types for out-of-tree platforms involves significant code duplication, especially for the constructor arguments. This can be refactored to improve readability and maintainability by extracting common arguments into a dictionary. if current_platform.is_out_of_tree() and not self.mambaish_config:
pool_args = {
"max_total_num_tokens": self.max_total_num_tokens,
"page_size": self.page_size,
"dtype": self.kv_cache_dtype,
"layer_num": self.num_effective_layers,
"device": self.device,
"enable_memory_saver": self.server_args.enable_memory_saver,
"start_layer": self.start_layer,
"end_layer": self.end_layer,
}
if self.use_mla_backend and is_nsa_model:
PoolCls = current_platform.get_nsa_kv_pool_cls()
pool_args.update({
"kv_lora_rank": self.model_config.kv_lora_rank,
"qk_rope_head_dim": self.model_config.qk_rope_head_dim,
"kv_cache_dim": self.calculate_mla_kv_cache_dim(),
"index_head_dim": get_nsa_index_head_dim(
self.model_config.hf_config
),
})
elif self.use_mla_backend:
PoolCls = current_platform.get_mla_kv_pool_cls()
pool_args.update({
"kv_lora_rank": self.model_config.kv_lora_rank,
"qk_rope_head_dim": self.model_config.qk_rope_head_dim,
"index_head_dim": (
self.model_config.index_head_dim if is_nsa_model else None
),
})
else:
PoolCls = current_platform.get_mha_kv_pool_cls()
pool_args.update({
"head_num": self.model_config.get_num_kv_heads(
get_attention_tp_size()
),
"head_dim": self.model_config.head_dim,
})
self.token_to_kv_pool = PoolCls(**pool_args) |
||
| elif ( | ||
| self.server_args.attention_backend == "ascend" and not self.mambaish_config | ||
| ): | ||
| if self.is_hybrid_swa: | ||
| from sglang.srt.hardware_backend.npu.memory_pool_npu import ( | ||
| NPUMHATokenToKVPool, | ||
|
|
@@ -513,7 +569,17 @@ def _init_pools(self: ModelRunner): | |
| # Initialize token_to_kv_pool_allocator | ||
| need_sort = self.server_args.disaggregation_mode in ("decode", "prefill") | ||
| if self.token_to_kv_pool_allocator is None: | ||
| if _is_npu and ( | ||
| if current_platform.is_out_of_tree(): | ||
| AllocatorCls = current_platform.get_paged_allocator_cls() | ||
| self.token_to_kv_pool_allocator = AllocatorCls( | ||
| self.max_total_num_tokens, | ||
| page_size=self.page_size, | ||
| dtype=self.kv_cache_dtype, | ||
| device=self.device, | ||
| kvcache=self.token_to_kv_pool, | ||
| need_sort=need_sort, | ||
| ) | ||
| elif _is_npu and ( | ||
| self.server_args.attention_backend == "ascend" | ||
| or self.hybrid_gdn_config is not None | ||
| ): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are using
getattrwith dispatch_key to auto-find forward methods defined in sub-classes, it seems we no longer need allforward_xxxinMultiPlatformOp? Maybeforward_nativecan be kept as an escape hatch.The current
MultiPlatformOp'sforward_xxxmethods provide some fallback logic, e.g., hip -> cuda. And we can havefallback_dispatch_keysfor each platform to cover this.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. Once the platform interface stabilizes, we'd love to clean these up — the
fallback_dispatch_keysidea is a nice approach too. For now we're trying tokeep in-tree changes minimal in this PR, so I'd prefer to address it in a
follow-up if that's okay with you. Really appreciate the input!