Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,7 @@ struct MoeSortingMultiPhaseKernel_P0
void* p_expert_mesh; // [expert, tokens]
index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
// used for ws/LDS calculation
index_t num_experts;
index_t mesh_stride; // mesh_stride for p_expert_mesh
mdiv topk_mdiv;
};
Expand All @@ -1597,6 +1598,7 @@ struct MoeSortingMultiPhaseKernel_P0
k.p_local_tokens = h.p_local_tokens;
k.p_expert_mesh = h.p_ws;
k.tokens = h.tokens;
k.num_experts = h.num_experts;
k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
return k;
Expand Down Expand Up @@ -1655,14 +1657,18 @@ struct MoeSortingMultiPhaseKernel_P0
IndexType eid = x[j.value]; // ext_vector_type must use int to []
uint32_t curr_token_id, curr_topk_id;
kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
if constexpr(Problem::LocalToken)
if(eid < kargs.num_experts)
{
if(static_cast<index_t>(curr_token_id) < tokens)
if constexpr(Problem::LocalToken)
{
if(static_cast<index_t>(curr_token_id) < tokens)
p_expert_mesh[eid * mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
}
else
p_expert_mesh[eid * mesh_stride + curr_token_id] =
(curr_topk_id + 1) & 0xffff;
}
else
p_expert_mesh[eid * mesh_stride + curr_token_id] = (curr_topk_id + 1) & 0xffff;
});
}
}
Expand Down
Loading