-
-
Notifications
You must be signed in to change notification settings - Fork 11.5k
[FEAT] [ROCm]: Add AITER CK 2 Stages MoE support #17110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
4972bcd
1e83ec8
9f5de58
4a10295
f5ec370
92fc4de
7a8e063
17146df
303456f
24dd605
5efe74d
8cd674c
58b7a45
74b8ddf
7056fba
cba7244
4821254
3deab9f
b56270e
f7d628a
fca957b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,7 @@ def rocm_aiter_asm_moe_tkw1_impl( | |
| hidden_states: torch.Tensor, | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| topk_weight: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| fc1_scale: Optional[torch.Tensor] = None, | ||
| fc2_scale: Optional[torch.Tensor] = None, | ||
|
|
@@ -40,7 +40,7 @@ def rocm_aiter_asm_moe_tkw1_impl( | |
| return asm_moe_tkw1(hidden_states, | ||
| w1, | ||
| w2, | ||
| topk_weight, | ||
| topk_weights, | ||
| topk_ids, | ||
| fc1_scale=fc1_scale, | ||
| fc2_scale=fc2_scale, | ||
|
|
@@ -56,7 +56,7 @@ def rocm_aiter_asm_moe_tkw1_fake( | |
| hidden_states: torch.Tensor, | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| topk_weight: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| fc1_scale: Optional[torch.Tensor] = None, | ||
| fc2_scale: Optional[torch.Tensor] = None, | ||
|
|
@@ -151,7 +151,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake( | |
| def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor, | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| topk_weight: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| fc1_scale: Optional[torch.Tensor] = None, | ||
| fc2_scale: Optional[torch.Tensor] = None, | ||
|
|
@@ -174,7 +174,7 @@ def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor, | |
| return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states, | ||
| w1=w1, | ||
| w2=w2, | ||
| topk_weight=topk_weight, | ||
| topk_weight=topk_weights, | ||
| topk_ids=topk_ids, | ||
| fc1_scale=fc1_scale, | ||
| fc2_scale=fc2_scale, | ||
|
|
@@ -187,7 +187,7 @@ def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor, | |
| def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor, | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| topk_weight: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| fc1_scale: Optional[torch.Tensor] = None, | ||
| fc2_scale: Optional[torch.Tensor] = None, | ||
|
|
@@ -198,6 +198,49 @@ def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor, | |
| return torch.empty_like(hidden_states) | ||
|
|
||
|
|
||
| def rocm_aiter_ck_moe_2stages_impl( | ||
| hidden_states: torch.Tensor, | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| fc1_scale: Optional[torch.Tensor] = None, | ||
| fc2_scale: Optional[torch.Tensor] = None, | ||
| a1_scale: Optional[torch.Tensor] = None, | ||
| a2_scale: Optional[torch.Tensor] = None, | ||
| block_size: Optional[List[int]] = None, | ||
| expert_mask: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| from aiter.fused_moe_bf16_asm import ck_moe_2stages | ||
| return ck_moe_2stages(a1=hidden_states, | ||
| w1=w1, | ||
| w2=w2, | ||
| topk_weight=topk_weights, | ||
| topk_ids=topk_ids, | ||
| fc1_scale=fc1_scale, | ||
| fc2_scale=fc2_scale, | ||
| a1_scale=a1_scale, | ||
| a2_scale=a2_scale, | ||
| block_size=block_size, | ||
| expert_mask=expert_mask) | ||
|
|
||
|
|
||
| def rocm_aiter_ck_moe_2stages_fake( | ||
| hidden_states: torch.Tensor, | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| fc1_scale: Optional[torch.Tensor] = None, | ||
| fc2_scale: Optional[torch.Tensor] = None, | ||
| a1_scale: Optional[torch.Tensor] = None, | ||
| a2_scale: Optional[torch.Tensor] = None, | ||
| block_size: Optional[List[int]] = None, | ||
| expert_mask: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| return torch.empty_like(hidden_states) | ||
|
|
||
|
|
||
| def rocm_aiter_topk_softmax_impl(topk_weights: torch.Tensor, | ||
| topk_indices: torch.Tensor, | ||
| token_expert_indices: torch.Tensor, | ||
|
|
@@ -250,6 +293,14 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor, | |
| dispatch_key=current_platform.dispatch_key, | ||
| ) | ||
|
|
||
| direct_register_custom_op( | ||
| op_name="rocm_aiter_ck_moe_2stages", | ||
| op_func=rocm_aiter_ck_moe_2stages_impl, | ||
| mutates_args=[], | ||
| fake_impl=rocm_aiter_ck_moe_2stages_fake, | ||
| dispatch_key=current_platform.dispatch_key, | ||
| ) | ||
|
|
||
| direct_register_custom_op( | ||
| op_name="rocm_aiter_topk_softmax", | ||
| op_func=rocm_aiter_topk_softmax_impl, | ||
|
|
@@ -259,29 +310,23 @@ def rocm_aiter_topk_softmax_fake(topk_weights: torch.Tensor, | |
| ) | ||
|
|
||
|
|
||
| def rocm_aiter_fused_experts(hidden_states: torch.Tensor, | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| inplace: bool = False, | ||
| activation: str = "silu", | ||
| apply_router_weight_on_input: bool = False, | ||
| use_fp8_w8a8: bool = False, | ||
| use_int8_w8a8: bool = False, | ||
| use_int8_w8a16: bool = False, | ||
| use_int4_w4a16: bool = False, | ||
| per_channel_quant: bool = False, | ||
| global_num_experts: int = -1, | ||
| expert_map: Optional[torch.Tensor] = None, | ||
| w1_scale: Optional[torch.Tensor] = None, | ||
| w2_scale: Optional[torch.Tensor] = None, | ||
| w1_zp: Optional[torch.Tensor] = None, | ||
| w2_zp: Optional[torch.Tensor] = None, | ||
| a1_scale: Optional[torch.Tensor] = None, | ||
| a2_scale: Optional[torch.Tensor] = None, | ||
| block_shape: Optional[List[int]] = None, | ||
| allow_deep_gemm: bool = False) -> torch.Tensor: | ||
| def rocm_aiter_fused_experts( | ||
| hidden_states: torch.Tensor, | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| activation: str = "silu", | ||
| apply_router_weight_on_input: bool = False, | ||
| use_fp8_w8a8: bool = False, | ||
| per_channel_quant: bool = False, | ||
| w1_scale: Optional[torch.Tensor] = None, | ||
| w2_scale: Optional[torch.Tensor] = None, | ||
| a1_scale: Optional[torch.Tensor] = None, | ||
| a2_scale: Optional[torch.Tensor] = None, | ||
| block_shape: Optional[List[int]] = None, | ||
| use_ck_moe_2stages: bool = False, | ||
|
||
| ) -> torch.Tensor: | ||
|
|
||
| from vllm.model_executor.layers.quantization.utils.fp8_utils import ( | ||
| per_token_group_quant_fp8) | ||
|
|
@@ -304,8 +349,8 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor, | |
| a1, a1_scale = per_token_group_quant_fp8(hidden_states, block_shape[1]) | ||
|
|
||
| return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1( | ||
| topk_ids, topk_weights, hidden_states.dtype, expert_map, a1, w1, | ||
| w2, w1_scale, w2_scale, a1_scale, block_shape, None) | ||
| topk_ids, topk_weights, hidden_states.dtype, None, a1, w1, w2, | ||
| w1_scale, w2_scale, a1_scale, block_shape, None) | ||
|
|
||
| # w8a8 per-channel quantization | ||
| elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: | ||
|
|
@@ -330,17 +375,36 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor, | |
| fc2_smooth_scale=None, | ||
| a16=False, | ||
| per_tensor_quant_scale=None, | ||
| expert_mask=expert_map, | ||
| expert_mask=None, | ||
| activation_str=activation) | ||
|
|
||
| # w8a8 per-tensor activation per-tensor weight | ||
| elif use_fp8_w8a8: | ||
| assert not apply_router_weight_on_input, ( | ||
| "apply_router_weight_on_input is not supported for fp8_w8a8") | ||
|
|
||
| # - faster static per-tensor-activation static per-tensor-weight | ||
| # fp8 quantization w8a8 | ||
| if use_ck_moe_2stages and a1_scale is not None and a2_scale is not None: | ||
| return torch.ops.vllm.rocm_aiter_ck_moe_2stages( | ||
| hidden_states=hidden_states, | ||
| w1=w1, | ||
| w2=w2, | ||
| topk_weights=topk_weights, | ||
| topk_ids=topk_ids, | ||
| fc1_scale=w1_scale, | ||
| fc2_scale=w2_scale, | ||
| a1_scale=a1_scale, | ||
| a2_scale=a2_scale) | ||
|
|
||
| # - fallback static per-tensor-activation static per-tensor-weight | ||
| # fp8 quantization w8a8 | ||
| # - dynamic per-tensor activation static per-tensor-weight | ||
| # fp8 quantization w8a8 | ||
| return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states, | ||
| w1=w1, | ||
| w2=w2, | ||
| topk_weight=topk_weights, | ||
| topk_weights=topk_weights, | ||
| topk_ids=topk_ids, | ||
| fc1_scale=w1_scale, | ||
| fc2_scale=w2_scale, | ||
|
|
@@ -360,6 +424,15 @@ def rocm_aiter_fused_experts(hidden_states: torch.Tensor, | |
| topk_ids = topk_ids.to(torch.int32) | ||
| topk_weights = torch.ones_like(topk_weights, dtype=torch.float32) | ||
|
|
||
| # faster w16a16 | ||
| if use_ck_moe_2stages: | ||
| return torch.ops.vllm.rocm_aiter_ck_moe_2stages( | ||
| hidden_states=hidden_states, | ||
| w1=w1, | ||
| w2=w2, | ||
| topk_weights=topk_weights, | ||
| topk_ids=topk_ids) | ||
|
|
||
| # w16a16 fallback to rocm_aiter_ck_moe w16a16 | ||
| return torch.ops.vllm.rocm_aiter_ck_moe(hidden_states=hidden_states, | ||
| w1=w1, | ||
|
|
@@ -379,7 +452,8 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, | |
| return topk_weights, topk_indices | ||
|
|
||
|
|
||
| def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]: | ||
| def shuffle_weights(*tensors: torch.Tensor, | ||
| layout: tuple[int, int]) -> tuple[torch.Tensor, ...]: | ||
| """ | ||
| Applies shuffle_weight function from AITER to each | ||
| input tensor and returns them. | ||
|
|
@@ -391,11 +465,12 @@ def shuffle_weights(*tensors: torch.Tensor) -> Tuple[torch.Tensor, ...]: | |
| A Tuple of shuffled tensors. | ||
| """ | ||
| from aiter.ops.shuffle import shuffle_weight | ||
| return tuple(shuffle_weight(tensor) for tensor in tensors) | ||
|
|
||
| return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) | ||
|
|
||
|
|
||
| def expand_weights(*tensors: torch.Tensor, | ||
| expansion_dims: list[int]) -> Tuple[torch.Tensor, ...]: | ||
| expansion_dims: List[int]) -> Tuple[torch.Tensor, ...]: | ||
| """ | ||
| Expands the dimensions of input tensors. | ||
|
|
||
|
|
@@ -413,4 +488,4 @@ def expand_weights(*tensors: torch.Tensor, | |
|
|
||
| return tuple( | ||
| tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1)) | ||
| for tensor, dim in zip(tensors, expansion_dims)) | ||
| for tensor, dim in zip(tensors, expansion_dims)) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I doesn't look like you need this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed. this is not needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have removed the argument
allow_deep_gemm.use_ck_moe_2stagesis kept as there is an RFC that highlights Accessing envs.ENV is very costly. RFC Issue #17067 .Thus, all the env are only invoked and stored as a property of the class during initialization stage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To pass the pre-commit tests of file
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py, we have adjusted the logic of assignment of fused_experts function tofused_experts_functo becomerocm_aiter_fused_experts_func, following the approach invllm/attention/backends/rocm_flash_attn.py, where the attention functions are assigned to different property name:self.fa_attn_func,self.sdpa_attn_funcandself.triton_attn_funcThis also allows us to clean up the unused arguments of the function in
rocm_aiter_fused_experts(vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py). Note that theexpert_maphas been removed as a bugfix,expert_map(integer ID of the experts) in vLLM andexpert_mask(boolean mask of active experts on current GPU) in the AITER ops are different. The currentrocm_aiter_fused_expertshas removed theexpert_mapargument and encourage to add it back when enabling EP using AITER in future PR.