From bc497593e4ba2982f4b3ffd8b830c87050c497b5 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Thu, 9 Apr 2026 02:14:43 -0500 Subject: [PATCH 1/2] Allow callers to pass pre-allocated moe_buf to avoid output copy Add an optional `moe_buf` parameter through the moe_sorting and fused_moe call chain. When provided, the sorting kernel writes directly into the caller's buffer instead of allocating a new one, eliminating a redundant copy on the output path. Made-with: Cursor --- aiter/fused_moe.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index d8fca4b4c9..de92f03e38 100755 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -35,6 +35,7 @@ def _moe_sorting_impl( num_local_tokens, dispatch_policy, use_opus, + moe_buf=None, ): device = topk_ids.device M, topk = topk_ids.shape @@ -47,7 +48,8 @@ def _moe_sorting_impl( ) sorted_expert_ids = torch.empty(max_num_m_blocks, dtype=dtypes.i32, device=device) num_valid_ids = torch.empty(2, dtype=dtypes.i32, device=device) - moe_buf = torch.empty((M, model_dim), dtype=moebuf_dtype, device=device) + if moe_buf is None: + moe_buf = torch.empty((M, model_dim), dtype=moebuf_dtype, device=device) fwd_fn = aiter.moe_sorting_opus_fwd if use_opus else aiter.moe_sorting_fwd fwd_fn( @@ -77,6 +79,7 @@ def moe_sorting( expert_mask=None, num_local_tokens=None, dispatch_policy=0, + moe_buf=None, ): try: return _moe_sorting_impl( @@ -90,6 +93,7 @@ def moe_sorting( num_local_tokens, dispatch_policy, use_opus=_USE_OPUS_MOE_SORTING, + moe_buf=moe_buf, ) except Exception as e: logger.error(f"Error in moe_sorting: {e}") @@ -141,6 +145,7 @@ def fused_moe( bias1=None, bias2=None, splitk=0, + moe_buf=None, ): if not block_size_M: block_size_M = -1 @@ -166,6 +171,7 @@ def fused_moe( intermediate_pad=intermediate_pad, bias1=bias1, bias2=bias2, + moe_buf=moe_buf, ) @@ -193,7 +199,10 @@ def fused_moe_fake( intermediate_pad: int = 0, bias1: Optional[torch.Tensor] = None, bias2: Optional[torch.Tensor] = None, + moe_buf: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if moe_buf is not None: + return moe_buf device = topk_ids.device M, topk = topk_ids.shape dtype = hidden_states.dtype if dtype is None else dtype @@ -227,6 +236,7 @@ def fused_moe_( intermediate_pad: int = 0, bias1: Optional[torch.Tensor] = None, bias2: Optional[torch.Tensor] = None, + moe_buf: Optional[torch.Tensor] = None, ) -> torch.Tensor: # We do such convert since custom_op schema restriction on block_size_M, and Enum type activation = ActivationType(activation) @@ -305,6 +315,7 @@ def fused_moe_( expert_mask, num_local_tokens, moe_sorting_dispatch_policy, + moe_buf=moe_buf, ) if metadata.run_1stage: From 60a459c3d1c6046f4f1db0bfb161e1c5d6147dd2 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Fri, 10 Apr 2026 04:04:52 -0500 Subject: [PATCH 2/2] Add moe_buf pass-through test to existing test_moe_sorting Made-with: Cursor --- op_tests/test_moe_sorting.py | 54 ++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/op_tests/test_moe_sorting.py b/op_tests/test_moe_sorting.py index 38b8b9e1cc..3c754f4e73 100644 --- a/op_tests/test_moe_sorting.py +++ b/op_tests/test_moe_sorting.py @@ -173,6 +173,60 @@ def test_moe_sorting( sorted_expert_ids_b[expert_mask], msg="sorted_expert_ids", ) + + # Verify moe_buf pass-through: pre-allocated buffer should be reused + pre_buf = torch.empty((token, model_dim), dtype=dtype, device="cuda") + pre_buf_ptr = pre_buf.data_ptr() + ( + ( + sorted_ids_c, + sorted_weights_c, + sorted_expert_ids_c, + num_tokens_post_padded_c, + moe_buf_c, + ), + _, + ) = run_perftest( + moe_sorting, + topk_ids, + topk_weights, + E, + model_dim, + dtype, + BLOCK_SIZE_M, + expert_mask, + num_local_tokens, + dispatch_policy, + moe_buf=pre_buf, + num_warmup=1, + num_iters=2, + ) + assert ( + moe_buf_c.data_ptr() == pre_buf_ptr + ), "moe_buf pass-through: buffer not reused" + checkAllclose( + num_tokens_post_padded_a, + num_tokens_post_padded_c, + atol=0, + msg="moe_buf pass-through: num_tokens_post_padded", + ) + checkAllclose( + sorted_ids_a[:num_tokens_post_pad], + sorted_ids_c[:num_tokens_post_pad], + atol=0, + msg="moe_buf pass-through: sorted_ids", + ) + checkAllclose( + sorted_weights_a[mask], + sorted_weights_c[mask], + msg="moe_buf pass-through: sorted_weights", + ) + checkAllclose( + sorted_expert_ids_a[expert_mask], + sorted_expert_ids_c[expert_mask], + msg="moe_buf pass-through: sorted_expert_ids", + ) + return {"us": avg_b}