Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def _moe_sorting_impl(
num_local_tokens,
dispatch_policy,
use_opus,
moe_buf=None,
):
device = topk_ids.device
M, topk = topk_ids.shape
Expand All @@ -47,7 +48,8 @@ def _moe_sorting_impl(
)
sorted_expert_ids = torch.empty(max_num_m_blocks, dtype=dtypes.i32, device=device)
num_valid_ids = torch.empty(2, dtype=dtypes.i32, device=device)
moe_buf = torch.empty((M, model_dim), dtype=moebuf_dtype, device=device)
if moe_buf is None:
moe_buf = torch.empty((M, model_dim), dtype=moebuf_dtype, device=device)

fwd_fn = aiter.moe_sorting_opus_fwd if use_opus else aiter.moe_sorting_fwd
fwd_fn(
Expand Down Expand Up @@ -77,6 +79,7 @@ def moe_sorting(
expert_mask=None,
num_local_tokens=None,
dispatch_policy=0,
moe_buf=None,
):
try:
return _moe_sorting_impl(
Expand All @@ -90,6 +93,7 @@ def moe_sorting(
num_local_tokens,
dispatch_policy,
use_opus=_USE_OPUS_MOE_SORTING,
moe_buf=moe_buf,
)
except Exception as e:
logger.error(f"Error in moe_sorting: {e}")
Expand Down Expand Up @@ -141,6 +145,7 @@ def fused_moe(
bias1=None,
bias2=None,
splitk=0,
moe_buf=None,
):
if not block_size_M:
block_size_M = -1
Expand All @@ -166,6 +171,7 @@ def fused_moe(
intermediate_pad=intermediate_pad,
bias1=bias1,
bias2=bias2,
moe_buf=moe_buf,
)


Expand Down Expand Up @@ -193,7 +199,10 @@ def fused_moe_fake(
intermediate_pad: int = 0,
bias1: Optional[torch.Tensor] = None,
bias2: Optional[torch.Tensor] = None,
moe_buf: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if moe_buf is not None:
return moe_buf
device = topk_ids.device
M, topk = topk_ids.shape
dtype = hidden_states.dtype if dtype is None else dtype
Expand Down Expand Up @@ -227,6 +236,7 @@ def fused_moe_(
intermediate_pad: int = 0,
bias1: Optional[torch.Tensor] = None,
bias2: Optional[torch.Tensor] = None,
moe_buf: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# We do such convert since custom_op schema restriction on block_size_M, and Enum type
activation = ActivationType(activation)
Expand Down Expand Up @@ -305,6 +315,7 @@ def fused_moe_(
expert_mask,
num_local_tokens,
moe_sorting_dispatch_policy,
moe_buf=moe_buf,
)

if metadata.run_1stage:
Expand Down
54 changes: 54 additions & 0 deletions op_tests/test_moe_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,60 @@ def test_moe_sorting(
sorted_expert_ids_b[expert_mask],
msg="sorted_expert_ids",
)

# Verify moe_buf pass-through: pre-allocated buffer should be reused
pre_buf = torch.empty((token, model_dim), dtype=dtype, device="cuda")
pre_buf_ptr = pre_buf.data_ptr()
(
(
sorted_ids_c,
sorted_weights_c,
sorted_expert_ids_c,
num_tokens_post_padded_c,
moe_buf_c,
),
_,
) = run_perftest(
moe_sorting,
topk_ids,
topk_weights,
E,
model_dim,
dtype,
BLOCK_SIZE_M,
expert_mask,
num_local_tokens,
dispatch_policy,
moe_buf=pre_buf,
num_warmup=1,
num_iters=2,
)
assert (
moe_buf_c.data_ptr() == pre_buf_ptr
), "moe_buf pass-through: buffer not reused"
checkAllclose(
num_tokens_post_padded_a,
num_tokens_post_padded_c,
atol=0,
msg="moe_buf pass-through: num_tokens_post_padded",
)
checkAllclose(
sorted_ids_a[:num_tokens_post_pad],
sorted_ids_c[:num_tokens_post_pad],
atol=0,
msg="moe_buf pass-through: sorted_ids",
)
checkAllclose(
sorted_weights_a[mask],
sorted_weights_c[mask],
msg="moe_buf pass-through: sorted_weights",
)
checkAllclose(
sorted_expert_ids_a[expert_mask],
sorted_expert_ids_c[expert_mask],
msg="moe_buf pass-through: sorted_expert_ids",
)

return {"us": avg_b}


Expand Down
Loading