Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
get_compiler_backend,
is_cpu,
is_cuda,
Expand All @@ -38,6 +39,7 @@

_is_cuda = is_cuda()
_is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()

Expand All @@ -46,6 +48,11 @@

if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
if _use_aiter:
try:
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")


def fused_topk_torch_native(
Expand Down Expand Up @@ -347,6 +354,25 @@ def biased_grouped_topk_gpu(
topk_ids, expert_location_dispatch_info, num_token_non_padded
)
return topk_weights, topk_ids
elif _use_aiter:
token = gating_output.shape[0]
device = gating_output.device
assert (
hidden_states.shape[0] == gating_output.shape[0]
), f"Number of tokens mismatch: hidden_states.shape[0] = {hidden_states.shape[0]}, gating_output.shape[0] = {gating_output.shape[0]}"
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
aiter_biased_grouped_topk(
gating_output,
correction_bias,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
renormalize,
routed_scaling_factor,
)
return topk_weights, topk_ids
else:
biased_grouped_topk_fn = (
torch.compile(
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def capture(self) -> None:
empty_cache=False,
)
capture_range.set_description(
f"Capturing batches ({avail_mem=:.2f} GB)"
f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
)

with patch_model(
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,8 @@ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if not _is_cuda:
if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
Expand Down
Loading