diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 53ed0e0d63..4700c6fe71 100755 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -17,8 +17,7 @@ from aiter.jit.utils.chip_info import get_cu_num, get_gfx from aiter.jit.utils.torch_guard import torch_compile_guard from aiter.ops.flydsl.utils import is_flydsl_available -from aiter.ops.triton.quant.fused_mxfp4_quant import fused_dynamic_mxfp4_quant_moe_sort -from aiter.utility import fp4_utils +from aiter import fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd BLOCK_SIZE_M = 32 @@ -454,12 +453,12 @@ def fused_moe_1stage( token_num = hidden_states.shape[0] E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape) if quant_type == QuantType.per_1x32: - a1_scale = fp4_utils.moe_mxfp4_sort( + a1_scale = mxfp4_moe_sort_fwd( a1_scale, - sorted_ids, - num_valid_ids, - token_num, - block_size_M, + sorted_ids=sorted_ids, + num_valid_ids=num_valid_ids, + token_num=token_num, + cols=model_dim, ) w1_scale = w1_scale.view(E, -1) w2_scale = w2_scale.view(E, -1) @@ -1144,7 +1143,6 @@ def fused_moe_2stages( bias2=None, ): quant_func = get_quant(quant_type) - token_num_quant_moe_sort_switch = 1024 token_num, _ = hidden_states.shape E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape) dtype = moe_out.dtype @@ -1196,36 +1194,23 @@ def fused_moe_2stages( # Input is already quantized to fp4x2 (e.g., from FP4 dispatch), # skip re-quantization, only sort the scale a1 = hidden_states - a1_scale = fp4_utils.moe_mxfp4_sort( + a1_scale = mxfp4_moe_sort_fwd( a1_scale, sorted_ids=sorted_ids, num_valid_ids=num_valid_ids, token_num=token_num, - block_size=block_size_M, + cols=model_dim, ) - elif token_num <= token_num_quant_moe_sort_switch: + else: a1, a1_scale = fused_dynamic_mxfp4_quant_moe_sort( hidden_states, sorted_ids=sorted_ids, num_valid_ids=num_valid_ids, token_num=token_num, - topk=1, + topk=topk, block_size=block_size_M, - ) - else: - a1, a1_scale = quant_func( - hidden_states, - scale=a1_scale, - quant_dtype=q_dtype_a, num_rows=num_local_tokens, ) - a1_scale = fp4_utils.moe_mxfp4_sort( - a1_scale, - sorted_ids=sorted_ids, - num_valid_ids=num_valid_ids, - token_num=token_num, - block_size=block_size_M, - ) elif hidden_states.dtype != q_dtype_a: if quant_type == QuantType.per_1x128 and metadata.stage1.func is asm_stage1: quant_func = functools.partial(quant_func, transpose_scale=True) @@ -1312,30 +1297,15 @@ def fused_moe_2stages( a2_scale = a1_scale elif quant_type == QuantType.per_1x32: a2 = a2.view(-1, inter_dim) - if token_num <= token_num_quant_moe_sort_switch: - a2, a2_scale = fused_dynamic_mxfp4_quant_moe_sort( - a2, - sorted_ids=sorted_ids, - num_valid_ids=num_valid_ids, - token_num=token_num, - topk=topk, - block_size=block_size_M, - ) - else: - a2, a2_scale = quant_func( - a2, - scale=a2_scale, - quant_dtype=q_dtype_a, - num_rows=num_local_tokens, - num_rows_factor=topk, - ) - a2_scale = fp4_utils.moe_mxfp4_sort( - a2_scale[: token_num * topk, :].view(token_num, topk, -1), - sorted_ids=sorted_ids, - num_valid_ids=num_valid_ids, - token_num=token_num, - block_size=block_size_M, - ) + a2, a2_scale = fused_dynamic_mxfp4_quant_moe_sort( + a2, + sorted_ids=sorted_ids, + num_valid_ids=num_valid_ids, + token_num=token_num, + topk=topk, + block_size=block_size_M, + num_rows=num_local_tokens, + ) a2 = a2.view(token_num, topk, -1) elif quant_type == QuantType.per_1x128 and metadata.stage1.func is asm_stage1: a2_v = a2[:token_num, :, :] diff --git a/aiter/ops/quant.py b/aiter/ops/quant.py index 36b0ba62cf..696118e2f6 100644 --- a/aiter/ops/quant.py +++ b/aiter/ops/quant.py @@ -310,7 +310,7 @@ def per_1x32_f4_quant_hip( torch.empty( ( (m + 255) // 256 * 256, - (n // 32 + 7) // 8 * 8, + ((n + 31) // 32 + 7) // 8 * 8, ), dtype=torch.uint8, device=device, @@ -321,7 +321,7 @@ def per_1x32_f4_quant_hip( else: scale = ( torch.empty( - (m, n // 32), + (m, (n + 31) // 32), dtype=torch.uint8, device=device, ) @@ -548,6 +548,102 @@ def moe_smooth_per_token_scaled_quant_v2( ... +@compile_ops("module_quant") +def mxfp4_moe_sort_hip( + out_scale: torch.Tensor, + scale: torch.Tensor, + sorted_ids: torch.Tensor, + num_valid_ids: torch.Tensor, + token_num: int, + cols: int, +) -> None: + """ + MoE scale sorting with MXFP4 shuffle layout. + """ + ... + + +def mxfp4_moe_sort_fwd( + scale: torch.Tensor, + sorted_ids: torch.Tensor, + num_valid_ids: torch.Tensor, + token_num: int, + cols: int, +): + out_scale = torch.empty( + (sorted_ids.shape[0] + 31) // 32 * 32, + (cols + 31) // 32, + dtype=dtypes.fp8_e8m0, + device=scale.device, + ) + mxfp4_moe_sort_hip(out_scale, scale, sorted_ids, num_valid_ids, token_num, cols) + return out_scale + + +@compile_ops("module_quant") +def fused_dynamic_mxfp4_quant_moe_sort_hip( + out: torch.Tensor, + scales: torch.Tensor, + input: torch.Tensor, + sorted_ids: torch.Tensor, + num_valid_ids: torch.Tensor, + token_num: int, + block_m: int, + group_size: int = 32, +) -> None: + """ + HIP path for fused dynamic MXFP4 quantization and MoE scale sorting. + """ + ... + + +def fused_dynamic_mxfp4_quant_moe_sort( + input: torch.Tensor, + sorted_ids: torch.Tensor, + num_valid_ids: torch.Tensor, + token_num: int, + topk: int, # stage1 and stage2: same topk value + block_size: int, + num_rows: Optional[torch.Tensor] = None, + group_size: int = 32, +) -> Tuple[torch.Tensor, torch.Tensor]: + token_num_quant_moe_sort_switch = [ + 8 * 64 / topk, # stage1 + 8 * 1024 / topk, # stage2 + ] + M, N = input.view(-1, input.shape[-1]).shape + is_stage1 = M == token_num + topk = 1 if is_stage1 else topk + scale = torch.empty( + (sorted_ids.shape[0] + 31) // 32 * 32, + (N + 31) // 32, + dtype=dtypes.fp8_e8m0, + device=input.device, + ) + if ( + (is_stage1 and M <= token_num_quant_moe_sort_switch[0]) + or (not is_stage1 and M <= token_num_quant_moe_sort_switch[1]) + or group_size != 32 + ): + out = torch.empty(M, N // 2, dtype=dtypes.fp4x2, device=input.device) + fused_dynamic_mxfp4_quant_moe_sort_hip( + out, + scale, + input, + sorted_ids, + num_valid_ids, + token_num, + block_size, + group_size, + ) + else: + out, scale_ = per_1x32_f4_quant_hip( + input, None, dtypes.fp4x2, num_rows=num_rows, num_rows_factor=topk + ) + mxfp4_moe_sort_hip(scale, scale_, sorted_ids, num_valid_ids, token_num, N) + return out, scale + + @compile_ops("module_quant") def partial_transpose( out: Tensor, diff --git a/csrc/include/quant.h b/csrc/include/quant.h index c3a0f1e3d9..4d7d425261 100644 --- a/csrc/include/quant.h +++ b/csrc/include/quant.h @@ -66,4 +66,20 @@ void moe_smooth_per_token_scaled_quant_v2(torch::Tensor& out, // [..., d int block_m, bool shuffle_scale = false, bool transpose_out = false); + +void fused_dynamic_mxfp4_quant_moe_sort_hip(torch::Tensor& out, // [token_num * topk, d / 2] + torch::Tensor& scales, // swizzled e8m0 bytes + torch::Tensor const& input, // [token_num * topk, d] + torch::Tensor const& sorted_ids, + torch::Tensor const& num_valid_ids, + int token_num, + int block_m, + int group_size = 32); + +void mxfp4_moe_sort_hip(torch::Tensor& out_scale, + torch::Tensor const& scale, + torch::Tensor const& sorted_ids, + torch::Tensor const& num_valid_ids, + int token_num, + int cols); } // namespace aiter diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 1cf7445028..69066fcd58 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -2,6 +2,7 @@ // Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "aiter_tensor.h" #include #include @@ -1322,6 +1323,24 @@ namespace py = pybind11; py::arg("block_m"), \ py::arg("shuffle_scale") = false, \ py::arg("transpose_out") = false); \ + m.def("fused_dynamic_mxfp4_quant_moe_sort_hip", \ + &aiter::fused_dynamic_mxfp4_quant_moe_sort_hip, \ + py::arg("out"), \ + py::arg("scales"), \ + py::arg("input"), \ + py::arg("sorted_ids"), \ + py::arg("num_valid_ids"), \ + py::arg("token_num"), \ + py::arg("block_m"), \ + py::arg("group_size") = 32); \ + m.def("mxfp4_moe_sort_hip", \ + &aiter::mxfp4_moe_sort_hip, \ + py::arg("out_scale"), \ + py::arg("scale"), \ + py::arg("sorted_ids"), \ + py::arg("num_valid_ids"), \ + py::arg("token_num"), \ + py::arg("cols")); \ m.def("partial_transpose", \ &aiter::partial_transpose, \ py::arg("out"), \ @@ -1413,37 +1432,39 @@ namespace py = pybind11; py::arg("epsilon"), \ py::arg("use_model_sensitive_rmsnorm") = 0); -#define ROPE_1C_UNCACHED_FWD_PYBIND m.def("rope_fwd_impl", &rope_fwd_impl); -#define ROPE_2C_UNCACHED_FWD_PYBIND m.def("rope_2c_fwd_impl", &rope_2c_fwd_impl); -#define ROPE_1C_CACHED_FWD_PYBIND m.def("rope_cached_fwd_impl", &rope_cached_fwd_impl); -#define ROPE_2C_CACHED_FWD_PYBIND m.def("rope_cached_2c_fwd_impl", &rope_cached_2c_fwd_impl); -#define ROPE_1C_THD_FWD_PYBIND m.def("rope_thd_fwd_impl", &rope_thd_fwd_impl); -#define ROPE_1C_2D_FWD_PYBIND m.def("rope_2d_fwd_impl", &rope_2d_fwd_impl); - -#define ROPE_1C_UNCACHED_BWD_PYBIND m.def("rope_bwd_impl", &rope_bwd_impl); -#define ROPE_2C_UNCACHED_BWD_PYBIND m.def("rope_2c_bwd_impl", &rope_2c_bwd_impl); -#define ROPE_1C_CACHED_BWD_PYBIND m.def("rope_cached_bwd_impl", &rope_cached_bwd_impl); -#define ROPE_2C_CACHED_BWD_PYBIND m.def("rope_cached_2c_bwd_impl", &rope_cached_2c_bwd_impl); -#define ROPE_1C_THD_BWD_PYBIND m.def("rope_thd_bwd_impl", &rope_thd_bwd_impl); -#define ROPE_1C_2D_BWD_PYBIND m.def("rope_2d_bwd_impl", &rope_2d_bwd_impl); - - -#define ROPE_1C_CACHED_POSITIONS_FWD_PYBIND m.def("rope_cached_positions_fwd_impl", &rope_cached_positions_fwd_impl) -#define ROPE_2C_CACHED_POSITIONS_FWD_PYBIND \ - m.def("rope_cached_positions_2c_fwd_impl", \ - &rope_cached_positions_2c_fwd_impl, \ - py::arg("output_x"), \ - py::arg("output_y"), \ - py::arg("input_x"), \ - py::arg("input_y"), \ - py::arg("cos"), \ - py::arg("sin"), \ - py::arg("positions"), \ - py::arg("rotate_style"), \ - py::arg("reuse_freqs_front_part"), \ +#define ROPE_1C_UNCACHED_FWD_PYBIND m.def("rope_fwd_impl", &rope_fwd_impl); +#define ROPE_2C_UNCACHED_FWD_PYBIND m.def("rope_2c_fwd_impl", &rope_2c_fwd_impl); +#define ROPE_1C_CACHED_FWD_PYBIND m.def("rope_cached_fwd_impl", &rope_cached_fwd_impl); +#define ROPE_2C_CACHED_FWD_PYBIND m.def("rope_cached_2c_fwd_impl", &rope_cached_2c_fwd_impl); +#define ROPE_1C_THD_FWD_PYBIND m.def("rope_thd_fwd_impl", &rope_thd_fwd_impl); +#define ROPE_1C_2D_FWD_PYBIND m.def("rope_2d_fwd_impl", &rope_2d_fwd_impl); + +#define ROPE_1C_UNCACHED_BWD_PYBIND m.def("rope_bwd_impl", &rope_bwd_impl); +#define ROPE_2C_UNCACHED_BWD_PYBIND m.def("rope_2c_bwd_impl", &rope_2c_bwd_impl); +#define ROPE_1C_CACHED_BWD_PYBIND m.def("rope_cached_bwd_impl", &rope_cached_bwd_impl); +#define ROPE_2C_CACHED_BWD_PYBIND m.def("rope_cached_2c_bwd_impl", &rope_cached_2c_bwd_impl); +#define ROPE_1C_THD_BWD_PYBIND m.def("rope_thd_bwd_impl", &rope_thd_bwd_impl); +#define ROPE_1C_2D_BWD_PYBIND m.def("rope_2d_bwd_impl", &rope_2d_bwd_impl); + +#define ROPE_1C_CACHED_POSITIONS_FWD_PYBIND \ + m.def("rope_cached_positions_fwd_impl", &rope_cached_positions_fwd_impl) +#define ROPE_2C_CACHED_POSITIONS_FWD_PYBIND \ + m.def("rope_cached_positions_2c_fwd_impl", \ + &rope_cached_positions_2c_fwd_impl, \ + py::arg("output_x"), \ + py::arg("output_y"), \ + py::arg("input_x"), \ + py::arg("input_y"), \ + py::arg("cos"), \ + py::arg("sin"), \ + py::arg("positions"), \ + py::arg("rotate_style"), \ + py::arg("reuse_freqs_front_part"), \ py::arg("nope_first")) -#define ROPE_1C_CACHED_POSITIONS_OFFSETS_FWD_PYBIND m.def("rope_cached_positions_offsets_fwd_impl", &rope_cached_positions_offsets_fwd_impl) -#define ROPE_2C_CACHED_POSITIONS_OFFSETS_FWD_PYBIND m.def("rope_cached_positions_offsets_2c_fwd_impl", &rope_cached_positions_offsets_2c_fwd_impl) +#define ROPE_1C_CACHED_POSITIONS_OFFSETS_FWD_PYBIND \ + m.def("rope_cached_positions_offsets_fwd_impl", &rope_cached_positions_offsets_fwd_impl) +#define ROPE_2C_CACHED_POSITIONS_OFFSETS_FWD_PYBIND \ + m.def("rope_cached_positions_offsets_2c_fwd_impl", &rope_cached_positions_offsets_2c_fwd_impl) #define FUSED_QKNORM_MROPE_CACHE_QUANT_PYBIND \ m.def("fused_qk_norm_mrope_3d_cache_pts_quant_shuffle", \ diff --git a/csrc/kernels/quant_kernels.cu b/csrc/kernels/quant_kernels.cu index a0b5e6c588..a62243f5a2 100644 --- a/csrc/kernels/quant_kernels.cu +++ b/csrc/kernels/quant_kernels.cu @@ -1508,9 +1508,9 @@ __global__ void moe_smooth_per_token_scaled_quant_kernel_v2(DTYPE_O* __restrict_ #define MOE_SMOOTH_PER_TOKEN_SCALED_QUANT_KERNEL_V2_IMPL(quant_kernel, DTYPE_O, THREAD_DATA, BLOCK_SIZE) \ AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "quant_kernel", [&] { \ - using input_dtype = typename t2opus::type; \ - int warps_per_cu = 8 * BLOCK_SIZE / WARP_SIZE; \ - int num_tg = persistent_mode? num_cu * warps_per_cu : num_blocks; \ + using input_dtype = typename t2opus::type; \ + int blocks_per_cu = 8 * 4 / (BLOCK_SIZE / WARP_SIZE); \ + int num_tg = persistent_mode ? num_cu * blocks_per_cu : num_blocks; \ dim3 const grid(num_tg); \ aiter::quant_kernel \ <<>>( \ @@ -1614,4 +1614,360 @@ void moe_smooth_per_token_scaled_quant_v2( } +template +__global__ void mxfp4_quant_moe_sort_kernel( + DTYPE_O* __restrict__ out, + uint8_t* __restrict__ scale, + DTYPE_I const* __restrict__ input, + int32_t const* __restrict__ sorted_ids, + int32_t const* __restrict__ num_valid_ids, + const int32_t num_tokens, + const int32_t cols, + const int32_t group_size, + const int32_t block_m, + const int32_t sub_block_m, + const int32_t num_blocks, + const int32_t num_tg, + const int32_t topk, + const int32_t input_stride) +{ + int num_thread_per_group = group_size / thread_data_size; + int num_valid_ids_value = num_valid_ids[0]; + int block_idx = blockIdx.x; + int lane_idx = threadIdx.x % WARP_SIZE; + const int scale_k = threadIdx.x / num_thread_per_group; + static constexpr int32_t vec_size_i = + thread_data_size == 0 ? 16 / sizeof(DTYPE_I) : thread_data_size; + static constexpr int32_t load_chunk_bytes = + (sizeof(DTYPE_I) * vec_size_i % 16 == 0 ? 16 + : (sizeof(DTYPE_I) * vec_size_i % 8 == 0 ? 8 : 4)); + using vec_i = opus::vector_t; + using vec_f = opus::vector_t; + const float inverted_DTYPE_MAX = + std::is_same_v + ? 0.25 + : (1. / static_cast(opus::finfo::max())); + const int32_t scaleN_valid = (cols + group_size - 1) / group_size; + const int32_t scaleN_pad = ((scaleN_valid + 7) / 8) * 8; + + auto fp4_scale = [](float tmp) { + uint32_t u32 = __builtin_bit_cast(uint32_t, tmp); + uint32_t exponent = (u32 >> 23) & 0b11111111; + if(exponent == 0b11111111) + { + return __builtin_bit_cast(float, exponent << 23); + } + if(((u32 & 0x400000)) && (((u32 & 0x200000)) || ((u32 & 0x1FFFFF)) || (exponent))) + exponent += 1; + return __builtin_bit_cast(float, exponent << 23); + }; + auto fp4_scale_shuffle_id = [](int32_t scaleN_pad_, int32_t x, int32_t y) { + return (x / 32 * scaleN_pad_) * 32 + (y / 8) * 256 + (y % 4) * 64 + + (x % 16) * 4 + (y % 8) / 4 * 2 + (x % 32) / 16; + }; + + for(; block_idx < num_blocks; block_idx += num_tg) + { + int sorted_ids_offset = block_idx * sub_block_m; + if(sorted_ids_offset >= num_valid_ids_value) + { + return; + } + int token_id_info_list; + if (lane_idx < sub_block_m) + { + token_id_info_list = sorted_ids[sorted_ids_offset + lane_idx]; + } + int token_id_list = token_id_info_list & 0xFFFFFF; + int topk_id_list = token_id_info_list >> 24; + for(int i = 0; i < sub_block_m; i++) + { + int token_idx = __builtin_amdgcn_readlane(token_id_list, i); + int topk_id = __builtin_amdgcn_readlane(topk_id_list, i); + if(token_idx >= num_tokens) + { + break; + } + + int64_t input_offset; + if (topk == 1) + { + input_offset = (int64_t)(token_idx) * input_stride; + } + else + { + input_offset = (int64_t)(token_idx * topk + topk_id) * input_stride; + } + auto buffer_input = + opus::make_gmem(input + input_offset, cols * sizeof(DTYPE_I)); + vec_i vec_input = load_vector_nbytes( + buffer_input, threadIdx.x * vec_size_i); + vec_f vec_input_f; + float* input_f_ptr = reinterpret_cast(&vec_input_f); + float absMax = 1e-10f; + #pragma unroll + for(int j = 0; j < vec_size_i; j++) + { + vec_input_f[j] = static_cast(vec_input[j]); + absMax = max(absMax, abs(vec_input_f[j])); + } + absMax = multithread_reduce(absMax, hipcub::Max(), num_thread_per_group); + + float row_scale = std::is_same_v + ? fp4_scale(absMax) * inverted_DTYPE_MAX + : absMax * inverted_DTYPE_MAX; + + const int sorted_row = sorted_ids_offset + i; + if(threadIdx.x % num_thread_per_group == 0 && scale_k < scaleN_valid) + { + uint8_t bs_e8m0 = (__builtin_bit_cast(uint32_t, row_scale) >> 23) & 0xFF; + int addr = fp4_scale_shuffle_id(scaleN_pad, sorted_row, scale_k); + scale[addr] = bs_e8m0; + } + + if(topk_id < topk) + { + int64_t out_offset = (int64_t)(token_idx * topk + topk_id) * cols; + scaled_quant_vgpr_impl( + out, input_f_ptr, &row_scale, cols, out_offset); + } + } + } +} + + +#define MXFP4_QUANT_MOE_SORT_KERNEL_IMPL(DTYPE_O, THREAD_DATA, BLOCK_SIZE) \ + AITER_DISPATCH_FLOATING16_TYPES(input.scalar_type(), "mxfp4_quant_moe_sort_kernel", [&] { \ + AITER_CHECK(group_size % THREAD_DATA == 0, __func__, " group_size is not divisible by THREAD_DATA"); \ + using input_dtype = typename t2opus::type; \ + int blocks_per_cu = 8 * 4 / (BLOCK_SIZE / WARP_SIZE); \ + int num_tg = persistent_mode ? num_cu * blocks_per_cu : num_blocks; \ + dim3 const grid(num_tg); \ + mxfp4_quant_moe_sort_kernel \ + <<>>( \ + reinterpret_cast(output.data_ptr()), \ + reinterpret_cast(scale.data_ptr()), \ + reinterpret_cast(input.data_ptr()), \ + sorted_ids.data_ptr(), \ + num_valid_ids.data_ptr(), \ + token_num, \ + cols, \ + group_size, \ + block_m, \ + sub_block_m, \ + num_blocks, \ + num_tg, \ + topk, \ + input_stride); \ + }); + + +#define MXFP4_QUANT_MOE_SORT_KERNEL_DISPATCH(DTYPE_O, cols_) \ + if(cols_ <= 2 * BlockSize) \ + { \ + MXFP4_QUANT_MOE_SORT_KERNEL_IMPL(DTYPE_O, 8, BlockSize / 4) \ + } \ + else if(cols_ <= 4 * BlockSize) \ + { \ + MXFP4_QUANT_MOE_SORT_KERNEL_IMPL(DTYPE_O, 8, BlockSize / 2) \ + } \ + else if(cols_ <= 8 * BlockSize) \ + { \ + MXFP4_QUANT_MOE_SORT_KERNEL_IMPL(DTYPE_O, 8, BlockSize) \ + } \ + else if(cols_ <= 16 * BlockSize) \ + { \ + MXFP4_QUANT_MOE_SORT_KERNEL_IMPL(DTYPE_O, 16, BlockSize) \ + } \ + else if(cols_ <= 16 * BlockSize * 2) \ + { \ + MXFP4_QUANT_MOE_SORT_KERNEL_IMPL(DTYPE_O, 32, BlockSize) \ + } \ + else \ + { \ + TORCH_CHECK(false, "input last dim has exceeded the maximum value ", 32 * BlockSize) \ + } + +void fused_dynamic_mxfp4_quant_moe_sort_hip( + torch::Tensor& output, + torch::Tensor& scale, + torch::Tensor const& input, + torch::Tensor const& sorted_ids, + torch::Tensor const& num_valid_ids, + int token_num, + int block_m, + int group_size = 32 +) +{ + int cols = input.size(-1); + int topk = input.numel() / (cols * token_num); + int num_experts = (sorted_ids.size(0) + topk - topk * token_num) / block_m; + + const int num_cu = get_num_cu_func(); + int sub_block_m = (token_num * topk) > (num_cu * 8) || num_experts < 64 ? 2 : 4; + TORCH_CHECK(block_m % sub_block_m == 0, __func__, " block_m is not divisible by sub_block_m"); + int num_blocks = (sorted_ids.size(0) + sub_block_m - 1) / sub_block_m; + const bool persistent_mode = false; + const int input_stride = input.stride(-2); + + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); + +#if defined(__Float4_e2m1fn_x2) + if(output.dtype() == torch_fp4x2 || output.dtype() == torch::kUInt8) + { + MXFP4_QUANT_MOE_SORT_KERNEL_DISPATCH(opus::fp4_t, cols); + } + else + { + TORCH_CHECK(false, __func__, ": not support output type: ", output.dtype()); + } +#else + TORCH_CHECK(false, __func__, ": not support fp4x2 on this device"); +#endif +} + +template +__global__ void mxfp4_moe_sort_kernel( + uint8_t* __restrict__ out_scale, + uint8_t* __restrict__ scale, + int32_t const* __restrict__ sorted_ids, + int32_t const* __restrict__ num_valid_ids, + const int32_t num_tokens, + const int32_t cols, + const int32_t num_blocks, + const int32_t num_tg, + const int32_t topk) +{ + constexpr int threads_per_row = block_size / num_rows; + int num_valid_ids_value = num_valid_ids[0]; + int block_idx = blockIdx.x; + int row_i = threadIdx.x / threads_per_row; + int scale_k = threadIdx.x % threads_per_row * thread_data_size; + const int scale_per_row = (cols + group_size - 1) / group_size; + static constexpr int32_t vec_size_i = thread_data_size; + static constexpr int32_t load_chunk_bytes = + (sizeof(uint8_t) * vec_size_i % 16 == 0 ? 16 + : (sizeof(uint8_t) * vec_size_i % 8 == 0 ? 8 + : (sizeof(uint8_t) * vec_size_i % 4 == 0 ? 4 : 2))); + using vec_i = opus::vector_t; + const int32_t scaleN_valid = (cols + group_size - 1) / group_size; + const int32_t scaleN_pad = ((scaleN_valid + 7) / 8) * 8; + auto fp4_scale_shuffle_id = [](int32_t scaleN_pad_, int32_t x, int32_t y) { + return (x / 32 * scaleN_pad_) * 32 + (y / 8) * 256 + (y % 4) * 64 + + (x % 16) * 4 + (y % 8) / 4 * 2 + (x % 32) / 16; + }; + auto buffer_scale = + opus::make_gmem(scale, scale_per_row * num_tokens * topk * sizeof(uint8_t)); + for(; block_idx < num_blocks; block_idx += num_tg) + { + int sorted_row = block_idx * num_rows + row_i; + int token_id_info = num_tokens; + if (sorted_row < num_valid_ids_value) + { + token_id_info = sorted_ids[sorted_row]; + } + int token_idx = token_id_info & 0xFFFFFF; + int topk_id = token_id_info >> 24; + if(token_idx < num_tokens && (topk == 1 || topk_id < topk)) + { + int64_t scale_offset; + if (topk == 1) + { + scale_offset = (int64_t)(token_idx) * scale_per_row; + } + else + { + scale_offset = (int64_t)(token_idx * topk + topk_id) * scale_per_row; + } + vec_i vec_scale = load_vector_nbytes( + buffer_scale, scale_offset + scale_k); + + for(int j = 0; j < vec_size_i; j++) + { + if((scale_k + j) < scaleN_valid) + { + int addr = fp4_scale_shuffle_id(scaleN_pad, sorted_row, scale_k + j); + out_scale[addr] = vec_scale[j]; + } + } + } + } +} + + +#define MXFP4_MOE_SORT_KERNEL_IMPL(MAX_COL, THREAD_DATA, BLOCK_SIZE) \ + constexpr int GROUP_SIZE = 32; \ + constexpr int NUM_ROWS = BLOCK_SIZE / (MAX_COL /(GROUP_SIZE * THREAD_DATA)); \ + TORCH_CHECK(BLOCK_SIZE % (MAX_COL /(GROUP_SIZE * THREAD_DATA)) == 0); \ + int num_blocks = (sorted_ids.size(0) + NUM_ROWS - 1) / NUM_ROWS; \ + int blocks_per_cu = 8 * 4 / (BLOCK_SIZE / WARP_SIZE); \ + int num_tg = persistent_mode ? num_cu * blocks_per_cu : num_blocks; \ + dim3 const grid(num_tg); \ + mxfp4_moe_sort_kernel \ + <<>>( \ + reinterpret_cast(out_scale.data_ptr()), \ + reinterpret_cast(scale.data_ptr()), \ + sorted_ids.data_ptr(), \ + num_valid_ids.data_ptr(), \ + token_num, cols, num_blocks, num_tg, topk); + + +#define MXFP4_MOE_SORT_KERNEL_DISPATCH(cols_) \ + if(cols_ <= 256) \ + { \ + MXFP4_MOE_SORT_KERNEL_IMPL(256, 4, 256) \ + } \ + else if(cols_ <= 512) \ + { \ + MXFP4_MOE_SORT_KERNEL_IMPL(512, 4, 256) \ + } \ + else if(cols_ <= 1024) \ + { \ + MXFP4_MOE_SORT_KERNEL_IMPL(1024, 4, 256) \ + } \ + else if(cols_ <= 2048) \ + { \ + MXFP4_MOE_SORT_KERNEL_IMPL(2048, 8, 256) \ + } \ + else if(cols_ <= 4096) \ + { \ + MXFP4_MOE_SORT_KERNEL_IMPL(4096, 16, 256) \ + } \ + else if(cols_ <= 6144) \ + { \ + MXFP4_MOE_SORT_KERNEL_IMPL(6144, 24, 256) \ + } \ + else if(cols_ <= 8192) \ + { \ + MXFP4_MOE_SORT_KERNEL_IMPL(8192, 32, 256) \ + } \ + else if(cols_ <= 16384) \ + { \ + MXFP4_MOE_SORT_KERNEL_IMPL(16384, 32, 256) \ + } \ + else \ + { \ + TORCH_CHECK(false, "input last dim has exceeded the maximum value ", 16384) \ + } + +void mxfp4_moe_sort_hip( + torch::Tensor& out_scale, + torch::Tensor const& scale, + torch::Tensor const& sorted_ids, + torch::Tensor const& num_valid_ids, + int token_num, + int cols +) +{ + const int num_cu = get_num_cu_func(); + const bool persistent_mode = false; + int topk = scale.numel() / ((cols + 31) / 32 * token_num); + + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(scale)); + const hipStream_t stream = at::hip::getCurrentHIPStream(); + + MXFP4_MOE_SORT_KERNEL_DISPATCH(cols); +} + } // namespace aiter \ No newline at end of file diff --git a/op_tests/test_moe_sorting_mxfp4.py b/op_tests/test_moe_sorting_mxfp4.py index 9a9571c8e0..2beb79a674 100644 --- a/op_tests/test_moe_sorting_mxfp4.py +++ b/op_tests/test_moe_sorting_mxfp4.py @@ -2,8 +2,10 @@ # Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. import torch import aiter -from aiter.test_common import checkAllclose, benchmark, run_perftest, perftest +from aiter.test_common import checkAllclose, benchmark, run_perftest from aiter.fused_moe import moe_sorting, fused_topk +from aiter.ops.triton.quant.fused_mxfp4_quant import fused_dynamic_mxfp4_quant_moe_sort +from aiter.jit.utils.chip_info import get_gfx from aiter import get_torch_quant, dtypes from aiter.utility import fp4_utils import pandas as pd @@ -15,7 +17,6 @@ torch.set_printoptions(threshold=float("inf")) -@perftest() def run_torch(scale, sorted_ids, num_valid_ids, token_num): topk = 1 if len(scale.shape) == 3: @@ -29,7 +30,7 @@ def run_torch(scale, sorted_ids, num_valid_ids, token_num): sorted_ids = sorted_ids * topk + topk_ids sorted_ids[mask] = 0 # set to 0 to avoid overflow scale = scale[sorted_ids] - scale[mask] = 0 + scale.view(torch.uint8)[mask] = 0 sm, sn = scale.shape tmp = torch.zeros( ((sm + 31) // 32 * 32, sn), dtype=scale.dtype, device=scale.device @@ -43,6 +44,20 @@ def run_torch(scale, sorted_ids, num_valid_ids, token_num): return ref +def run_split_quant_sort(scale, input, sorted_ids, num_valid_ids, token_num): + model_dim = input.shape[-1] + out, scale_ = aiter.per_1x32_f4_quant_hip(input, None, dtypes.fp4x2) + aiter.mxfp4_moe_sort_hip( + scale, + scale_, + sorted_ids, + num_valid_ids, + token_num, + model_dim, + ) + return out, scale + + @benchmark() def test_moe_mxfp4_sort(dtype, token_num, model_dim, E, topk, block_size, stage): input = torch.randn((token_num, model_dim), dtype=dtype) @@ -60,11 +75,12 @@ def test_moe_mxfp4_sort(dtype, token_num, model_dim, E, topk, block_size, stage) if stage == "stage1": scale = torch.arange(token_num * model_dim // 32, dtype=torch.uint8) scale = scale.view(token_num, model_dim // 32) + topk = 1 else: scale = torch.arange(token_num * topk * model_dim // 32, dtype=torch.uint8) scale = scale.view(token_num, topk, model_dim // 32) - ref, us_ref = run_torch(scale.clone(), sorted_ids.clone(), num_valid_ids, token_num) - sorted_mxfp4_scale, us = run_perftest( + ref = run_torch(scale.clone(), sorted_ids.clone(), num_valid_ids, token_num) + triton_scale, triton_us = run_perftest( fp4_utils.moe_mxfp4_sort, scale, sorted_ids, @@ -73,15 +89,147 @@ def test_moe_mxfp4_sort(dtype, token_num, model_dim, E, topk, block_size, stage) block_size, ) + hip_scale = torch.zeros( + ((sorted_ids.shape[0] + 31) // 32 * 32, model_dim // 32), + dtype=torch.uint8, + device=input.device, + ) + _, hip_us = run_perftest( + aiter.mxfp4_moe_sort_hip, + hip_scale, + scale, + sorted_ids, + num_valid_ids, + token_num, + model_dim, + ) + num_valid_ids = num_valid_ids.item() num_valid_ids = (num_valid_ids + block_size - 1) // block_size * block_size - err = checkAllclose( + triton_err = checkAllclose( ref[:num_valid_ids], - sorted_mxfp4_scale[:num_valid_ids].view(torch.uint8), + triton_scale[:num_valid_ids].view(torch.uint8), msg="sorted_mxfp4_scale", ) - return {"us_ref": us_ref, "us": us, "err": err} + + hip_err = checkAllclose( + ref[:num_valid_ids].view(torch.uint8), + hip_scale[:num_valid_ids].view(torch.uint8), + msg="hip sorted_mxfp4_scale", + ) + return { + "triton_us": triton_us, + "triton_err": triton_err, + "hip_us": hip_us, + "hip_err": hip_err, + } + + +@benchmark() +def test_moe_mxfp4_quant_sort(dtype, token_num, model_dim, E, topk, block_size, stage): + if get_gfx().startswith("gfx94"): + return {} + input = torch.randn((token_num, model_dim), dtype=dtype) + score = torch.randn((token_num, E), dtype=dtype) + + topk_weights, topk_ids = fused_topk(input, score, topk, True) + sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = moe_sorting( + topk_ids, + topk_weights, + E, + model_dim, + dtype, + ) + num_valid_ids = num_valid_ids[0] + if stage != "stage1": + input = torch.randn((token_num * topk, model_dim), dtype=dtype) + else: + topk = 1 + ref_out, scale = get_torch_quant(aiter.QuantType.per_1x32)( + input, quant_dtype=dtypes.fp4x2 + ) + ref_scale = run_torch(scale.clone(), sorted_ids.clone(), num_valid_ids, token_num) + + split_scale = torch.zeros( + ((sorted_ids.shape[0] + 31) // 32 * 32, model_dim // 32), + dtype=torch.uint8, + device=input.device, + ) + (split_out, split_scale), split_us = run_perftest( + run_split_quant_sort, + split_scale, + input, + sorted_ids, + num_valid_ids, + token_num, + ) + + hip_scale = torch.zeros( + ((sorted_ids.shape[0] + 31) // 32 * 32, model_dim // 32), + dtype=torch.uint8, + device=input.device, + ) + hip_out = torch.empty( + (token_num * topk, model_dim // 2), + dtype=dtypes.fp4x2, + device=input.device, + ) + _, hip_us = run_perftest( + aiter.fused_dynamic_mxfp4_quant_moe_sort_hip, + hip_out, + hip_scale, + input, + sorted_ids, + num_valid_ids, + token_num, + block_size, + ) + + (triton_out, triton_scale), triton_us = run_perftest( + fused_dynamic_mxfp4_quant_moe_sort, + input, + sorted_ids=sorted_ids, + num_valid_ids=num_valid_ids, + token_num=token_num, + topk=topk, + block_size=block_size, + ) + + mask = ref_scale == 0 + triton_scale = triton_scale[: ref_scale.shape[0]] + triton_scale.view(torch.uint8)[mask] = 0 + num_valid_ids = num_valid_ids.item() + num_valid_ids = (num_valid_ids + block_size - 1) // block_size * block_size + + checkAllclose(ref_out.view(torch.uint8), hip_out.view(torch.uint8), msg="hip out") + hip_err = checkAllclose( + ref_scale[:num_valid_ids].view(torch.uint8), + hip_scale[:num_valid_ids].view(torch.uint8), + msg="hip sorted_mxfp4_scale", + ) + + # checkAllclose(ref_out.view(torch.uint8), triton_out.view(torch.uint8), msg="triton out") + triton_err = checkAllclose( + ref_scale[:num_valid_ids].view(torch.uint8), + triton_scale[:num_valid_ids].view(torch.uint8), + msg="triton sorted_mxfp4_scale", + ) + + split_err = checkAllclose( + ref_scale[:num_valid_ids].view(torch.uint8), + split_scale[:num_valid_ids].view(torch.uint8), + msg="split sorted_mxfp4_scale", + ) + + return { + "triton_us": triton_us, + "triton_err": triton_err, + "hip_us": hip_us, + "hip_err": hip_err, + "split_us": split_us, + "split_err": split_err, + } parser = argparse.ArgumentParser( @@ -100,39 +248,47 @@ def test_moe_mxfp4_sort(dtype, token_num, model_dim, E, topk, block_size, stage) e.g.: -d bf16""", ) parser.add_argument( - "-dim", + "-dim1", type=int, nargs="*", - default=[4096, 6144, 8192], - help="""Model dimension. - e.g.: -dim 4096""", + default=[4096, 7168], + help="""Model dimension for stage1. + e.g.: -dim1 4096""", ) parser.add_argument( - "-e", - "--expert", + "-dim2", type=int, nargs="*", - default=[32, 256, 257, 512], - help="""Number of experts. - e.g.: -e 32""", + default=[256, 2048], + help="""Inter dimension for stage2. + e.g.: -dim2 256""", ) parser.add_argument( - "-t", - "--topk", - type=int, + "-ek", + "--expert_topk", + type=dtypes.str2tuple, nargs="*", - default=[5, 8], - help="""Number of top experts. - e.g.: -t 5""", + default=[[32, 5], [256, 8], [512, 8]], + help="""Number of experts. + e.g.: -ek 32,5""", ) parser.add_argument( "-m", type=int, nargs="*", - default=[1, 31, 64, 128, 256, 10000, 163840], + default=[1, 64, 128, 256, 1024, 2050, 4200, 10000, 163840], help="""M of mnk. e.g.: -m 64""", ) +parser.add_argument( + "-bm", + "--block_m", + type=int, + default=32, + choices=[16, 32, 64, 80, 128], + help="""Block M. + e.g.: -bm 64""", +) args = parser.parse_args() @@ -140,11 +296,10 @@ def test_moe_mxfp4_sort(dtype, token_num, model_dim, E, topk, block_size, stage) for dtype in args.dtype: for ( dim, - E, - topk, + (E, topk), m, - ) in itertools.product(args.dim, args.expert, args.topk, args.m): - ret = test_moe_mxfp4_sort(dtype, m, dim, E, topk, 32, "stage1") + ) in itertools.product(args.dim1, args.expert_topk, args.m): + ret = test_moe_mxfp4_sort(dtype, m, dim, E, topk, args.block_m, "stage1") df.append(ret) df = pd.DataFrame(df) df_md = df.to_markdown(index=False) @@ -154,12 +309,37 @@ def test_moe_mxfp4_sort(dtype, token_num, model_dim, E, topk, block_size, stage) for dtype in args.dtype: for ( dim, - E, - topk, + (E, topk), m, - ) in itertools.product(args.dim, args.expert, args.topk, args.m): - ret = test_moe_mxfp4_sort(dtype, m, dim, E, topk, 32, "stage2") + ) in itertools.product(args.dim2, args.expert_topk, args.m): + ret = test_moe_mxfp4_sort(dtype, m, dim, E, topk, args.block_m, "stage2") df.append(ret) df = pd.DataFrame(df) df_md = df.to_markdown(index=False) aiter.logger.info("moe_sorting_mxfp4_stage2 summary (markdown):\n%s", df_md) + +df = [] +for dtype in args.dtype: + for ( + dim, + (E, topk), + m, + ) in itertools.product(args.dim1, args.expert_topk, args.m): + ret = test_moe_mxfp4_quant_sort(dtype, m, dim, E, topk, args.block_m, "stage1") + df.append(ret) +df = pd.DataFrame(df) +df_md = df.to_markdown(index=False) +aiter.logger.info("moe_mxfp4_quant_sort_stage1 summary (markdown):\n%s", df_md) + +df = [] +for dtype in args.dtype: + for ( + dim, + (E, topk), + m, + ) in itertools.product(args.dim2, args.expert_topk, args.m): + ret = test_moe_mxfp4_quant_sort(dtype, m, dim, E, topk, args.block_m, "stage2") + df.append(ret) +df = pd.DataFrame(df) +df_md = df.to_markdown(index=False) +aiter.logger.info("moe_mxfp4_quant_sort_stage2 summary (markdown):\n%s", df_md)