-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
[ROCm][Perf] Add Fused Shared Expert (FSE) support for Qwen3-Next #39280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
105a3b1
02f8429
fa543be
fd1cc69
f148c9f
8e5d3fd
6fa2868
2205cd9
30c7e57
3f7875d
95f11c1
8c7cf2f
6349f10
e5c041d
096c8fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -109,6 +109,55 @@ def init_aiter_topK_meta_data( | |
| aiter_topK_meta_data = (total_topk_weights, total_topk_ids) | ||
|
|
||
|
|
||
| def inject_shared_expert_weights( | ||
| topk_weights: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| topk: int, | ||
| num_fused_shared_experts: int, | ||
| shared_expert_weights: torch.Tensor | None = None, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Merge routed topk results with the shared expert buffer and inject | ||
| dynamic per-token shared expert gate values for AITER fusion. | ||
|
|
||
| For routers that already return the combined buffer (e.g. GroupedTopKRouter | ||
| via rocm_aiter_grouped_topk), only the dynamic weight injection is needed. | ||
| For routers that return only routed slots (e.g. FusedTopKRouter), this also | ||
| copies the routed results into the pre-allocated combined buffer. | ||
| """ | ||
| if num_fused_shared_experts == 0: | ||
| return topk_weights, topk_ids | ||
|
|
||
| assert aiter_topK_meta_data is not None, ( | ||
| "aiter_topK_meta_data is not initialized but " | ||
| "num_fused_shared_experts > 0. Ensure init_aiter_topK_meta_data " | ||
| "is called before routing." | ||
| ) | ||
|
|
||
| total_topk_weights, total_topk_ids = aiter_topK_meta_data | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would this cause a problem if
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it is. I think we should ideally update this soon to have managed access of aiter_topK_meta_data to not rely on initialization from callers before using it. For now, I've added an assert to prevent this incorrect situation. |
||
| token = topk_weights.shape[0] | ||
|
|
||
| assert total_topk_weights.shape[0] >= token, ( | ||
| f"AITER topK meta data supports {total_topk_weights.shape[0]} " | ||
| f"tokens, but got {token} tokens." | ||
| ) | ||
|
|
||
| total_topk_weights_slice = total_topk_weights[:token] | ||
| total_topk_ids_slice = total_topk_ids[:token] | ||
|
|
||
| if topk_weights.shape[1] == topk: | ||
| total_topk_weights_slice[:, :topk] = topk_weights | ||
| total_topk_ids_slice[:, :topk] = topk_ids | ||
| topk_weights = total_topk_weights_slice | ||
| topk_ids = total_topk_ids_slice | ||
|
|
||
| if shared_expert_weights is not None: | ||
| topk_weights[:, topk : topk + num_fused_shared_experts] = shared_expert_weights[ | ||
| :token | ||
| ] | ||
|
|
||
| return topk_weights, topk_ids | ||
|
|
||
|
|
||
| def rocm_aiter_grouped_topk( | ||
| hidden_states: torch.Tensor, | ||
| gating_output: torch.Tensor, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| from collections.abc import Callable | ||
|
|
||
| import torch | ||
|
|
||
| from vllm._aiter_ops import rocm_aiter_ops | ||
| from vllm.distributed.eplb.eplb_state import EplbLayerState | ||
| from vllm.model_executor.layers.fused_moe.config import ( | ||
| RoutingMethodType, | ||
| get_routing_method_type, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter | ||
| from vllm.model_executor.layers.fused_moe.router.fused_topk_router import ( | ||
| dispatch_topk_softmax_func, | ||
| ) | ||
|
|
||
|
|
||
| class AiterSharedRoutedFusedMoERouter(BaseRouter): | ||
| """ | ||
| ROCm AITER router for models with fused shared experts (e.g. Qwen3-MoE). | ||
|
|
||
| When the AITER topk_softmax kernel supports sigmoid fusion, the routing | ||
| softmax and shared-expert sigmoid are computed in a single kernel launch. | ||
| Otherwise the shared-expert weights are injected into the pre-allocated | ||
| AITER buffer via a fallback path. | ||
|
|
||
| Only instantiated when rocm_aiter fused-MoE is active and | ||
| num_fused_shared_experts > 0. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| top_k: int, | ||
| global_num_experts: int, | ||
| eplb_state: EplbLayerState, | ||
| num_fused_shared_experts: int, | ||
| scoring_func: str = "softmax", | ||
| renormalize: bool = True, | ||
| enable_eplb: bool = False, | ||
| indices_type_getter: Callable[[], torch.dtype | None] | None = None, | ||
| ): | ||
| super().__init__( | ||
| top_k=top_k, | ||
| global_num_experts=global_num_experts, | ||
| eplb_state=eplb_state, | ||
| enable_eplb=enable_eplb, | ||
| indices_type_getter=indices_type_getter, | ||
| ) | ||
| self.renormalize = renormalize | ||
| self.scoring_func = scoring_func | ||
| self.num_fused_shared_experts = num_fused_shared_experts | ||
|
|
||
| @property | ||
| def routing_method_type(self) -> RoutingMethodType: | ||
| return get_routing_method_type( | ||
| scoring_func=self.scoring_func, | ||
| top_k=self.top_k, | ||
| renormalize=self.renormalize, | ||
| num_expert_group=None, | ||
| has_e_score_bias=False, | ||
| ) | ||
|
|
||
| def _compute_routing( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| router_logits: torch.Tensor, | ||
| indices_type: torch.dtype | None, | ||
| *, | ||
| input_ids: torch.Tensor | None = None, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| assert hidden_states.size(0) == router_logits.size(0), ( | ||
| "Number of tokens mismatch" | ||
| ) | ||
|
|
||
| from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( | ||
| aiter_topK_meta_data, | ||
| ) | ||
|
|
||
| M = hidden_states.size(0) | ||
| topk = self.top_k | ||
| num_fse = self.num_fused_shared_experts | ||
|
|
||
| token_expert_indices = torch.empty( | ||
| M, topk, dtype=torch.int32, device=hidden_states.device | ||
| ) | ||
|
|
||
| if rocm_aiter_ops.fuse_sigmoid_in_kernel(aiter_topK_meta_data): | ||
| total_topk_weights, total_topk_ids = aiter_topK_meta_data # type: ignore[misc] | ||
| total_topk_weights_slice = total_topk_weights[:M] | ||
| topk_ids_slice = total_topk_ids[:M, :topk] | ||
|
|
||
| topk_func = dispatch_topk_softmax_func(use_rocm_aiter=True) | ||
| topk_func( | ||
| total_topk_weights_slice, | ||
| topk_ids_slice, | ||
| token_expert_indices, | ||
| router_logits, | ||
| self.renormalize, | ||
| num_fse, | ||
| "sigmoid", | ||
| ) | ||
| return total_topk_weights_slice, total_topk_ids[:M] | ||
|
|
||
| routing_logits = router_logits[:, :-num_fse] | ||
| shared_logits = router_logits[:, -num_fse:] | ||
|
|
||
| topk_weights = torch.empty( | ||
| M, topk, dtype=torch.float32, device=hidden_states.device | ||
| ) | ||
| topk_ids = torch.empty( | ||
| M, | ||
| topk, | ||
| dtype=torch.int32 if indices_type is None else indices_type, | ||
| device=hidden_states.device, | ||
| ) | ||
|
|
||
| topk_func = dispatch_topk_softmax_func( | ||
| use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled() | ||
| ) | ||
| topk_weights, topk_ids = topk_func( | ||
| topk_weights, | ||
| topk_ids, | ||
| token_expert_indices, | ||
| routing_logits, | ||
| self.renormalize, | ||
| ) | ||
|
|
||
| if aiter_topK_meta_data is not None: | ||
| from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( | ||
| inject_shared_expert_weights, | ||
| ) | ||
|
|
||
| shared_weights = torch.sigmoid(shared_logits) | ||
| topk_weights, topk_ids = inject_shared_expert_weights( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems ot me this |
||
| topk_weights, | ||
| topk_ids, | ||
| topk=topk, | ||
| num_fused_shared_experts=num_fse, | ||
| shared_expert_weights=shared_weights, | ||
| ) | ||
|
|
||
| return topk_weights, topk_ids | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,8 +5,14 @@ | |
| import torch | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm._aiter_ops import rocm_aiter_ops | ||
| from vllm.distributed.eplb.eplb_state import EplbLayerState | ||
| from vllm.model_executor.layers.fused_moe.config import RoutingMethodType | ||
| from vllm.model_executor.layers.fused_moe.config import ( | ||
| RoutingMethodType, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.router.aiter_shared_routed_fused_moe_router import ( # noqa: E501 | ||
| AiterSharedRoutedFusedMoERouter, | ||
| ) | ||
| from vllm.model_executor.layers.fused_moe.router.custom_routing_router import ( | ||
| CustomRoutingRouter, | ||
| ) | ||
|
|
@@ -67,7 +73,8 @@ def create_fused_moe_router( | |
| 3. GroupedTopKRouter - if use_grouped_topk is True | ||
| 4. CustomRoutingRouter - if custom_routing_function is not None | ||
| 5. FusedTopKBiasRouter - if e_score_correction_bias is not None | ||
| 6. FusedTopKRouter - default fallback | ||
| 6. AiterSharedRoutedFusedMoERouter - if num_fused_shared_experts > 0 | ||
| 7. FusedTopKRouter - default fallback | ||
|
|
||
| Common arguments: | ||
| top_k: Number of experts to select per token | ||
|
|
@@ -199,6 +206,22 @@ def create_fused_moe_router( | |
| hash_indices_table=hash_indices_table, | ||
| ) | ||
|
|
||
| if ( | ||
| num_fused_shared_experts > 0 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens if should we just reject?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. currently we take FusedTopKRouter. Which is what happened prior as well. So I think we're okay on that front. It's not a change in behavior in the router unless the specific 3 conditions here are met
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please open a github issue to audit and guard this for future so we have a clear view of what does and does not work |
||
| and scoring_func == "softmax" | ||
| and rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() | ||
| ): | ||
| return AiterSharedRoutedFusedMoERouter( | ||
| top_k=top_k, | ||
| global_num_experts=global_num_experts, | ||
| eplb_state=eplb_state, | ||
| num_fused_shared_experts=num_fused_shared_experts, | ||
| renormalize=renormalize, | ||
| scoring_func=scoring_func, | ||
| enable_eplb=enable_eplb, | ||
| indices_type_getter=indices_type_getter, | ||
| ) | ||
|
|
||
| return FusedTopKRouter( | ||
| top_k=top_k, | ||
| global_num_experts=global_num_experts, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems unnessrary to have this attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to wait for CI to finish before pushing anything else. I'm happy to remove it. This is consistent with some other attributes that aren't used elsewhere and that was the reason for this. I thought there might be debugging or other reasons that most construction args are saved as attributes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i know, i hate all those old attrs since it makes it hard to tell what "owns" the object