Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7d417a5
redirect asm_moe_tkw1 call to fused_moe in order to force kernel tuning
antsaukk Nov 7, 2025
81dad6c
add required keys to fused_moe_1stage_dict
antsaukk Nov 7, 2025
cac3dc2
add kernel descriptors and code object files
antsaukk Nov 7, 2025
75fa008
add 32x128 file descriptors and code objects for tuning
antsaukk Nov 7, 2025
e5e5787
move code objects and kernel descriptors to correct csv
antsaukk Nov 7, 2025
5e37fc9
remove unnecessary import, add quant type argument
antsaukk Nov 11, 2025
2d0c5a1
move fused_moe_stage1_tkw1 into fused_moe.py
antsaukk Nov 11, 2025
702b73c
remove unnecessary kernel code object files
antsaukk Nov 11, 2025
5359944
merging latest main
antsaukk Nov 12, 2025
52002fc
Merge branch 'main' into asm_moe_tkw1_refactoring
antsaukk Nov 12, 2025
8e65339
Add missing comma
antsaukk Nov 12, 2025
5152563
saved modified tuned fmoe config for testing purposes
antsaukk Nov 12, 2025
c2faf4d
apply black required formatting
antsaukk Nov 12, 2025
42fe584
Merge branch 'main' into asm_moe_tkw1_refactoring
antsaukk Nov 12, 2025
7a36ba6
remove fused_moe_stage1_tkw1 and place aiter.fmoe_g1u1_tkw1 under fus…
antsaukk Nov 18, 2025
2a95fc9
Merge branch 'main' into asm_moe_tkw1_refactoring
antsaukk Nov 18, 2025
67c74ad
remove unnecesary arguments
antsaukk Nov 18, 2025
1cfe55f
apply black formatting
antsaukk Nov 18, 2025
4902128
simplify aiter.fmoe_g1u1_tkw1 call
antsaukk Nov 18, 2025
0a66435
add doweight_stage1 column to fused_moe_1stage_dict map and remove el…
antsaukk Nov 18, 2025
35ed2b2
Merge branch 'main' into asm_moe_tkw1_refactoring
antsaukk Nov 18, 2025
4c5ebf6
add doweight_stage1 to query key
antsaukk Nov 18, 2025
0430e19
modidy elif to select run_stage=True for tokens > 16
antsaukk Nov 18, 2025
5145cdc
apply black formatting
antsaukk Nov 18, 2025
8d6b209
Merge branch 'main' into asm_moe_tkw1_refactoring
antsaukk Nov 19, 2025
e2df6c6
removing csv and .co files as they will come in separate commit
antsaukk Nov 19, 2025
efc4dcf
removing log logger.info(f[get_2stage_cfgs] run_1stage)
anugodavar Nov 19, 2025
12acca5
Merge branch 'main' into asm_moe_tkw1_refactoring
anugodavar Nov 19, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 57 additions & 29 deletions aiter/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,6 @@ def fused_moe_(
)

if metadata.run_1stage:
assert (
doweight_stage1 == False
), "doweight_stage1 not support in fused_moe_1stage"
return metadata.stage1(
hidden_states,
w1,
Expand All @@ -278,6 +275,9 @@ def fused_moe_(
a1_scale=a1_scale,
a2_scale=a2_scale,
num_local_tokens=num_local_tokens,
M=M,
device=topk_ids.device,
doweight_stage1=doweight_stage1,
)
else:
return fused_moe_2stages(
Expand Down Expand Up @@ -333,6 +333,9 @@ def fused_moe_1stage(
a1_scale=None, # [expert(local_expert:EP), 1, model_dim]
a2_scale=None, # [expert(local_expert:EP), 1, inter_dim]
num_local_tokens: Optional[torch.tensor] = None,
M: int = None,
device=None,
doweight_stage1: bool = None,
):
if quant_type == QuantType.No and activation == ActivationType.Silu and not isG1U1:
# pure bf16
Expand All @@ -347,7 +350,31 @@ def fused_moe_1stage(
num_valid_ids,
topk,
)
elif quant_type == QuantType.per_Token and doweight_stage1 and isG1U1:
a8_type = w1.dtype
_, model_dim, _ = w2.shape

a8 = torch.empty((M, model_dim), dtype=a8_type, device=device)
a8_scale = torch.empty(M, dtype=dtypes.fp32, device=device)
aiter.dynamic_per_token_scaled_quant(a8, hidden_states, a8_scale)

aiter.fmoe_g1u1_tkw1(
moe_buf,
a8,
w1,
w2,
sorted_ids,
sorted_weights,
sorted_expert_ids,
num_valid_ids,
topk,
a8_scale,
w1_scale,
w2_scale,
kernelName,
a2_scale,
activation,
)
else:
quant_func = get_quant(quant_type)
if hidden_states.dtype != q_dtype_a:
Expand Down Expand Up @@ -451,23 +478,25 @@ def get_block_size_M(token, topk, expert, inter_dim):
fused_moe_1stage_dict = {
"gfx942":
{
# activation, quant_type, dtype, q_dtype_a, q_dtype_w, isG1U1, API
(ActivationType.Silu, QuantType.No, dtypes.bf16, dtypes.bf16, dtypes.bf16, False) : aiter.fmoe,
(ActivationType.Silu, QuantType.No, dtypes.fp16, dtypes.fp16, dtypes.fp16, False) : aiter.fmoe,
(ActivationType.Gelu, QuantType.per_Token, dtypes.bf16, dtypes.fp8, dtypes.i4x2, True) : aiter.fmoe_g1u1,
(ActivationType.Silu, QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2, dtypes.fp4x2, True) : aiter.fmoe_g1u1,
(ActivationType.Silu, QuantType.per_Token, dtypes.bf16, dtypes.i8, dtypes.i8, True) : aiter.fmoe_g1u1,
(ActivationType.Gelu, QuantType.per_Token, dtypes.bf16, dtypes.i8, dtypes.i8, True) : aiter.fmoe_g1u1,
(ActivationType.Silu, QuantType.per_Token, dtypes.bf16, dtypes.fp8, dtypes.fp8, True) : aiter.fmoe_g1u1,
(ActivationType.Gelu, QuantType.per_Token, dtypes.bf16, dtypes.fp8, dtypes.fp8, True) : aiter.fmoe_g1u1,
(ActivationType.Silu, QuantType.per_1x128, dtypes.bf16, dtypes.fp8, dtypes.fp8, True) : aiter.fmoe_g1u1,
(ActivationType.Silu, QuantType.per_Token, dtypes.bf16, dtypes.i8, dtypes.i8, False) : aiter.fmoe_int8_g1u0,
(ActivationType.Gelu, QuantType.per_Token, dtypes.bf16, dtypes.i8, dtypes.i8, False) : aiter.fmoe_int8_g1u0,
# activation, quant_type, dtype, q_dtype_a, q_dtype_w, isG1U1, doweight_stage1, API
(ActivationType.Silu, QuantType.No, dtypes.bf16, dtypes.bf16, dtypes.bf16, False, False) : aiter.fmoe,
(ActivationType.Silu, QuantType.No, dtypes.fp16, dtypes.fp16, dtypes.fp16, False, False) : aiter.fmoe,
(ActivationType.Gelu, QuantType.per_Token, dtypes.bf16, dtypes.fp8, dtypes.i4x2, True, False) : aiter.fmoe_g1u1,
(ActivationType.Silu, QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2, dtypes.fp4x2, True, False) : aiter.fmoe_g1u1,
(ActivationType.Silu, QuantType.per_Token, dtypes.bf16, dtypes.i8, dtypes.i8, True, False) : aiter.fmoe_g1u1,
(ActivationType.Gelu, QuantType.per_Token, dtypes.bf16, dtypes.i8, dtypes.i8, True, False) : aiter.fmoe_g1u1,
(ActivationType.Silu, QuantType.per_Token, dtypes.bf16, dtypes.fp8, dtypes.fp8, True, False) : aiter.fmoe_g1u1,
(ActivationType.Gelu, QuantType.per_Token, dtypes.bf16, dtypes.fp8, dtypes.fp8, True, False) : aiter.fmoe_g1u1,
(ActivationType.Silu, QuantType.per_1x128, dtypes.bf16, dtypes.fp8, dtypes.fp8, True, False) : aiter.fmoe_g1u1,
(ActivationType.Silu, QuantType.per_Token, dtypes.bf16, dtypes.i8, dtypes.i8, False, False) : aiter.fmoe_int8_g1u0,
(ActivationType.Gelu, QuantType.per_Token, dtypes.bf16, dtypes.i8, dtypes.i8, False, False) : aiter.fmoe_int8_g1u0,
},
"gfx950":
{
(ActivationType.Silu, QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2, dtypes.fp4x2, True) : aiter.fmoe_g1u1,
(ActivationType.Silu, QuantType.per_1x128, dtypes.bf16, dtypes.fp8, dtypes.fp8, True) : aiter.fmoe_fp8_blockscale_g1u1,
(ActivationType.Silu, QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2, dtypes.fp4x2, True, False) : aiter.fmoe_g1u1,
(ActivationType.Silu, QuantType.per_1x128, dtypes.bf16, dtypes.fp8, dtypes.fp8, True, False) : aiter.fmoe_fp8_blockscale_g1u1,
(ActivationType.Silu, QuantType.per_Token, dtypes.bf16, dtypes.bf16, dtypes.bf16, False, False) : aiter.fmoe,
(ActivationType.Silu, QuantType.per_Token, dtypes.bf16, dtypes.fp8, dtypes.fp8, True, True) : aiter.fmoe_g1u1_tkw1,
}
}
# fmt: on
Expand Down Expand Up @@ -601,21 +630,20 @@ def FinalFunc():
kernelName2 = ""
run_1stage = False
if (
not doweight_stage1
and (
activation,
q_type,
dtype,
q_dtype_a,
q_dtype_w,
use_g1u1,
)
in fused_moe_1stage_dict[get_gfx()]
):
activation,
q_type,
dtype,
q_dtype_a,
q_dtype_w,
use_g1u1,
doweight_stage1,
) in fused_moe_1stage_dict[get_gfx()]:
if q_type == QuantType.per_1x128:
run_1stage = True and (inter_dim % 256 == 0)
elif q_type == QuantType.per_Token and q_dtype_w in [dtypes.i8, dtypes.fp8]:
elif q_type == QuantType.per_Token and q_dtype_w == dtypes.i8:
run_1stage = token > 32
elif q_type == QuantType.per_Token and q_dtype_w == dtypes.fp8:
run_1stage = token > 16
elif q_type != QuantType.per_1x32:
run_1stage = token < 256

Expand Down
150 changes: 15 additions & 135 deletions aiter/fused_moe_bf16_asm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from aiter import logger
from aiter import pertoken_quant, get_hip_quant
from aiter import ActivationType, QuantType, dtypes
from aiter.fused_moe import fused_moe

BLOCK_SIZE_M = 32

Expand Down Expand Up @@ -280,143 +281,22 @@ def asm_moe_tkw1(
expert_mask=None,
activation=ActivationType.Silu,
):
E, model_dim, inter_dim = w2.shape
global_E = E
if expert_mask is not None:
global_E = expert_mask.numel()
M, topk = topk_ids.shape
dtype = hidden_states.dtype
device = topk_ids.device
lastdim_mul = 8 if w1.dtype in {dtypes.i32, torch.uint32} else 1
sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = (
moe_sorting_ck(
topk_ids, topk_weight, global_E, model_dim, dtype, BLOCK_SIZE_M, expert_mask
)
return fused_moe(
hidden_states,
w1,
w2,
topk_weight,
topk_ids,
expert_mask=expert_mask,
activation=activation,
quant_type=QuantType.per_Token,
doweight_stage1=True,
w1_scale=fc1_scale,
w2_scale=fc2_scale,
a1_scale=fc1_smooth_scale,
a2_scale=fc2_smooth_scale,
)

if fc1_scale is None:
# pure bf16
aiter.fmoe(
moe_buf,
hidden_states,
w1,
w2,
sorted_ids,
sorted_weights,
sorted_expert_ids,
num_valid_ids,
topk,
)
elif a16:
# a16w8 smooth quant fmoe
if w1.dtype == dtypes.fp8 and inter_dim * 2 == w1.shape[1]:
aiter.fmoe_fp8_g1u1_a16(
moe_buf,
hidden_states,
w1,
w2,
sorted_ids,
sorted_weights,
sorted_expert_ids,
num_valid_ids,
topk,
fc1_scale,
fc2_scale,
fc1_smooth_scale,
fc2_smooth_scale,
)
elif w1.dtype == dtypes.i8 and inter_dim == w1.shape[1]:
aiter.fmoe_int8_g1u0_a16(
moe_buf,
hidden_states,
w1,
w2,
sorted_ids,
sorted_weights,
sorted_expert_ids,
num_valid_ids,
topk,
fc1_scale,
fc2_scale,
fc1_smooth_scale,
fc2_smooth_scale,
)
else:
raise ValueError(f"Invalid args: {w1.dtype} {w1.shape=} {w2.shape=}")

else:
# a8w8 fmoe, opt: smooth quant
a8_type = (
w1.dtype
if w1.dtype != dtypes.i32 and w1.dtype != torch.uint32
else dtypes.fp8
)
if fc1_smooth_scale is not None:
a8 = torch.empty((topk * M, model_dim), dtype=a8_type, device=device)
a8_scale = torch.empty((topk * M), dtype=dtypes.fp32, device=device)

# moe_smoothquant_fwd need topk_ids which contains local_expert_id
if expert_mask is not None:
local_expert_hash = expert_mask.cumsum(0, dtype=dtypes.i32)
local_expert_hash[local_expert_hash > 0] -= 1
topk_ids = local_expert_hash[topk_ids]

aiter.moe_smoothquant_fwd(
a8, hidden_states, fc1_smooth_scale, topk_ids, a8_scale
)
else:
if (
w1.dtype == dtypes.fp8
or w1.dtype == dtypes.i32
and w1.dtype == torch.uint32
):
a8 = torch.empty((M, model_dim), dtype=a8_type, device=device)
a8_scale = torch.empty(M, dtype=dtypes.fp32, device=device)
if per_tensor_quant_scale is None:
aiter.dynamic_per_token_scaled_quant(a8, hidden_states, a8_scale)
else:
aiter.static_per_tensor_quant(
a8, hidden_states, per_tensor_quant_scale
)
a8_scale.fill_(per_tensor_quant_scale)
elif w1.dtype == dtypes.i8:
a8 = torch.empty((M, model_dim), dtype=w1.dtype, device=device)
a8_scale = torch.empty(M, dtype=dtypes.fp32, device=device)
fc1_smooth_scale = torch.ones(
model_dim, dtype=dtypes.fp32, device=device
)
aiter.smoothquant_fwd(a8, hidden_states, fc1_smooth_scale, a8_scale)
else:
logger.warning("FMOE fall into pure torch quant...")
a8, a8_scale = aiter.pertoken_quant(hidden_states, quant_dtype=w1.dtype)
if w2.shape[2] * 2 * lastdim_mul == w1.shape[1]:
fmoe_func = aiter.fmoe_g1u1_tkw1

else:
raise ValueError(
f"Invalid MoE weight: {w1.shape=} {w2.shape=} {lastdim_mul}"
)

fmoe_func(
moe_buf,
a8,
w1,
w2,
sorted_ids,
sorted_weights,
sorted_expert_ids,
num_valid_ids,
topk,
a8_scale,
fc1_scale,
fc2_scale,
"",
fc2_smooth_scale,
activation,
)
# fc2_smooth_scale)
return moe_buf


def get_block_size(token, topk, expert):
token_per_expert = token * topk / expert
Expand Down
Loading