diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 0250fbfac709..505b6f52dfde 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -120,34 +120,46 @@ def _rocm_aiter_fused_moe_impl( intermediate_pad: int = 0, bias1: torch.Tensor | None = None, bias2: torch.Tensor | None = None, -) -> torch.Tensor: + moe_buf: torch.Tensor | None = None, +) -> None: from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe activation = ActivationType(activation_method) quant_type = QuantType(quant_method) - return fused_moe( - hidden_states, - w1, - w2, - topk_weight, - topk_ids, - expert_mask, - activation, - quant_type, - doweight_stage1, - w1_scale, - w2_scale, - a1_scale, - a2_scale, - num_local_tokens=num_local_tokens, - dtype=output_dtype, - hidden_pad=hidden_pad, - intermediate_pad=intermediate_pad, - bias1=bias1, - bias2=bias2, - ) + try: + from aiter.fused_moe import output_buffer_override + ctx = output_buffer_override(moe_buf) if moe_buf is not None else None + except ImportError: + ctx = None + + from contextlib import nullcontext + with ctx if ctx is not None else nullcontext(): + result = fused_moe( + hidden_states, + w1, + w2, + topk_weight, + topk_ids, + expert_mask, + activation, + quant_type, + doweight_stage1, + w1_scale, + w2_scale, + a1_scale, + a2_scale, + num_local_tokens=num_local_tokens, + dtype=output_dtype, + hidden_pad=hidden_pad, + intermediate_pad=intermediate_pad, + bias1=bias1, + bias2=bias2, + ) + + if moe_buf is not None and result is not moe_buf: + moe_buf.copy_(result) def _rocm_aiter_fused_moe_fake( @@ -170,10 +182,9 @@ def _rocm_aiter_fused_moe_fake( intermediate_pad: int = 0, bias1: torch.Tensor | None = None, bias2: torch.Tensor | None = None, -) -> torch.Tensor: - if output_dtype is not None: - return torch.empty_like(hidden_states, dtype=output_dtype) - return torch.empty_like(hidden_states) + moe_buf: torch.Tensor | None = None, +) -> None: + pass def _rocm_aiter_asm_moe_tkw1_impl( @@ -1372,7 +1383,7 @@ def register_ops_once() -> None: direct_register_custom_op( op_name="rocm_aiter_fused_moe", op_func=_rocm_aiter_fused_moe_impl, - mutates_args=[], + mutates_args=["moe_buf"], fake_impl=_rocm_aiter_fused_moe_fake, dispatch_key=current_platform.dispatch_key, ) @@ -1690,8 +1701,16 @@ def fused_moe( intermediate_pad: int = 0, bias1: torch.Tensor | None = None, bias2: torch.Tensor | None = None, + moe_buf: torch.Tensor | None = None, ) -> torch.Tensor: - return torch.ops.vllm.rocm_aiter_fused_moe( + if moe_buf is None: + M = topk_ids.shape[0] + model_dim = w2.shape[1] + dtype = output_dtype if output_dtype is not None else hidden_states.dtype + moe_buf = torch.empty( + (M, model_dim), dtype=dtype, device=hidden_states.device + ) + torch.ops.vllm.rocm_aiter_fused_moe( hidden_states, w1, w2, @@ -1711,7 +1730,9 @@ def fused_moe( intermediate_pad, bias1, bias2, + moe_buf, ) + return moe_buf @staticmethod def asm_moe_tkw1( diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index b0f967085ae4..9a245bf8ae5d 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -1215,6 +1215,7 @@ def _fused_experts( expert_map: torch.Tensor | None, apply_router_weight_on_input: bool, expert_tokens_meta: ExpertTokensMetadata | None, + final_output: torch.Tensor | None = None, ) -> torch.Tensor: _, M_full, N, K, top_k = self.fused_experts.moe_problem_size( a1q, w1, w2, topk_ids @@ -1243,6 +1244,16 @@ def _fused_experts( activation, ) + # When expert workspaces are both empty (e.g. AiterExperts manages + # its own buffers), write directly into the caller's output tensor + # to avoid a redundant copy in the finalize step. + if ( + final_output is not None + and prod(workspace13.shape) == 0 + and prod(workspace2.shape) == 0 + ): + fused_out = final_output + self.fused_experts.apply( output=fused_out, hidden_states=a1q, @@ -1403,6 +1414,7 @@ def apply( expert_map=expert_map, apply_router_weight_on_input=apply_router_weight_on_input, expert_tokens_meta=expert_tokens_meta, + final_output=output, ) return self._finalize( diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 495b9daaff45..639a61230405 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -195,6 +195,7 @@ def rocm_aiter_fused_experts( a1q_scale: torch.Tensor | None = None, num_local_tokens: torch.Tensor | None = None, output_dtype: torch.dtype | None = None, + moe_buf: torch.Tensor | None = None, ) -> torch.Tensor: """ROCm AITER fused MoE expert computation.""" if quant_config is None: @@ -309,6 +310,7 @@ def rocm_aiter_fused_experts( intermediate_pad=intermediate_pad, bias1=quant_config.w1_bias if quant_config.use_mxfp4_w4a16 else None, bias2=quant_config.w2_bias if quant_config.use_mxfp4_w4a16 else None, + moe_buf=moe_buf, ) @@ -436,5 +438,7 @@ def apply( a1q_scale=a1q_scale, num_local_tokens=num_local_tokens, output_dtype=output.dtype, + moe_buf=output, ) - output.copy_(result) + if result is not output: + output.copy_(result) diff --git a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py index 4cebe608a6b4..b73b86411f87 100644 --- a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py +++ b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py @@ -62,6 +62,9 @@ def apply( if output is None: return fused_expert_output + if output is fused_expert_output: + return output + # MoEPrepareAndFinalizeNoDPEPModular needs the output to be in the `output` # tensor. assert output.size() == fused_expert_output.size(), (