Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,19 @@ def _rocm_aiter_topk_softmax_impl(
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
num_shared_experts: int = 0,
shared_expert_scoring_func: str = "",
) -> None:
from aiter import topk_softmax

topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
num_shared_experts,
shared_expert_scoring_func,
)


Expand All @@ -275,6 +283,8 @@ def _rocm_aiter_topk_softmax_fake(
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
num_shared_experts: int = 0,
shared_expert_scoring_func: str = "",
) -> None:
pass

Expand Down Expand Up @@ -1206,6 +1216,9 @@ class rocm_aiter_ops:
_MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
# TODO: Consolidate under _LINEAR_ENABLED
_TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
# Lazily probed: whether aiter.topk_softmax supports the
# num_shared_experts / shared_expert_scoring_func args (7-arg form).
_TOPK_SOFTMAX_FUSED_SIGMOID: bool | None = None

_ALL_REDUCE_MAX_SIZE: int = 8192 * 1024 * 8 * 2
_CUSTOM_ALL_REDUCE: AiterCustomAllreduceProto | None = None
Expand Down Expand Up @@ -1322,6 +1335,52 @@ def is_fused_moe_enabled(cls) -> bool:
def is_fusion_moe_shared_experts_enabled(cls) -> bool:
return cls.is_fused_moe_enabled() and cls._MOE_SHARED_EXPERTS_ENABLED

@classmethod
@if_aiter_supported
def topk_softmax_supports_fused_sigmoid(cls) -> bool:
"""Check if topk_softmax supports fused shared expert activation."""
if cls._TOPK_SOFTMAX_FUSED_SIGMOID is None:
try:
import inspect

from aiter import topk_softmax

params = inspect.signature(topk_softmax).parameters
if "num_shared_experts" in params:
cls._TOPK_SOFTMAX_FUSED_SIGMOID = True
else:
# @compile_ops wrapper loses the original signature.
# Fall back to the torch custom op schema.
import torch

schema = getattr(
getattr(torch.ops.aiter, "topk_softmax", None), "default", None
)
schema_str = str(getattr(schema, "_schema", ""))
cls._TOPK_SOFTMAX_FUSED_SIGMOID = "num_shared_experts" in schema_str
except (ImportError, ValueError):
cls._TOPK_SOFTMAX_FUSED_SIGMOID = False
return cls._TOPK_SOFTMAX_FUSED_SIGMOID

@classmethod
@if_aiter_supported
def fuse_sigmoid_in_kernel(cls, aiter_topK_meta_data: object) -> bool:
"""Whether fused shared-expert sigmoid in the topk kernel is usable.

Combines the cached static capability checks (FSE enabled, fused-moe
enabled, topk_softmax supports fused sigmoid) with the runtime
readiness check (topK meta-data buffer initialized).

``aiter_topK_meta_data`` is accepted as a parameter rather than
imported internally so callers cannot hit initialization-order
issues where the module-level global has not been set yet.
"""
return (
cls.is_fusion_moe_shared_experts_enabled()
and cls.topk_softmax_supports_fused_sigmoid()
and aiter_topK_meta_data is not None
)

@classmethod
@if_aiter_supported
def is_mla_enabled(cls) -> bool:
Expand Down Expand Up @@ -1795,9 +1854,17 @@ def topk_softmax(
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool,
num_shared_experts: int = 0,
shared_expert_scoring_func: str = "",
) -> tuple[torch.Tensor, ...]:
torch.ops.vllm.rocm_aiter_topk_softmax(
topk_weights, topk_indices, token_expert_indices, gating_output, renormalize
topk_weights,
topk_indices,
token_expert_indices,
gating_output,
renormalize,
num_shared_experts,
shared_expert_scoring_func,
)
return topk_weights, topk_indices

Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def __init__(
router_logits_dtype: torch.dtype | None = None,
gate: torch.nn.Module | None = None,
shared_experts: torch.nn.Module | None = None,
shared_expert_gate: torch.nn.Module | None = None,
routed_input_transform: torch.nn.Module | None = None,
routed_output_transform: torch.nn.Module | None = None,
apply_routed_scale_to_output: bool = False,
Expand Down Expand Up @@ -370,6 +371,8 @@ def __init__(
if n_shared_experts is not None and self.aiter_fmoe_shared_expert_enabled
else 0
)
self.shared_expert_gate = shared_expert_gate

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems unnessrary to have this attribute?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to wait for CI to finish before pushing anything else. I'm happy to remove it. This is consistent with some other attributes that aren't used elsewhere and that was the reason for this. I thought there might be debugging or other reasons that most construction args are saved as attributes.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i know, i hate all those old attrs since it makes it hard to tell what "owns" the object


if (
not self.aiter_fmoe_shared_expert_enabled
and self.num_fused_shared_experts != 0
Expand Down Expand Up @@ -608,6 +611,7 @@ def _get_quant_method() -> FusedMoEMethodBase:
router=self.router,
gate=gate,
shared_experts=shared_experts,
shared_expert_gate=self.shared_expert_gate,
quant_method=self.quant_method,
enable_dbo=self.vllm_config.parallel_config.enable_dbo,
routed_input_transform=routed_input_transform,
Expand Down
49 changes: 49 additions & 0 deletions vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,55 @@ def init_aiter_topK_meta_data(
aiter_topK_meta_data = (total_topk_weights, total_topk_ids)


def inject_shared_expert_weights(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
topk: int,
num_fused_shared_experts: int,
shared_expert_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Merge routed topk results with the shared expert buffer and inject
dynamic per-token shared expert gate values for AITER fusion.

For routers that already return the combined buffer (e.g. GroupedTopKRouter
via rocm_aiter_grouped_topk), only the dynamic weight injection is needed.
For routers that return only routed slots (e.g. FusedTopKRouter), this also
copies the routed results into the pre-allocated combined buffer.
"""
if num_fused_shared_experts == 0:
return topk_weights, topk_ids

assert aiter_topK_meta_data is not None, (
"aiter_topK_meta_data is not initialized but "
"num_fused_shared_experts > 0. Ensure init_aiter_topK_meta_data "
"is called before routing."
)

total_topk_weights, total_topk_ids = aiter_topK_meta_data

@dllehr-amd dllehr-amd May 7, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this cause a problem if aiter_topK_meta_data is None and num_fused_shared_experts > 0 ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is. I think we should ideally update this soon to have managed access of aiter_topK_meta_data to not rely on initialization from callers before using it.

For now, I've added an assert to prevent this incorrect situation.

token = topk_weights.shape[0]

assert total_topk_weights.shape[0] >= token, (
f"AITER topK meta data supports {total_topk_weights.shape[0]} "
f"tokens, but got {token} tokens."
)

total_topk_weights_slice = total_topk_weights[:token]
total_topk_ids_slice = total_topk_ids[:token]

if topk_weights.shape[1] == topk:
total_topk_weights_slice[:, :topk] = topk_weights
total_topk_ids_slice[:, :topk] = topk_ids
topk_weights = total_topk_weights_slice
topk_ids = total_topk_ids_slice

if shared_expert_weights is not None:
topk_weights[:, topk : topk + num_fused_shared_experts] = shared_expert_weights[
:token
]

return topk_weights, topk_ids


def rocm_aiter_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable

import torch

from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType,
get_routing_method_type,
)
from vllm.model_executor.layers.fused_moe.router.base_router import BaseRouter
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
dispatch_topk_softmax_func,
)


class AiterSharedRoutedFusedMoERouter(BaseRouter):
"""
ROCm AITER router for models with fused shared experts (e.g. Qwen3-MoE).

When the AITER topk_softmax kernel supports sigmoid fusion, the routing
softmax and shared-expert sigmoid are computed in a single kernel launch.
Otherwise the shared-expert weights are injected into the pre-allocated
AITER buffer via a fallback path.

Only instantiated when rocm_aiter fused-MoE is active and
num_fused_shared_experts > 0.
"""

def __init__(
self,
top_k: int,
global_num_experts: int,
eplb_state: EplbLayerState,
num_fused_shared_experts: int,
scoring_func: str = "softmax",
renormalize: bool = True,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
):
super().__init__(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)
self.renormalize = renormalize
self.scoring_func = scoring_func
self.num_fused_shared_experts = num_fused_shared_experts

@property
def routing_method_type(self) -> RoutingMethodType:
return get_routing_method_type(
scoring_func=self.scoring_func,
top_k=self.top_k,
renormalize=self.renormalize,
num_expert_group=None,
has_e_score_bias=False,
)

def _compute_routing(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
indices_type: torch.dtype | None,
*,
input_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == router_logits.size(0), (
"Number of tokens mismatch"
)

from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
aiter_topK_meta_data,
)

M = hidden_states.size(0)
topk = self.top_k
num_fse = self.num_fused_shared_experts

token_expert_indices = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)

if rocm_aiter_ops.fuse_sigmoid_in_kernel(aiter_topK_meta_data):
total_topk_weights, total_topk_ids = aiter_topK_meta_data # type: ignore[misc]
total_topk_weights_slice = total_topk_weights[:M]
topk_ids_slice = total_topk_ids[:M, :topk]

topk_func = dispatch_topk_softmax_func(use_rocm_aiter=True)
topk_func(
total_topk_weights_slice,
topk_ids_slice,
token_expert_indices,
router_logits,
self.renormalize,
num_fse,
"sigmoid",
)
return total_topk_weights_slice, total_topk_ids[:M]

routing_logits = router_logits[:, :-num_fse]
shared_logits = router_logits[:, -num_fse:]

topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(
M,
topk,
dtype=torch.int32 if indices_type is None else indices_type,
device=hidden_states.device,
)

topk_func = dispatch_topk_softmax_func(
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
)
topk_weights, topk_ids = topk_func(
topk_weights,
topk_ids,
token_expert_indices,
routing_logits,
self.renormalize,
)

if aiter_topK_meta_data is not None:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
inject_shared_expert_weights,
)

shared_weights = torch.sigmoid(shared_logits)
topk_weights, topk_ids = inject_shared_expert_weights(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems ot me this inject_shared_experts_weight function should be defined in this file

topk_weights,
topk_ids,
topk=topk,
num_fused_shared_experts=num_fse,
shared_expert_weights=shared_weights,
)

return topk_weights, topk_ids
27 changes: 25 additions & 2 deletions vllm/model_executor/layers/fused_moe/router/router_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@
import torch

import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.model_executor.layers.fused_moe.config import RoutingMethodType
from vllm.model_executor.layers.fused_moe.config import (
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.router.aiter_shared_routed_fused_moe_router import ( # noqa: E501
AiterSharedRoutedFusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.router.custom_routing_router import (
CustomRoutingRouter,
)
Expand Down Expand Up @@ -67,7 +73,8 @@ def create_fused_moe_router(
3. GroupedTopKRouter - if use_grouped_topk is True
4. CustomRoutingRouter - if custom_routing_function is not None
5. FusedTopKBiasRouter - if e_score_correction_bias is not None
6. FusedTopKRouter - default fallback
6. AiterSharedRoutedFusedMoERouter - if num_fused_shared_experts > 0
7. FusedTopKRouter - default fallback

Common arguments:
top_k: Number of experts to select per token
Expand Down Expand Up @@ -199,6 +206,22 @@ def create_fused_moe_router(
hash_indices_table=hash_indices_table,
)

if (
num_fused_shared_experts > 0

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if num_fused_shared_experts > 0 and either scoring_func != softmax or is not aiter?

should we just reject?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently we take FusedTopKRouter. Which is what happened prior as well. So I think we're okay on that front. It's not a change in behavior in the router unless the specific 3 conditions here are met

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please open a github issue to audit and guard this for future so we have a clear view of what does and does not work

and scoring_func == "softmax"
and rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
):
return AiterSharedRoutedFusedMoERouter(
top_k=top_k,
global_num_experts=global_num_experts,
eplb_state=eplb_state,
num_fused_shared_experts=num_fused_shared_experts,
renormalize=renormalize,
scoring_func=scoring_func,
enable_eplb=enable_eplb,
indices_type_getter=indices_type_getter,
)

return FusedTopKRouter(
top_k=top_k,
global_num_experts=global_num_experts,
Expand Down
Loading
Loading