Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 49 additions & 28 deletions vllm/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,34 +120,46 @@ def _rocm_aiter_fused_moe_impl(
intermediate_pad: int = 0,
bias1: torch.Tensor | None = None,
bias2: torch.Tensor | None = None,
) -> torch.Tensor:
moe_buf: torch.Tensor | None = None,
) -> None:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe

activation = ActivationType(activation_method)
quant_type = QuantType(quant_method)

return fused_moe(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
expert_mask,
activation,
quant_type,
doweight_stage1,
w1_scale,
w2_scale,
a1_scale,
a2_scale,
num_local_tokens=num_local_tokens,
dtype=output_dtype,
hidden_pad=hidden_pad,
intermediate_pad=intermediate_pad,
bias1=bias1,
bias2=bias2,
)
try:
from aiter.fused_moe import output_buffer_override
ctx = output_buffer_override(moe_buf) if moe_buf is not None else None
except ImportError:
ctx = None
Comment on lines +131 to +135
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Importing and checking for output_buffer_override inside the custom op implementation will incur significant Python overhead on every call if the user has an older version of aiter where this function is missing. Since this is in the hot path of MoE execution, this check should be cached at the module level or within rocm_aiter_ops to avoid repeated ImportError exceptions.


from contextlib import nullcontext
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Importing nullcontext inside the function adds unnecessary overhead to the hot path. Please move this import to the top of the file.

with ctx if ctx is not None else nullcontext():
result = fused_moe(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
expert_mask,
activation,
quant_type,
doweight_stage1,
w1_scale,
w2_scale,
a1_scale,
a2_scale,
num_local_tokens=num_local_tokens,
dtype=output_dtype,
hidden_pad=hidden_pad,
intermediate_pad=intermediate_pad,
bias1=bias1,
bias2=bias2,
)

if moe_buf is not None and result is not moe_buf:
moe_buf.copy_(result)


def _rocm_aiter_fused_moe_fake(
Expand All @@ -170,10 +182,9 @@ def _rocm_aiter_fused_moe_fake(
intermediate_pad: int = 0,
bias1: torch.Tensor | None = None,
bias2: torch.Tensor | None = None,
) -> torch.Tensor:
if output_dtype is not None:
return torch.empty_like(hidden_states, dtype=output_dtype)
return torch.empty_like(hidden_states)
moe_buf: torch.Tensor | None = None,
) -> None:
pass


def _rocm_aiter_asm_moe_tkw1_impl(
Expand Down Expand Up @@ -1372,7 +1383,7 @@ def register_ops_once() -> None:
direct_register_custom_op(
op_name="rocm_aiter_fused_moe",
op_func=_rocm_aiter_fused_moe_impl,
mutates_args=[],
mutates_args=["moe_buf"],
fake_impl=_rocm_aiter_fused_moe_fake,
dispatch_key=current_platform.dispatch_key,
)
Expand Down Expand Up @@ -1690,8 +1701,16 @@ def fused_moe(
intermediate_pad: int = 0,
bias1: torch.Tensor | None = None,
bias2: torch.Tensor | None = None,
moe_buf: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_fused_moe(
if moe_buf is None:
M = topk_ids.shape[0]
model_dim = w2.shape[1]
dtype = output_dtype if output_dtype is not None else hidden_states.dtype
moe_buf = torch.empty(
(M, model_dim), dtype=dtype, device=hidden_states.device
)
torch.ops.vllm.rocm_aiter_fused_moe(
hidden_states,
w1,
w2,
Expand All @@ -1711,7 +1730,9 @@ def fused_moe(
intermediate_pad,
bias1,
bias2,
moe_buf,
)
return moe_buf

@staticmethod
def asm_moe_tkw1(
Expand Down
12 changes: 12 additions & 0 deletions vllm/model_executor/layers/fused_moe/modular_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,7 @@ def _fused_experts(
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
expert_tokens_meta: ExpertTokensMetadata | None,
final_output: torch.Tensor | None = None,
) -> torch.Tensor:
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids
Expand Down Expand Up @@ -1243,6 +1244,16 @@ def _fused_experts(
activation,
)

# When expert workspaces are both empty (e.g. AiterExperts manages
# its own buffers), write directly into the caller's output tensor
# to avoid a redundant copy in the finalize step.
if (
final_output is not None
and prod(workspace13.shape) == 0
and prod(workspace2.shape) == 0
):
fused_out = final_output
Comment on lines +1250 to +1255
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This optimization assumes that any expert implementation with empty workspaces is safe to write directly into final_output. While this is true for AiterExperts (which uses TopKWeightAndReduceNoOP), it might be unsafe for other modular experts if they don't handle the case where output aliases fused_expert_output. Additionally, _allocate_buffers was already called at line 1233, which might have reserved workspace memory for fused_out that is now being bypassed. Consider making this optimization more explicit or ensuring that _allocate_buffers is aware of the final_output override to avoid redundant workspace reservations.


self.fused_experts.apply(
output=fused_out,
hidden_states=a1q,
Expand Down Expand Up @@ -1403,6 +1414,7 @@ def apply(
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_tokens_meta=expert_tokens_meta,
final_output=output,
)

return self._finalize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def rocm_aiter_fused_experts(
a1q_scale: torch.Tensor | None = None,
num_local_tokens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = None,
moe_buf: torch.Tensor | None = None,
) -> torch.Tensor:
"""ROCm AITER fused MoE expert computation."""
if quant_config is None:
Expand Down Expand Up @@ -309,6 +310,7 @@ def rocm_aiter_fused_experts(
intermediate_pad=intermediate_pad,
bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None,
bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None,
moe_buf=moe_buf,
)


Expand Down Expand Up @@ -436,5 +438,7 @@ def apply(
a1q_scale=a1q_scale,
num_local_tokens=num_local_tokens,
output_dtype=output.dtype,
moe_buf=output,
)
output.copy_(result)
if result is not output:
output.copy_(result)
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def apply(
if output is None:
return fused_expert_output

if output is fused_expert_output:
return output

# MoEPrepareAndFinalizeNoDPEPModular needs the output to be in the `output`
# tensor.
assert output.size() == fused_expert_output.size(), (
Expand Down
Loading