diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index 890934a93c..abab9cfc48 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -254,9 +254,6 @@ def fused_moe_( ) if metadata.run_1stage: - assert ( - doweight_stage1 == False - ), "doweight_stage1 not support in fused_moe_1stage" return metadata.stage1( hidden_states, w1, @@ -278,6 +275,9 @@ def fused_moe_( a1_scale=a1_scale, a2_scale=a2_scale, num_local_tokens=num_local_tokens, + M=M, + device=topk_ids.device, + doweight_stage1=doweight_stage1, ) else: return fused_moe_2stages( @@ -333,6 +333,9 @@ def fused_moe_1stage( a1_scale=None, # [expert(local_expert:EP), 1, model_dim] a2_scale=None, # [expert(local_expert:EP), 1, inter_dim] num_local_tokens: Optional[torch.tensor] = None, + M: int = None, + device=None, + doweight_stage1: bool = None, ): if quant_type == QuantType.No and activation == ActivationType.Silu and not isG1U1: # pure bf16 @@ -347,7 +350,31 @@ def fused_moe_1stage( num_valid_ids, topk, ) + elif quant_type == QuantType.per_Token and doweight_stage1 and isG1U1: + a8_type = w1.dtype + _, model_dim, _ = w2.shape + a8 = torch.empty((M, model_dim), dtype=a8_type, device=device) + a8_scale = torch.empty(M, dtype=dtypes.fp32, device=device) + aiter.dynamic_per_token_scaled_quant(a8, hidden_states, a8_scale) + + aiter.fmoe_g1u1_tkw1( + moe_buf, + a8, + w1, + w2, + sorted_ids, + sorted_weights, + sorted_expert_ids, + num_valid_ids, + topk, + a8_scale, + w1_scale, + w2_scale, + kernelName, + a2_scale, + activation, + ) else: quant_func = get_quant(quant_type) if hidden_states.dtype != q_dtype_a: @@ -451,23 +478,25 @@ def get_block_size_M(token, topk, expert, inter_dim): fused_moe_1stage_dict = { "gfx942": { - # activation, quant_type, dtype, q_dtype_a, q_dtype_w, isG1U1, API - (ActivationType.Silu, QuantType.No, dtypes.bf16, dtypes.bf16, dtypes.bf16, False) : aiter.fmoe, - (ActivationType.Silu, QuantType.No, dtypes.fp16, dtypes.fp16, dtypes.fp16, False) : aiter.fmoe, - (ActivationType.Gelu, QuantType.per_Token, dtypes.bf16, dtypes.fp8, dtypes.i4x2, True) : aiter.fmoe_g1u1, - (ActivationType.Silu, QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2, dtypes.fp4x2, True) : aiter.fmoe_g1u1, - (ActivationType.Silu, QuantType.per_Token, dtypes.bf16, dtypes.i8, dtypes.i8, True) : aiter.fmoe_g1u1, - (ActivationType.Gelu, QuantType.per_Token, dtypes.bf16, dtypes.i8, dtypes.i8, True) : aiter.fmoe_g1u1, - (ActivationType.Silu, QuantType.per_Token, dtypes.bf16, dtypes.fp8, dtypes.fp8, True) : aiter.fmoe_g1u1, - (ActivationType.Gelu, QuantType.per_Token, dtypes.bf16, dtypes.fp8, dtypes.fp8, True) : aiter.fmoe_g1u1, - (ActivationType.Silu, QuantType.per_1x128, dtypes.bf16, dtypes.fp8, dtypes.fp8, True) : aiter.fmoe_g1u1, - (ActivationType.Silu, QuantType.per_Token, dtypes.bf16, dtypes.i8, dtypes.i8, False) : aiter.fmoe_int8_g1u0, - (ActivationType.Gelu, QuantType.per_Token, dtypes.bf16, dtypes.i8, dtypes.i8, False) : aiter.fmoe_int8_g1u0, + # activation, quant_type, dtype, q_dtype_a, q_dtype_w, isG1U1, doweight_stage1, API + (ActivationType.Silu, QuantType.No, dtypes.bf16, dtypes.bf16, dtypes.bf16, False, False) : aiter.fmoe, + (ActivationType.Silu, QuantType.No, dtypes.fp16, dtypes.fp16, dtypes.fp16, False, False) : aiter.fmoe, + (ActivationType.Gelu, QuantType.per_Token, dtypes.bf16, dtypes.fp8, dtypes.i4x2, True, False) : aiter.fmoe_g1u1, + (ActivationType.Silu, QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2, dtypes.fp4x2, True, False) : aiter.fmoe_g1u1, + (ActivationType.Silu, QuantType.per_Token, dtypes.bf16, dtypes.i8, dtypes.i8, True, False) : aiter.fmoe_g1u1, + (ActivationType.Gelu, QuantType.per_Token, dtypes.bf16, dtypes.i8, dtypes.i8, True, False) : aiter.fmoe_g1u1, + (ActivationType.Silu, QuantType.per_Token, dtypes.bf16, dtypes.fp8, dtypes.fp8, True, False) : aiter.fmoe_g1u1, + (ActivationType.Gelu, QuantType.per_Token, dtypes.bf16, dtypes.fp8, dtypes.fp8, True, False) : aiter.fmoe_g1u1, + (ActivationType.Silu, QuantType.per_1x128, dtypes.bf16, dtypes.fp8, dtypes.fp8, True, False) : aiter.fmoe_g1u1, + (ActivationType.Silu, QuantType.per_Token, dtypes.bf16, dtypes.i8, dtypes.i8, False, False) : aiter.fmoe_int8_g1u0, + (ActivationType.Gelu, QuantType.per_Token, dtypes.bf16, dtypes.i8, dtypes.i8, False, False) : aiter.fmoe_int8_g1u0, }, "gfx950": { - (ActivationType.Silu, QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2, dtypes.fp4x2, True) : aiter.fmoe_g1u1, - (ActivationType.Silu, QuantType.per_1x128, dtypes.bf16, dtypes.fp8, dtypes.fp8, True) : aiter.fmoe_fp8_blockscale_g1u1, + (ActivationType.Silu, QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2, dtypes.fp4x2, True, False) : aiter.fmoe_g1u1, + (ActivationType.Silu, QuantType.per_1x128, dtypes.bf16, dtypes.fp8, dtypes.fp8, True, False) : aiter.fmoe_fp8_blockscale_g1u1, + (ActivationType.Silu, QuantType.per_Token, dtypes.bf16, dtypes.bf16, dtypes.bf16, False, False) : aiter.fmoe, + (ActivationType.Silu, QuantType.per_Token, dtypes.bf16, dtypes.fp8, dtypes.fp8, True, True) : aiter.fmoe_g1u1_tkw1, } } # fmt: on @@ -601,21 +630,20 @@ def FinalFunc(): kernelName2 = "" run_1stage = False if ( - not doweight_stage1 - and ( - activation, - q_type, - dtype, - q_dtype_a, - q_dtype_w, - use_g1u1, - ) - in fused_moe_1stage_dict[get_gfx()] - ): + activation, + q_type, + dtype, + q_dtype_a, + q_dtype_w, + use_g1u1, + doweight_stage1, + ) in fused_moe_1stage_dict[get_gfx()]: if q_type == QuantType.per_1x128: run_1stage = True and (inter_dim % 256 == 0) - elif q_type == QuantType.per_Token and q_dtype_w in [dtypes.i8, dtypes.fp8]: + elif q_type == QuantType.per_Token and q_dtype_w == dtypes.i8: run_1stage = token > 32 + elif q_type == QuantType.per_Token and q_dtype_w == dtypes.fp8: + run_1stage = token > 16 elif q_type != QuantType.per_1x32: run_1stage = token < 256 diff --git a/aiter/fused_moe_bf16_asm.py b/aiter/fused_moe_bf16_asm.py index 81df5ea592..87a9ccbc43 100755 --- a/aiter/fused_moe_bf16_asm.py +++ b/aiter/fused_moe_bf16_asm.py @@ -8,6 +8,7 @@ from aiter import logger from aiter import pertoken_quant, get_hip_quant from aiter import ActivationType, QuantType, dtypes +from aiter.fused_moe import fused_moe BLOCK_SIZE_M = 32 @@ -280,143 +281,22 @@ def asm_moe_tkw1( expert_mask=None, activation=ActivationType.Silu, ): - E, model_dim, inter_dim = w2.shape - global_E = E - if expert_mask is not None: - global_E = expert_mask.numel() - M, topk = topk_ids.shape - dtype = hidden_states.dtype - device = topk_ids.device - lastdim_mul = 8 if w1.dtype in {dtypes.i32, torch.uint32} else 1 - sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = ( - moe_sorting_ck( - topk_ids, topk_weight, global_E, model_dim, dtype, BLOCK_SIZE_M, expert_mask - ) + return fused_moe( + hidden_states, + w1, + w2, + topk_weight, + topk_ids, + expert_mask=expert_mask, + activation=activation, + quant_type=QuantType.per_Token, + doweight_stage1=True, + w1_scale=fc1_scale, + w2_scale=fc2_scale, + a1_scale=fc1_smooth_scale, + a2_scale=fc2_smooth_scale, ) - if fc1_scale is None: - # pure bf16 - aiter.fmoe( - moe_buf, - hidden_states, - w1, - w2, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - topk, - ) - elif a16: - # a16w8 smooth quant fmoe - if w1.dtype == dtypes.fp8 and inter_dim * 2 == w1.shape[1]: - aiter.fmoe_fp8_g1u1_a16( - moe_buf, - hidden_states, - w1, - w2, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - topk, - fc1_scale, - fc2_scale, - fc1_smooth_scale, - fc2_smooth_scale, - ) - elif w1.dtype == dtypes.i8 and inter_dim == w1.shape[1]: - aiter.fmoe_int8_g1u0_a16( - moe_buf, - hidden_states, - w1, - w2, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - topk, - fc1_scale, - fc2_scale, - fc1_smooth_scale, - fc2_smooth_scale, - ) - else: - raise ValueError(f"Invalid args: {w1.dtype} {w1.shape=} {w2.shape=}") - - else: - # a8w8 fmoe, opt: smooth quant - a8_type = ( - w1.dtype - if w1.dtype != dtypes.i32 and w1.dtype != torch.uint32 - else dtypes.fp8 - ) - if fc1_smooth_scale is not None: - a8 = torch.empty((topk * M, model_dim), dtype=a8_type, device=device) - a8_scale = torch.empty((topk * M), dtype=dtypes.fp32, device=device) - - # moe_smoothquant_fwd need topk_ids which contains local_expert_id - if expert_mask is not None: - local_expert_hash = expert_mask.cumsum(0, dtype=dtypes.i32) - local_expert_hash[local_expert_hash > 0] -= 1 - topk_ids = local_expert_hash[topk_ids] - - aiter.moe_smoothquant_fwd( - a8, hidden_states, fc1_smooth_scale, topk_ids, a8_scale - ) - else: - if ( - w1.dtype == dtypes.fp8 - or w1.dtype == dtypes.i32 - and w1.dtype == torch.uint32 - ): - a8 = torch.empty((M, model_dim), dtype=a8_type, device=device) - a8_scale = torch.empty(M, dtype=dtypes.fp32, device=device) - if per_tensor_quant_scale is None: - aiter.dynamic_per_token_scaled_quant(a8, hidden_states, a8_scale) - else: - aiter.static_per_tensor_quant( - a8, hidden_states, per_tensor_quant_scale - ) - a8_scale.fill_(per_tensor_quant_scale) - elif w1.dtype == dtypes.i8: - a8 = torch.empty((M, model_dim), dtype=w1.dtype, device=device) - a8_scale = torch.empty(M, dtype=dtypes.fp32, device=device) - fc1_smooth_scale = torch.ones( - model_dim, dtype=dtypes.fp32, device=device - ) - aiter.smoothquant_fwd(a8, hidden_states, fc1_smooth_scale, a8_scale) - else: - logger.warning("FMOE fall into pure torch quant...") - a8, a8_scale = aiter.pertoken_quant(hidden_states, quant_dtype=w1.dtype) - if w2.shape[2] * 2 * lastdim_mul == w1.shape[1]: - fmoe_func = aiter.fmoe_g1u1_tkw1 - - else: - raise ValueError( - f"Invalid MoE weight: {w1.shape=} {w2.shape=} {lastdim_mul}" - ) - - fmoe_func( - moe_buf, - a8, - w1, - w2, - sorted_ids, - sorted_weights, - sorted_expert_ids, - num_valid_ids, - topk, - a8_scale, - fc1_scale, - fc2_scale, - "", - fc2_smooth_scale, - activation, - ) - # fc2_smooth_scale) - return moe_buf - def get_block_size(token, topk, expert): token_per_expert = token * topk / expert