-
Notifications
You must be signed in to change notification settings - Fork 5.7k
[Feature] Enable return routed experts #12162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+646
−10
Merged
Changes from all commits
Commits
Show all changes
70 commits
Select commit
Hold shift + click to select a range
cf8d0fd
init
ocss884 a720fae
Merge branch 'sgl-project:main' into return_routed_expert
ocss884 542d490
more
ocss884 cfc330b
small fix
ocss884 1b9e9fa
small fix
ocss884 5a70eb7
refactor
ocss884 6924725
add layer_id to select_experts
ocss884 1f40d43
more
ocss884 794af1f
more
ocss884 e15db98
rm
ocss884 f92a997
copy to req when finished
ocss884 41bebeb
misc
ocss884 4ec04be
lint
ocss884 fa916bc
small fix
ocss884 98c8fe0
nit
ocss884 8da8de1
Merge branch 'main' into return_routed_expert
ocss884 6c92475
use key word args
ocss884 812421a
more
ocss884 ded85ca
adapt for cuda graph
yizhang2077 ce7463d
enable for engine
ocss884 bf63df9
Merge branch 'main' into return_routed_expert
ocss884 22e3910
fix for schedule overlap
ocss884 b2af1cb
Merge branch 'main' into return_routed_expert
zhaochenyang20 911bd1a
Merge branch 'main' into return_routed_expert
ocss884 0a0645f
bugfix
ocss884 b041ddc
add arg for bench_serving
ocss884 c70bded
more
ocss884 3d54694
add out_cache_loc_cpu to improve performance
ocss884 1b701a8
Merge branch 'main' into return_routed_expert
ocss884 31d9509
more
ocss884 b72a6f2
move topk_ids tolist into detokenizer
yizhang2077 fdb5c47
rollback
ocss884 b77a5fb
add ut
ocss884 83192fb
dp attn
ocss884 aa4df5e
fix dp attn
ocss884 e7b1fc9
Merge branch 'main' into return_routed_expert
ocss884 65471d5
Merge branch 'main' into return_routed_expert
ocss884 4c23438
fix a bug when dp attn + r3 + cuda graph
yizhang2077 aaca6e3
Merge branch 'main' into return_routed_expert
ocss884 823ea6b
small fix & support dpsk
ocss884 54ba922
rollback
ocss884 7d9a69d
add to test suite
ocss884 d49fe1a
Merge branch 'main' into return_routed_expert
ocss884 d8e2e34
Merge branch 'main' into return_routed_expert
zhaochenyang20 e8c276a
Merge branch 'main' into return_routed_expert
Kangyan-Zhou 4457751
update test
ocss884 b4f4aaf
Merge branch 'main' into return_routed_expert
ocss884 5a3e735
fix test
ocss884 f9a442b
fix test
ocss884 676612f
Merge branch 'main' into return_routed_expert
ocss884 041dfcd
tiny fix for retract
yizhang2077 e6af67d
cleanup code
ocss884 067f185
bugfix
ocss884 be8ce97
bugfix
ocss884 09b203b
Merge branch 'main' into return_routed_expert
ocss884 017d0b9
Merge branch 'main' into return_routed_expert
ocss884 808d29f
Merge branch 'main' into return_routed_expert
hnyls2002 dbe9775
fix merge conflicts
hnyls2002 33fe8ee
fix glm4 moe
hnyls2002 5770aa2
remove out cache loc cpu
hnyls2002 2f580f5
clean up
hnyls2002 b32e693
tiny fix
hnyls2002 7da8cfa
split function out
hnyls2002 37a6c76
remove duplicate code
hnyls2002 0189d33
use b64 return
ocss884 ee2a780
Merge branch 'main' into return_routed_expert
ocss884 f486811
small bugfix
ocss884 ca4b3a4
fix test
ocss884 efee10b
Merge branch 'main' into return_routed_expert
ocss884 67210c2
lint
ocss884 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
289 changes: 289 additions & 0 deletions
289
python/sglang/srt/layers/moe/routed_experts_capturer.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,289 @@ | ||
| import logging | ||
| from abc import ABC | ||
| from typing import Optional | ||
|
|
||
| import numpy as np | ||
| import pybase64 | ||
| import torch | ||
|
|
||
| from sglang.srt.configs.model_config import ModelConfig | ||
| from sglang.srt.layers.dp_attention import ( | ||
| get_attention_dp_rank, | ||
| get_dp_local_info, | ||
| is_dp_attention_enabled, | ||
| ) | ||
| from sglang.srt.mem_cache.memory_pool import ReqToTokenPool | ||
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch | ||
| from sglang.srt.server_args import get_global_server_args | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _GB = 1024 * 1024 * 1024 | ||
| _MB = 1024 * 1024 | ||
|
|
||
|
|
||
| def get_tensor_size_bytes(t: torch.Tensor): | ||
| return np.prod(t.shape) * t.dtype.itemsize | ||
|
|
||
|
|
||
| class _RoutedExpertsDeviceCache: | ||
| def __init__( | ||
| self, | ||
| max_running_requests: int, | ||
| num_hidden_layers: int, | ||
| num_experts_per_tok: int, | ||
| num_fused_shared_experts: int, | ||
| device: str, | ||
| ) -> None: | ||
| self.buffer = torch.zeros( | ||
| ( | ||
| max( | ||
| get_global_server_args().chunked_prefill_size | ||
| * get_global_server_args().dp_size, | ||
| max_running_requests, | ||
| ), | ||
| num_hidden_layers, | ||
| num_experts_per_tok + num_fused_shared_experts, | ||
| ), | ||
| dtype=torch.int32, | ||
| device=device, | ||
| ) | ||
| self._finalize_allocation_log() | ||
|
|
||
| def get_buffer_size_bytes(self): | ||
| assert hasattr(self, "buffer") | ||
| return get_tensor_size_bytes(self.buffer) | ||
|
|
||
| def capture_fwd_routed_experts(self, layer_id: int, topk_ids: torch.Tensor): | ||
| assert layer_id is not None, "capturing routing experts but get layer_id None" | ||
| batch, _ = topk_ids.shape | ||
| self.buffer[:batch, layer_id, :] = topk_ids | ||
|
|
||
| def _finalize_allocation_log(self): | ||
| """Common logging and memory usage computation for captured experts buffers.""" | ||
| buffer_size_MB = self.get_buffer_size_bytes() / _MB | ||
| logger.info( | ||
| f"Routing experts device buffer allocated. #shape: {tuple(self.buffer.shape)}, size: {buffer_size_MB:.2f} MB" | ||
| ) | ||
|
|
||
|
|
||
| class _RoutedExpertsHostCache: | ||
| def __init__( | ||
| self, | ||
| num_tokens: int, | ||
| num_hidden_layers: int, | ||
| num_experts_per_tok: int, | ||
| ) -> None: | ||
| self.num_tokens = num_tokens | ||
| self.buffer = torch.zeros( | ||
| ( | ||
| num_tokens, | ||
| num_hidden_layers, | ||
| num_experts_per_tok, | ||
| ), | ||
| dtype=torch.int32, | ||
| device="cpu", | ||
| pin_memory=True, | ||
| ) | ||
| self._finalize_allocation_log() | ||
|
|
||
| def get_buffer_size_bytes(self): | ||
| assert hasattr(self, "buffer") | ||
| return get_tensor_size_bytes(self.buffer) | ||
|
|
||
| def set_experts_buffer(self, layer_id: int, loc: torch.Tensor, top_k: torch.Tensor): | ||
| self.buffer[layer_id, loc, :] = top_k.to(device="cpu", non_blocking=True) | ||
|
|
||
| def _finalize_allocation_log(self): | ||
| """Common logging and memory usage computation for captured experts buffers.""" | ||
| buffer_size_GB = self.get_buffer_size_bytes() / _GB | ||
| logger.info( | ||
| f"Routing experts host buffer allocated. #tokens: {self.num_tokens}, size: {buffer_size_GB:.2f} GB" | ||
| ) | ||
|
|
||
|
|
||
| class RoutedExpertsCapturer(ABC): | ||
| @staticmethod | ||
| def create( | ||
| enable: bool, | ||
| model_config: ModelConfig, | ||
| num_fused_shared_experts: int, | ||
| num_tokens: int, | ||
| max_running_requests: int, | ||
| device: str, | ||
| ): | ||
| if enable: | ||
| return _RoutedExpertsCapturerReal( | ||
| model_config, | ||
| num_tokens=num_tokens, | ||
| max_running_requests=max_running_requests, | ||
| num_fused_shared_experts=num_fused_shared_experts, | ||
| device=device, | ||
| ) | ||
| else: | ||
| return _RoutedExpertsCapturerNoop() | ||
|
|
||
| def _sync_fwd_experts_buffer_DtoH( | ||
| self, | ||
| forward_batch: ForwardBatch, | ||
| can_run_graph: bool, | ||
| cuda_graph_batch: int, | ||
| ): | ||
| raise NotImplementedError | ||
|
|
||
| def capture(self, layer_id: int, topk_ids: torch.Tensor): | ||
| raise NotImplementedError | ||
|
|
||
| def get_routed_experts( | ||
| self, | ||
| req_pool_idx: int, | ||
| seqlen: int, | ||
| req_to_token_pool: ReqToTokenPool, | ||
| ): | ||
| raise NotImplementedError | ||
|
|
||
| def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch): | ||
| raise NotImplementedError | ||
|
|
||
| def get_host_cache(self): | ||
| raise NotImplementedError | ||
|
|
||
| def get_device_cache(self): | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| class _RoutedExpertsCapturerReal(RoutedExpertsCapturer): | ||
| """Capturer for routed experts with host buffer""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_config: ModelConfig, | ||
| num_tokens: int, | ||
| max_running_requests: int, | ||
| num_fused_shared_experts: int, | ||
| device: str, | ||
| ): | ||
| self.num_fused_shared_experts = num_fused_shared_experts | ||
| self.num_hidden_layers = model_config.hf_text_config.num_hidden_layers | ||
| self.num_experts_per_tok = model_config.hf_text_config.num_experts_per_tok | ||
|
|
||
| self.host_cache = _RoutedExpertsHostCache( | ||
| num_tokens=num_tokens, | ||
| num_hidden_layers=self.num_hidden_layers, | ||
| num_experts_per_tok=self.num_experts_per_tok, | ||
| ) | ||
|
|
||
| self.device_cache = _RoutedExpertsDeviceCache( | ||
| max_running_requests=max_running_requests, | ||
| num_hidden_layers=self.num_hidden_layers, | ||
| num_experts_per_tok=self.num_experts_per_tok, | ||
| num_fused_shared_experts=self.num_fused_shared_experts, | ||
| device=device, | ||
| ) | ||
|
|
||
| def _sync_fwd_experts_buffer_DtoH( | ||
| self, | ||
| forward_batch: ForwardBatch, | ||
| can_run_graph: bool, | ||
| cuda_graph_batch: int, | ||
| ): | ||
| if is_dp_attention_enabled(): | ||
| local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) | ||
| # handle with cuda graph padding | ||
| if can_run_graph: | ||
| local_start_pos = get_attention_dp_rank() * cuda_graph_batch | ||
| local_end_pos = local_start_pos + local_num_tokens | ||
| else: | ||
| local_end_pos = local_start_pos + local_num_tokens | ||
| else: | ||
| local_start_pos = 0 | ||
| local_end_pos = forward_batch.out_cache_loc.shape[0] | ||
|
|
||
| # FIXME: sync explicitly here, overlap scheduler breaks here. | ||
| out_cache_loc_cpu = forward_batch.out_cache_loc.cpu() | ||
| self.host_cache.buffer[out_cache_loc_cpu] = self.device_cache.buffer[ | ||
| local_start_pos:local_end_pos, :, : self.num_experts_per_tok | ||
| ].cpu() | ||
|
|
||
| def capture(self, layer_id: int, topk_ids: torch.Tensor): | ||
| self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) | ||
|
|
||
| def get_routed_experts( | ||
| self, | ||
| req_pool_idx: int, | ||
| seqlen: int, | ||
| req_to_token_pool: ReqToTokenPool, | ||
| ): | ||
| cache_pool_idx = ( | ||
| req_to_token_pool.req_to_token[req_pool_idx][: seqlen - 1].cpu().clone() | ||
| ) | ||
| return self.get_host_cache().buffer[cache_pool_idx] | ||
|
|
||
| def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch): | ||
| self._sync_fwd_experts_buffer_DtoH( | ||
| forward_batch=forward_batch, | ||
| can_run_graph=can_run_graph, | ||
| cuda_graph_batch=cuda_graph_batch, | ||
| ) | ||
|
|
||
| def get_host_cache(self): | ||
| return self.host_cache | ||
|
|
||
| def get_device_cache(self): | ||
| return self.device_cache | ||
|
|
||
|
|
||
| class _RoutedExpertsCapturerNoop(RoutedExpertsCapturer): | ||
| def __init__(self): | ||
| pass | ||
|
|
||
| def _sync_fwd_experts_buffer_DtoH( | ||
| self, | ||
| forward_batch: ForwardBatch, | ||
| can_run_graph: bool, | ||
| cuda_graph_batch: int, | ||
| ): | ||
| pass | ||
|
|
||
| def capture(self, layer_id: int, topk_ids: torch.Tensor): | ||
| pass | ||
|
|
||
| def get_routed_experts( | ||
| self, | ||
| req_pool_idx: int, | ||
| seqlen: int, | ||
| req_to_token_pool: ReqToTokenPool, | ||
| ): | ||
| pass | ||
|
|
||
| def on_forward_end(self, forward_batch, can_run_graph, cuda_graph_batch): | ||
| pass | ||
|
|
||
| def get_host_cache(self): | ||
| pass | ||
|
|
||
| def get_device_cache(self): | ||
| pass | ||
|
|
||
|
|
||
| _global_expert_capturer: Optional[RoutedExpertsCapturer] = _RoutedExpertsCapturerNoop() | ||
|
|
||
|
|
||
| def get_global_experts_capturer(): | ||
| return _global_expert_capturer | ||
|
|
||
|
|
||
| def set_global_experts_capturer(capturer: RoutedExpertsCapturer): | ||
| global _global_expert_capturer | ||
| _global_expert_capturer = capturer | ||
|
|
||
|
|
||
| def extract_routed_experts_from_meta_info(data): | ||
| # To solve the performance issue, we return the experts_ids in base64 | ||
| # We left this function for user to change it back to normal int32 | ||
| # See detokenizer_manager::_extract_routed_experts | ||
| routed_experts_base64 = data["meta_info"].get("routed_experts", None) | ||
| routed_experts = np.frombuffer( | ||
| pybase64.b64decode(routed_experts_base64.encode("utf-8")), dtype=np.int32 | ||
| ) | ||
| return routed_experts |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
about how to pass in layer_id: ExpertDistributionRecorder has a with_current_layer. what about refactoring into sth like
and let both ExpertDistributionRecorder and the new RoutedExpertsCapture to use current_layer_mgr.get_the_current_layer_id_value