From 5abd9a9ef15003a8c334accd1d8e8d5b24d33860 Mon Sep 17 00:00:00 2001 From: chenjun Date: Sun, 5 Apr 2026 14:20:36 +0000 Subject: [PATCH 1/8] add fused_dynamic_mxfp4_quant_moe_sort_hip --- aiter/ops/quant.py | 40 ++++ csrc/include/quant.h | 9 + csrc/include/rocm_ops.hpp | 353 +++++++++++++++-------------- csrc/kernels/quant_kernels.cu | 208 +++++++++++++++++ op_tests/test_moe_sorting_mxfp4.py | 171 +++++++++++--- 5 files changed, 583 insertions(+), 198 deletions(-) diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index 36b0ba62cf..848d402efc 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -548,6 +548,46 @@ def moe_smooth_per_token_scaled_quant_v2( ... +@compile_ops("module_quant") +def fused_dynamic_mxfp4_quant_moe_sort_hip( + out: torch.Tensor, + scales: torch.Tensor, + input: torch.Tensor, + sorted_ids: torch.Tensor, + num_valid_ids: torch.Tensor, + topk: int, + block_m: int, + group_size: int = 32, +) -> None: + """ + HIP path for fused dynamic MXFP4 quantization and MoE scale sorting. + """ + ... + + +def fused_dynamic_mxfp4_quant_moe_sort( + input: torch.Tensor, + sorted_ids: torch.Tensor, + num_valid_ids: torch.Tensor, + token_num: int, + topk: int, + block_size: int, + group_size: int = 32, +) -> Tuple[torch.Tensor, torch.Tensor]: + M, N = input.view(-1, input.shape[-1]).shape + out = torch.empty(M, N // 2, dtype=dtypes.fp4x2, device=input.device) + scales = torch.empty( + sorted_ids.shape[0], + (N + 31) // 32, + dtype=dtypes.fp8_e8m0, + device=input.device, + ) + fused_dynamic_mxfp4_quant_moe_sort_hip( + out, scales, input, sorted_ids, num_valid_ids, topk, block_size, group_size + ) + return out, scales + + @compile_ops("module_quant") def partial_transpose( out: Tensor, diff --git a/csrc/include/quant.h b/csrc/include/quant.h index c3a0f1e3d9..eb658927b7 100644 --- a/csrc/include/quant.h +++ b/csrc/include/quant.h @@ -66,4 +66,13 @@ void moe_smooth_per_token_scaled_quant_v2(torch::Tensor& out, // [..., d int block_m, bool shuffle_scale = false, bool transpose_out = false); + +void fused_dynamic_mxfp4_quant_moe_sort_hip(torch::Tensor& out, // [token_num * topk, d / 2] + torch::Tensor& scales, // swizzled e8m0 bytes + torch::Tensor const& input, // [token_num * topk, d] + torch::Tensor const& sorted_ids, + torch::Tensor const& num_valid_ids, + int topk, + int block_m, + int group_size = 32); } // namespace aiter diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 46e9f50106..3fcfd7e836 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -2,9 +2,9 @@ // Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "aiter_tensor.h" #include #include -#include "aiter_tensor.h" namespace py = pybind11; @@ -326,144 +326,147 @@ namespace py = pybind11; py::arg("is_neox"), \ py::arg("is_nope_first")); -#define AITER_TENSOR_PYBIND \ - pybind11::class_(m, "aiter_tensor_t") \ - .def(pybind11::init<>()) \ - .def_readwrite("numel_", &aiter_tensor_t::numel_) \ - .def_readwrite("ndim", &aiter_tensor_t::ndim) \ - .def_readwrite("device_id", &aiter_tensor_t::device_id); \ - m.def("make_aiter_tensor", \ - [](int64_t data_ptr, size_t numel, int ndim, \ - const std::vector& shape, \ - const std::vector& strides, \ - int dtype, int device_id) { \ - aiter_tensor_t at{}; \ - at.ptr = (void*)data_ptr; \ - at.numel_ = numel; \ - at.ndim = ndim; \ - for(int i = 0; i < ndim && i < 8; i++) { \ - at.shape[i] = shape[i]; \ - at.strides[i] = strides[i]; \ - } \ - at.dtype_ = (AiterDtype)dtype; \ - at.device_id = device_id; \ - return at; \ - }, \ - pybind11::arg("data_ptr"), \ - pybind11::arg("numel"), \ - pybind11::arg("ndim"), \ - pybind11::arg("shape"), \ - pybind11::arg("strides"), \ - pybind11::arg("dtype"), \ - pybind11::arg("device_id")); - -#define CUSTOM_ALL_REDUCE_PYBIND \ - m.def("init_custom_ar", \ - &aiter::init_custom_ar, \ - py::arg("meta_ptr"), \ - py::arg("rank_data_ptr"), \ - py::arg("rank_data_sz"), \ - py::arg("ipc_handle_ptrs"), \ - py::arg("offsets"), \ - py::arg("rank"), \ - py::arg("fully_connected")); \ - m.def("all_reduce", \ - &aiter::all_reduce, \ - py::arg("_fa"), \ - py::arg("inp"), \ - py::arg("out"), \ - py::arg("use_new"), \ - py::arg("open_fp8_quant"), \ - py::arg("reg_inp_ptr"), \ - py::arg("reg_inp_bytes"), \ - py::arg("stream")); \ - m.def("reduce_scatter", \ - &aiter::reduce_scatter, \ - py::arg("_fa"), \ - py::arg("inp"), \ - py::arg("out"), \ - py::arg("reg_ptr"), \ - py::arg("reg_bytes"), \ - py::arg("stream")); \ - m.def("all_gather_reg", \ - &aiter::all_gather_reg, \ - py::arg("_fa"), \ - py::arg("inp"), \ - py::arg("out"), \ - py::arg("dim"), \ - py::arg("stream")); \ - m.def("all_gather_unreg", \ - &aiter::all_gather_unreg, \ - py::arg("_fa"), \ - py::arg("inp"), \ - py::arg("reg_buffer"), \ - py::arg("out"), \ - py::arg("reg_bytes"), \ - py::arg("dim"), \ - py::arg("stream")); \ - m.def("fused_allreduce_rmsnorm", \ - &aiter::fused_allreduce_rmsnorm, \ - py::arg("_fa"), \ - py::arg("inp"), \ - py::arg("res_inp"), \ - py::arg("res_out"), \ - py::arg("out"), \ - py::arg("w"), \ - py::arg("eps"), \ - py::arg("reg_ptr"), \ - py::arg("reg_bytes"), \ - py::arg("use_1stage"), \ - py::arg("stream")); \ - m.def("fused_allreduce_rmsnorm_quant", \ - &aiter::fused_allreduce_rmsnorm_quant, \ - py::arg("_fa"), \ - py::arg("inp"), \ - py::arg("res_inp"), \ - py::arg("res_out"), \ - py::arg("out"), \ - py::arg("scale_out"), \ - py::arg("w"), \ - py::arg("eps"), \ - py::arg("reg_ptr"), \ - py::arg("reg_bytes"), \ - py::arg("use_1stage"), \ - py::arg("stream")); \ - m.def("all_reduce_asm_", &all_reduce_asm, ""); \ - m.def("all_reduce_rmsnorm_", &all_reduce_rmsnorm, "all_reduce_rmsnorm"); \ - m.def("all_reduce_rmsnorm_quant_", &all_reduce_rmsnorm_quant, "all_reduce_rmsnorm_quant"); \ - m.def("dispose", &aiter::dispose, py::arg("_fa")); \ - m.def("meta_size", &aiter::meta_size); \ - m.def("register_input_buffer", \ - &aiter::register_input_buffer, \ - py::arg("_fa"), \ - py::arg("self_ptr"), \ - py::arg("ipc_handle_ptrs"), \ - py::arg("offsets")); \ - m.def("register_output_buffer", \ - &aiter::register_output_buffer, \ - py::arg("_fa"), \ - py::arg("self_ptr"), \ - py::arg("ipc_handle_ptrs"), \ - py::arg("offsets")); \ - m.def("get_graph_buffer_count", &aiter::get_graph_buffer_count, py::arg("_fa")); \ - m.def("get_graph_buffer_ipc_meta", \ - &aiter::get_graph_buffer_ipc_meta, \ - py::arg("_fa"), \ - py::arg("handle_out"), \ - py::arg("offset_out")); \ - m.def("register_graph_buffers", \ - &aiter::register_graph_buffers, \ - py::arg("_fa"), \ - py::arg("handle_ptrs"), \ - py::arg("offset_ptrs")); \ - m.def("allocate_meta_buffer", \ - &aiter::allocate_meta_buffer, \ - py::arg("size"), \ - py::arg("stream")); \ - m.def("free_meta_buffer", &aiter::free_meta_buffer, py::arg("ptr")); \ - m.def("get_meta_buffer_ipc_handle", \ - &aiter::get_meta_buffer_ipc_handle, \ - py::arg("inp_ptr"), \ +#define AITER_TENSOR_PYBIND \ + pybind11::class_(m, "aiter_tensor_t") \ + .def(pybind11::init<>()) \ + .def_readwrite("numel_", &aiter_tensor_t::numel_) \ + .def_readwrite("ndim", &aiter_tensor_t::ndim) \ + .def_readwrite("device_id", &aiter_tensor_t::device_id); \ + m.def( \ + "make_aiter_tensor", \ + [](int64_t data_ptr, \ + size_t numel, \ + int ndim, \ + const std::vector& shape, \ + const std::vector& strides, \ + int dtype, \ + int device_id) { \ + aiter_tensor_t at{}; \ + at.ptr = (void*)data_ptr; \ + at.numel_ = numel; \ + at.ndim = ndim; \ + for(int i = 0; i < ndim && i < 8; i++) \ + { \ + at.shape[i] = shape[i]; \ + at.strides[i] = strides[i]; \ + } \ + at.dtype_ = (AiterDtype)dtype; \ + at.device_id = device_id; \ + return at; \ + }, \ + pybind11::arg("data_ptr"), \ + pybind11::arg("numel"), \ + pybind11::arg("ndim"), \ + pybind11::arg("shape"), \ + pybind11::arg("strides"), \ + pybind11::arg("dtype"), \ + pybind11::arg("device_id")); + +#define CUSTOM_ALL_REDUCE_PYBIND \ + m.def("init_custom_ar", \ + &aiter::init_custom_ar, \ + py::arg("meta_ptr"), \ + py::arg("rank_data_ptr"), \ + py::arg("rank_data_sz"), \ + py::arg("ipc_handle_ptrs"), \ + py::arg("offsets"), \ + py::arg("rank"), \ + py::arg("fully_connected")); \ + m.def("all_reduce", \ + &aiter::all_reduce, \ + py::arg("_fa"), \ + py::arg("inp"), \ + py::arg("out"), \ + py::arg("use_new"), \ + py::arg("open_fp8_quant"), \ + py::arg("reg_inp_ptr"), \ + py::arg("reg_inp_bytes"), \ + py::arg("stream")); \ + m.def("reduce_scatter", \ + &aiter::reduce_scatter, \ + py::arg("_fa"), \ + py::arg("inp"), \ + py::arg("out"), \ + py::arg("reg_ptr"), \ + py::arg("reg_bytes"), \ + py::arg("stream")); \ + m.def("all_gather_reg", \ + &aiter::all_gather_reg, \ + py::arg("_fa"), \ + py::arg("inp"), \ + py::arg("out"), \ + py::arg("dim"), \ + py::arg("stream")); \ + m.def("all_gather_unreg", \ + &aiter::all_gather_unreg, \ + py::arg("_fa"), \ + py::arg("inp"), \ + py::arg("reg_buffer"), \ + py::arg("out"), \ + py::arg("reg_bytes"), \ + py::arg("dim"), \ + py::arg("stream")); \ + m.def("fused_allreduce_rmsnorm", \ + &aiter::fused_allreduce_rmsnorm, \ + py::arg("_fa"), \ + py::arg("inp"), \ + py::arg("res_inp"), \ + py::arg("res_out"), \ + py::arg("out"), \ + py::arg("w"), \ + py::arg("eps"), \ + py::arg("reg_ptr"), \ + py::arg("reg_bytes"), \ + py::arg("use_1stage"), \ + py::arg("stream")); \ + m.def("fused_allreduce_rmsnorm_quant", \ + &aiter::fused_allreduce_rmsnorm_quant, \ + py::arg("_fa"), \ + py::arg("inp"), \ + py::arg("res_inp"), \ + py::arg("res_out"), \ + py::arg("out"), \ + py::arg("scale_out"), \ + py::arg("w"), \ + py::arg("eps"), \ + py::arg("reg_ptr"), \ + py::arg("reg_bytes"), \ + py::arg("use_1stage"), \ + py::arg("stream")); \ + m.def("all_reduce_asm_", &all_reduce_asm, ""); \ + m.def("all_reduce_rmsnorm_", &all_reduce_rmsnorm, "all_reduce_rmsnorm"); \ + m.def("all_reduce_rmsnorm_quant_", &all_reduce_rmsnorm_quant, "all_reduce_rmsnorm_quant"); \ + m.def("dispose", &aiter::dispose, py::arg("_fa")); \ + m.def("meta_size", &aiter::meta_size); \ + m.def("register_input_buffer", \ + &aiter::register_input_buffer, \ + py::arg("_fa"), \ + py::arg("self_ptr"), \ + py::arg("ipc_handle_ptrs"), \ + py::arg("offsets")); \ + m.def("register_output_buffer", \ + &aiter::register_output_buffer, \ + py::arg("_fa"), \ + py::arg("self_ptr"), \ + py::arg("ipc_handle_ptrs"), \ + py::arg("offsets")); \ + m.def("get_graph_buffer_count", &aiter::get_graph_buffer_count, py::arg("_fa")); \ + m.def("get_graph_buffer_ipc_meta", \ + &aiter::get_graph_buffer_ipc_meta, \ + py::arg("_fa"), \ + py::arg("handle_out"), \ + py::arg("offset_out")); \ + m.def("register_graph_buffers", \ + &aiter::register_graph_buffers, \ + py::arg("_fa"), \ + py::arg("handle_ptrs"), \ + py::arg("offset_ptrs")); \ + m.def( \ + "allocate_meta_buffer", &aiter::allocate_meta_buffer, py::arg("size"), py::arg("stream")); \ + m.def("free_meta_buffer", &aiter::free_meta_buffer, py::arg("ptr")); \ + m.def("get_meta_buffer_ipc_handle", \ + &aiter::get_meta_buffer_ipc_handle, \ + py::arg("inp_ptr"), \ py::arg("out_handle_ptr")); #define CUSTOM_PYBIND \ @@ -1319,6 +1322,16 @@ namespace py = pybind11; py::arg("block_m"), \ py::arg("shuffle_scale") = false, \ py::arg("transpose_out") = false); \ + m.def("fused_dynamic_mxfp4_quant_moe_sort_hip", \ + &aiter::fused_dynamic_mxfp4_quant_moe_sort_hip, \ + py::arg("out"), \ + py::arg("scales"), \ + py::arg("input"), \ + py::arg("sorted_ids"), \ + py::arg("num_valid_ids"), \ + py::arg("topk"), \ + py::arg("block_m"), \ + py::arg("group_size") = 32); \ m.def("partial_transpose", \ &aiter::partial_transpose, \ py::arg("out"), \ @@ -1410,37 +1423,39 @@ namespace py = pybind11; py::arg("epsilon"), \ py::arg("use_model_sensitive_rmsnorm") = 0); -#define ROPE_1C_UNCACHED_FWD_PYBIND m.def("rope_fwd_impl", &rope_fwd_impl); -#define ROPE_2C_UNCACHED_FWD_PYBIND m.def("rope_2c_fwd_impl", &rope_2c_fwd_impl); -#define ROPE_1C_CACHED_FWD_PYBIND m.def("rope_cached_fwd_impl", &rope_cached_fwd_impl); -#define ROPE_2C_CACHED_FWD_PYBIND m.def("rope_cached_2c_fwd_impl", &rope_cached_2c_fwd_impl); -#define ROPE_1C_THD_FWD_PYBIND m.def("rope_thd_fwd_impl", &rope_thd_fwd_impl); -#define ROPE_1C_2D_FWD_PYBIND m.def("rope_2d_fwd_impl", &rope_2d_fwd_impl); - -#define ROPE_1C_UNCACHED_BWD_PYBIND m.def("rope_bwd_impl", &rope_bwd_impl); -#define ROPE_2C_UNCACHED_BWD_PYBIND m.def("rope_2c_bwd_impl", &rope_2c_bwd_impl); -#define ROPE_1C_CACHED_BWD_PYBIND m.def("rope_cached_bwd_impl", &rope_cached_bwd_impl); -#define ROPE_2C_CACHED_BWD_PYBIND m.def("rope_cached_2c_bwd_impl", &rope_cached_2c_bwd_impl); -#define ROPE_1C_THD_BWD_PYBIND m.def("rope_thd_bwd_impl", &rope_thd_bwd_impl); -#define ROPE_1C_2D_BWD_PYBIND m.def("rope_2d_bwd_impl", &rope_2d_bwd_impl); - - -#define ROPE_1C_CACHED_POSITIONS_FWD_PYBIND m.def("rope_cached_positions_fwd_impl", &rope_cached_positions_fwd_impl) -#define ROPE_2C_CACHED_POSITIONS_FWD_PYBIND \ - m.def("rope_cached_positions_2c_fwd_impl", \ - &rope_cached_positions_2c_fwd_impl, \ - py::arg("output_x"), \ - py::arg("output_y"), \ - py::arg("input_x"), \ - py::arg("input_y"), \ - py::arg("cos"), \ - py::arg("sin"), \ - py::arg("positions"), \ - py::arg("rotate_style"), \ - py::arg("reuse_freqs_front_part"), \ +#define ROPE_1C_UNCACHED_FWD_PYBIND m.def("rope_fwd_impl", &rope_fwd_impl); +#define ROPE_2C_UNCACHED_FWD_PYBIND m.def("rope_2c_fwd_impl", &rope_2c_fwd_impl); +#define ROPE_1C_CACHED_FWD_PYBIND m.def("rope_cached_fwd_impl", &rope_cached_fwd_impl); +#define ROPE_2C_CACHED_FWD_PYBIND m.def("rope_cached_2c_fwd_impl", &rope_cached_2c_fwd_impl); +#define ROPE_1C_THD_FWD_PYBIND m.def("rope_thd_fwd_impl", &rope_thd_fwd_impl); +#define ROPE_1C_2D_FWD_PYBIND m.def("rope_2d_fwd_impl", &rope_2d_fwd_impl); + +#define ROPE_1C_UNCACHED_BWD_PYBIND m.def("rope_bwd_impl", &rope_bwd_impl); +#define ROPE_2C_UNCACHED_BWD_PYBIND m.def("rope_2c_bwd_impl", &rope_2c_bwd_impl); +#define ROPE_1C_CACHED_BWD_PYBIND m.def("rope_cached_bwd_impl", &rope_cached_bwd_impl); +#define ROPE_2C_CACHED_BWD_PYBIND m.def("rope_cached_2c_bwd_impl", &rope_cached_2c_bwd_impl); +#define ROPE_1C_THD_BWD_PYBIND m.def("rope_thd_bwd_impl", &rope_thd_bwd_impl); +#define ROPE_1C_2D_BWD_PYBIND m.def("rope_2d_bwd_impl", &rope_2d_bwd_impl); + +#define ROPE_1C_CACHED_POSITIONS_FWD_PYBIND \ + m.def("rope_cached_positions_fwd_impl", &rope_cached_positions_fwd_impl) +#define ROPE_2C_CACHED_POSITIONS_FWD_PYBIND \ + m.def("rope_cached_positions_2c_fwd_impl", \ + &rope_cached_positions_2c_fwd_impl, \ + py::arg("output_x"), \ + py::arg("output_y"), \ + py::arg("input_x"), \ + py::arg("input_y"), \ + py::arg("cos"), \ + py::arg("sin"), \ + py::arg("positions"), \ + py::arg("rotate_style"), \ + py::arg("reuse_freqs_front_part"), \ py::arg("nope_first")) -#define ROPE_1C_CACHED_POSITIONS_OFFSETS_FWD_PYBIND m.def("rope_cached_positions_offsets_fwd_impl", &rope_cached_positions_offsets_fwd_impl) -#define ROPE_2C_CACHED_POSITIONS_OFFSETS_FWD_PYBIND m.def("rope_cached_positions_offsets_2c_fwd_impl", &rope_cached_positions_offsets_2c_fwd_impl) +#define ROPE_1C_CACHED_POSITIONS_OFFSETS_FWD_PYBIND \ + m.def("rope_cached_positions_offsets_fwd_impl", &rope_cached_positions_offsets_fwd_impl) +#define ROPE_2C_CACHED_POSITIONS_OFFSETS_FWD_PYBIND \ + m.def("rope_cached_positions_offsets_2c_fwd_impl", &rope_cached_positions_offsets_2c_fwd_impl) #define FUSED_QKNORM_MROPE_CACHE_QUANT_PYBIND \ m.def("fused_qk_norm_mrope_3d_cache_pts_quant_shuffle", \ diff --git a/csrc/kernels/quant_kernels.cu b/csrc/kernels/quant_kernels.cu index a0b5e6c588..50d58bf6b9 100644 --- a/csrc/kernels/quant_kernels.cu +++ b/csrc/kernels/quant_kernels.cu @@ -1614,4 +1614,212 @@ void moe_smooth_per_token_scaled_quant_v2( } +template +__global__ void mxfp4_quant_moe_sort_kernel( + DTYPE_O* __restrict__ out, + uint8_t* __restrict__ scale, + DTYPE_I const* __restrict__ input, + int32_t const* __restrict__ sorted_ids, + int32_t const* __restrict__ num_valid_ids, + const int32_t num_tokens, + const int32_t cols, + const int32_t group_size, + const int32_t block_m, + const int32_t sub_block_m, + const int32_t num_blocks, + const int32_t num_tg, + const int32_t topk, + const int32_t input_stride) +{ + int num_thread_per_group = group_size / thread_data_size; + int num_valid_ids_value = num_valid_ids[0]; + int block_idx = blockIdx.x; + int lane_idx = threadIdx.x % WARP_SIZE; + const int scale_k = threadIdx.x / num_thread_per_group; + static constexpr int32_t vec_size_i = + thread_data_size == 0 ? 16 / sizeof(DTYPE_I) : thread_data_size; + static constexpr int32_t load_chunk_bytes = + (sizeof(DTYPE_I) * vec_size_i % 16 == 0 ? 16 + : (sizeof(DTYPE_I) * vec_size_i % 8 == 0 ? 8 : 4)); + using vec_i = opus::vector_t; + using vec_f = opus::vector_t; + const float inverted_DTYPE_MAX = + std::is_same_v + ? 0.25 + : (1. / static_cast(opus::finfo::max())); + const int32_t scaleN_valid = (cols + group_size - 1) / group_size; + const int32_t scaleN_pad = ((scaleN_valid + 7) / 8) * 8; + + auto fp4_scale = [](float tmp) { + uint32_t u32 = __builtin_bit_cast(uint32_t, tmp); + uint32_t exponent = (u32 >> 23) & 0b11111111; + if(exponent == 0b11111111) + { + return __builtin_bit_cast(float, exponent << 23); + } + if(((u32 & 0x400000)) && (((u32 & 0x200000)) || ((u32 & 0x1FFFFF)) || (exponent))) + exponent += 1; + return __builtin_bit_cast(float, exponent << 23); + }; + auto fp4_scale_shuffle_id = [](int32_t scaleN_pad_, int32_t x, int32_t y) { + return (x / 32 * scaleN_pad_) * 32 + (y / 8) * 256 + (y % 4) * 64 + + (x % 16) * 4 + (y % 8) / 4 * 2 + (x % 32) / 16; + }; + + for(; block_idx < num_blocks; block_idx += num_tg) + { + int sorted_ids_offset = block_idx * sub_block_m; + if(sorted_ids_offset >= num_valid_ids_value) + { + return; + } + int token_id_info_list; + if (lane_idx < sub_block_m) + { + token_id_info_list = sorted_ids[sorted_ids_offset + lane_idx]; + } + int token_id_list = token_id_info_list & 0xFFFFFF; + int topk_id_list = token_id_info_list >> 24; + for(int i = 0; i < sub_block_m; i++) + { + int token_idx = __builtin_amdgcn_readlane(token_id_list, i); + int topk_id = __builtin_amdgcn_readlane(topk_id_list, i); + if(token_idx >= num_tokens) + { + break; + } + + int64_t input_offset; + if (topk == 1) + { + input_offset = (int64_t)(token_idx) * input_stride; + } + else + { + input_offset = (int64_t)(token_idx * topk + topk_id) * input_stride; + } + auto buffer_input = + opus::make_gmem(input + input_offset, cols * sizeof(DTYPE_I)); + vec_i vec_input = load_vector_nbytes( + buffer_input, threadIdx.x * vec_size_i); + vec_f vec_input_f; + float* input_f_ptr = reinterpret_cast(&vec_input_f); + float absMax = 1e-10f; + #pragma unroll + for(int j = 0; j < vec_size_i; j++) + { + vec_input_f[j] = static_cast(vec_input[j]); + absMax = max(absMax, abs(vec_input_f[j])); + } + absMax = multithread_reduce(absMax, hipcub::Max(), num_thread_per_group); + + float row_scale = std::is_same_v + ? fp4_scale(absMax) * inverted_DTYPE_MAX + : absMax * inverted_DTYPE_MAX; + + const int sorted_row = sorted_ids_offset + i; + if(threadIdx.x % num_thread_per_group == 0 && scale_k < scaleN_valid) + { + uint8_t bs_e8m0 = (__builtin_bit_cast(uint32_t, row_scale) >> 23) & 0xFF; + int addr = fp4_scale_shuffle_id(scaleN_pad, sorted_row, scale_k); + scale[addr] = bs_e8m0; + } + + if(topk_id < topk) + { + int64_t out_offset = (int64_t)(token_idx * topk + topk_id) * cols; + scaled_quant_vgpr_impl( + out, input_f_ptr, &row_scale, cols, out_offset); + } + } + } +} + + +#define MXFP4_QUANT_MOE_SORT_KERNEL_IMPL(DTYPE_O, THREAD_DATA, BLOCK_SIZE) \ + AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "mxfp4_quant_moe_sort_kernel", [&] { \ + AITER_CHECK(group_size % THREAD_DATA == 0, __func__, " group_size is not divisible by THREAD_DATA"); \ + using input_dtype = typename t2opus::type; \ + int warps_per_cu = 8 * BLOCK_SIZE / WARP_SIZE; \ + int num_tg = persistent_mode ? num_cu * warps_per_cu : num_blocks; \ + dim3 const grid(num_tg); \ + mxfp4_quant_moe_sort_kernel \ + <<>>( \ + reinterpret_cast(output.data_ptr()), \ + reinterpret_cast(scale.data_ptr()), \ + reinterpret_cast(input.data_ptr()), \ + sorted_ids.data_ptr(), \ + num_valid_ids.data_ptr(), \ + token_num, \ + cols, \ + group_size, \ + block_m, \ + sub_block_m, \ + num_blocks, \ + num_tg, \ + topk, \ + input_stride); \ + }); + + +#define MXFP4_QUANT_MOE_SORT_KERNEL_DISPATCH(DTYPE_O, cols_) \ + if(cols_ <= 4 * BlockSize) \ + { \ + MXFP4_QUANT_MOE_SORT_KERNEL_IMPL(DTYPE_O, 8, BlockSize / 2) \ + } \ + else if(cols_ <= 8 * BlockSize) \ + { \ + MXFP4_QUANT_MOE_SORT_KERNEL_IMPL(DTYPE_O, 8, BlockSize) \ + } \ + else if(cols_ <= 16 * BlockSize) \ + { \ + MXFP4_QUANT_MOE_SORT_KERNEL_IMPL(DTYPE_O, 16, BlockSize) \ + } \ + else if(cols_ <= 16 * BlockSize * 2) \ + { \ + MXFP4_QUANT_MOE_SORT_KERNEL_IMPL(DTYPE_O, 32, BlockSize) \ + } \ + else \ + { \ + TORCH_CHECK(false, "input last dim has exceeded the maximum value ", 32 * BlockSize) \ + } + +void fused_dynamic_mxfp4_quant_moe_sort_hip( + torch::Tensor& output, + torch::Tensor& scale, + torch::Tensor const& input, + torch::Tensor const& sorted_ids, + torch::Tensor const& num_valid_ids, + int topk, + int block_m, + int group_size = 32 +) +{ + int cols = input.size(-1); + int token_num = input.numel() / (cols * topk); + + const int num_cu = get_num_cu_func(); + int sub_block_m = (token_num * topk) > (num_cu * 8) ? 2 : 8; + TORCH_CHECK(block_m % sub_block_m == 0, __func__, " block_m is not divisible by sub_block_m"); + int num_blocks = (sorted_ids.size(0) + sub_block_m - 1) / sub_block_m; + const bool persistent_mode = false; + const int input_stride = input.stride(-2); + + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); + +#if defined(__Float4_e2m1fn_x2) + if(output.dtype() == torch_fp4x2 || output.dtype() == torch::kUInt8) + { + MXFP4_QUANT_MOE_SORT_KERNEL_DISPATCH(opus::fp4_t, cols); + } + else + { + TORCH_CHECK(false, __func__, ": not support output type: ", output.dtype()); + } +#else + TORCH_CHECK(false, __func__, ": not support fp4x2 on this device"); +#endif +} + } // namespace aiter \ No newline at end of file diff --git a/op_tests/test_moe_sorting_mxfp4.py b/op_tests/test_moe_sorting_mxfp4.py index 9a9571c8e0..2bad7489be 100644 --- a/op_tests/test_moe_sorting_mxfp4.py +++ b/op_tests/test_moe_sorting_mxfp4.py @@ -2,8 +2,9 @@ # Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import torch import aiter -from aiter.test_common import checkAllclose, benchmark, run_perftest, perftest +from aiter.test_common import checkAllclose, benchmark, run_perftest from aiter.fused_moe import moe_sorting, fused_topk +from aiter.ops.triton.quant.fused_mxfp4_quant import fused_dynamic_mxfp4_quant_moe_sort from aiter import get_torch_quant, dtypes from aiter.utility import fp4_utils import pandas as pd @@ -15,7 +16,6 @@ torch.set_printoptions(threshold=float("inf")) -@perftest() def run_torch(scale, sorted_ids, num_valid_ids, token_num): topk = 1 if len(scale.shape) == 3: @@ -29,7 +29,7 @@ def run_torch(scale, sorted_ids, num_valid_ids, token_num): sorted_ids = sorted_ids * topk + topk_ids sorted_ids[mask] = 0 # set to 0 to avoid overflow scale = scale[sorted_ids] - scale[mask] = 0 + scale.view(torch.uint8)[mask] = 0 sm, sn = scale.shape tmp = torch.zeros( ((sm + 31) // 32 * 32, sn), dtype=scale.dtype, device=scale.device @@ -43,6 +43,19 @@ def run_torch(scale, sorted_ids, num_valid_ids, token_num): return ref +def run_hip_quant_sort(scale, input, sorted_ids, num_valid_ids, topk, block_size): + M, N = input.shape + out = torch.empty( + (M, N // 2), + dtype=dtypes.fp4x2, + device=input.device, + ) + aiter.fused_dynamic_mxfp4_quant_moe_sort_hip( + out, scale, input, sorted_ids, num_valid_ids, topk, block_size + ) + return out, scale + + @benchmark() def test_moe_mxfp4_sort(dtype, token_num, model_dim, E, topk, block_size, stage): input = torch.randn((token_num, model_dim), dtype=dtype) @@ -63,7 +76,7 @@ def test_moe_mxfp4_sort(dtype, token_num, model_dim, E, topk, block_size, stage) else: scale = torch.arange(token_num * topk * model_dim // 32, dtype=torch.uint8) scale = scale.view(token_num, topk, model_dim // 32) - ref, us_ref = run_torch(scale.clone(), sorted_ids.clone(), num_valid_ids, token_num) + ref = run_torch(scale.clone(), sorted_ids.clone(), num_valid_ids, token_num) sorted_mxfp4_scale, us = run_perftest( fp4_utils.moe_mxfp4_sort, scale, @@ -81,7 +94,83 @@ def test_moe_mxfp4_sort(dtype, token_num, model_dim, E, topk, block_size, stage) sorted_mxfp4_scale[:num_valid_ids].view(torch.uint8), msg="sorted_mxfp4_scale", ) - return {"us_ref": us_ref, "us": us, "err": err} + return {"us": us, "err": err} + + +@benchmark() +def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size, stage): + input = torch.randn((token_num, model_dim), dtype=dtype) + score = torch.randn((token_num, E), dtype=dtype) + + topk_weights, topk_ids = fused_topk(input, score, topk, True) + sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = moe_sorting( + topk_ids, + topk_weights, + E, + model_dim, + dtype, + ) + num_valid_ids = num_valid_ids[0] + if stage != "stage1": + input = torch.randn((token_num * topk, model_dim), dtype=dtype) + else: + topk = 1 + ref_out, scale = get_torch_quant(aiter.QuantType.per_1x32)( + input, quant_dtype=dtypes.fp4x2 + ) + ref_scale = run_torch(scale.clone(), sorted_ids.clone(), num_valid_ids, token_num) + + hip_scale = torch.zeros( + (sorted_ids.shape[0], model_dim // 32), + dtype=torch.uint8, + device=input.device, + ) + (hip_out, hip_scale), hip_us = run_perftest( + run_hip_quant_sort, + hip_scale, + input, + sorted_ids, + num_valid_ids, + topk, + block_size, + ) + + (triton_out, triton_scale), triton_us = run_perftest( + fused_dynamic_mxfp4_quant_moe_sort, + input, + sorted_ids=sorted_ids, + num_valid_ids=num_valid_ids, + token_num=token_num, + topk=topk, + block_size=block_size, + ) + + mask = ref_scale == 0 + triton_scale = triton_scale[: ref_scale.shape[0]] + triton_scale.view(torch.uint8)[mask] = 0 + num_valid_ids = num_valid_ids.item() + num_valid_ids = (num_valid_ids + block_size - 1) // block_size * block_size + + checkAllclose(ref_out.view(torch.uint8), hip_out.view(torch.uint8), msg="hip out") + hip_err = checkAllclose( + ref_scale[:num_valid_ids].view(torch.uint8), + hip_scale[:num_valid_ids].view(torch.uint8), + msg="hip sorted_mxfp4_scale", + ) + + # checkAllclose(ref_out.view(torch.uint8), triton_out.view(torch.uint8), msg="triton out") + triton_err = checkAllclose( + ref_scale[:num_valid_ids].view(torch.uint8), + triton_scale[:num_valid_ids].view(torch.uint8), + msg="triton sorted_mxfp4_scale", + ) + + return { + "us_triton": triton_us, + "err_triton": triton_err, + "us_hip": hip_us, + "err_hip": hip_err, + } parser = argparse.ArgumentParser( @@ -103,36 +192,36 @@ def test_moe_mxfp4_sort(dtype, token_num, model_dim, E, topk, block_size, stage) "-dim", type=int, nargs="*", - default=[4096, 6144, 8192], + default=[4096, 7168, 8192], help="""Model dimension. e.g.: -dim 4096""", ) parser.add_argument( - "-e", - "--expert", - type=int, + "-ek", + "--expert_topk", + type=dtypes.str2tuple, nargs="*", - default=[32, 256, 257, 512], + default=[[32, 5], [256, 8], [512, 8]], help="""Number of experts. - e.g.: -e 32""", -) -parser.add_argument( - "-t", - "--topk", - type=int, - nargs="*", - default=[5, 8], - help="""Number of top experts. - e.g.: -t 5""", + e.g.: -e 32,5""", ) parser.add_argument( "-m", type=int, nargs="*", - default=[1, 31, 64, 128, 256, 10000, 163840], + default=[1, 31, 64, 128, 256, 4200, 10000, 163840], help="""M of mnk. e.g.: -m 64""", ) +parser.add_argument( + "-bm", + "--block_m", + type=int, + default=32, + choices=[16, 32, 64, 80, 128], + help="""Block M. + e.g.: -bm 64""", +) args = parser.parse_args() @@ -140,11 +229,10 @@ def test_moe_mxfp4_sort(dtype, token_num, model_dim, E, topk, block_size, stage) for dtype in args.dtype: for ( dim, - E, - topk, + (E, topk), m, - ) in itertools.product(args.dim, args.expert, args.topk, args.m): - ret = test_moe_mxfp4_sort(dtype, m, dim, E, topk, 32, "stage1") + ) in itertools.product(args.dim, args.expert_topk, args.m): + ret = test_moe_mxfp4_sort(dtype, m, dim, E, topk, args.block_m, "stage1") df.append(ret) df = pd.DataFrame(df) df_md = df.to_markdown(index=False) @@ -154,12 +242,37 @@ def test_moe_mxfp4_sort(dtype, token_num, model_dim, E, topk, block_size, stage) for dtype in args.dtype: for ( dim, - E, - topk, + (E, topk), m, - ) in itertools.product(args.dim, args.expert, args.topk, args.m): - ret = test_moe_mxfp4_sort(dtype, m, dim, E, topk, 32, "stage2") + ) in itertools.product(args.dim, args.expert_topk, args.m): + ret = test_moe_mxfp4_sort(dtype, m, dim, E, topk, args.block_m, "stage2") df.append(ret) df = pd.DataFrame(df) df_md = df.to_markdown(index=False) aiter.logger.info("moe_sorting_mxfp4_stage2 summary (markdown):\n%s", df_md) + +df = [] +for dtype in args.dtype: + for ( + dim, + (E, topk), + m, + ) in itertools.product(args.dim, args.expert_topk, args.m): + ret = test_moe_mxfp4_quant_sort(dtype, m, dim, E, topk, args.block_m, "stage1") + df.append(ret) +df = pd.DataFrame(df) +df_md = df.to_markdown(index=False) +aiter.logger.info("moe_mxfp4_quant_sort_stage1 summary (markdown):\n%s", df_md) + +df = [] +for dtype in args.dtype: + for ( + dim, + (E, topk), + m, + ) in itertools.product(args.dim, args.expert_topk, args.m): + ret = test_moe_mxfp4_quant_sort(dtype, m, dim, E, topk, args.block_m, "stage2") + df.append(ret) +df = pd.DataFrame(df) +df_md = df.to_markdown(index=False) +aiter.logger.info("moe_mxfp4_quant_sort_stage2 summary (markdown):\n%s", df_md) From e735ecd4c787bcc726b2a4c808c78e15a9c1ae69 Mon Sep 17 00:00:00 2001 From: chenjun Date: Sun, 5 Apr 2026 09:21:53 -0500 Subject: [PATCH 2/8] use hip fused_dynamic_mxfp4_quant_moe_sort in fuse_moe --- aiter/fused_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 53ed0e0d63..923e6149b1 100755 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -17,7 +17,7 @@ from aiter.jit.utils.chip_info import get_cu_num, get_gfx from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.flydsl.utils import is_flydsl_available -from aiter.ops.triton.quant.fused_mxfp4_quant import fused_dynamic_mxfp4_quant_moe_sort +from aiter import fused_dynamic_mxfp4_quant_moe_sort from aiter.utility import fp4_utils BLOCK_SIZE_M = 32 From e7961a0750a0b063f26f3cfb95b2d1a51bf71e6d Mon Sep 17 00:00:00 2001 From: chenjun Date: Sun, 5 Apr 2026 10:22:22 -0500 Subject: [PATCH 3/8] update --- csrc/kernels/quant_kernels.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/kernels/quant_kernels.cu b/csrc/kernels/quant_kernels.cu index 50d58bf6b9..68fa219f6a 100644 --- a/csrc/kernels/quant_kernels.cu +++ b/csrc/kernels/quant_kernels.cu @@ -1797,9 +1797,10 @@ void fused_dynamic_mxfp4_quant_moe_sort_hip( { int cols = input.size(-1); int token_num = input.numel() / (cols * topk); + int num_experts = (sorted_ids.size(0) + topk - topk * token_num) / block_m; const int num_cu = get_num_cu_func(); - int sub_block_m = (token_num * topk) > (num_cu * 8) ? 2 : 8; + int sub_block_m = (token_num * topk) > (num_cu * 8) || num_experts < 64 ? 2 : 8; TORCH_CHECK(block_m % sub_block_m == 0, __func__, " block_m is not divisible by sub_block_m"); int num_blocks = (sorted_ids.size(0) + sub_block_m - 1) / sub_block_m; const bool persistent_mode = false; From a5da769ddc2400b5d1bf87d1b5ba52f6a4de0a85 Mon Sep 17 00:00:00 2001 From: chenjun Date: Tue, 7 Apr 2026 05:44:48 -0500 Subject: [PATCH 4/8] add mxfp4_moe_sort_hip --- aiter/ops/quant.py | 19 +++- csrc/include/quant.h | 9 ++ csrc/include/rocm_ops.hpp | 10 ++ csrc/kernels/quant_kernels.cu | 156 ++++++++++++++++++++++++++++- op_tests/test_moe_sorting_mxfp4.py | 108 ++++++++++++++++---- 5 files changed, 275 insertions(+), 27 deletions(-) diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index 848d402efc..d0170bc7ef 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -577,7 +577,7 @@ def fused_dynamic_mxfp4_quant_moe_sort( M, N = input.view(-1, input.shape[-1]).shape out = torch.empty(M, N // 2, dtype=dtypes.fp4x2, device=input.device) scales = torch.empty( - sorted_ids.shape[0], + ((sorted_ids.shape[0] + 31) // 32 * 32, N // 32), (N + 31) // 32, dtype=dtypes.fp8_e8m0, device=input.device, @@ -588,6 +588,23 @@ def fused_dynamic_mxfp4_quant_moe_sort( return out, scales +@compile_ops("module_quant") +def mxfp4_moe_sort_hip( + out_scale: torch.Tensor, + scale: torch.Tensor, + sorted_ids: torch.Tensor, + num_valid_ids: torch.Tensor, + token_num: int, + cols: int, + topk: int, + block_m: int, +) -> None: + """ + MoE scale sorting with MXFP4 shuffle layout. + """ + ... + + @compile_ops("module_quant") def partial_transpose( out: Tensor, diff --git a/csrc/include/quant.h b/csrc/include/quant.h index eb658927b7..92de857397 100644 --- a/csrc/include/quant.h +++ b/csrc/include/quant.h @@ -75,4 +75,13 @@ void fused_dynamic_mxfp4_quant_moe_sort_hip(torch::Tensor& out, // [toke int topk, int block_m, int group_size = 32); + +void mxfp4_moe_sort_hip(torch::Tensor& out_scale, + torch::Tensor const& scale, + torch::Tensor const& sorted_ids, + torch::Tensor const& num_valid_ids, + int token_num, + int cols, + int topk, + int block_m); } // namespace aiter diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 3fcfd7e836..24f0621eed 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1332,6 +1332,16 @@ namespace py = pybind11; py::arg("topk"), \ py::arg("block_m"), \ py::arg("group_size") = 32); \ + m.def("mxfp4_moe_sort_hip", \ + &aiter::mxfp4_moe_sort_hip, \ + py::arg("out_scale"), \ + py::arg("scale"), \ + py::arg("sorted_ids"), \ + py::arg("num_valid_ids"), \ + py::arg("token_num"), \ + py::arg("cols"), \ + py::arg("topk"), \ + py::arg("block_m")); \ m.def("partial_transpose", \ &aiter::partial_transpose, \ py::arg("out"), \ diff --git a/csrc/kernels/quant_kernels.cu b/csrc/kernels/quant_kernels.cu index 68fa219f6a..0efe7f5969 100644 --- a/csrc/kernels/quant_kernels.cu +++ b/csrc/kernels/quant_kernels.cu @@ -1740,8 +1740,8 @@ __global__ void mxfp4_quant_moe_sort_kernel( AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "mxfp4_quant_moe_sort_kernel", [&] { \ AITER_CHECK(group_size % THREAD_DATA == 0, __func__, " group_size is not divisible by THREAD_DATA"); \ using input_dtype = typename t2opus::type; \ - int warps_per_cu = 8 * BLOCK_SIZE / WARP_SIZE; \ - int num_tg = persistent_mode ? num_cu * warps_per_cu : num_blocks; \ + int blocks_per_cu = 8 * (BLOCK_SIZE / WARP_SIZE); \ + int num_tg = persistent_mode ? num_cu * blocks_per_cu : num_blocks; \ dim3 const grid(num_tg); \ mxfp4_quant_moe_sort_kernel \ <<>>( \ @@ -1763,7 +1763,11 @@ __global__ void mxfp4_quant_moe_sort_kernel( #define MXFP4_QUANT_MOE_SORT_KERNEL_DISPATCH(DTYPE_O, cols_) \ - if(cols_ <= 4 * BlockSize) \ + if(cols_ <= 2 * BlockSize) \ + { \ + MXFP4_QUANT_MOE_SORT_KERNEL_IMPL(DTYPE_O, 8, BlockSize / 4) \ + } \ + else if(cols_ <= 4 * BlockSize) \ { \ MXFP4_QUANT_MOE_SORT_KERNEL_IMPL(DTYPE_O, 8, BlockSize / 2) \ } \ @@ -1800,7 +1804,7 @@ void fused_dynamic_mxfp4_quant_moe_sort_hip( int num_experts = (sorted_ids.size(0) + topk - topk * token_num) / block_m; const int num_cu = get_num_cu_func(); - int sub_block_m = (token_num * topk) > (num_cu * 8) || num_experts < 64 ? 2 : 8; + int sub_block_m = (token_num * topk) > (num_cu * 8) || num_experts < 64 ? 2 : 4; TORCH_CHECK(block_m % sub_block_m == 0, __func__, " block_m is not divisible by sub_block_m"); int num_blocks = (sorted_ids.size(0) + sub_block_m - 1) / sub_block_m; const bool persistent_mode = false; @@ -1823,4 +1827,148 @@ void fused_dynamic_mxfp4_quant_moe_sort_hip( #endif } +template +__global__ void mxfp4_moe_sort_kernel( + uint8_t* __restrict__ out_scale, + uint8_t* __restrict__ scale, + int32_t const* __restrict__ sorted_ids, + int32_t const* __restrict__ num_valid_ids, + const int32_t num_tokens, + const int32_t cols, + const int32_t num_blocks, + const int32_t num_tg, + const int32_t topk) +{ + constexpr int threads_per_row = block_size / num_rows; + int num_valid_ids_value = num_valid_ids[0]; + int block_idx = blockIdx.x; + int row_i = threadIdx.x / threads_per_row; + int scale_k = threadIdx.x % threads_per_row * thread_data_size; + const int scale_per_row = (cols + group_size - 1) / group_size; + static constexpr int32_t vec_size_i = thread_data_size; + static constexpr int32_t load_chunk_bytes = + (sizeof(uint8_t) * vec_size_i % 16 == 0 ? 16 + : (sizeof(uint8_t) * vec_size_i % 8 == 0 ? 8 + : (sizeof(uint8_t) * vec_size_i % 4 == 0 ? 4 : 2))); + using vec_i = opus::vector_t; + const int32_t scaleN_valid = (cols + group_size - 1) / group_size; + const int32_t scaleN_pad = ((scaleN_valid + 7) / 8) * 8; + auto fp4_scale_shuffle_id = [](int32_t scaleN_pad_, int32_t x, int32_t y) { + return (x / 32 * scaleN_pad_) * 32 + (y / 8) * 256 + (y % 4) * 64 + + (x % 16) * 4 + (y % 8) / 4 * 2 + (x % 32) / 16; + }; + auto buffer_scale = + opus::make_gmem(scale, scale_per_row * num_tokens * topk * sizeof(uint8_t)); + for(; block_idx < num_blocks; block_idx += num_tg) + { + int sorted_row = block_idx * num_rows + row_i; + int token_id_info = num_tokens; + if (sorted_row < num_valid_ids_value) + { + token_id_info = sorted_ids[sorted_row]; + } + int token_idx = token_id_info & 0xFFFFFF; + int topk_id = token_id_info >> 24; + if(token_idx < num_tokens && (topk == 1 || topk_id < topk)) + { + int64_t scale_offset; + if (topk == 1) + { + scale_offset = (int64_t)(token_idx) * scale_per_row; + } + else + { + scale_offset = (int64_t)(token_idx * topk + topk_id) * scale_per_row; + } + vec_i vec_scale = load_vector_nbytes( + buffer_scale, scale_offset + scale_k); + + for(int j = 0; j < vec_size_i; j++) + { + if((scale_k + j) < scaleN_valid) + { + int addr = fp4_scale_shuffle_id(scaleN_pad, sorted_row, scale_k + j); + out_scale[addr] = vec_scale[j]; + } + } + } + } +} + + +#define MXFP4_MOE_SORT_KERNEL_IMPL(MAX_COL, THREAD_DATA, BLOCK_SIZE) \ + constexpr int GROUP_SIZE = 32; \ + constexpr int NUM_ROWS = BLOCK_SIZE / (MAX_COL /(GROUP_SIZE * THREAD_DATA)); \ + TORCH_CHECK(BLOCK_SIZE % (MAX_COL /(GROUP_SIZE * THREAD_DATA)) == 0); \ + int num_blocks = (sorted_ids.size(0) + NUM_ROWS - 1) / NUM_ROWS; \ + int blocks_per_cu = 8 * (BLOCK_SIZE / WARP_SIZE); \ + int num_tg = persistent_mode ? num_cu * blocks_per_cu : num_blocks; \ + dim3 const grid(num_tg); \ + mxfp4_moe_sort_kernel \ + <<>>( \ + reinterpret_cast(out_scale.data_ptr()), \ + reinterpret_cast(scale.data_ptr()), \ + sorted_ids.data_ptr(), \ + num_valid_ids.data_ptr(), \ + token_num, cols, num_blocks, num_tg, topk); + + +#define MXFP4_MOE_SORT_KERNEL_DISPATCH(cols_) \ + if(cols_ <= 256) \ + { \ + MXFP4_MOE_SORT_KERNEL_IMPL(256, 4, 256) \ + } \ + else if(cols_ <= 512) \ + { \ + MXFP4_MOE_SORT_KERNEL_IMPL(512, 4, 256) \ + } \ + else if(cols_ <= 1024) \ + { \ + MXFP4_MOE_SORT_KERNEL_IMPL(1024, 4, 256) \ + } \ + else if(cols_ <= 2048) \ + { \ + MXFP4_MOE_SORT_KERNEL_IMPL(2048, 8, 256) \ + } \ + else if(cols_ <= 4096) \ + { \ + MXFP4_MOE_SORT_KERNEL_IMPL(4096, 16, 256) \ + } \ + else if(cols_ <= 6144) \ + { \ + MXFP4_MOE_SORT_KERNEL_IMPL(6144, 24, 256) \ + } \ + else if(cols_ <= 8192) \ + { \ + MXFP4_MOE_SORT_KERNEL_IMPL(8192, 32, 256) \ + } \ + else if(cols_ <= 16384) \ + { \ + MXFP4_MOE_SORT_KERNEL_IMPL(16384, 32, 256) \ + } \ + else \ + { \ + TORCH_CHECK(false, "input last dim has exceeded the maximum value ", 16384) \ + } + +void mxfp4_moe_sort_hip( + torch::Tensor& out_scale, + torch::Tensor const& scale, + torch::Tensor const& sorted_ids, + torch::Tensor const& num_valid_ids, + int token_num, + int cols, + int topk, + int block_m +) +{ + const int num_cu = get_num_cu_func(); + const bool persistent_mode = false; + + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(scale)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); + + MXFP4_MOE_SORT_KERNEL_DISPATCH(cols); +} + } // namespace aiter \ No newline at end of file diff --git a/op_tests/test_moe_sorting_mxfp4.py b/op_tests/test_moe_sorting_mxfp4.py index 2bad7489be..9bf7167218 100644 --- a/op_tests/test_moe_sorting_mxfp4.py +++ b/op_tests/test_moe_sorting_mxfp4.py @@ -5,6 +5,7 @@ from aiter.test_common import checkAllclose, benchmark, run_perftest from aiter.fused_moe import moe_sorting, fused_topk from aiter.ops.triton.quant.fused_mxfp4_quant import fused_dynamic_mxfp4_quant_moe_sort +from aiter.jit.utils.chip_info import get_gfx from aiter import get_torch_quant, dtypes from aiter.utility import fp4_utils import pandas as pd @@ -43,15 +44,18 @@ def run_torch(scale, sorted_ids, num_valid_ids, token_num): return ref -def run_hip_quant_sort(scale, input, sorted_ids, num_valid_ids, topk, block_size): - M, N = input.shape - out = torch.empty( - (M, N // 2), - dtype=dtypes.fp4x2, - device=input.device, - ) - aiter.fused_dynamic_mxfp4_quant_moe_sort_hip( - out, scale, input, sorted_ids, num_valid_ids, topk, block_size +def run_split_quant_sort(scale, input, sorted_ids, num_valid_ids, topk, block_size): + out, scale_ = aiter.per_1x32_f4_quant_hip(input, None, dtypes.fp4x2) + token_num = input.numel() // (input.shape[-1] * topk) + aiter.mxfp4_moe_sort_hip( + scale, + scale_, + sorted_ids, + num_valid_ids, + token_num, + input.shape[-1], + topk, + block_size, ) return out, scale @@ -73,11 +77,12 @@ def test_moe_mxfp4_sort(dtype, token_num, model_dim, E, topk, block_size, stage) if stage == "stage1": scale = torch.arange(token_num * model_dim // 32, dtype=torch.uint8) scale = scale.view(token_num, model_dim // 32) + topk = 1 else: scale = torch.arange(token_num * topk * model_dim // 32, dtype=torch.uint8) scale = scale.view(token_num, topk, model_dim // 32) ref = run_torch(scale.clone(), sorted_ids.clone(), num_valid_ids, token_num) - sorted_mxfp4_scale, us = run_perftest( + triton_scale, triton_us = run_perftest( fp4_utils.moe_mxfp4_sort, scale, sorted_ids, @@ -86,19 +91,49 @@ def test_moe_mxfp4_sort(dtype, token_num, model_dim, E, topk, block_size, stage) block_size, ) + hip_scale = torch.zeros( + ((sorted_ids.shape[0] + 31) // 32 * 32, model_dim // 32), + dtype=torch.uint8, + device=input.device, + ) + _, hip_us = run_perftest( + aiter.mxfp4_moe_sort_hip, + hip_scale, + scale, + sorted_ids, + num_valid_ids, + token_num, + model_dim, + topk, + block_size, + ) + num_valid_ids = num_valid_ids.item() num_valid_ids = (num_valid_ids + block_size - 1) // block_size * block_size - err = checkAllclose( + triton_err = checkAllclose( ref[:num_valid_ids], - sorted_mxfp4_scale[:num_valid_ids].view(torch.uint8), + triton_scale[:num_valid_ids].view(torch.uint8), msg="sorted_mxfp4_scale", ) - return {"us": us, "err": err} + + hip_err = checkAllclose( + ref[:num_valid_ids].view(torch.uint8), + hip_scale[:num_valid_ids].view(torch.uint8), + msg="hip sorted_mxfp4_scale", + ) + return { + "triton_us": triton_us, + "triton_err": triton_err, + "hip_us": hip_us, + "hip_err": hip_err, + } @benchmark() def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size, stage): + if get_gfx().startswith("gfx94"): + return {} input = torch.randn((token_num, model_dim), dtype=dtype) score = torch.randn((token_num, E), dtype=dtype) @@ -120,13 +155,34 @@ def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size, ) ref_scale = run_torch(scale.clone(), sorted_ids.clone(), num_valid_ids, token_num) + split_scale = torch.zeros( + ((sorted_ids.shape[0] + 31) // 32 * 32, model_dim // 32), + dtype=torch.uint8, + device=input.device, + ) + (split_out, split_scale), split_us = run_perftest( + run_split_quant_sort, + split_scale, + input, + sorted_ids, + num_valid_ids, + topk, + block_size, + ) + hip_scale = torch.zeros( - (sorted_ids.shape[0], model_dim // 32), + ((sorted_ids.shape[0] + 31) // 32 * 32, model_dim // 32), dtype=torch.uint8, device=input.device, ) - (hip_out, hip_scale), hip_us = run_perftest( - run_hip_quant_sort, + hip_out = torch.empty( + (token_num * topk, model_dim // 2), + dtype=dtypes.fp4x2, + device=input.device, + ) + _, hip_us = run_perftest( + aiter.fused_dynamic_mxfp4_quant_moe_sort_hip, + hip_out, hip_scale, input, sorted_ids, @@ -165,11 +221,19 @@ def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size, msg="triton sorted_mxfp4_scale", ) + split_err = checkAllclose( + ref_scale[:num_valid_ids].view(torch.uint8), + split_scale[:num_valid_ids].view(torch.uint8), + msg="split sorted_mxfp4_scale", + ) + return { - "us_triton": triton_us, - "err_triton": triton_err, - "us_hip": hip_us, - "err_hip": hip_err, + "triton_us": triton_us, + "triton_err": triton_err, + "hip_us": hip_us, + "hip_err": hip_err, + "split_us": split_us, + "split_err": split_err, } @@ -192,7 +256,7 @@ def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size, "-dim", type=int, nargs="*", - default=[4096, 7168, 8192], + default=[4096, 7168], help="""Model dimension. e.g.: -dim 4096""", ) @@ -209,7 +273,7 @@ def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size, "-m", type=int, nargs="*", - default=[1, 31, 64, 128, 256, 4200, 10000, 163840], + default=[1, 64, 128, 256, 1024, 2050, 4200, 10000, 163840], help="""M of mnk. e.g.: -m 64""", ) From b17d3996f58280a7b4ce86c9bf160ae05e67cf2d Mon Sep 17 00:00:00 2001 From: chenjun Date: Tue, 7 Apr 2026 11:07:11 -0500 Subject: [PATCH 5/8] add dispatch to choose the fused kernel or not in aiter.fused_dynamic_mxfp4_quant_moe_sort --- aiter/fused_moe.py | 51 ++++++------------------------ aiter/ops/quant.py | 36 +++++++++++++++------ op_tests/test_moe_sorting_mxfp4.py | 22 +++++++++---- 3 files changed, 51 insertions(+), 58 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 923e6149b1..4fb5538944 100755 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -1144,7 +1144,6 @@ def fused_moe_2stages( bias2=None, ): quant_func = get_quant(quant_type) - token_num_quant_moe_sort_switch = 1024 token_num, _ = hidden_states.shape E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape) dtype = moe_out.dtype @@ -1203,27 +1202,13 @@ def fused_moe_2stages( token_num=token_num, block_size=block_size_M, ) - elif token_num <= token_num_quant_moe_sort_switch: + else: a1, a1_scale = fused_dynamic_mxfp4_quant_moe_sort( hidden_states, sorted_ids=sorted_ids, num_valid_ids=num_valid_ids, token_num=token_num, - topk=1, - block_size=block_size_M, - ) - else: - a1, a1_scale = quant_func( - hidden_states, - scale=a1_scale, - quant_dtype=q_dtype_a, - num_rows=num_local_tokens, - ) - a1_scale = fp4_utils.moe_mxfp4_sort( - a1_scale, - sorted_ids=sorted_ids, - num_valid_ids=num_valid_ids, - token_num=token_num, + topk=topk, block_size=block_size_M, ) elif hidden_states.dtype != q_dtype_a: @@ -1312,30 +1297,14 @@ def fused_moe_2stages( a2_scale = a1_scale elif quant_type == QuantType.per_1x32: a2 = a2.view(-1, inter_dim) - if token_num <= token_num_quant_moe_sort_switch: - a2, a2_scale = fused_dynamic_mxfp4_quant_moe_sort( - a2, - sorted_ids=sorted_ids, - num_valid_ids=num_valid_ids, - token_num=token_num, - topk=topk, - block_size=block_size_M, - ) - else: - a2, a2_scale = quant_func( - a2, - scale=a2_scale, - quant_dtype=q_dtype_a, - num_rows=num_local_tokens, - num_rows_factor=topk, - ) - a2_scale = fp4_utils.moe_mxfp4_sort( - a2_scale[: token_num * topk, :].view(token_num, topk, -1), - sorted_ids=sorted_ids, - num_valid_ids=num_valid_ids, - token_num=token_num, - block_size=block_size_M, - ) + a2, a2_scale = fused_dynamic_mxfp4_quant_moe_sort( + a2, + sorted_ids=sorted_ids, + num_valid_ids=num_valid_ids, + token_num=token_num, + topk=topk, + block_size=block_size_M, + ) a2 = a2.view(token_num, topk, -1) elif quant_type == QuantType.per_1x128 and metadata.stage1.func is asm_stage1: a2_v = a2[:token_num, :, :] diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index d0170bc7ef..2a1ba26466 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -555,7 +555,7 @@ def fused_dynamic_mxfp4_quant_moe_sort_hip( input: torch.Tensor, sorted_ids: torch.Tensor, num_valid_ids: torch.Tensor, - topk: int, + topk: int, # stage1 : 1, stage2 : topk block_m: int, group_size: int = 32, ) -> None: @@ -570,22 +570,38 @@ def fused_dynamic_mxfp4_quant_moe_sort( sorted_ids: torch.Tensor, num_valid_ids: torch.Tensor, token_num: int, - topk: int, + topk: int, # stage1 and stage2: same topk value block_size: int, group_size: int = 32, ) -> Tuple[torch.Tensor, torch.Tensor]: + token_num_quant_moe_sort_switch = [ + 8 * 64 / topk, + 8 * 1024 / topk, + ] # [stage1, stage2] M, N = input.view(-1, input.shape[-1]).shape - out = torch.empty(M, N // 2, dtype=dtypes.fp4x2, device=input.device) - scales = torch.empty( - ((sorted_ids.shape[0] + 31) // 32 * 32, N // 32), + is_stage1 = M == token_num + topk = 1 if is_stage1 else topk + scale = torch.empty( + (sorted_ids.shape[0] + 31) // 32 * 32, (N + 31) // 32, dtype=dtypes.fp8_e8m0, device=input.device, ) - fused_dynamic_mxfp4_quant_moe_sort_hip( - out, scales, input, sorted_ids, num_valid_ids, topk, block_size, group_size - ) - return out, scales + if ( + (is_stage1 and M <= token_num_quant_moe_sort_switch[0]) + or (not is_stage1 and M <= token_num_quant_moe_sort_switch[1]) + or group_size != 32 + ): + out = torch.empty(M, N // 2, dtype=dtypes.fp4x2, device=input.device) + fused_dynamic_mxfp4_quant_moe_sort_hip( + out, scale, input, sorted_ids, num_valid_ids, topk, block_size, group_size + ) + else: + out, scale_ = per_1x32_f4_quant_hip(input, None, dtypes.fp4x2) + mxfp4_moe_sort_hip( + scale, scale_, sorted_ids, num_valid_ids, token_num, N, topk, block_size + ) + return out, scale @compile_ops("module_quant") @@ -596,7 +612,7 @@ def mxfp4_moe_sort_hip( num_valid_ids: torch.Tensor, token_num: int, cols: int, - topk: int, + topk: int, # stage1 : 1, stage2 : topk block_m: int, ) -> None: """ diff --git a/op_tests/test_moe_sorting_mxfp4.py b/op_tests/test_moe_sorting_mxfp4.py index 9bf7167218..cb3773183e 100644 --- a/op_tests/test_moe_sorting_mxfp4.py +++ b/op_tests/test_moe_sorting_mxfp4.py @@ -253,12 +253,20 @@ def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size, e.g.: -d bf16""", ) parser.add_argument( - "-dim", + "-dim1", type=int, nargs="*", default=[4096, 7168], - help="""Model dimension. - e.g.: -dim 4096""", + help="""Model dimension for stage1. + e.g.: -dim1 4096""", +) +parser.add_argument( + "-dim2", + type=int, + nargs="*", + default=[256, 2048], + help="""Inter dimension for stage2. + e.g.: -dim2 256""", ) parser.add_argument( "-ek", @@ -295,7 +303,7 @@ def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size, dim, (E, topk), m, - ) in itertools.product(args.dim, args.expert_topk, args.m): + ) in itertools.product(args.dim1, args.expert_topk, args.m): ret = test_moe_mxfp4_sort(dtype, m, dim, E, topk, args.block_m, "stage1") df.append(ret) df = pd.DataFrame(df) @@ -308,7 +316,7 @@ def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size, dim, (E, topk), m, - ) in itertools.product(args.dim, args.expert_topk, args.m): + ) in itertools.product(args.dim2, args.expert_topk, args.m): ret = test_moe_mxfp4_sort(dtype, m, dim, E, topk, args.block_m, "stage2") df.append(ret) df = pd.DataFrame(df) @@ -321,7 +329,7 @@ def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size, dim, (E, topk), m, - ) in itertools.product(args.dim, args.expert_topk, args.m): + ) in itertools.product(args.dim1, args.expert_topk, args.m): ret = test_moe_mxfp4_quant_sort(dtype, m, dim, E, topk, args.block_m, "stage1") df.append(ret) df = pd.DataFrame(df) @@ -334,7 +342,7 @@ def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size, dim, (E, topk), m, - ) in itertools.product(args.dim, args.expert_topk, args.m): + ) in itertools.product(args.dim2, args.expert_topk, args.m): ret = test_moe_mxfp4_quant_sort(dtype, m, dim, E, topk, args.block_m, "stage2") df.append(ret) df = pd.DataFrame(df) From 3b3257f78cac2dc3efed3f1b5c8c0b76512e4b8a Mon Sep 17 00:00:00 2001 From: chenjun Date: Tue, 7 Apr 2026 22:16:29 -0500 Subject: [PATCH 6/8] rm topk in api and use mxfp4_moe_sort_fwd instead of fp4_utils.moe_mxfp4_sort in fused_moe --- aiter/fused_moe.py | 18 +++---- aiter/ops/quant.py | 77 +++++++++++++++++++----------- csrc/include/quant.h | 6 +-- csrc/include/rocm_ops.hpp | 6 +-- csrc/kernels/quant_kernels.cu | 11 ++--- op_tests/test_moe_sorting_mxfp4.py | 17 +++---- 6 files changed, 75 insertions(+), 60 deletions(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 4fb5538944..055c8f4382 100755 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -17,7 +17,7 @@ from aiter.jit.utils.chip_info import get_cu_num, get_gfx from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.flydsl.utils import is_flydsl_available -from aiter import fused_dynamic_mxfp4_quant_moe_sort +from aiter import fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd from aiter.utility import fp4_utils BLOCK_SIZE_M = 32 @@ -454,12 +454,12 @@ def fused_moe_1stage( token_num = hidden_states.shape[0] E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape) if quant_type == QuantType.per_1x32: - a1_scale = fp4_utils.moe_mxfp4_sort( + a1_scale = mxfp4_moe_sort_fwd( a1_scale, - sorted_ids, - num_valid_ids, - token_num, - block_size_M, + sorted_ids=sorted_ids, + num_valid_ids=num_valid_ids, + token_num=token_num, + cols=model_dim, ) w1_scale = w1_scale.view(E, -1) w2_scale = w2_scale.view(E, -1) @@ -1195,12 +1195,12 @@ def fused_moe_2stages( # Input is already quantized to fp4x2 (e.g., from FP4 dispatch), # skip re-quantization, only sort the scale a1 = hidden_states - a1_scale = fp4_utils.moe_mxfp4_sort( + a1_scale = mxfp4_moe_sort_fwd( a1_scale, sorted_ids=sorted_ids, num_valid_ids=num_valid_ids, token_num=token_num, - block_size=block_size_M, + cols=model_dim, ) else: a1, a1_scale = fused_dynamic_mxfp4_quant_moe_sort( @@ -1210,6 +1210,7 @@ def fused_moe_2stages( token_num=token_num, topk=topk, block_size=block_size_M, + num_rows=num_local_tokens, ) elif hidden_states.dtype != q_dtype_a: if quant_type == QuantType.per_1x128 and metadata.stage1.func is asm_stage1: @@ -1304,6 +1305,7 @@ def fused_moe_2stages( token_num=token_num, topk=topk, block_size=block_size_M, + num_rows=num_local_tokens, ) a2 = a2.view(token_num, topk, -1) elif quant_type == QuantType.per_1x128 and metadata.stage1.func is asm_stage1: diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index 2a1ba26466..696118e2f6 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -310,7 +310,7 @@ def per_1x32_f4_quant_hip( torch.empty( ( (m + 255) // 256 * 256, - (n // 32 + 7) // 8 * 8, + ((n + 31) // 32 + 7) // 8 * 8, ), dtype=torch.uint8, device=device, @@ -321,7 +321,7 @@ def per_1x32_f4_quant_hip( else: scale = ( torch.empty( - (m, n // 32), + (m, (n + 31) // 32), dtype=torch.uint8, device=device, ) @@ -548,6 +548,38 @@ def moe_smooth_per_token_scaled_quant_v2( ... +@compile_ops("module_quant") +def mxfp4_moe_sort_hip( + out_scale: torch.Tensor, + scale: torch.Tensor, + sorted_ids: torch.Tensor, + num_valid_ids: torch.Tensor, + token_num: int, + cols: int, +) -> None: + """ + MoE scale sorting with MXFP4 shuffle layout. + """ + ... + + +def mxfp4_moe_sort_fwd( + scale: torch.Tensor, + sorted_ids: torch.Tensor, + num_valid_ids: torch.Tensor, + token_num: int, + cols: int, +): + out_scale = torch.empty( + (sorted_ids.shape[0] + 31) // 32 * 32, + (cols + 31) // 32, + dtype=dtypes.fp8_e8m0, + device=scale.device, + ) + mxfp4_moe_sort_hip(out_scale, scale, sorted_ids, num_valid_ids, token_num, cols) + return out_scale + + @compile_ops("module_quant") def fused_dynamic_mxfp4_quant_moe_sort_hip( out: torch.Tensor, @@ -555,7 +587,7 @@ def fused_dynamic_mxfp4_quant_moe_sort_hip( input: torch.Tensor, sorted_ids: torch.Tensor, num_valid_ids: torch.Tensor, - topk: int, # stage1 : 1, stage2 : topk + token_num: int, block_m: int, group_size: int = 32, ) -> None: @@ -572,12 +604,13 @@ def fused_dynamic_mxfp4_quant_moe_sort( token_num: int, topk: int, # stage1 and stage2: same topk value block_size: int, + num_rows: Optional[torch.Tensor] = None, group_size: int = 32, ) -> Tuple[torch.Tensor, torch.Tensor]: token_num_quant_moe_sort_switch = [ - 8 * 64 / topk, - 8 * 1024 / topk, - ] # [stage1, stage2] + 8 * 64 / topk, # stage1 + 8 * 1024 / topk, # stage2 + ] M, N = input.view(-1, input.shape[-1]).shape is_stage1 = M == token_num topk = 1 if is_stage1 else topk @@ -594,33 +627,23 @@ def fused_dynamic_mxfp4_quant_moe_sort( ): out = torch.empty(M, N // 2, dtype=dtypes.fp4x2, device=input.device) fused_dynamic_mxfp4_quant_moe_sort_hip( - out, scale, input, sorted_ids, num_valid_ids, topk, block_size, group_size + out, + scale, + input, + sorted_ids, + num_valid_ids, + token_num, + block_size, + group_size, ) else: - out, scale_ = per_1x32_f4_quant_hip(input, None, dtypes.fp4x2) - mxfp4_moe_sort_hip( - scale, scale_, sorted_ids, num_valid_ids, token_num, N, topk, block_size + out, scale_ = per_1x32_f4_quant_hip( + input, None, dtypes.fp4x2, num_rows=num_rows, num_rows_factor=topk ) + mxfp4_moe_sort_hip(scale, scale_, sorted_ids, num_valid_ids, token_num, N) return out, scale -@compile_ops("module_quant") -def mxfp4_moe_sort_hip( - out_scale: torch.Tensor, - scale: torch.Tensor, - sorted_ids: torch.Tensor, - num_valid_ids: torch.Tensor, - token_num: int, - cols: int, - topk: int, # stage1 : 1, stage2 : topk - block_m: int, -) -> None: - """ - MoE scale sorting with MXFP4 shuffle layout. - """ - ... - - @compile_ops("module_quant") def partial_transpose( out: Tensor, diff --git a/csrc/include/quant.h b/csrc/include/quant.h index 92de857397..4d7d425261 100644 --- a/csrc/include/quant.h +++ b/csrc/include/quant.h @@ -72,7 +72,7 @@ void fused_dynamic_mxfp4_quant_moe_sort_hip(torch::Tensor& out, // [toke torch::Tensor const& input, // [token_num * topk, d] torch::Tensor const& sorted_ids, torch::Tensor const& num_valid_ids, - int topk, + int token_num, int block_m, int group_size = 32); @@ -81,7 +81,5 @@ void mxfp4_moe_sort_hip(torch::Tensor& out_scale, torch::Tensor const& sorted_ids, torch::Tensor const& num_valid_ids, int token_num, - int cols, - int topk, - int block_m); + int cols); } // namespace aiter diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 24f0621eed..8bb7df5f25 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1329,7 +1329,7 @@ namespace py = pybind11; py::arg("input"), \ py::arg("sorted_ids"), \ py::arg("num_valid_ids"), \ - py::arg("topk"), \ + py::arg("token_num"), \ py::arg("block_m"), \ py::arg("group_size") = 32); \ m.def("mxfp4_moe_sort_hip", \ @@ -1339,9 +1339,7 @@ namespace py = pybind11; py::arg("sorted_ids"), \ py::arg("num_valid_ids"), \ py::arg("token_num"), \ - py::arg("cols"), \ - py::arg("topk"), \ - py::arg("block_m")); \ + py::arg("cols")); \ m.def("partial_transpose", \ &aiter::partial_transpose, \ py::arg("out"), \ diff --git a/csrc/kernels/quant_kernels.cu b/csrc/kernels/quant_kernels.cu index 0efe7f5969..41ac55ef75 100644 --- a/csrc/kernels/quant_kernels.cu +++ b/csrc/kernels/quant_kernels.cu @@ -1794,13 +1794,13 @@ void fused_dynamic_mxfp4_quant_moe_sort_hip( torch::Tensor const& input, torch::Tensor const& sorted_ids, torch::Tensor const& num_valid_ids, - int topk, + int token_num, int block_m, int group_size = 32 ) { int cols = input.size(-1); - int token_num = input.numel() / (cols * topk); + int topk = input.numel() / (cols * token_num); int num_experts = (sorted_ids.size(0) + topk - topk * token_num) / block_m; const int num_cu = get_num_cu_func(); @@ -1957,14 +1957,13 @@ void mxfp4_moe_sort_hip( torch::Tensor const& sorted_ids, torch::Tensor const& num_valid_ids, int token_num, - int cols, - int topk, - int block_m + int cols ) { const int num_cu = get_num_cu_func(); const bool persistent_mode = false; - + int topk = scale.numel() / ((cols + 31) / 32 * token_num); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(scale)); const hipStream_t stream = at::hip::getCurrentHIPStream(); diff --git a/op_tests/test_moe_sorting_mxfp4.py b/op_tests/test_moe_sorting_mxfp4.py index cb3773183e..2beb79a674 100644 --- a/op_tests/test_moe_sorting_mxfp4.py +++ b/op_tests/test_moe_sorting_mxfp4.py @@ -44,18 +44,16 @@ def run_torch(scale, sorted_ids, num_valid_ids, token_num): return ref -def run_split_quant_sort(scale, input, sorted_ids, num_valid_ids, topk, block_size): +def run_split_quant_sort(scale, input, sorted_ids, num_valid_ids, token_num): + model_dim = input.shape[-1] out, scale_ = aiter.per_1x32_f4_quant_hip(input, None, dtypes.fp4x2) - token_num = input.numel() // (input.shape[-1] * topk) aiter.mxfp4_moe_sort_hip( scale, scale_, sorted_ids, num_valid_ids, token_num, - input.shape[-1], - topk, - block_size, + model_dim, ) return out, scale @@ -104,8 +102,6 @@ def test_moe_mxfp4_sort(dtype, token_num, model_dim, E, topk, block_size, stage) num_valid_ids, token_num, model_dim, - topk, - block_size, ) num_valid_ids = num_valid_ids.item() @@ -166,8 +162,7 @@ def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size, input, sorted_ids, num_valid_ids, - topk, - block_size, + token_num, ) hip_scale = torch.zeros( @@ -187,7 +182,7 @@ def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size, input, sorted_ids, num_valid_ids, - topk, + token_num, block_size, ) @@ -275,7 +270,7 @@ def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size, nargs="*", default=[[32, 5], [256, 8], [512, 8]], help="""Number of experts. - e.g.: -e 32,5""", + e.g.: -ek 32,5""", ) parser.add_argument( "-m", From 5684160972c35a00753662e98a67b41183719295 Mon Sep 17 00:00:00 2001 From: chenjun Date: Tue, 7 Apr 2026 22:23:18 -0500 Subject: [PATCH 7/8] format --- aiter/fused_moe.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 055c8f4382..4700c6fe71 100755 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -18,7 +18,6 @@ from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.flydsl.utils import is_flydsl_available from aiter import fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd -from aiter.utility import fp4_utils BLOCK_SIZE_M = 32 From b3581aa622cd7cd17b3abacb9daef49e2bb5de62 Mon Sep 17 00:00:00 2001 From: chenjun Date: Wed, 8 Apr 2026 01:22:22 -0500 Subject: [PATCH 8/8] update --- csrc/kernels/quant_kernels.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/kernels/quant_kernels.cu b/csrc/kernels/quant_kernels.cu index 41ac55ef75..a62243f5a2 100644 --- a/csrc/kernels/quant_kernels.cu +++ b/csrc/kernels/quant_kernels.cu @@ -1508,9 +1508,9 @@ __global__ void moe_smooth_per_token_scaled_quant_kernel_v2(DTYPE_O* __restrict_ #define MOE_SMOOTH_PER_TOKEN_SCALED_QUANT_KERNEL_V2_IMPL(quant_kernel, DTYPE_O, THREAD_DATA, BLOCK_SIZE) \ AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "quant_kernel", [&] { \ - using input_dtype = typename t2opus::type; \ - int warps_per_cu = 8 * BLOCK_SIZE / WARP_SIZE; \ - int num_tg = persistent_mode? num_cu * warps_per_cu : num_blocks; \ + using input_dtype = typename t2opus::type; \ + int blocks_per_cu = 8 * 4 / (BLOCK_SIZE / WARP_SIZE); \ + int num_tg = persistent_mode ? num_cu * blocks_per_cu : num_blocks; \ dim3 const grid(num_tg); \ aiter::quant_kernel \ <<>>( \ @@ -1740,7 +1740,7 @@ __global__ void mxfp4_quant_moe_sort_kernel( AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "mxfp4_quant_moe_sort_kernel", [&] { \ AITER_CHECK(group_size % THREAD_DATA == 0, __func__, " group_size is not divisible by THREAD_DATA"); \ using input_dtype = typename t2opus::type; \ - int blocks_per_cu = 8 * (BLOCK_SIZE / WARP_SIZE); \ + int blocks_per_cu = 8 * 4 / (BLOCK_SIZE / WARP_SIZE); \ int num_tg = persistent_mode ? num_cu * blocks_per_cu : num_blocks; \ dim3 const grid(num_tg); \ mxfp4_quant_moe_sort_kernel \ @@ -1901,7 +1901,7 @@ __global__ void mxfp4_moe_sort_kernel( constexpr int NUM_ROWS = BLOCK_SIZE / (MAX_COL /(GROUP_SIZE * THREAD_DATA)); \ TORCH_CHECK(BLOCK_SIZE % (MAX_COL /(GROUP_SIZE * THREAD_DATA)) == 0); \ int num_blocks = (sorted_ids.size(0) + NUM_ROWS - 1) / NUM_ROWS; \ - int blocks_per_cu = 8 * (BLOCK_SIZE / WARP_SIZE); \ + int blocks_per_cu = 8 * 4 / (BLOCK_SIZE / WARP_SIZE); \ int num_tg = persistent_mode ? num_cu * blocks_per_cu : num_blocks; \ dim3 const grid(num_tg); \ mxfp4_moe_sort_kernel \