-
-
Notifications
You must be signed in to change notification settings - Fork 16.6k
[ROCm] Eliminate redundant MoE buffer copies in AITER fused experts #41020
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -120,34 +120,46 @@ def _rocm_aiter_fused_moe_impl( | |
| intermediate_pad: int = 0, | ||
| bias1: torch.Tensor | None = None, | ||
| bias2: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| moe_buf: torch.Tensor | None = None, | ||
| ) -> None: | ||
| from aiter import ActivationType, QuantType | ||
| from aiter.fused_moe import fused_moe | ||
|
|
||
| activation = ActivationType(activation_method) | ||
| quant_type = QuantType(quant_method) | ||
|
|
||
| return fused_moe( | ||
| hidden_states, | ||
| w1, | ||
| w2, | ||
| topk_weight, | ||
| topk_ids, | ||
| expert_mask, | ||
| activation, | ||
| quant_type, | ||
| doweight_stage1, | ||
| w1_scale, | ||
| w2_scale, | ||
| a1_scale, | ||
| a2_scale, | ||
| num_local_tokens=num_local_tokens, | ||
| dtype=output_dtype, | ||
| hidden_pad=hidden_pad, | ||
| intermediate_pad=intermediate_pad, | ||
| bias1=bias1, | ||
| bias2=bias2, | ||
| ) | ||
| try: | ||
| from aiter.fused_moe import output_buffer_override | ||
| ctx = output_buffer_override(moe_buf) if moe_buf is not None else None | ||
| except ImportError: | ||
| ctx = None | ||
|
|
||
| from contextlib import nullcontext | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| with ctx if ctx is not None else nullcontext(): | ||
| result = fused_moe( | ||
| hidden_states, | ||
| w1, | ||
| w2, | ||
| topk_weight, | ||
| topk_ids, | ||
| expert_mask, | ||
| activation, | ||
| quant_type, | ||
| doweight_stage1, | ||
| w1_scale, | ||
| w2_scale, | ||
| a1_scale, | ||
| a2_scale, | ||
| num_local_tokens=num_local_tokens, | ||
| dtype=output_dtype, | ||
| hidden_pad=hidden_pad, | ||
| intermediate_pad=intermediate_pad, | ||
| bias1=bias1, | ||
| bias2=bias2, | ||
| ) | ||
|
|
||
| if moe_buf is not None and result is not moe_buf: | ||
| moe_buf.copy_(result) | ||
|
|
||
|
|
||
| def _rocm_aiter_fused_moe_fake( | ||
|
|
@@ -170,10 +182,9 @@ def _rocm_aiter_fused_moe_fake( | |
| intermediate_pad: int = 0, | ||
| bias1: torch.Tensor | None = None, | ||
| bias2: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| if output_dtype is not None: | ||
| return torch.empty_like(hidden_states, dtype=output_dtype) | ||
| return torch.empty_like(hidden_states) | ||
| moe_buf: torch.Tensor | None = None, | ||
| ) -> None: | ||
| pass | ||
|
|
||
|
|
||
| def _rocm_aiter_asm_moe_tkw1_impl( | ||
|
|
@@ -1372,7 +1383,7 @@ def register_ops_once() -> None: | |
| direct_register_custom_op( | ||
| op_name="rocm_aiter_fused_moe", | ||
| op_func=_rocm_aiter_fused_moe_impl, | ||
| mutates_args=[], | ||
| mutates_args=["moe_buf"], | ||
| fake_impl=_rocm_aiter_fused_moe_fake, | ||
| dispatch_key=current_platform.dispatch_key, | ||
| ) | ||
|
|
@@ -1690,8 +1701,16 @@ def fused_moe( | |
| intermediate_pad: int = 0, | ||
| bias1: torch.Tensor | None = None, | ||
| bias2: torch.Tensor | None = None, | ||
| moe_buf: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| return torch.ops.vllm.rocm_aiter_fused_moe( | ||
| if moe_buf is None: | ||
| M = topk_ids.shape[0] | ||
| model_dim = w2.shape[1] | ||
| dtype = output_dtype if output_dtype is not None else hidden_states.dtype | ||
| moe_buf = torch.empty( | ||
| (M, model_dim), dtype=dtype, device=hidden_states.device | ||
| ) | ||
| torch.ops.vllm.rocm_aiter_fused_moe( | ||
| hidden_states, | ||
| w1, | ||
| w2, | ||
|
|
@@ -1711,7 +1730,9 @@ def fused_moe( | |
| intermediate_pad, | ||
| bias1, | ||
| bias2, | ||
| moe_buf, | ||
| ) | ||
| return moe_buf | ||
|
|
||
| @staticmethod | ||
| def asm_moe_tkw1( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1215,6 +1215,7 @@ def _fused_experts( | |
| expert_map: torch.Tensor | None, | ||
| apply_router_weight_on_input: bool, | ||
| expert_tokens_meta: ExpertTokensMetadata | None, | ||
| final_output: torch.Tensor | None = None, | ||
| ) -> torch.Tensor: | ||
| _, M_full, N, K, top_k = self.fused_experts.moe_problem_size( | ||
| a1q, w1, w2, topk_ids | ||
|
|
@@ -1243,6 +1244,16 @@ def _fused_experts( | |
| activation, | ||
| ) | ||
|
|
||
| # When expert workspaces are both empty (e.g. AiterExperts manages | ||
| # its own buffers), write directly into the caller's output tensor | ||
| # to avoid a redundant copy in the finalize step. | ||
| if ( | ||
| final_output is not None | ||
| and prod(workspace13.shape) == 0 | ||
| and prod(workspace2.shape) == 0 | ||
| ): | ||
| fused_out = final_output | ||
|
Comment on lines
+1250
to
+1255
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This optimization assumes that any expert implementation with empty workspaces is safe to write directly into |
||
|
|
||
| self.fused_experts.apply( | ||
| output=fused_out, | ||
| hidden_states=a1q, | ||
|
|
@@ -1403,6 +1414,7 @@ def apply( | |
| expert_map=expert_map, | ||
| apply_router_weight_on_input=apply_router_weight_on_input, | ||
| expert_tokens_meta=expert_tokens_meta, | ||
| final_output=output, | ||
| ) | ||
|
|
||
| return self._finalize( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Importing and checking for
output_buffer_overrideinside the custom op implementation will incur significant Python overhead on every call if the user has an older version ofaiterwhere this function is missing. Since this is in the hot path of MoE execution, this check should be cached at the module level or withinrocm_aiter_opsto avoid repeatedImportErrorexceptions.