diff --git a/CMakeLists.txt b/CMakeLists.txt index c4a7a2c31..070862634 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -216,6 +216,7 @@ define_gpu_extension_target( set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" + "csrc/moe/grouped_topk.cpp" "csrc/moe/moe_align_sum_kernels.cpp") message(STATUS "Enabling moe extension.") diff --git a/benchmark/benchmark_grouped_topk.py b/benchmark/benchmark_grouped_topk.py new file mode 100644 index 000000000..7d7a071e2 --- /dev/null +++ b/benchmark/benchmark_grouped_topk.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import itertools +from argparse import ArgumentParser +from typing import Optional + +import torch +import triton + +from tests.ops.grouped_topk import grouped_topk, grouped_topk_native + +dpcpp_device = torch.device("xpu") + + +@torch.compile +def grouped_topk_compile( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.size(0) == gating_output.size(0), ( + "Number of tokens mismatch") + if scoring_func == "softmax": + gating_output = gating_output.to(torch.float32) + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + num_token = scores.shape[0] + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + e_score_correction_bias = e_score_correction_bias.to(torch.float32) + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) + else: + group_scores = (scores.view(num_token, num_expert_group, + -1).max(dim=-1).values) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, + sorted=True)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = (group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, -1)) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=True)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=True) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +num_tokens_range = [1, 2, 64, 256] +num_experts_range = [64, 128, 256] +topk_range = [4, 6, 8] +renormalize_range = [True, False] +num_expert_group_range = [8] +topk_group_range = [4, 6, 8] +scoring_func_range = ["sigmoid", "softmax"] +has_bias_range = [True, False] +configs = list( + itertools.product( + num_tokens_range, + num_experts_range, + topk_range, + renormalize_range, + num_expert_group_range, + topk_group_range, + scoring_func_range, + has_bias_range, + )) + + +def get_benchmark(): + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=[ + "num_tokens", "num_experts", "topk", "renormalize", + "num_expert_group", "topk_group", "scoring_func", "has_bias" + ], + x_vals=[tuple(_) for _ in configs], + line_arg="provider", + line_vals=["vllm", "native", "compile"], + line_names=["vllm", "native", "compile"], + styles=[("blue", "-"), ("green", "-"), ("orange", "-"), + ("red", "-")], + ylabel="us", + plot_name="grouped_topk-perf", + args={}, + )) + def benchmark( + num_tokens: int, + num_experts: int, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + scoring_func: str = "softmax", + has_bias: bool = False, + provider: str = "vllm", + ): + dtype = torch.float16 + torch.set_default_device("xpu") + + gating_output = torch.randn(num_tokens, + num_experts, + device=dpcpp_device).to(dtype) + hidden_states = torch.zeros(num_tokens, + num_experts, + device=dpcpp_device).to(dtype) + + bias = None + if has_bias: + if has_bias and scoring_func == "sigmoid" \ + and dtype is not torch.float32: + # using a bias of bigger number to avoid Low-precision + bias = torch.arange(1, + num_experts + 1).to(dpcpp_device).to(dtype) + else: + bias = torch.randn(num_experts, device=dpcpp_device).to(dtype) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "vllm": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: grouped_topk( + hidden_states, + gating_output, + topk, + renormalize, + num_expert_group, + topk_group, + scoring_func=scoring_func, + e_score_correction_bias=bias, + ), + quantiles=quantiles, + ) + elif provider == "native": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: grouped_topk_native( + hidden_states, + gating_output, + topk, + renormalize, + num_expert_group, + topk_group, + scoring_func=scoring_func, + e_score_correction_bias=bias, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: grouped_topk_compile( + hidden_states, + gating_output, + topk, + renormalize, + num_expert_group, + topk_group, + scoring_func=scoring_func, + e_score_correction_bias=bias, + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + parser = ArgumentParser(description="Benchmark the grouped topk kernel.") + parser.add_argument( + "--save-path", + type=str, + default="./configs/grouped_topk/", + help="Path to save grouped_topk benchmark results", + ) + + args = parser.parse_args() + + benchmark = get_benchmark() + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/csrc/moe/grouped_topk.cpp b/csrc/moe/grouped_topk.cpp new file mode 100644 index 000000000..f9c4abd02 --- /dev/null +++ b/csrc/moe/grouped_topk.cpp @@ -0,0 +1,434 @@ +#include + +#include "../utils.h" +#include "../dispatch_utils.h" + +namespace vllm { +namespace GroupedTopKImpl { + +enum class ScoringFunc { + DEFAULT = 0, + SOFTMAX = 1, + SIGMOID = 2, +}; + +template +struct Fused_Grouped_Topk { + static constexpr int sub_group_size = 32; + static constexpr int max_group_size = 1024; + static constexpr int malloc_per_item = MAX_EXPERT_GROUPS; + static constexpr float kNegInfinity = INFINITY * -1; + + Fused_Grouped_Topk(float* topk_weights, int* topk_ids, const T* gating_output, + const T* e_score_correction_bias, + const ScoringFunc scoring_mode, const bool renormalize, + const int tokens, const int experts, const int top_k, + const int num_expert_group, const int topk_group) + : topk_weights(topk_weights), + topk_ids(topk_ids), + gating_output(gating_output), + e_score_correction_bias(e_score_correction_bias), + scoring_mode(scoring_mode), + renormalize(renormalize), + tokens(tokens), + experts(experts), + top_k(top_k), + num_expert_group(num_expert_group), + topk_group(topk_group) {} + + static inline sycl::nd_range<3> get_nd_range(const int tokens, + const int experts) { + int calc_per_item = (experts + sub_group_size - 1) / sub_group_size; + int group_size = (experts + calc_per_item - 1) / calc_per_item; + group_size = group_size < sub_group_size ? sub_group_size : group_size; + group_size = group_size < max_group_size ? group_size : max_group_size; + int sub_groups_per_group = + (group_size + sub_group_size - 1) / sub_group_size; + group_size = sub_groups_per_group * sub_group_size; + int global_size = + (tokens + sub_groups_per_group - 1) / sub_groups_per_group; + + sycl::range<3> local(1, 1, group_size); + sycl::range<3> global(1, 1, global_size); + return sycl::nd_range<3>(global * local, local); + } + + static inline float Sigmoid(float x) { + return 1.0f / (1.0f + sycl::native::exp(-x)); + } + + [[sycl::reqd_sub_group_size(sub_group_size)]] void operator()( + sycl::nd_item<3> item) const { + int group_id = item.get_group_linear_id(); + int local_range = item.get_local_range(2); + int sub_groups_per_group = local_range / sub_group_size; + int calc_per_item = (experts + sub_group_size - 1) / sub_group_size; + + int experts_per_group = experts / num_expert_group; + + sycl::sub_group sg = item.get_sub_group(); + int sg_id = sg.get_group_id(); + int sg_local_id = sg.get_local_id(); + + int tid = group_id * sub_groups_per_group + sg_id; + + if (tid >= tokens) { + return; // Out of bounds + } + + T load_elems[malloc_per_item]; + int local_idx[malloc_per_item]; + T bias[malloc_per_item]; + + int start_offset = sg_local_id * calc_per_item; + int local_num = calc_per_item; + + if (start_offset + local_num >= experts) { + local_num = experts - start_offset; + if (local_num < 0) { + local_num = 0; // No elements to process + } + } + + for (int e = 0; e < calc_per_item; ++e) { + load_elems[e] = kNegInfinity; + local_idx[e] = -1; + bias[e] = 0.0f; // Initialize bias to zero + } + + for (int e = 0; e < local_num; ++e) { + load_elems[e] = gating_output[tid * experts + start_offset + e]; + } + + float local_elems[malloc_per_item]; + + for (int e = 0; e < local_num; ++e) { + local_elems[e] = load_elems[e]; + local_idx[e] = start_offset + e; + } + + if (scoring_mode == ScoringFunc::SOFTMAX) { + float softmax_max = kNegInfinity; + for (int e = 0; e < local_num; ++e) { + softmax_max = + (softmax_max > local_elems[e]) ? softmax_max : local_elems[e]; + } + for (int offset = sub_group_size / 2; offset > 0; offset /= 2) { + float other_val = sycl::permute_group_by_xor(sg, softmax_max, offset); + softmax_max = (softmax_max > other_val) ? softmax_max : other_val; + } + float softmax_sum = 0.0f; + for (int e = 0; e < local_num; ++e) { + float s = local_elems[e]; + softmax_sum += sycl::native::exp(s - softmax_max); + } + for (int offset = sub_group_size / 2; offset > 0; offset /= 2) { + float other_val = sycl::permute_group_by_xor(sg, softmax_sum, offset); + softmax_sum += other_val; + } + for (int e = 0; e < local_num; ++e) { + float s = local_elems[e]; + local_elems[e] = sycl::native::exp(s - softmax_max) / softmax_sum; + } + } else if (scoring_mode == ScoringFunc::SIGMOID) { + for (int e = 0; e < local_num; ++e) { + float s = load_elems[e]; + load_elems[e] = Sigmoid(s); + } + for (int e = 0; e < local_num; ++e) { + local_elems[e] = load_elems[e]; + } + } + + bool has_bias = e_score_correction_bias != nullptr; + if (has_bias) { + for (int e = 0; e < local_num; ++e) { + bias[e] = e_score_correction_bias[start_offset + e]; + } + } + + // perform topk_group groups + // 1 calculate each group scores + float group_scores[malloc_per_item * 2]; + for (int i = 0; i < num_expert_group * 2; ++i) { + group_scores[i] = kNegInfinity; + } + for (int i = 0; i < local_num; ++i) { + float b = bias[i]; + float score = local_elems[i] + b; + int i_group = (calc_per_item * sg_local_id + i) / experts_per_group; + float group_max = group_scores[i_group]; + float group_next_max = group_scores[num_expert_group + i_group]; + if (score > group_max) { + group_next_max = group_max; + group_max = score; + } else if (score > group_next_max) { + group_next_max = score; + } + group_scores[i_group] = group_max; + group_scores[num_expert_group + i_group] = group_next_max; + } + for (int i = 0; i < num_expert_group; ++i) { + float group_max = group_scores[i]; + float group_next_max = group_scores[num_expert_group + i]; + + float max1 = sycl::reduce_over_group( + sg, sycl::max(group_max, group_next_max), sycl::maximum<>()); + float local_second = + (group_max < max1 && group_max > -INFINITY) ? group_max : -INFINITY; + local_second = (group_next_max < max1 && group_next_max > local_second) + ? group_next_max + : local_second; + float max2 = sycl::reduce_over_group(sg, local_second, sycl::maximum<>()); + group_scores[i] = max1 + (has_bias ? max2 : 0.0f); + } + + // 2 find topk_group groups as kNegInfinity + int group_topk_idx[malloc_per_item]; + for (int k = 0; k < topk_group; ++k) { + float k_max = group_scores[0]; + int k_max_idx = 0; + for (int e = 1; e < num_expert_group; ++e) { + float score = group_scores[e]; + + if (score > k_max) { + k_max = score; + k_max_idx = e; + } + } + group_scores[k_max_idx] = kNegInfinity; + group_topk_idx[k] = k_max_idx; + } + + // 3 mask no-topk_group groups + for (int i = 0; i < calc_per_item; ++i) { + bool is_masked = true; + for (int k = 0; k < topk_group; ++k) { + if ((local_idx[i] / experts_per_group) == group_topk_idx[k]) { + is_masked = false; + break; + } + } + if (is_masked) { + local_elems[i] = kNegInfinity; + } + } + + // Perform top-k selection + float topk_weights_local[malloc_per_item]; + int topk_ids_local[malloc_per_item]; + + for (int k = 0; k < top_k; ++k) { + float k_max = kNegInfinity; + int k_max_idx = -1; + int remove_ix = -1; + for (int e = 0; e < calc_per_item; ++e) { + float le = local_elems[e]; + float b = bias[e]; + float my_val = le + b; + int my_idx = local_idx[e]; + for (int offset = sub_group_size / 2; offset > 0; offset /= 2) { + float other_val = sycl::permute_group_by_xor(sg, my_val, offset); + int other_idx = sycl::permute_group_by_xor(sg, my_idx, offset); + if (other_val > my_val || + (other_val == my_val && other_idx < my_idx)) { + my_val = other_val; + my_idx = other_idx; + } + } + if (my_val > k_max || (my_val == k_max && my_idx < k_max_idx)) { + k_max = my_val; + k_max_idx = my_idx; + + if (k_max_idx == local_idx[e]) { + remove_ix = e; // Mark this index for removal + } else + remove_ix = -1; + } + } + + int select_item = k_max_idx / calc_per_item; + int select_elem = k_max_idx % calc_per_item; + k_max = local_elems[select_elem]; + k_max = sycl::group_broadcast(sg, k_max, select_item); + if (remove_ix != -1) { + local_elems[remove_ix] = + kNegInfinity; // Reset the score to avoid re-selection + local_idx[remove_ix] = -1; + remove_ix = -1; + } + + topk_weights_local[k] = k_max; + topk_ids_local[k] = k_max_idx < 0 ? k : k_max_idx; + } + + if (renormalize) { + // Renormalize the top-k weights + float sum = 0; + for (int i = 0; i < top_k; ++i) { + sum += topk_weights_local[i]; + } + if (sum > 0) { + for (int i = 0; i < top_k; ++i) { + topk_weights_local[i] /= sum; + } + } + } + + if (sg_local_id == 0) { + int offset = tid * top_k; + for (int i = 0; i < top_k; ++i) { + topk_weights[offset + i] = topk_weights_local[i]; + if (!(topk_ids_local[i] >= 0 && topk_ids_local[i] < experts)) { + // Ensure valid index + topk_ids[offset + i] = 0; + continue; + } + topk_ids[offset + i] = topk_ids_local[i]; + } + } + } + float* topk_weights; + int* topk_ids; + const T* gating_output; + const T* e_score_correction_bias; + const ScoringFunc scoring_mode; + const bool renormalize; + const int tokens; + const int experts; + const int top_k; + const int num_expert_group; + const int topk_group; +}; + +template +void launch_fused_grouped_topk(sycl::queue& queue, float* topk_weights, + int* topk_ids, const T* gating_output, + const T* e_score_correction_bias, + const ScoringFunc scoring_mode, + const bool renormalize, const int tokens, + const int experts, const int top_k, + const int num_expert_group, + const int topk_group) { + using Kernel = Fused_Grouped_Topk; + auto range = Kernel::get_nd_range(tokens, experts); + + queue.submit([&](sycl::handler& cgh) { + Kernel task(topk_weights, topk_ids, gating_output, e_score_correction_bias, + scoring_mode, renormalize, tokens, experts, top_k, + num_expert_group, topk_group); + cgh.parallel_for(range, task); + }); +} + +template +void fused_grouped_topk(float* topk_weights, int* topk_ids, + const T* gating_output, + const T* e_score_correction_bias, + const ScoringFunc scoring_mode, const bool renormalize, + const int tokens, const int experts, const int top_k, + const int num_expert_group, const int topk_group) { + auto& queue = vllm::xpu::vllmGetQueue(); + + TORCH_CHECK(topk_group <= num_expert_group, + "topk_group must be less than or equal to num_expert_group"); + TORCH_CHECK(experts % num_expert_group == 0, + "The number of experts (experts=", experts, + ") must be divisible by num_expert_group (", num_expert_group, + ")."); + + int max_expert_group = ((num_expert_group + 7) / 8) * 8; +#define CASE_TOPK(K) \ + case K: \ + launch_fused_grouped_topk( \ + queue, topk_weights, topk_ids, gating_output, e_score_correction_bias, \ + scoring_mode, renormalize, tokens, experts, top_k, num_expert_group, \ + topk_group); \ + break; + switch (max_expert_group) { + CASE_TOPK(8) + CASE_TOPK(16) + default: + TORCH_CHECK(false, "error: not support num_expert_group=%d,\n", + num_expert_group); + } +#undef CASE_TOPK +} + +}; // namespace GroupedTopKImpl +} // namespace vllm + +/** + * @brief Perform grouped topk after sigmoid/addbias on gating_output. + * @param gating_output The gating output tensor of shape [n_tokens, n_experts]. + * @param n_topk The number of top experts to select. + * @param n_topk_group The number of top experts to select in the group. + * @return A tuple of tensors (topk_weights, topk_indices). + */ +std::tuple grouped_topk( + const torch::Tensor& hidden_states, const torch::Tensor& gating_output, + const int64_t n_topk, const bool renormalize, const int64_t n_expert_group, + const int64_t n_topk_group, const c10::string_view scoring_func, + const c10::optional& bias) { + auto shape = gating_output.sizes().vec(); + TORCH_CHECK(hidden_states.sizes()[0] == gating_output.sizes()[0], + "Number of tokens mismatch") + TORCH_CHECK(shape.size() == 2, "gating_output must be 2D tensor, but got ", + shape.size(), "D"); + if (bias.has_value()) { + auto shape_bias = bias->sizes().vec(); + TORCH_CHECK( + shape_bias[0] == shape[1], + "gating_output and bias must has same innermost dimension, but got ", + shape, " and ", shape_bias); + } + int n_tokens = shape[0]; + int n_experts = shape[1]; + + vllm::GroupedTopKImpl::ScoringFunc scoring_mode; + if (scoring_func == "sigmoid") { + scoring_mode = vllm::GroupedTopKImpl::ScoringFunc::SIGMOID; + } else if (scoring_func == "softmax") { + scoring_mode = vllm::GroupedTopKImpl::ScoringFunc::SOFTMAX; + } else { + scoring_mode = vllm::GroupedTopKImpl::ScoringFunc::DEFAULT; + } + + auto topk_weights = + torch::empty({n_tokens, n_topk}, at::dtype(at::kFloat).device(at::kXPU)); + auto topk_indices = + torch::empty({n_tokens, n_topk}, at::dtype(at::kInt).device(at::kXPU)); + + if (gating_output.scalar_type() == at::kBFloat16) { + using scalar_t = sycl::ext::oneapi::bfloat16; + vllm::GroupedTopKImpl::fused_grouped_topk( + reinterpret_cast(topk_weights.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), + reinterpret_cast(gating_output.data_ptr()), + bias.has_value() ? reinterpret_cast(bias->data_ptr()) + : nullptr, + scoring_mode, renormalize, n_tokens, n_experts, n_topk, n_expert_group, + n_topk_group); + } else if (gating_output.scalar_type() == at::kHalf) { + using scalar_t = sycl::half; + vllm::GroupedTopKImpl::fused_grouped_topk( + reinterpret_cast(topk_weights.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), + reinterpret_cast(gating_output.data_ptr()), + bias.has_value() ? reinterpret_cast(bias->data_ptr()) + : nullptr, + scoring_mode, renormalize, n_tokens, n_experts, n_topk, n_expert_group, + n_topk_group); + } else { + using scalar_t = float; + vllm::GroupedTopKImpl::fused_grouped_topk( + reinterpret_cast(topk_weights.data_ptr()), + reinterpret_cast(topk_indices.data_ptr()), + reinterpret_cast(gating_output.data_ptr()), + bias.has_value() ? reinterpret_cast(bias->data_ptr()) + : nullptr, + scoring_mode, renormalize, n_tokens, n_experts, n_topk, n_expert_group, + n_topk_group); + } + return std::make_tuple(topk_weights, topk_indices); +} \ No newline at end of file diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 6c5f0be14..b8d83d4f4 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -2,4 +2,10 @@ #include +std::tuple grouped_topk( + const torch::Tensor& hidden_states, const torch::Tensor& gating_output, + const int64_t n_topk, const bool renormalize, const int64_t n_expert_group, + const int64_t n_topk_group, const c10::string_view scoring_func, + const c10::optional& bias); + void moe_sum(torch::Tensor& input, torch::Tensor& output); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 5e0051aa0..8acd45ce7 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -6,6 +6,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // from all selected experts. m.def("moe_sum(Tensor input, Tensor! output) -> ()"); m.impl("moe_sum", torch::kXPU, &moe_sum); + + // Grouped TopK + m.def( + "grouped_topk(Tensor hidden_states, Tensor gating_output, int n_topk, " + "bool renormalize, int n_expert_group, int n_topk_group, str " + "scoring_func, Tensor? bias=None) -> (Tensor, Tensor)"); + m.impl("grouped_topk", torch::kXPU, &grouped_topk); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/tests/ops/grouped_topk.py b/tests/ops/grouped_topk.py new file mode 100644 index 000000000..fadb25082 --- /dev/null +++ b/tests/ops/grouped_topk.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + + +def grouped_topk_native( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + assert hidden_states.size(0) == gating_output.size(0), ( + "Number of tokens mismatch") + if scoring_func == "softmax": + gating_output = gating_output.to(torch.float32) + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + num_token = scores.shape[0] + if e_score_correction_bias is not None: + # Store original scores before applying correction bias. We use biased + # scores for expert selection but original scores for routing weights + e_score_correction_bias = e_score_correction_bias.to(torch.float32) + original_scores = scores + scores = scores + e_score_correction_bias.unsqueeze(0) + group_scores = (scores.view(num_token, num_expert_group, + -1).topk(2, dim=-1)[0].sum(dim=-1)) + else: + group_scores = (scores.view(num_token, num_expert_group, + -1).max(dim=-1).values) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, + sorted=True)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = (group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, -1)) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), + float("-inf")) # [n, e] + + if e_score_correction_bias is not None: + topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=True)[1] + # Use original unbiased scores for the routing weights + topk_weights = original_scores.gather(1, topk_ids) + else: + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=True) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + import tests.register_ops as ops + return ops.grouped_topk(hidden_states, gating_output, topk, renormalize, + num_expert_group, topk_group, scoring_func, + e_score_correction_bias) diff --git a/tests/register_ops.py b/tests/register_ops.py index 7f186adf9..3e41c35c7 100644 --- a/tests/register_ops.py +++ b/tests/register_ops.py @@ -1,11 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import TYPE_CHECKING, Optional + import torch -from typing import Optional import vllm_xpu_kernels._C # noqa: F401 import vllm_xpu_kernels._moe_C # noqa: F401 +if TYPE_CHECKING: + + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + # layer norm ops def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, @@ -116,3 +127,41 @@ def dynamic_per_token_scaled_fp8_quant( # moe def moe_sum(input: torch.Tensor, output: torch.Tensor) -> None: torch.ops._moe_C.moe_sum(input, output) + + +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops._moe_C.grouped_topk(hidden_states, gating_output, topk, + renormalize, num_expert_group, + topk_group, scoring_func, + e_score_correction_bias) + + +if hasattr(torch.ops._moe_C, "grouped_topk"): + + @register_fake("_moe_C::grouped_topk") + def _grouped_topk_fake( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int, + topk_group: int, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + topk_weights = torch.empty((gating_output.size(0), topk), + dtype=torch.float32, + device=hidden_states.device) + topk_indices = torch.empty((gating_output.size(0), topk), + dtype=torch.int32, + device=hidden_states.device) + return topk_weights, topk_indices diff --git a/tests/test_grouped_topk.py b/tests/test_grouped_topk.py new file mode 100644 index 000000000..748eab70f --- /dev/null +++ b/tests/test_grouped_topk.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from tests.ops.grouped_topk import grouped_topk, grouped_topk_native +from tests.utils import opcheck + +dpcpp_device = torch.device("xpu") + + +class TestTorchMethod: + + @pytest.mark.parametrize("seed", [123, 356, 478]) + @pytest.mark.parametrize("dtype", + [torch.float16, torch.bfloat16, torch.float32]) + @pytest.mark.parametrize("n_token", [1, 2, 64, 256]) + @pytest.mark.parametrize("n_expert", [64, 128, 256]) + @pytest.mark.parametrize("n_topk", [4, 6, 8]) + @pytest.mark.parametrize("n_topk_group", [4, 6, 8]) + @pytest.mark.parametrize("n_expert_group", [8]) + @pytest.mark.parametrize("renormalize", [True, False]) + @pytest.mark.parametrize("scoring_func", ["sigmoid", "softmax"]) + @pytest.mark.parametrize("has_bias", [True, False]) + def test_grouped_topk( + self, + seed, + dtype, + n_token, + n_expert, + n_topk, + n_expert_group, + n_topk_group, + renormalize, + scoring_func, + has_bias, + ): + + torch.manual_seed(seed) + torch.set_default_device("xpu") + gating_output = torch.randn(n_token, n_expert, + device=dpcpp_device).to(dtype) + hidden_states = torch.zeros(n_token, n_expert, + device=dpcpp_device).to(dtype) + bias = None + if has_bias: + if has_bias and scoring_func == "sigmoid" \ + and dtype is not torch.float32: + # using a bias of bigger number to avoid Low-precision + bias = torch.arange(1, n_expert + 1).to(dpcpp_device).to(dtype) + else: + bias = torch.randn(n_expert, device=dpcpp_device).to(dtype) + + ref_topk_weights, ref_topk_indices = grouped_topk_native( + hidden_states, + gating_output, + n_topk, + renormalize, + n_expert_group, + n_topk_group, + scoring_func=scoring_func, + e_score_correction_bias=bias, + ) + + topk_weights, topk_indices = grouped_topk( + hidden_states, + gating_output, + n_topk, + renormalize, + n_expert_group, + n_topk_group, + scoring_func=scoring_func, + e_score_correction_bias=bias, + ) + + # Compare the results + torch.testing.assert_close(ref_topk_weights, + topk_weights, + atol=2e-2, + rtol=1e-2) + assert torch.equal(ref_topk_indices, topk_indices) + + opcheck( + torch.ops._moe_C.grouped_topk, + (hidden_states, gating_output, n_topk, renormalize, n_expert_group, + n_topk_group, scoring_func, bias), + ) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + @pytest.mark.parametrize("renormalize", [True]) + @pytest.mark.parametrize("full_nan", [True, False]) + def test_grouped_topk_sigmoid_nan( + self, + dtype, + renormalize, + full_nan, + ): + n_token = 512 + n_expert = 256 + n_topk = 8 + n_expert_group = 8 + n_topk_group = 4 + + gating_output = torch.randn(n_token, n_expert, + device=dpcpp_device).to(dtype) + hidden_states = torch.zeros(n_token, n_expert, + device=dpcpp_device).to(dtype) + bias = torch.randn(n_expert, device=dpcpp_device).to(dtype) + + if full_nan: + gating_output = torch.full(gating_output.size(), + float("nan"), + device=dpcpp_device, + dtype=dtype).contiguous() + else: + gating_output[0][0] = float("nan") + + topk_weights, topk_indices = grouped_topk( + hidden_states, + gating_output, + n_topk, + renormalize, + n_expert_group, + n_topk_group, + e_score_correction_bias=bias, + ) + + assert torch.all(topk_indices < n_expert)