|
78 | 78 | ) |
79 | 79 | else: |
80 | 80 | fused_experts = None # type: ignore |
| 81 | + |
81 | 82 | FusedMoEPermuteExpertsUnpermute = object # type: ignore |
82 | 83 | FusedMoEPrepareAndFinalize = object # type: ignore |
| 84 | + if envs.VLLM_XPU_MOE_USE_TRITON: |
| 85 | + from .fused_moe import ( |
| 86 | + TritonExperts, |
| 87 | + eplb_map_to_physical_and_record, |
| 88 | + fused_experts, |
| 89 | + ) |
83 | 90 | from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe |
84 | 91 |
|
85 | 92 | def _eplb_map_to_physical_and_record( |
@@ -908,6 +915,26 @@ def forward_xpu( |
908 | 915 | or logical_replica_count is not None |
909 | 916 | ): |
910 | 917 | raise NotImplementedError("Expert load balancing is not supported for XPU.") |
| 918 | + if envs.VLLM_XPU_MOE_USE_TRITON: |
| 919 | + return self.forward_cuda( |
| 920 | + layer, |
| 921 | + x, |
| 922 | + use_grouped_topk, |
| 923 | + top_k, |
| 924 | + router_logits, |
| 925 | + renormalize, |
| 926 | + topk_group, |
| 927 | + num_expert_group, |
| 928 | + global_num_experts, |
| 929 | + expert_map, |
| 930 | + custom_routing_function, |
| 931 | + scoring_func, |
| 932 | + routed_scaling_factor, |
| 933 | + e_score_correction_bias, |
| 934 | + apply_router_weight_on_input, |
| 935 | + activation, |
| 936 | + ) |
| 937 | + |
911 | 938 | M, _ = x.size() |
912 | 939 | routing_weights = torch.empty(M, top_k, dtype=torch.float32, device=x.device) |
913 | 940 | selected_experts = torch.empty(M, top_k, dtype=torch.int32, device=x.device) |
@@ -1009,7 +1036,9 @@ def forward_tpu( |
1009 | 1036 | elif current_platform.is_cpu(): |
1010 | 1037 | forward_native = forward_cpu |
1011 | 1038 | elif current_platform.is_xpu(): |
1012 | | - forward_native = forward_xpu |
| 1039 | + forward_native = ( |
| 1040 | + forward_xpu if not envs.VLLM_XPU_MOE_USE_TRITON else forward_cuda |
| 1041 | + ) |
1013 | 1042 | else: |
1014 | 1043 | forward_native = forward_cuda |
1015 | 1044 |
|
|
0 commit comments