Skip to content

Commit c8c6268

Browse files
authored
add triton moe fall back by env var (vllm-project#20)
Signed-off-by: Kunshang Ji <[email protected]>
1 parent a3e529c commit c8c6268

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

vllm/envs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@
225225
VLLM_FLATTEN_LOGPROBS: bool = False
226226
VLLM_XPU_USE_W8A8_GEMM: bool = False
227227
VLLM_XPU_ATTN_HEAD_SIZE_PAD: bool = False
228+
VLLM_XPU_MOE_USE_TRITON: bool = False
228229

229230

230231
def get_default_cache_root():
@@ -1490,6 +1491,9 @@ def get_vllm_port() -> int | None:
14901491
"VLLM_XPU_ATTN_HEAD_SIZE_PAD": lambda: bool(
14911492
int(os.getenv("VLLM_XPU_ATTN_HEAD_SIZE_PAD", "0"))
14921493
),
1494+
"VLLM_XPU_MOE_USE_TRITON": lambda: bool(
1495+
int(os.getenv("VLLM_XPU_MOE_USE_TRITON", "0"))
1496+
),
14931497
}
14941498

14951499
# --8<-- [end:env-vars-definition]
@@ -1618,6 +1622,7 @@ def compute_hash() -> str:
16181622
"VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL",
16191623
"VLLM_XPU_USE_W8A8_GEMM",
16201624
"VLLM_XPU_ATTN_HEAD_SIZE_PAD",
1625+
"VLLM_XPU_MOE_USE_TRITON",
16211626
]
16221627
for key in environment_variables_to_hash:
16231628
# if this goes out of sync with environment_variables,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,15 @@
7878
)
7979
else:
8080
fused_experts = None # type: ignore
81+
8182
FusedMoEPermuteExpertsUnpermute = object # type: ignore
8283
FusedMoEPrepareAndFinalize = object # type: ignore
84+
if envs.VLLM_XPU_MOE_USE_TRITON:
85+
from .fused_moe import (
86+
TritonExperts,
87+
eplb_map_to_physical_and_record,
88+
fused_experts,
89+
)
8390
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe
8491

8592
def _eplb_map_to_physical_and_record(
@@ -908,6 +915,26 @@ def forward_xpu(
908915
or logical_replica_count is not None
909916
):
910917
raise NotImplementedError("Expert load balancing is not supported for XPU.")
918+
if envs.VLLM_XPU_MOE_USE_TRITON:
919+
return self.forward_cuda(
920+
layer,
921+
x,
922+
use_grouped_topk,
923+
top_k,
924+
router_logits,
925+
renormalize,
926+
topk_group,
927+
num_expert_group,
928+
global_num_experts,
929+
expert_map,
930+
custom_routing_function,
931+
scoring_func,
932+
routed_scaling_factor,
933+
e_score_correction_bias,
934+
apply_router_weight_on_input,
935+
activation,
936+
)
937+
911938
M, _ = x.size()
912939
routing_weights = torch.empty(M, top_k, dtype=torch.float32, device=x.device)
913940
selected_experts = torch.empty(M, top_k, dtype=torch.int32, device=x.device)
@@ -1009,7 +1036,9 @@ def forward_tpu(
10091036
elif current_platform.is_cpu():
10101037
forward_native = forward_cpu
10111038
elif current_platform.is_xpu():
1012-
forward_native = forward_xpu
1039+
forward_native = (
1040+
forward_xpu if not envs.VLLM_XPU_MOE_USE_TRITON else forward_cuda
1041+
)
10131042
else:
10141043
forward_native = forward_cuda
10151044

0 commit comments

Comments
 (0)