@@ -26,6 +26,7 @@ def rocm_aiter_fused_experts(
2626 topk_weights : torch .Tensor ,
2727 topk_ids : torch .Tensor ,
2828 use_fp8_w8a8 : bool = False ,
29+ apply_router_weight_on_input : bool = False ,
2930 w1_scale : Optional [torch .Tensor ] = None ,
3031 w2_scale : Optional [torch .Tensor ] = None ,
3132 block_shape : Optional [List [int ]] = None ,
@@ -39,6 +40,18 @@ def rocm_aiter_fused_experts(
3940 from vllm .model_executor .layers .quantization .utils .fp8_utils import (
4041 per_token_group_quant_fp8 )
4142
43+ if apply_router_weight_on_input :
44+ assert (topk_weights .dim () == 2
45+ ), "`topk_weights` should be in shape (num_tokens, topk)"
46+ _ , topk = topk_weights .shape
47+ assert (
48+ topk == 1
49+ ), "Only support topk=1 when `apply_router_weight_on_input` is True"
50+
51+ hidden_states = hidden_states * topk_weights .to (hidden_states .dtype )
52+ topk_ids = topk_ids .to (torch .int32 )
53+ topk_weights = torch .ones_like (topk_weights , dtype = torch .float32 )
54+
4255 if envs .VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8 :
4356 assert w1_scale is not None
4457 assert w2_scale is not None
0 commit comments