Skip to content

Commit 36a97a5

Browse files
alexsun07whybeyoung
authored andcommitted
[AMD] add aiter fused moe in DeepEP path (sgl-project#7268)
1 parent 2fcdf6d commit 36a97a5

File tree

2 files changed

+98
-14
lines changed

2 files changed

+98
-14
lines changed

python/sglang/srt/layers/moe/ep_moe/layer.py

Lines changed: 79 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,16 @@
5454

5555
_is_hip = is_hip()
5656
_is_fp8_fnuz = is_fp8_fnuz()
57+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
5758

5859
if _is_hip:
5960
from vllm._custom_ops import scaled_fp8_quant
6061

62+
if _use_aiter:
63+
from aiter import ActivationType, QuantType
64+
from aiter.fused_moe import fused_moe
65+
from aiter.ops.shuffle import shuffle_weight
66+
6167
logger = logging.getLogger(__name__)
6268

6369

@@ -1046,6 +1052,15 @@ def process_weights_after_loading(self, layer: Module) -> None:
10461052
w2_weight_scale, requires_grad=False
10471053
)
10481054
layer.w2_input_scale = None
1055+
if _use_aiter:
1056+
layer.w13_weight = torch.nn.Parameter(
1057+
shuffle_weight(layer.w13_weight.data, (16, 16)),
1058+
requires_grad=False,
1059+
)
1060+
layer.w2_weight = torch.nn.Parameter(
1061+
shuffle_weight(layer.w2_weight.data, (16, 16)),
1062+
requires_grad=False,
1063+
)
10491064
return
10501065

10511066
def apply(
@@ -1117,18 +1132,36 @@ def __init__(
11171132
assert (
11181133
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
11191134
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
1120-
self.w13_weight_fp8 = (
1121-
self.w13_weight,
1122-
(
1123-
self.w13_weight_scale_inv
1124-
if self.use_block_quant
1125-
else self.w13_weight_scale
1126-
),
1127-
)
1128-
self.w2_weight_fp8 = (
1129-
self.w2_weight,
1130-
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
1131-
)
1135+
if _use_aiter:
1136+
# expert_mask is of size (self.num_experts_per_partition + 1),
1137+
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
1138+
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
1139+
# self.expert_mask = [1, 1, 1, 1, 0]
1140+
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
1141+
self.expert_mask = torch.zeros(
1142+
(self.num_experts_per_partition + 1),
1143+
device=torch.cuda.current_device(),
1144+
dtype=torch.int,
1145+
)
1146+
# the last one is invalid rank_id
1147+
self.expert_mask[:-1] = 1
1148+
else:
1149+
self.w13_weight_fp8 = (
1150+
self.w13_weight,
1151+
(
1152+
self.w13_weight_scale_inv
1153+
if self.use_block_quant
1154+
else self.w13_weight_scale
1155+
),
1156+
)
1157+
self.w2_weight_fp8 = (
1158+
self.w2_weight,
1159+
(
1160+
self.w2_weight_scale_inv
1161+
if self.use_block_quant
1162+
else self.w2_weight_scale
1163+
),
1164+
)
11321165

11331166
def forward(
11341167
self,
@@ -1142,6 +1175,9 @@ def forward(
11421175
num_recv_tokens_per_expert: List[int],
11431176
forward_mode: ForwardMode,
11441177
):
1178+
if _use_aiter:
1179+
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
1180+
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
11451181
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
11461182
if resolved_deepep_mode == DeepEPMode.normal:
11471183
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
@@ -1274,6 +1310,37 @@ def forward_normal(
12741310
)
12751311
return down_output
12761312

1313+
def forward_aiter(
1314+
self,
1315+
hidden_states: torch.Tensor,
1316+
topk_idx: torch.Tensor,
1317+
topk_weights: torch.Tensor,
1318+
):
1319+
if hidden_states.shape[0] == 0:
1320+
return hidden_states
1321+
# in original deepep, idx == -1 meaning invalid and will not be processed.
1322+
# aiter does not accept -1, we use a expert mask to make these idx invalid
1323+
# (idx == num_experts_per_partition) meaning not used in aiter fused_moe
1324+
topk_idx_copy = topk_idx.to(torch.int32)
1325+
topk_idx_copy[topk_idx_copy == -1] = self.num_experts_per_partition
1326+
1327+
return fused_moe(
1328+
hidden_states,
1329+
self.w13_weight,
1330+
self.w2_weight,
1331+
topk_weights,
1332+
topk_idx_copy,
1333+
w1_scale=self.w13_weight_scale_inv,
1334+
w2_scale=self.w2_weight_scale_inv,
1335+
quant_type=QuantType.per_128x128,
1336+
activation=(
1337+
ActivationType.Silu
1338+
if self.activation == "silu"
1339+
else ActivationType.Gelu
1340+
),
1341+
expert_mask=self.expert_mask,
1342+
)
1343+
12771344
def forward_deepgemm_contiguous(
12781345
self,
12791346
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],

python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66
get_global_expert_distribution_recorder,
77
)
88
from sglang.srt.managers.schedule_batch import global_server_args_dict
9-
from sglang.srt.utils import DeepEPMode, get_int_env_var, load_json_config
9+
from sglang.srt.utils import (
10+
DeepEPMode,
11+
get_bool_env_var,
12+
get_int_env_var,
13+
is_hip,
14+
load_json_config,
15+
)
1016

1117
try:
1218
from deep_ep import Buffer, Config
@@ -32,6 +38,8 @@
3238
)
3339
from sglang.srt.model_executor.forward_batch_info import ForwardMode
3440

41+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
42+
3543
logger = logging.getLogger(__name__)
3644

3745

@@ -376,6 +384,15 @@ def _deepep_permute(
376384
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
377385
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
378386
"""
387+
if _use_aiter:
388+
# skip permutation here as aiter fused_moe has fused inside
389+
reorder_topk_ids = torch.empty(
390+
(0,), device=hidden_states.device, dtype=torch.int64
391+
)
392+
seg_indptr = torch.zeros(
393+
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
394+
)
395+
return reorder_topk_ids, seg_indptr, hidden_states
379396

380397
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
381398
topk_idx, self.num_experts
@@ -409,7 +426,7 @@ def combine_a(
409426
topk_idx: torch.Tensor,
410427
topk_weights: torch.Tensor,
411428
):
412-
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
429+
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
413430
output = hidden_states
414431
else:
415432
if hidden_states.shape[0] > 0:

0 commit comments

Comments
 (0)