Skip to content

Commit 92edf35

Browse files
authored
[ROCM] enable aiter fused moe kernel for llama4 bf16 checkpoints (#16674)
1 parent eb5819b commit 92edf35

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def rocm_aiter_fused_experts(
2626
topk_weights: torch.Tensor,
2727
topk_ids: torch.Tensor,
2828
use_fp8_w8a8: bool = False,
29+
apply_router_weight_on_input: bool = False,
2930
w1_scale: Optional[torch.Tensor] = None,
3031
w2_scale: Optional[torch.Tensor] = None,
3132
block_shape: Optional[List[int]] = None,
@@ -39,6 +40,18 @@ def rocm_aiter_fused_experts(
3940
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
4041
per_token_group_quant_fp8)
4142

43+
if apply_router_weight_on_input:
44+
assert (topk_weights.dim() == 2
45+
), "`topk_weights` should be in shape (num_tokens, topk)"
46+
_, topk = topk_weights.shape
47+
assert (
48+
topk == 1
49+
), "Only support topk=1 when `apply_router_weight_on_input` is True"
50+
51+
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
52+
topk_ids = topk_ids.to(torch.int32)
53+
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32)
54+
4255
if envs.VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE and use_fp8_w8a8:
4356
assert w1_scale is not None
4457
assert w2_scale is not None

0 commit comments

Comments
 (0)