Skip to content
Merged
68 changes: 19 additions & 49 deletions aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from aiter.jit.utils.chip_info import get_cu_num, get_gfx
from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter.ops.flydsl.utils import is_flydsl_available
from aiter.ops.triton.quant.fused_mxfp4_quant import fused_dynamic_mxfp4_quant_moe_sort
from aiter.utility import fp4_utils
from aiter import fused_dynamic_mxfp4_quant_moe_sort, mxfp4_moe_sort_fwd

BLOCK_SIZE_M = 32

Expand Down Expand Up @@ -454,12 +453,12 @@ def fused_moe_1stage(
token_num = hidden_states.shape[0]
E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape)
if quant_type == QuantType.per_1x32:
a1_scale = fp4_utils.moe_mxfp4_sort(
a1_scale = mxfp4_moe_sort_fwd(
a1_scale,
sorted_ids,
num_valid_ids,
token_num,
block_size_M,
sorted_ids=sorted_ids,
num_valid_ids=num_valid_ids,
token_num=token_num,
cols=model_dim,
)
w1_scale = w1_scale.view(E, -1)
w2_scale = w2_scale.view(E, -1)
Expand Down Expand Up @@ -1144,7 +1143,6 @@ def fused_moe_2stages(
bias2=None,
):
quant_func = get_quant(quant_type)
token_num_quant_moe_sort_switch = 1024
token_num, _ = hidden_states.shape
E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape)
dtype = moe_out.dtype
Expand Down Expand Up @@ -1196,36 +1194,23 @@ def fused_moe_2stages(
# Input is already quantized to fp4x2 (e.g., from FP4 dispatch),
# skip re-quantization, only sort the scale
a1 = hidden_states
a1_scale = fp4_utils.moe_mxfp4_sort(
a1_scale = mxfp4_moe_sort_fwd(
a1_scale,
sorted_ids=sorted_ids,
num_valid_ids=num_valid_ids,
token_num=token_num,
block_size=block_size_M,
cols=model_dim,
)
elif token_num <= token_num_quant_moe_sort_switch:
else:
a1, a1_scale = fused_dynamic_mxfp4_quant_moe_sort(
hidden_states,
sorted_ids=sorted_ids,
num_valid_ids=num_valid_ids,
token_num=token_num,
topk=1,
topk=topk,
block_size=block_size_M,
)
else:
a1, a1_scale = quant_func(
hidden_states,
scale=a1_scale,
quant_dtype=q_dtype_a,
num_rows=num_local_tokens,
)
a1_scale = fp4_utils.moe_mxfp4_sort(
a1_scale,
sorted_ids=sorted_ids,
num_valid_ids=num_valid_ids,
token_num=token_num,
block_size=block_size_M,
)
elif hidden_states.dtype != q_dtype_a:
if quant_type == QuantType.per_1x128 and metadata.stage1.func is asm_stage1:
quant_func = functools.partial(quant_func, transpose_scale=True)
Expand Down Expand Up @@ -1312,30 +1297,15 @@ def fused_moe_2stages(
a2_scale = a1_scale
elif quant_type == QuantType.per_1x32:
a2 = a2.view(-1, inter_dim)
if token_num <= token_num_quant_moe_sort_switch:
a2, a2_scale = fused_dynamic_mxfp4_quant_moe_sort(
a2,
sorted_ids=sorted_ids,
num_valid_ids=num_valid_ids,
token_num=token_num,
topk=topk,
block_size=block_size_M,
)
else:
a2, a2_scale = quant_func(
a2,
scale=a2_scale,
quant_dtype=q_dtype_a,
num_rows=num_local_tokens,
num_rows_factor=topk,
)
a2_scale = fp4_utils.moe_mxfp4_sort(
a2_scale[: token_num * topk, :].view(token_num, topk, -1),
sorted_ids=sorted_ids,
num_valid_ids=num_valid_ids,
token_num=token_num,
block_size=block_size_M,
)
a2, a2_scale = fused_dynamic_mxfp4_quant_moe_sort(
a2,
sorted_ids=sorted_ids,
num_valid_ids=num_valid_ids,
token_num=token_num,
topk=topk,
block_size=block_size_M,
num_rows=num_local_tokens,
)
a2 = a2.view(token_num, topk, -1)
elif quant_type == QuantType.per_1x128 and metadata.stage1.func is asm_stage1:
a2_v = a2[:token_num, :, :]
Expand Down
100 changes: 98 additions & 2 deletions aiter/ops/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def per_1x32_f4_quant_hip(
torch.empty(
(
(m + 255) // 256 * 256,
(n // 32 + 7) // 8 * 8,
((n + 31) // 32 + 7) // 8 * 8,
),
dtype=torch.uint8,
device=device,
Expand All @@ -321,7 +321,7 @@ def per_1x32_f4_quant_hip(
else:
scale = (
torch.empty(
(m, n // 32),
(m, (n + 31) // 32),
dtype=torch.uint8,
device=device,
)
Expand Down Expand Up @@ -548,6 +548,102 @@ def moe_smooth_per_token_scaled_quant_v2(
...


@compile_ops("module_quant")
def mxfp4_moe_sort_hip(
out_scale: torch.Tensor,
scale: torch.Tensor,
sorted_ids: torch.Tensor,
num_valid_ids: torch.Tensor,
token_num: int,
cols: int,
) -> None:
"""
MoE scale sorting with MXFP4 shuffle layout.
"""
...


def mxfp4_moe_sort_fwd(
scale: torch.Tensor,
sorted_ids: torch.Tensor,
num_valid_ids: torch.Tensor,
token_num: int,
cols: int,
):
out_scale = torch.empty(
(sorted_ids.shape[0] + 31) // 32 * 32,
(cols + 31) // 32,
dtype=dtypes.fp8_e8m0,
device=scale.device,
)
mxfp4_moe_sort_hip(out_scale, scale, sorted_ids, num_valid_ids, token_num, cols)
return out_scale


@compile_ops("module_quant")
def fused_dynamic_mxfp4_quant_moe_sort_hip(
out: torch.Tensor,
scales: torch.Tensor,
input: torch.Tensor,
sorted_ids: torch.Tensor,
num_valid_ids: torch.Tensor,
token_num: int,
block_m: int,
group_size: int = 32,
) -> None:
"""
HIP path for fused dynamic MXFP4 quantization and MoE scale sorting.
"""
...


def fused_dynamic_mxfp4_quant_moe_sort(
input: torch.Tensor,
sorted_ids: torch.Tensor,
num_valid_ids: torch.Tensor,
token_num: int,
topk: int, # stage1 and stage2: same topk value
block_size: int,
num_rows: Optional[torch.Tensor] = None,
group_size: int = 32,
) -> Tuple[torch.Tensor, torch.Tensor]:
token_num_quant_moe_sort_switch = [
8 * 64 / topk, # stage1
8 * 1024 / topk, # stage2
]
M, N = input.view(-1, input.shape[-1]).shape
Comment thread
junhaha666 marked this conversation as resolved.
is_stage1 = M == token_num
topk = 1 if is_stage1 else topk
scale = torch.empty(
(sorted_ids.shape[0] + 31) // 32 * 32,
(N + 31) // 32,
dtype=dtypes.fp8_e8m0,
device=input.device,
)
if (
(is_stage1 and M <= token_num_quant_moe_sort_switch[0])
or (not is_stage1 and M <= token_num_quant_moe_sort_switch[1])
or group_size != 32
):
out = torch.empty(M, N // 2, dtype=dtypes.fp4x2, device=input.device)
fused_dynamic_mxfp4_quant_moe_sort_hip(
out,
scale,
input,
sorted_ids,
num_valid_ids,
token_num,
block_size,
group_size,
)
else:
out, scale_ = per_1x32_f4_quant_hip(
input, None, dtypes.fp4x2, num_rows=num_rows, num_rows_factor=topk
)
mxfp4_moe_sort_hip(scale, scale_, sorted_ids, num_valid_ids, token_num, N)
return out, scale


@compile_ops("module_quant")
def partial_transpose(
out: Tensor,
Expand Down
16 changes: 16 additions & 0 deletions csrc/include/quant.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,20 @@ void moe_smooth_per_token_scaled_quant_v2(torch::Tensor& out, // [..., d
int block_m,
bool shuffle_scale = false,
bool transpose_out = false);

void fused_dynamic_mxfp4_quant_moe_sort_hip(torch::Tensor& out, // [token_num * topk, d / 2]
torch::Tensor& scales, // swizzled e8m0 bytes
torch::Tensor const& input, // [token_num * topk, d]
torch::Tensor const& sorted_ids,
torch::Tensor const& num_valid_ids,
int token_num,
int block_m,
int group_size = 32);

void mxfp4_moe_sort_hip(torch::Tensor& out_scale,
torch::Tensor const& scale,
torch::Tensor const& sorted_ids,
torch::Tensor const& num_valid_ids,
int token_num,
int cols);
} // namespace aiter
81 changes: 51 additions & 30 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

#include "aiter_tensor.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

Expand Down Expand Up @@ -1322,6 +1323,24 @@ namespace py = pybind11;
py::arg("block_m"), \
py::arg("shuffle_scale") = false, \
py::arg("transpose_out") = false); \
m.def("fused_dynamic_mxfp4_quant_moe_sort_hip", \
&aiter::fused_dynamic_mxfp4_quant_moe_sort_hip, \
py::arg("out"), \
py::arg("scales"), \
py::arg("input"), \
py::arg("sorted_ids"), \
py::arg("num_valid_ids"), \
py::arg("token_num"), \
py::arg("block_m"), \
py::arg("group_size") = 32); \
m.def("mxfp4_moe_sort_hip", \
&aiter::mxfp4_moe_sort_hip, \
py::arg("out_scale"), \
py::arg("scale"), \
py::arg("sorted_ids"), \
py::arg("num_valid_ids"), \
py::arg("token_num"), \
py::arg("cols")); \
m.def("partial_transpose", \
&aiter::partial_transpose, \
py::arg("out"), \
Expand Down Expand Up @@ -1413,37 +1432,39 @@ namespace py = pybind11;
py::arg("epsilon"), \
py::arg("use_model_sensitive_rmsnorm") = 0);

#define ROPE_1C_UNCACHED_FWD_PYBIND m.def("rope_fwd_impl", &rope_fwd_impl);
#define ROPE_2C_UNCACHED_FWD_PYBIND m.def("rope_2c_fwd_impl", &rope_2c_fwd_impl);
#define ROPE_1C_CACHED_FWD_PYBIND m.def("rope_cached_fwd_impl", &rope_cached_fwd_impl);
#define ROPE_2C_CACHED_FWD_PYBIND m.def("rope_cached_2c_fwd_impl", &rope_cached_2c_fwd_impl);
#define ROPE_1C_THD_FWD_PYBIND m.def("rope_thd_fwd_impl", &rope_thd_fwd_impl);
#define ROPE_1C_2D_FWD_PYBIND m.def("rope_2d_fwd_impl", &rope_2d_fwd_impl);

#define ROPE_1C_UNCACHED_BWD_PYBIND m.def("rope_bwd_impl", &rope_bwd_impl);
#define ROPE_2C_UNCACHED_BWD_PYBIND m.def("rope_2c_bwd_impl", &rope_2c_bwd_impl);
#define ROPE_1C_CACHED_BWD_PYBIND m.def("rope_cached_bwd_impl", &rope_cached_bwd_impl);
#define ROPE_2C_CACHED_BWD_PYBIND m.def("rope_cached_2c_bwd_impl", &rope_cached_2c_bwd_impl);
#define ROPE_1C_THD_BWD_PYBIND m.def("rope_thd_bwd_impl", &rope_thd_bwd_impl);
#define ROPE_1C_2D_BWD_PYBIND m.def("rope_2d_bwd_impl", &rope_2d_bwd_impl);


#define ROPE_1C_CACHED_POSITIONS_FWD_PYBIND m.def("rope_cached_positions_fwd_impl", &rope_cached_positions_fwd_impl)
#define ROPE_2C_CACHED_POSITIONS_FWD_PYBIND \
m.def("rope_cached_positions_2c_fwd_impl", \
&rope_cached_positions_2c_fwd_impl, \
py::arg("output_x"), \
py::arg("output_y"), \
py::arg("input_x"), \
py::arg("input_y"), \
py::arg("cos"), \
py::arg("sin"), \
py::arg("positions"), \
py::arg("rotate_style"), \
py::arg("reuse_freqs_front_part"), \
#define ROPE_1C_UNCACHED_FWD_PYBIND m.def("rope_fwd_impl", &rope_fwd_impl);
#define ROPE_2C_UNCACHED_FWD_PYBIND m.def("rope_2c_fwd_impl", &rope_2c_fwd_impl);
#define ROPE_1C_CACHED_FWD_PYBIND m.def("rope_cached_fwd_impl", &rope_cached_fwd_impl);
#define ROPE_2C_CACHED_FWD_PYBIND m.def("rope_cached_2c_fwd_impl", &rope_cached_2c_fwd_impl);
#define ROPE_1C_THD_FWD_PYBIND m.def("rope_thd_fwd_impl", &rope_thd_fwd_impl);
#define ROPE_1C_2D_FWD_PYBIND m.def("rope_2d_fwd_impl", &rope_2d_fwd_impl);

#define ROPE_1C_UNCACHED_BWD_PYBIND m.def("rope_bwd_impl", &rope_bwd_impl);
#define ROPE_2C_UNCACHED_BWD_PYBIND m.def("rope_2c_bwd_impl", &rope_2c_bwd_impl);
#define ROPE_1C_CACHED_BWD_PYBIND m.def("rope_cached_bwd_impl", &rope_cached_bwd_impl);
#define ROPE_2C_CACHED_BWD_PYBIND m.def("rope_cached_2c_bwd_impl", &rope_cached_2c_bwd_impl);
#define ROPE_1C_THD_BWD_PYBIND m.def("rope_thd_bwd_impl", &rope_thd_bwd_impl);
#define ROPE_1C_2D_BWD_PYBIND m.def("rope_2d_bwd_impl", &rope_2d_bwd_impl);

#define ROPE_1C_CACHED_POSITIONS_FWD_PYBIND \
m.def("rope_cached_positions_fwd_impl", &rope_cached_positions_fwd_impl)
#define ROPE_2C_CACHED_POSITIONS_FWD_PYBIND \
m.def("rope_cached_positions_2c_fwd_impl", \
&rope_cached_positions_2c_fwd_impl, \
py::arg("output_x"), \
py::arg("output_y"), \
py::arg("input_x"), \
py::arg("input_y"), \
py::arg("cos"), \
py::arg("sin"), \
py::arg("positions"), \
py::arg("rotate_style"), \
py::arg("reuse_freqs_front_part"), \
py::arg("nope_first"))
#define ROPE_1C_CACHED_POSITIONS_OFFSETS_FWD_PYBIND m.def("rope_cached_positions_offsets_fwd_impl", &rope_cached_positions_offsets_fwd_impl)
#define ROPE_2C_CACHED_POSITIONS_OFFSETS_FWD_PYBIND m.def("rope_cached_positions_offsets_2c_fwd_impl", &rope_cached_positions_offsets_2c_fwd_impl)
#define ROPE_1C_CACHED_POSITIONS_OFFSETS_FWD_PYBIND \
m.def("rope_cached_positions_offsets_fwd_impl", &rope_cached_positions_offsets_fwd_impl)
#define ROPE_2C_CACHED_POSITIONS_OFFSETS_FWD_PYBIND \
m.def("rope_cached_positions_offsets_2c_fwd_impl", &rope_cached_positions_offsets_2c_fwd_impl)

#define FUSED_QKNORM_MROPE_CACHE_QUANT_PYBIND \
m.def("fused_qk_norm_mrope_3d_cache_pts_quant_shuffle", \
Expand Down
Loading
Loading