diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 24d04632504..348758c57b1 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -30,6 +30,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import ( cpu_has_amx_support, + get_bool_env_var, get_compiler_backend, is_cpu, is_cuda, @@ -38,6 +39,7 @@ _is_cuda = is_cuda() _is_hip = is_hip() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() @@ -46,6 +48,11 @@ if _is_cuda or _is_hip: from sgl_kernel import topk_softmax +if _use_aiter: + try: + from aiter import biased_grouped_topk as aiter_biased_grouped_topk + except ImportError: + raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") def fused_topk_torch_native( @@ -347,6 +354,25 @@ def biased_grouped_topk_gpu( topk_ids, expert_location_dispatch_info, num_token_non_padded ) return topk_weights, topk_ids + elif _use_aiter: + token = gating_output.shape[0] + device = gating_output.device + assert ( + hidden_states.shape[0] == gating_output.shape[0] + ), f"Number of tokens mismatch: hidden_states.shape[0] = {hidden_states.shape[0]}, gating_output.shape[0] = {gating_output.shape[0]}" + topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device) + topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) + aiter_biased_grouped_topk( + gating_output, + correction_bias, + topk_weights, + topk_ids, + num_expert_group, + topk_group, + renormalize, + routed_scaling_factor, + ) + return topk_weights, topk_ids else: biased_grouped_topk_fn = ( torch.compile( diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index a51a06f09a7..6726b93f81d 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -419,7 +419,7 @@ def capture(self) -> None: empty_cache=False, ) capture_range.set_description( - f"Capturing batches ({avail_mem=:.2f} GB)" + f"Capturing batches ({bs=} {avail_mem=:.2f} GB)" ) with patch_model( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index c73200400e8..14b553a003b 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -388,7 +388,8 @@ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits ) - if not _is_cuda: + if not _is_cuda and not _use_aiter: + # fused in biased_grouped_topk so we can skip here final_hidden_states *= self.routed_scaling_factor if shared_output is not None: final_hidden_states = final_hidden_states + shared_output