diff --git a/aiter/fused_moe.py b/aiter/fused_moe.py index b356e0b235..b9f6c80161 100644 --- a/aiter/fused_moe.py +++ b/aiter/fused_moe.py @@ -107,7 +107,12 @@ def fused_moe( num_local_tokens: Optional[torch.tensor] = None, moe_sorting_dispatch_policy=0, dtype=None, -): + # following for cktile support + hidden_pad=0, + intermediate_pad=0, + bias1=None, + bias2=None, +): if not block_size_M: block_size_M = -1 return fused_moe_( @@ -128,6 +133,10 @@ def fused_moe( num_local_tokens=num_local_tokens, moe_sorting_dispatch_policy=moe_sorting_dispatch_policy, dtype=dtype, + hidden_pad=hidden_pad, + intermediate_pad=intermediate_pad, + bias1=bias1, + bias2=bias2, ) @@ -181,6 +190,10 @@ def fused_moe_( num_local_tokens: Optional[torch.Tensor] = None, moe_sorting_dispatch_policy: bool = 0, dtype: Optional[torch.dtype] = None, + hidden_pad: int = 0, + intermediate_pad: int = 0, + bias1: Optional[torch.Tensor] = None, + bias2: Optional[torch.Tensor] = None, ) -> torch.Tensor: # We do such convert since custom_op schema restriction on block_size_M, and Enum type activation = ActivationType(activation) @@ -223,6 +236,10 @@ def fused_moe_( isG1U1, activation, doweight_stage1, + hidden_pad, + intermediate_pad, + bias1, + bias2, ) block_size_M = metadata.block_m if block_size_M is None else block_size_M @@ -255,6 +272,8 @@ def fused_moe_( moe_buf, isG1U1, block_size_M, + # activation=activation, + # quant_type=quant_type, q_dtype_a=q_dtype_a, q_dtype_w=q_dtype_w, w1_scale=w1_scale, @@ -286,6 +305,11 @@ def fused_moe_( a1_scale=a1_scale, a2_scale=a2_scale, num_local_tokens=num_local_tokens, + # following for cktile support + hidden_pad=hidden_pad, + intermediate_pad=intermediate_pad, + bias1=bias1, + bias2=bias2, ) @@ -494,6 +518,10 @@ def get_2stage_cfgs( use_g1u1, activation, doweight_stage1, + hidden_pad, + intermediate_pad, + bias1, + bias2, ): def get_cfg_2stages(tune_file): import pandas as pd @@ -501,7 +529,6 @@ def get_cfg_2stages(tune_file): cfg_2stages = pd.read_csv(tune_file) cfg_2stages = cfg_2stages.set_index( [ - "cu_num", "token", "model_dim", "inter_dim", @@ -548,7 +575,6 @@ def MainFunc(): f.write( "token,model_dim,inter_dim,expert,topk,act_type,dtype,q_dtype_a,q_dtype_w,q_type,use_g1u1,doweight_stage1" ) - q_dtype_ws = q_dtype_w if q_dtype_w != torch.uint32 else "torch.int4" f.write( f"\n{token},{model_dim},{inter_dim},{expert},{topk},{activation},{dtype},{q_dtype_a},{q_dtype_ws},{q_type},{int(use_g1u1)},{int(doweight_stage1)}" @@ -627,6 +653,24 @@ def FinalFunc(): ksplit, run_1stage, ) + if dtype in [dtypes.bf16, dtypes.fp16] and q_type == QuantType.per_1x32 and activation == ActivationType.Swiglu: + return MOEMetadata( + functools.partial( + cktile_moe_stage1, + n_pad_zeros=intermediate_pad // 64 * 64 * (2 if use_g1u1 else 1), + k_pad_zeros=hidden_pad // 128 * 128, + bias1=bias1, + ), + functools.partial( + cktile_moe_stage2, + n_pad_zeros=hidden_pad // 64 * 64, + k_pad_zeros=intermediate_pad // 128 * 128, + bias2=bias2, + ), + 16 if token < 2048 else 32, + ksplit, + False, + ) if ( "ck2stages" in kernelName1 or (q_type == QuantType.per_1x128 and doweight_stage1) @@ -704,6 +748,11 @@ def fused_moe_2stages( a1_scale=None, # [expert(local_expert:EP), 1, model_dim] a2_scale=None, # [expert(local_expert:EP), 1, inter_dim] num_local_tokens: Optional[torch.tensor] = None, + # following for cktile support + hidden_pad=0, + intermediate_pad=0, + bias1=None, + bias2=None, ): quant_func = get_quant(quant_type) @@ -725,9 +774,18 @@ def fused_moe_2stages( isG1U1, activation, doweight_stage1, + hidden_pad, + intermediate_pad, + bias1, + bias2, ) - - if quant_type == QuantType.per_1x32: + if quant_type == QuantType.per_1x32 \ + and dtype in [dtypes.bf16, dtypes.fp16] \ + and w1.dtype == dtypes.fp4x2 \ + and activation == ActivationType.Swiglu: + a1 = hidden_states.to(dtype) + a1_scale = None + elif quant_type == QuantType.per_1x32: a1, a1_scale = quant_func( hidden_states, scale=a1_scale, @@ -768,7 +826,7 @@ def fused_moe_2stages( dtype=dtype, device=device, ) - + a2 = metadata.stage1( a1, w1, @@ -784,7 +842,11 @@ def fused_moe_2stages( sorted_weights=sorted_weights if doweight_stage1 else None, ) - if quant_type == QuantType.per_1x32: + if quant_type == QuantType.per_1x32 \ + and dtype in [dtypes.bf16, dtypes.fp16] \ + and w1.dtype == dtypes.fp4x2: + a2_scale = None + elif quant_type == QuantType.per_1x32: a2 = a2.view(-1, inter_dim) a2, a2_scale = quant_func( a2, @@ -975,7 +1037,7 @@ def torch_moe( return (out * topk_weight.view(B, -1, 1)).sum(dim=1).to(dtype) -# temp workaround for swiglu +#temp workaround for swiglu def swiglu(x_glu, x_linear, alpha: float = 1.702, limit: float = 7.0): # Clamp the input values x_glu = x_glu.clamp(min=None, max=limit) @@ -1103,7 +1165,6 @@ def torch_moe_stage2( w2_bias=None, doweight=True, ): - quant_type = quant_remap.get(quant_type, quant_type) ctype = dtypes.fp32 # compute type E, model_dim, inter_dim = get_inter_dim(w1.shape, w2.shape) if quant_type == QuantType.per_1x32: @@ -1172,6 +1233,101 @@ def torch_moe_stage2( return out.sum(1).to(dtype) +def cktile_moe_stage1( + hidden_states, + w1, + w2, + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + out, + topk, + block_m, + a1_scale, + w1_scale, + sorted_weights=None, + n_pad_zeros=0, + k_pad_zeros=0, + bias1=None, +): + token_num = hidden_states.shape[0] + _, n1, k1 = w1.shape + _, k2, n2 = w2.shape + D = n2 if k2 == k1 else n2*2 #bit4 format + # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size + + if w1.dtype is torch.uint32: + D = D * 8 + out = torch.empty((token_num, topk, D), dtype=hidden_states.dtype, device=hidden_states.device) + # print("Run cktile_moe_stage1: M=%d, N(N*2)=%d, K=%d, topk=%d, expert=%d"%(token_num, w1.shape[1], hidden_states.shape[1], topk, w1.shape[0])) + aiter.moe_cktile2stages_gemm1( + hidden_states, + w1, + out, + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + topk, + n_pad_zeros, + k_pad_zeros, + sorted_weights, + a1_scale, + w1_scale, + bias1, + block_m, + ) + return out + + +def cktile_moe_stage2( + a2, + w1, + w2, + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + out, + topk, + w2_scale, + a2_scale, + block_m, + sorted_weights=None, + zeros_out=False, + n_pad_zeros=0, + k_pad_zeros=0, + bias2=None, +): + token_num = a2.shape[0] + D = w2.shape[1] + # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size + + # out = torch.empty( + # (token_num, D), + # dtype=a2.dtype, + # device=a2.device, + # ) + # if zeros_out: + # out.fill_(0) + # print("Run cktile_moe_stage2: M=%d, N=%d, K=%d, topk=%d, expert=%d"%(a2.shape[0]*a2.shape[1], w2.shape[1], a2.shape[2], topk, w2.shape[0])) + aiter.moe_cktile2stages_gemm2( + a2, + w2, + out, + sorted_token_ids, + sorted_expert_ids, + num_valid_ids, + topk, + n_pad_zeros, + k_pad_zeros, + sorted_weights, + a2_scale, + w2_scale, + bias2, + block_m, + ) + return out + + def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -1233,4 +1389,4 @@ def fused_topk( # if renormalize: # topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids + return topk_weights, topk_ids \ No newline at end of file diff --git a/csrc/include/aiter_enum.h b/csrc/include/aiter_enum.h index 0c35e8158f..15126c8cf6 100644 --- a/csrc/include/aiter_enum.h +++ b/csrc/include/aiter_enum.h @@ -6,7 +6,8 @@ enum class ActivationType : int { No = -1, Silu = 0, - Gelu + Gelu = 1, + Swiglu = 2, }; enum class QuantType : int { diff --git a/csrc/include/rocm_ops.hpp b/csrc/include/rocm_ops.hpp index 45a1b441a1..6b2863d338 100644 --- a/csrc/include/rocm_ops.hpp +++ b/csrc/include/rocm_ops.hpp @@ -1320,6 +1320,7 @@ namespace py = pybind11; .value("No", ActivationType::No) \ .value("Silu", ActivationType::Silu) \ .value("Gelu", ActivationType::Gelu) \ + .value("Swiglu", ActivationType::Swiglu) \ .export_values(); \ pybind11::implicitly_convertible(); \ pybind11::implicitly_convertible(); diff --git a/op_tests/test_moe_2stage.py b/op_tests/test_moe_2stage.py index d66133b56b..cbe2a80fa6 100644 --- a/op_tests/test_moe_2stage.py +++ b/op_tests/test_moe_2stage.py @@ -25,8 +25,8 @@ from aiter.ops.shuffle import ( shuffle_weight, - shuffle_mxfp4_weight, shuffle_mxfp4_scale, + shuffle_mxfp4_weight, shuffle_weight_NK, ) from aiter import ActivationType @@ -80,7 +80,6 @@ def ck_moe_stage1( return out - def ck_moe_stage2( hidden_states, w1, # [E, inter_dim*2, model_dim] @@ -125,7 +124,6 @@ def ck_moe_stage2( ) return out - def cktile_moe_stage1( hidden_states, w1, # [E, inter_dim*2, model_dim] @@ -138,8 +136,8 @@ def cktile_moe_stage1( exp_bias1, dtype, topk, - n_pad_zeros=0, - k_pad_zeros=0, + n_pad_zeros = 0, + k_pad_zeros = 0, block_size=32, Activation=ActivationType.Silu, quant_type=aiter.QuantType.No, @@ -148,7 +146,7 @@ def cktile_moe_stage1( token_num = hidden_states.shape[0] _, n1, k1 = w1.shape _, k2, n2 = w2.shape - D = n2 if k2 == k1 else n2 * 2 # bit4 format + D = n2 if k2 == k1 else n2 * 2 #bit4 format # max_num_tokens_padded = sorted_expert_ids.shape[0]*block_size if w1.dtype is torch.uint32: @@ -173,7 +171,6 @@ def cktile_moe_stage1( ) return out - def cktile_moe_stage2( hidden_states, w1, # [E, inter_dim*2, model_dim] @@ -186,13 +183,13 @@ def cktile_moe_stage2( exp_bias2, dtype, topk, - n_pad_zeros=0, - k_pad_zeros=0, + n_pad_zeros = 0, + k_pad_zeros = 0, block_size=32, Activation=ActivationType.Silu, quant_type=aiter.QuantType.No, sorted_weights=None, # [max_num_tokens_padded] - zeros_out=False, + zeros_out = False ): token_num = hidden_states.shape[0] D = w2.shape[1] @@ -239,32 +236,30 @@ def test_fmoe( WQDType, use_g1u1=False, doweight_stage1=False, + hidden_pad=0, + intermediate_pad=0, ): if get_gfx() not in ["gfx950"] and qType == aiter.QuantType.per_1x32: return torch_quant = aiter.get_torch_quant(qType) torch_act = aiter.get_torch_act(actType) input = torch.randn((token, model_dim), dtype=dtype) - need_pad = qType == aiter.QuantType.per_1x32 - npad0 = 192 - kpad0 = 128 if use_g1u1: w1 = torch.randn((E, inter_dim * 2, model_dim), dtype=dtype) - if need_pad: - w1[:, :, -kpad0:] = 0 - w1[:, -npad0:, :] = 0 - w1[:, inter_dim - npad0 : inter_dim, :] = 0 + if (hidden_pad != 0 and intermediate_pad != 0): + w1[:,:,-hidden_pad:] = 0 + w1[:,-intermediate_pad:,:] = 0 + w1[:,inter_dim-intermediate_pad:inter_dim,:] = 0 exp_bias1 = torch.clamp(torch.randn((E, inter_dim * 2), dtype=dtype), -1.0, 1.0) else: w1 = torch.randn((E, inter_dim, model_dim), dtype=dtype) exp_bias1 = torch.clamp(torch.randn((E * inter_dim), dtype=dtype), -1.0, 1.0) w2 = torch.randn((E, model_dim, inter_dim), dtype=dtype) - if need_pad: - w2[:, :, -kpad0:] = 0 - w2[:, -npad0:, :] = 0 + if (hidden_pad != 0 and intermediate_pad != 0): + w2[:,:,-intermediate_pad:] = 0 + w2[:,-hidden_pad:,:] = 0 exp_bias2 = torch.clamp(torch.randn((E, model_dim), dtype=dtype), -1.0, 1.0) score = torch.randn((token, E), dtype=dtype) - # rand topk_weights, topk_ids = fused_topk(input, score, topk, True) # sequence # topk_ids_list = [[((i * topk) + j)% E for j in range(topk)] for i in range(token)] @@ -272,10 +267,10 @@ def test_fmoe( M, _ = topk_ids.shape - # BLOCK_SIZE_M = get_block_size_M(M, topk, E, inter_dim) - BLOCK_SIZE_M = 32 if M > 1024 else 16 - if qType == aiter.QuantType.per_128x128: - BLOCK_SIZE_M = 64 if M > 64 else 16 + BLOCK_SIZE_M = get_block_size_M(M, topk, E, inter_dim) + # BLOCK_SIZE_M = 32 if M > 1024 else 16 + # if qType == aiter.QuantType.per_128x128: + # BLOCK_SIZE_M = 64 if M > 64 else 16 sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, moe_buf = moe_sorting( topk_ids, topk_weights, E, model_dim, dtype, BLOCK_SIZE_M ) @@ -325,31 +320,25 @@ def weight_per_128x128_quant(weight, quant_dtype): if qType != aiter.QuantType.per_1x32: w1_qt = w1_qt_aiter = w1_qt.view(w1.shape) w2_qt = w2_qt_aiter = w2_qt.view(w2.shape) - else: w1_qt = w1_qt_aiter = w1_qt.view(w1.shape[0], w1.shape[1], w1.shape[2] // 2) w2_qt = w2_qt_aiter = w2_qt.view(w2.shape[0], w2.shape[1], w2.shape[2] // 2) + # Quant-ing a if qType == aiter.QuantType.per_128x128: a1_qt, a1_scale = aiter.pertoken_quant( input.view(token, -1, 128), quant_dtype=AQDType ) a1_qt = a1_qt.view(token, model_dim) a1_scale = a1_scale.squeeze(-1) - elif qType == aiter.QuantType.per_1x32 and ( - AQDType in [dtypes.bf16, dtypes.fp16] - ): # a16w4 + elif qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and WQDType == dtypes.fp4x2: #a16w4 a1_qt = input.to(AQDType) a1_scale = None else: a1_qt, a1_scale = torch_quant(input, quant_dtype=AQDType) # bias dtype convert - if ( - qType == aiter.QuantType.per_1x32 - and (AQDType in [dtypes.bf16, dtypes.fp16]) - and (WQDType == dtypes.fp4x2) - ): # a16w4 + if qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4 exp_bias1_aiter = exp_bias1.to(dtypes.fp32) exp_bias2_aiter = exp_bias2.to(dtypes.fp32) else: @@ -370,26 +359,22 @@ def weight_per_128x128_quant(weight, quant_dtype): shuffle_weight(w2_qt_aiter, (16, 16), use_int4=True) ) ) - elif ( - qType == aiter.QuantType.per_1x32 - and (AQDType in [dtypes.bf16, dtypes.fp16]) - and (WQDType == dtypes.fp4x2) - ): # a16w4 + w1_scale_aiter = fp4_utils.e8m0_shuffle(w1_scale) + w2_scale_aiter = fp4_utils.e8m0_shuffle(w2_scale) + elif qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4 w1_qt_aiter = shuffle_mxfp4_weight(w1_qt_aiter, 16, True) w1_scale_aiter = shuffle_mxfp4_scale(w1_scale, E, True) w2_qt_aiter = shuffle_mxfp4_weight(w2_qt_aiter, 16, False) w2_scale_aiter = shuffle_mxfp4_scale(w2_scale, E, False) - elif ( - WQDType != dtypes.fp4x2 - and (get_gfx() in ["gfx950"]) - and (qType != aiter.QuantType.per_128x128) - ): - inst_K = 128 // w1_qt_aiter.element_size() - w1_qt_aiter = shuffle_weight_NK(w1_qt_aiter, 16, inst_K) - w2_qt_aiter = shuffle_weight_NK(w2_qt_aiter, 16, inst_K) + # elif WQDType != dtypes.fp4x2 and (get_gfx() in ["gfx950"]): + # inst_K = 128 // w1_qt_aiter.element_size() + # w1_qt_aiter = shuffle_weight_NK(w1_qt_aiter, 16, inst_K) + # w2_qt_aiter = shuffle_weight_NK(w2_qt_aiter, 16, inst_K) elif WQDType != dtypes.fp4x2: w1_qt_aiter = shuffle_weight(w1_qt_aiter, layout=(16, 16)) w2_qt_aiter = shuffle_weight(w2_qt_aiter, layout=(16, 16)) + w1_scale_aiter = fp4_utils.e8m0_shuffle(w1_scale) + w2_scale_aiter = fp4_utils.e8m0_shuffle(w2_scale) # # ######################## stage 1 start ########### out1_ref = torch_moe_stage1( @@ -408,56 +393,62 @@ def weight_per_128x128_quant(weight, quant_dtype): ) # # ######################## ck stage 1 start ########### - out1_ck = torch.empty((token, topk, inter_dim), dtype=dtype) - - # out1_ck, us1 = run_perftest( - # ck_moe_stage1, - # a1_qt, - # w1_qt_aiter, - # w2_qt_aiter, - # sorted_ids, - # sorted_expert_ids, - # num_valid_ids, - # w1_scale, - # a1_scale, - # dtype, - # topk, - # BLOCK_SIZE_M, - # actType, - # quant_type=qType, - # sorted_weights=sorted_weights if doweight_stage1 else None, - # needTrace=True, - # ) - - # cktile_2stage - # out1_ck, us1 = run_perftest( - # cktile_moe_stage1, - # a1_qt, - # w1_qt_aiter, - # w2_qt_aiter, - # sorted_ids, - # sorted_expert_ids, - # num_valid_ids, - # w1_scale_aiter, - # a1_scale, - # exp_bias1_aiter, - # dtype, - # topk, - # npad0 * 2, - # kpad0, - # BLOCK_SIZE_M, - # actType, - # quant_type=qType, - # sorted_weights=sorted_weights if doweight_stage1 else None, - # # needTrace=True, - # # num_iters=2, - # # num_warmup=0, - # ) - # checkAllclose( - # out1_ref[:,:-npad0] if need_pad else out1_ref, - # out1_ck[:,:-npad0] if need_pad else out1_ck, - # msg=f"[perf] ck_moe_stage1:{us1:>8.2f} us, {token*model_dim*inter_dim*2*topk*2/us1/1000/1000:>8.2f} tflops......(quant:{AQDType})", - # ) + if WQDType == dtypes.fp4x2 or AQDType == dtypes.fp4x2: + out1_ck = torch.zeros((token, topk, inter_dim), dtype=dtype) + else: + out1_ck = torch.empty((token, topk, inter_dim), dtype=dtype) + if qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4: + npad0 = intermediate_pad // 64 * 64 + kpad0 = hidden_pad // 128 * 128 + out1_ck, us1 = run_perftest( + cktile_moe_stage1, + a1_qt, + w1_qt_aiter, + w2_qt_aiter, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + w1_scale_aiter, + a1_scale, + exp_bias1_aiter, + dtype, + topk, + npad0 * 2, + kpad0, + BLOCK_SIZE_M, + actType, + quant_type=qType, + sorted_weights=sorted_weights if doweight_stage1 else None, + # needTrace=True, + # num_iters=2, + # num_warmup=0, + ) + else: + out1_ck, us1 = run_perftest( + ck_moe_stage1, + a1_qt, + w1_qt_aiter, + w2_qt_aiter, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + w1_scale, + a1_scale, + dtype, + topk, + BLOCK_SIZE_M, + actType, + quant_type=qType, + sorted_weights=sorted_weights if doweight_stage1 else None, + needTrace=True, + ) + + checkAllclose( + out1_ref, + out1_ck, + msg=f"[perf] ck_moe_stage1:{us1:>8.2f} us, {token*model_dim*inter_dim*2*topk*2/us1/1000/1000:>8.2f} tflops......(quant:{AQDType})", + ) + # diff = torch.abs(out1_ref - out1_ck) # max_value= diff.max() # multi_index = np.unravel_index(torch.argmax(diff).item(), diff.shape) @@ -500,7 +491,7 @@ def weight_per_128x128_quant(weight, quant_dtype): out1_ref.view(token, -1, 128), quant_dtype=AQDType ) a2_scale = a2_scale.view(token, topk, -1) - elif qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]): + elif qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4 a2_qt = out1_ref a2_scale = None else: @@ -520,108 +511,110 @@ def weight_per_128x128_quant(weight, quant_dtype): w2_bias=exp_bias2, doweight=not doweight_stage1, ) - # out_ref = torch_moe( - # input, - # w1_qt, - # w2_qt, - # topk_weights, - # topk_ids, - # fc1_scale=w1_scale, - # fc2_scale=w2_scale, - # ) - # checkAllclose(out_ref, out2_ref, msg="[torch] 1_stage vs 2_stage") + # # out_ref = torch_moe( + # # input, + # # w1_qt, + # # w2_qt, + # # topk_weights, + # # topk_ids, + # # fc1_scale=w1_scale, + # # fc2_scale=w2_scale, + # # ) + # # checkAllclose(out_ref, out2_ref, msg="[torch] 1_stage vs 2_stage") out2_ck = torch.empty((token, model_dim), dtype=dtype) - # out2_ck, us2 = run_perftest( - # ck_moe_stage2, - # a2_qt, - # w1_qt_aiter, - # w2_qt_aiter, - # sorted_ids, - # sorted_expert_ids, - # num_valid_ids, - # w2_scale, - # a2_scale, - # dtype, - # topk, - # BLOCK_SIZE_M, - # actType, - # quant_type, - # sorted_weights if not doweight_stage1 else None, - # ) - - # # cktil2stage - # _, us2 = run_perftest( - # cktile_moe_stage2, - # a2_qt, - # w1_qt_aiter, - # w2_qt_aiter, - # sorted_ids, - # sorted_expert_ids, - # num_valid_ids, - # w2_scale_aiter, - # a2_scale, - # exp_bias2_aiter, - # dtype, - # topk, - # npad0, - # kpad0, - # BLOCK_SIZE_M, - # actType, - # quant_type, - # sorted_weights if not doweight_stage1 else None, - # # needTrace=True, - # # num_iters=2, - # # num_warmup=0, - # ) - # out2_ck = cktile_moe_stage2( - # a2_qt, - # w1_qt_aiter, - # w2_qt_aiter, - # sorted_ids, - # sorted_expert_ids, - # num_valid_ids, - # w2_scale_aiter, - # a2_scale, - # exp_bias2_aiter, - # dtype, - # topk, - # npad0, - # kpad0, - # BLOCK_SIZE_M, - # actType, - # quant_type, - # sorted_weights if not doweight_stage1 else None, - # True - # ) - - # checkAllclose( - # out2_ref, - # out2_ck, - # msg=f"[perf] ck_moe_stage2:{us2:>8.2f} us, {token*model_dim*inter_dim*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", - # ) + if qType == aiter.QuantType.per_1x32 and (AQDType in [dtypes.bf16, dtypes.fp16]) and (WQDType == dtypes.fp4x2): #a16w4 + npad0 = hidden_pad // 64 * 64 + kpad0 = intermediate_pad // 128 * 128 + _, us2 = run_perftest( + cktile_moe_stage2, + a2_qt, + w1_qt_aiter, + w2_qt_aiter, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + w2_scale_aiter, + a2_scale, + exp_bias2_aiter, + dtype, + topk, + npad0, + kpad0, + BLOCK_SIZE_M, + actType, + quant_type, + sorted_weights if not doweight_stage1 else None, + # needTrace=True, + # num_iters=2, + # num_warmup=0, + ) + out2_ck = cktile_moe_stage2( + a2_qt, + w1_qt_aiter, + w2_qt_aiter, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + w2_scale_aiter, + a2_scale, + exp_bias2_aiter, + dtype, + topk, + npad0, + kpad0, + BLOCK_SIZE_M, + actType, + quant_type, + sorted_weights if not doweight_stage1 else None, + True + ) + else: + out2_ck, us2 = run_perftest( + ck_moe_stage2, + a2_qt, + w1_qt_aiter, + w2_qt_aiter, + sorted_ids, + sorted_expert_ids, + num_valid_ids, + w2_scale, + a2_scale, + dtype, + topk, + BLOCK_SIZE_M, + actType, + quant_type, + sorted_weights if not doweight_stage1 else None, + ) + + checkAllclose( + out2_ref, + out2_ck, + msg=f"[perf] ck_moe_stage2:{us2:>8.2f} us, {token*model_dim*inter_dim*topk*2/us2/1000/1000:>8.2f} tflops......(quant:{AQDType})", + ) + # diff = torch.abs(out2_ref - out2_ck) # max_value= diff.max() # multi_index = np.unravel_index(torch.argmax(diff).item(), diff.shape) # print("max_diff", max_value.item(), ",ref=", out2_ref[multi_index].item(), ",ck=", out2_ck[multi_index].item()) # ######################## stage 2 end ########### - # # ######################## fused 2 stage ######### - us1 = 0 - out2_ck, us2 = run_perftest( - fused_moe, + out2_ck = fused_moe( input, w1_qt_aiter, w2_qt_aiter, topk_weights, topk_ids, - w1_scale=fp4_utils.e8m0_shuffle( - w1_scale - ), # e8m0_shuffle will do nothing if it's a fp32 - w2_scale=fp4_utils.e8m0_shuffle(w2_scale), + w1_scale=w1_scale_aiter, + w2_scale=w2_scale_aiter, quant_type=qType, activation=actType, doweight_stage1=doweight_stage1, + intermediate_pad=intermediate_pad, + hidden_pad=hidden_pad, + bias1=exp_bias1_aiter, + bias2=exp_bias2_aiter, ) err = checkAllclose( out2_ref, @@ -645,10 +638,11 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): l_dtype = ["bf16", "fp16"][:1] # l_dim = [(6144, 4096)] l_dim = [(7168, 256)] +# l_dim = [(3072, 3072)] l_tokenNum = [ # 1, - # 3, - # 5, + # 2, + # 4, 8, # 16, # 32, @@ -656,20 +650,26 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): # 128, # 256, # 1024, + # 2048, + # 3072, # 4096, + # 8192, # 163840, ] +l_act = [aiter.ActivationType.Silu, aiter.ActivationType.Gelu] l_quant = [ - # (aiter.QuantType.No, None, None), # a16w16 + # (aiter.QuantType.No, None, None), # a16w16 # (aiter.QuantType.per_Tensor, dtypes.fp8, dtypes.fp8), # a8w8 # (aiter.QuantType.per_Token, dtypes.fp8, dtypes.fp8), # a8w8 # (aiter.QuantType.per_Token, dtypes.fp8, torch.int4), # a8w4 # (aiter.QuantType.per_1x32, dtypes.fp4x2, dtypes.fp4x2), # a4w4 - (aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8), # a8w8 - # (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2), # a16w4 + # (aiter.QuantType.per_128x128, dtypes.fp8, dtypes.fp8), # a8w8 + (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2), # a16w4 + ] -l_act = [aiter.ActivationType.Silu, aiter.ActivationType.Gelu][:1] l_doweight_stage1 = [False, True][:1] +l_hidden_intermediate_pad = [(0, 0), (65, 65), (129, 191)] + parser = argparse.ArgumentParser( formatter_class=argparse.RawTextHelpFormatter, @@ -748,7 +748,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): "-e", "--expert", type=int, - default=8, + default=256, help="""Number of experts. e.g.: -e 8""", ) @@ -757,7 +757,7 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): "-k", "--topk", type=int, - default=2, + default=8, help="""Number of top experts. e.g.: -k 2""", ) @@ -781,30 +781,52 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): if args.doweight_stage1 is not None: l_doweight_stage1 = [args.doweight_stage1] - + +df = [] for ( dtype, - act_type, (quant_type, aq_dtype, wq_dtype), (model_dim, inter_dim), doweight_stage1, -) in itertools.product(l_dtype, l_act, l_quant, l_dim, l_doweight_stage1): - df = [] - for m in l_tokenNum: - ret = test_fmoe( - dtype, - m, - model_dim, - inter_dim, - args.expert, - args.topk, - act_type, - quant_type, - aq_dtype, - wq_dtype, - use_g1u1=True, - doweight_stage1=doweight_stage1, - ) - df.append(ret) - df = pd.DataFrame(df) - aiter.logger.info(f"summary:\n{df}") +) in itertools.product(l_dtype, l_quant, l_dim, l_doweight_stage1): + if (quant_type, aq_dtype, wq_dtype) == (aiter.QuantType.per_1x32, dtypes.bf16, dtypes.fp4x2): + for (hidden_pad, intermediate_pad) in l_hidden_intermediate_pad: + print(f"hidden_pad={hidden_pad}, intermediate_pad={intermediate_pad}") + for m in l_tokenNum: + ret = test_fmoe( + dtype, + m, + model_dim, + inter_dim, + args.expert, + args.topk, + aiter.ActivationType.Swiglu, + quant_type, + aq_dtype, + wq_dtype, + use_g1u1=True, + doweight_stage1=doweight_stage1, + hidden_pad=hidden_pad, + intermediate_pad=intermediate_pad, + ) + df.append(ret) + else: + for act_type in l_act: + for m in l_tokenNum: + ret = test_fmoe( + dtype, + m, + model_dim, + inter_dim, + args.expert, + args.topk, + act_type, + quant_type, + aq_dtype, + wq_dtype, + use_g1u1=True, + doweight_stage1=doweight_stage1, + ) + df.append(ret) +df = pd.DataFrame(df) +aiter.logger.info(f"summary:\n{df}") \ No newline at end of file